Refactor noise locked read/write, move handshake out to own interface

This commit is contained in:
eyedeekay
2025-03-19 22:28:14 -04:00
parent fa0a42855c
commit f2b6d4bc01
9 changed files with 133 additions and 90 deletions

View File

@ -0,0 +1,18 @@
package handshake
import "github.com/flynn/noise"
// HandshakeState manages the Noise handshake state
type HandshakeState interface {
// GenerateEphemeral creates ephemeral keypair
GenerateEphemeral() (*noise.DHKey, error)
// WriteMessage creates Noise message
WriteMessage([]byte) ([]byte, *noise.CipherState, *noise.CipherState, error)
// HandshakeComplete returns true if handshake is complete
HandshakeComplete() bool
// CompleteHandshake completes the handshake
CompleteHandshake() error
}

View File

@ -9,7 +9,7 @@ import (
"github.com/flynn/noise"
)
type HandshakeState struct {
type NoiseHandshakeState struct {
mutex sync.Mutex
ephemeral *noise.DHKey
pattern noise.HandshakePattern
@ -17,8 +17,8 @@ type HandshakeState struct {
*noise.HandshakeState
}
func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*HandshakeState, error) {
hs := &HandshakeState{
func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*NoiseHandshakeState, error) {
hs := &NoiseHandshakeState{
pattern: noise.HandshakeXK,
}
@ -38,9 +38,20 @@ func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*HandshakeState
return hs, nil
}
func (h *NoiseHandshakeState) HandshakeComplete() bool {
return h.handshakeComplete
}
func (h *NoiseHandshakeState) CompleteHandshake() error {
h.mutex.Lock()
defer h.mutex.Unlock()
h.handshakeComplete = true
return nil
}
// GenerateEphemeral creates the ephemeral keypair that will be used in handshake
// This needs to be separate so NTCP2 can obfuscate it
func (h *HandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
func (h *NoiseHandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -54,21 +65,21 @@ func (h *HandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
// SetEphemeral allows setting a potentially modified ephemeral key
// This is needed for NTCP2's obfuscation layer
func (h *HandshakeState) SetEphemeral(key *noise.DHKey) error {
func (h *NoiseHandshakeState) SetEphemeral(key *noise.DHKey) error {
h.mutex.Lock()
defer h.mutex.Unlock()
h.ephemeral = key
return nil
}
func (h *HandshakeState) WriteMessage(payload []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
func (h *NoiseHandshakeState) WriteMessage(payload []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
return h.HandshakeState.WriteMessage(nil, payload)
}
func (h *HandshakeState) ReadMessage(message []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
func (h *NoiseHandshakeState) ReadMessage(message []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
h.mutex.Lock()
defer h.mutex.Unlock()

View File

@ -18,7 +18,7 @@ func (c *NoiseSession) RunIncomingHandshake() error {
log.WithError(err).Error("Failed to compose receiver handshake message")
return err
}
c.HandshakeState = &HandshakeState{
c.HandshakeState = &NoiseHandshakeState{
HandshakeState: state,
}
log.WithFields(logrus.Fields{
@ -37,7 +37,7 @@ func (c *NoiseSession) RunIncomingHandshake() error {
log.Debug("Handshake message written successfully")
log.WithField("state", state).Debug("Handshake state after message write")
log.Println(state)
c.handshakeComplete = true
c.CompleteHandshake()
log.Debug("Incoming handshake completed successfully")
return nil
}

View File

@ -23,7 +23,7 @@ func (c *NoiseSession) RunOutgoingHandshake() error {
"negData_length": len(negData),
"msg_length": len(msg),
}).Debug("Initiator handshake message composed")
c.HandshakeState = &HandshakeState{
c.HandshakeState = &NoiseHandshakeState{
HandshakeState: state,
}
@ -40,7 +40,7 @@ func (c *NoiseSession) RunOutgoingHandshake() error {
log.Debug("Handshake message written successfully")
log.WithField("state", state).Debug("Handshake state after message write")
log.Println(state)
c.handshakeComplete = true
c.CompleteHandshake()
log.Debug("Outgoing handshake completed successfully")
return nil
}

View File

@ -1,6 +1,7 @@
package noise
import (
"encoding/binary"
"sync/atomic"
"github.com/samber/oops"
@ -25,7 +26,7 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
}
log.Debug("NoiseSession Read: retrying atomic operation")
}
if !c.handshakeComplete {
if !c.HandshakeComplete() {
log.Debug("NoiseSession Read: handshake not complete, running incoming handshake")
if err := c.RunIncomingHandshake(); err != nil {
log.WithError(err).Error("NoiseSession Read: failed to run incoming handshake")
@ -34,7 +35,7 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
}
c.Mutex.Lock()
defer c.Mutex.Unlock()
if !c.handshakeComplete {
if !c.HandshakeComplete() {
log.Error("NoiseSession Read: internal error - handshake still not complete after running")
return 0, oops.Errorf("internal error")
}
@ -48,47 +49,46 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
}
func (c *NoiseSession) decryptPacket(data []byte) (int, []byte, error) {
log.WithField("data_length", len(data)).Debug("Starting packet decryption")
log.WithField("data_length", len(data)).Debug("NoiseSession: Starting packet decryption")
if c.CipherState == nil {
log.Error("Packet decryption: CipherState is nil")
return 0, nil, oops.Errorf("CipherState is nil")
log.Error("NoiseSession: decryptPacket - readState is nil")
return 0, nil, oops.Errorf("readState is nil")
}
// Decrypt
decryptedData, err := c.CipherState.Decrypt(nil, nil, data)
if len(data) < 2 {
log.Error("NoiseSession: decryptPacket - packet too short")
return 0, nil, oops.Errorf("packet too short")
}
// Extract payload length from prefix
payloadLen := binary.BigEndian.Uint16(data[:2])
if len(data[2:]) < int(payloadLen) {
log.Error("NoiseSession: decryptPacket - incomplete packet")
return 0, nil, oops.Errorf("incomplete packet")
}
// Decrypt payload
decryptedData, err := c.CipherState.Decrypt(nil, nil, data[2:2+payloadLen])
if err != nil {
log.WithError(err).Error("Packet decryption: failed to decrypt data")
return 0, nil, err
log.WithError(err).Error("NoiseSession: decryptPacket - failed to decrypt data")
return 0, nil, oops.Errorf("failed to decrypt: %w", err)
}
m := len(decryptedData)
log.WithField("decrypted_length", m).Debug("Packet decryption: successfully decrypted data")
log.WithFields(logrus.Fields{
"encrypted_length": payloadLen,
"decrypted_length": m,
}).Debug("NoiseSession: decryptPacket - packet decrypted successfully")
return m, decryptedData, nil
/*packet := c.InitializePacket()
maxPayloadSize := c.maxPayloadSizeForRead(packet)
if m > int(maxPayloadSize) {
m = int(maxPayloadSize)
}
if c.CipherState != nil {
////fmt.Println("writing encrypted packet:", m)
packet.reserve(uint16Size + uint16Size + m + macSize)
packet.resize(uint16Size + uint16Size + m)
copy(packet.data[uint16Size+uint16Size:], data[:m])
binary.BigEndian.PutUint16(packet.data[uint16Size:], uint16(m))
//fmt.Println("encrypt size", uint16(m))
} else {
packet.resize(len(packet.data) + len(data))
copy(packet.data[uint16Size:len(packet.data)], data[:m])
binary.BigEndian.PutUint16(packet.data, uint16(len(data)))
}
b := c.encryptIfNeeded(packet)*/
//c.freeBlock(packet)
}
func (c *NoiseSession) readPacketLocked(data []byte) (int, error) {
log.WithField("data_length", len(data)).Debug("Starting readPacketLocked")
var n int
if len(data) == 0 { // special case to answer when everything is ok during handshake
if len(data) == 0 { // Handle special case where data length is zero during handshake
log.Debug("readPacketLocked: special case - reading 2 bytes during handshake")
if _, err := c.Conn.Read(make([]byte, 2)); err != nil {
log.WithError(err).Error("readPacketLocked: failed to read 2 bytes during handshake")
@ -96,26 +96,18 @@ func (c *NoiseSession) readPacketLocked(data []byte) (int, error) {
}
}
for len(data) > 0 {
m, b, err := c.encryptPacket(data)
_, b, err := c.decryptPacket(data)
if err != nil {
log.WithError(err).Error("readPacketLocked: failed to encrypt packet")
return 0, err
}
/*
if n, err := c.Conn.Read(b); err != nil {
return n, err
} else {
n += m
data = data[m:]
}
*/
n, err := c.Conn.Read(b)
bytesRead, err := c.Conn.Read(b)
if err != nil {
log.WithError(err).WithField("bytes_read (aka n)", n).Error("readPacketLocked: failed to read from connection")
return n, err
log.WithError(err).WithField("bytes_read", bytesRead).Error("readPacketLocked: failed to read from connection")
return bytesRead, err
}
n += m
data = data[m:]
n += bytesRead
data = data[bytesRead:]
log.WithFields(logrus.Fields{
"bytes_read": n,
"remaining_data": len(data),

View File

@ -13,6 +13,7 @@ import (
"github.com/go-i2p/go-i2p/lib/common/router_info"
"github.com/go-i2p/go-i2p/lib/transport"
"github.com/go-i2p/go-i2p/lib/transport/handshake"
)
type NoiseSession struct {
@ -20,7 +21,7 @@ type NoiseSession struct {
*noise.CipherState
*sync.Cond
*NoiseTransport // The parent transport, which "Dialed" the connection to the peer with whom we established the session
*HandshakeState
handshake.HandshakeState
RecvQueue *cb.Queue
SendQueue *cb.Queue
VerifyCallback VerifyCallbackFunc

View File

@ -31,10 +31,28 @@ type NoiseTransport struct {
}
func (noopt *NoiseTransport) Compatible(routerInfo router_info.RouterInfo) bool {
// TODO implement
// panic("implement me")
log.Warn("func (noopt *NoiseTransport) Compatible(routerInfo router_info.RouterInfo) is not implemented!")
return true
// Check if we have an existing session with this router
_, ok := noopt.peerConnections[routerInfo.IdentHash()]
if ok {
return true
}
// Check router addresses for Noise protocol support
for _, addr := range routerInfo.RouterAddresses() {
transportStyle, err := addr.TransportStyle().Data()
if err != nil {
continue
}
// Check for Noise protocol support
if transportStyle == NOISE_PROTOCOL_NAME {
// A router is compatible if it has a static key
if addr.CheckOption("s") {
return true
}
}
}
return false
}
var exampleNoiseTransport transport.Transport = &NoiseTransport{}
@ -123,7 +141,7 @@ func (c *NoiseTransport) getSession(routerInfo router_info.RouterInfo) (transpor
return nil, err
}
for {
if session.(*NoiseSession).handshakeComplete {
if session.(*NoiseSession).HandshakeComplete() {
log.Debug("NoiseTransport: Handshake complete")
return nil, nil
}

View File

@ -26,7 +26,7 @@ func (c *NoiseSession) Write(b []byte) (int, error) {
}
log.Debug("NoiseSession: Write - retrying atomic operation")
}
if !c.handshakeComplete {
if !c.HandshakeComplete() {
log.Debug("NoiseSession: Write - handshake not complete, running outgoing handshake")
if err := c.RunOutgoingHandshake(); err != nil {
log.WithError(err).Error("NoiseSession: Write - failed to run outgoing handshake")
@ -35,7 +35,7 @@ func (c *NoiseSession) Write(b []byte) (int, error) {
}
c.Mutex.Lock()
defer c.Mutex.Unlock()
if !c.handshakeComplete {
if !c.HandshakeComplete() {
log.Error("NoiseSession: Write - internal error, handshake still not complete")
return 0, oops.Errorf("internal error")
}
@ -53,48 +53,33 @@ func (c *NoiseSession) encryptPacket(data []byte) (int, []byte, error) {
m := len(data)
if c.CipherState == nil {
log.Error("NoiseSession: encryptPacket - CipherState is nil")
return 0, nil, oops.Errorf("CipherState is nil")
log.Error("NoiseSession: encryptPacket - writeState is nil")
return 0, nil, oops.Errorf("writeState is nil")
}
// Create length prefix first
lengthPrefix := make([]byte, 2)
binary.BigEndian.PutUint16(lengthPrefix, uint16(m))
// Encrypt the data
encryptedData, err := c.CipherState.Encrypt(nil, nil, data)
if err != nil {
log.WithError(err).Error("NoiseSession: encryptPacket - failed to encrypt data")
return 0, nil, oops.Errorf("failed to encrypt: '%w'", err)
return 0, nil, oops.Errorf("failed to encrypt: %w", err)
}
// m := len(encryptedData)
lengthPrefix := make([]byte, 2)
binary.BigEndian.PutUint16(lengthPrefix, uint16(len(encryptedData)))
// Combine length prefix and encrypted data
packet := make([]byte, 0, len(lengthPrefix)+len(encryptedData))
packet = append(packet, lengthPrefix...)
packet = append(packet, encryptedData...)
// Append encr data to prefix
packet := append(lengthPrefix, encryptedData...)
log.WithFields(logrus.Fields{
"original_length": m,
"encrypted_length": len(encryptedData),
"packet_length": len(packet),
}).Debug("NoiseSession: encryptPacket - packet encrypted successfully")
return m, packet, nil
/*packet := c.InitializePacket()
maxPayloadSize := c.maxPayloadSizeForWrite(packet)
if m > int(maxPayloadSize) {
m = int(maxPayloadSize)
}
if c.CipherState != nil {
////fmt.Println("writing encrypted packet:", m)
packet.reserve(uint16Size + uint16Size + m + macSize)
packet.resize(uint16Size + uint16Size + m)
copy(packet.data[uint16Size+uint16Size:], data[:m])
binary.BigEndian.PutUint16(packet.data[uint16Size:], uint16(m))
//fmt.Println("encrypt size", uint16(m))
} else {
packet.resize(len(packet.data) + len(data))
copy(packet.data[uint16Size:len(packet.data)], data[:m])
binary.BigEndian.PutUint16(packet.data, uint16(len(data)))
}
b := c.encryptIfNeeded(packet)*/
//c.freeBlock(packet)
}
func (c *NoiseSession) writePacketLocked(data []byte) (int, error) {

View File

@ -0,0 +1,18 @@
package ntcp
// SessionRequestMessage represents Message 1 of the NTCP2 handshake
type SessionRequestMessage struct {
ObfuscatedKey []byte // 32 bytes ephemeral key X
Timestamp uint32 // Current time
Options [16]byte // Options block
Padding []byte // Random padding
}
// SessionRequestBuilder handles creation of NTCP2 Message 1
type SessionRequestBuilder interface {
// CreateSessionRequest builds Message 1 of handshake
CreateSessionRequest() (*SessionRequestMessage, error)
// ObfuscateEphemeral obfuscates ephemeral key using AES
ObfuscateEphemeral(key []byte) ([]byte, error)
}