mirror of
https://github.com/go-i2p/go-i2p.git
synced 2025-06-07 10:01:41 -04:00
Refactor noise locked read/write, move handshake out to own interface
This commit is contained in:
18
lib/transport/handshake/hanshake.go
Normal file
18
lib/transport/handshake/hanshake.go
Normal file
@ -0,0 +1,18 @@
|
||||
package handshake
|
||||
|
||||
import "github.com/flynn/noise"
|
||||
|
||||
// HandshakeState manages the Noise handshake state
|
||||
type HandshakeState interface {
|
||||
// GenerateEphemeral creates ephemeral keypair
|
||||
GenerateEphemeral() (*noise.DHKey, error)
|
||||
|
||||
// WriteMessage creates Noise message
|
||||
WriteMessage([]byte) ([]byte, *noise.CipherState, *noise.CipherState, error)
|
||||
|
||||
// HandshakeComplete returns true if handshake is complete
|
||||
HandshakeComplete() bool
|
||||
|
||||
// CompleteHandshake completes the handshake
|
||||
CompleteHandshake() error
|
||||
}
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
type HandshakeState struct {
|
||||
type NoiseHandshakeState struct {
|
||||
mutex sync.Mutex
|
||||
ephemeral *noise.DHKey
|
||||
pattern noise.HandshakePattern
|
||||
@ -17,8 +17,8 @@ type HandshakeState struct {
|
||||
*noise.HandshakeState
|
||||
}
|
||||
|
||||
func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*HandshakeState, error) {
|
||||
hs := &HandshakeState{
|
||||
func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*NoiseHandshakeState, error) {
|
||||
hs := &NoiseHandshakeState{
|
||||
pattern: noise.HandshakeXK,
|
||||
}
|
||||
|
||||
@ -38,9 +38,20 @@ func NewHandshakeState(staticKey noise.DHKey, isInitiator bool) (*HandshakeState
|
||||
return hs, nil
|
||||
}
|
||||
|
||||
func (h *NoiseHandshakeState) HandshakeComplete() bool {
|
||||
return h.handshakeComplete
|
||||
}
|
||||
|
||||
func (h *NoiseHandshakeState) CompleteHandshake() error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
h.handshakeComplete = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateEphemeral creates the ephemeral keypair that will be used in handshake
|
||||
// This needs to be separate so NTCP2 can obfuscate it
|
||||
func (h *HandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
|
||||
func (h *NoiseHandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
@ -54,21 +65,21 @@ func (h *HandshakeState) GenerateEphemeral() (*noise.DHKey, error) {
|
||||
|
||||
// SetEphemeral allows setting a potentially modified ephemeral key
|
||||
// This is needed for NTCP2's obfuscation layer
|
||||
func (h *HandshakeState) SetEphemeral(key *noise.DHKey) error {
|
||||
func (h *NoiseHandshakeState) SetEphemeral(key *noise.DHKey) error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
h.ephemeral = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeState) WriteMessage(payload []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
func (h *NoiseHandshakeState) WriteMessage(payload []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
return h.HandshakeState.WriteMessage(nil, payload)
|
||||
}
|
||||
|
||||
func (h *HandshakeState) ReadMessage(message []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
func (h *NoiseHandshakeState) ReadMessage(message []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -18,7 +18,7 @@ func (c *NoiseSession) RunIncomingHandshake() error {
|
||||
log.WithError(err).Error("Failed to compose receiver handshake message")
|
||||
return err
|
||||
}
|
||||
c.HandshakeState = &HandshakeState{
|
||||
c.HandshakeState = &NoiseHandshakeState{
|
||||
HandshakeState: state,
|
||||
}
|
||||
log.WithFields(logrus.Fields{
|
||||
@ -37,7 +37,7 @@ func (c *NoiseSession) RunIncomingHandshake() error {
|
||||
log.Debug("Handshake message written successfully")
|
||||
log.WithField("state", state).Debug("Handshake state after message write")
|
||||
log.Println(state)
|
||||
c.handshakeComplete = true
|
||||
c.CompleteHandshake()
|
||||
log.Debug("Incoming handshake completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func (c *NoiseSession) RunOutgoingHandshake() error {
|
||||
"negData_length": len(negData),
|
||||
"msg_length": len(msg),
|
||||
}).Debug("Initiator handshake message composed")
|
||||
c.HandshakeState = &HandshakeState{
|
||||
c.HandshakeState = &NoiseHandshakeState{
|
||||
HandshakeState: state,
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ func (c *NoiseSession) RunOutgoingHandshake() error {
|
||||
log.Debug("Handshake message written successfully")
|
||||
log.WithField("state", state).Debug("Handshake state after message write")
|
||||
log.Println(state)
|
||||
c.handshakeComplete = true
|
||||
c.CompleteHandshake()
|
||||
log.Debug("Outgoing handshake completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package noise
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/samber/oops"
|
||||
@ -25,7 +26,7 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
|
||||
}
|
||||
log.Debug("NoiseSession Read: retrying atomic operation")
|
||||
}
|
||||
if !c.handshakeComplete {
|
||||
if !c.HandshakeComplete() {
|
||||
log.Debug("NoiseSession Read: handshake not complete, running incoming handshake")
|
||||
if err := c.RunIncomingHandshake(); err != nil {
|
||||
log.WithError(err).Error("NoiseSession Read: failed to run incoming handshake")
|
||||
@ -34,7 +35,7 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
|
||||
}
|
||||
c.Mutex.Lock()
|
||||
defer c.Mutex.Unlock()
|
||||
if !c.handshakeComplete {
|
||||
if !c.HandshakeComplete() {
|
||||
log.Error("NoiseSession Read: internal error - handshake still not complete after running")
|
||||
return 0, oops.Errorf("internal error")
|
||||
}
|
||||
@ -48,47 +49,46 @@ func (c *NoiseSession) Read(b []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (c *NoiseSession) decryptPacket(data []byte) (int, []byte, error) {
|
||||
log.WithField("data_length", len(data)).Debug("Starting packet decryption")
|
||||
log.WithField("data_length", len(data)).Debug("NoiseSession: Starting packet decryption")
|
||||
|
||||
if c.CipherState == nil {
|
||||
log.Error("Packet decryption: CipherState is nil")
|
||||
return 0, nil, oops.Errorf("CipherState is nil")
|
||||
log.Error("NoiseSession: decryptPacket - readState is nil")
|
||||
return 0, nil, oops.Errorf("readState is nil")
|
||||
}
|
||||
// Decrypt
|
||||
decryptedData, err := c.CipherState.Decrypt(nil, nil, data)
|
||||
|
||||
if len(data) < 2 {
|
||||
log.Error("NoiseSession: decryptPacket - packet too short")
|
||||
return 0, nil, oops.Errorf("packet too short")
|
||||
}
|
||||
|
||||
// Extract payload length from prefix
|
||||
payloadLen := binary.BigEndian.Uint16(data[:2])
|
||||
if len(data[2:]) < int(payloadLen) {
|
||||
log.Error("NoiseSession: decryptPacket - incomplete packet")
|
||||
return 0, nil, oops.Errorf("incomplete packet")
|
||||
}
|
||||
|
||||
// Decrypt payload
|
||||
decryptedData, err := c.CipherState.Decrypt(nil, nil, data[2:2+payloadLen])
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Packet decryption: failed to decrypt data")
|
||||
return 0, nil, err
|
||||
log.WithError(err).Error("NoiseSession: decryptPacket - failed to decrypt data")
|
||||
return 0, nil, oops.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
m := len(decryptedData)
|
||||
log.WithField("decrypted_length", m).Debug("Packet decryption: successfully decrypted data")
|
||||
log.WithFields(logrus.Fields{
|
||||
"encrypted_length": payloadLen,
|
||||
"decrypted_length": m,
|
||||
}).Debug("NoiseSession: decryptPacket - packet decrypted successfully")
|
||||
|
||||
return m, decryptedData, nil
|
||||
/*packet := c.InitializePacket()
|
||||
maxPayloadSize := c.maxPayloadSizeForRead(packet)
|
||||
if m > int(maxPayloadSize) {
|
||||
m = int(maxPayloadSize)
|
||||
}
|
||||
if c.CipherState != nil {
|
||||
////fmt.Println("writing encrypted packet:", m)
|
||||
packet.reserve(uint16Size + uint16Size + m + macSize)
|
||||
packet.resize(uint16Size + uint16Size + m)
|
||||
copy(packet.data[uint16Size+uint16Size:], data[:m])
|
||||
binary.BigEndian.PutUint16(packet.data[uint16Size:], uint16(m))
|
||||
//fmt.Println("encrypt size", uint16(m))
|
||||
} else {
|
||||
packet.resize(len(packet.data) + len(data))
|
||||
copy(packet.data[uint16Size:len(packet.data)], data[:m])
|
||||
binary.BigEndian.PutUint16(packet.data, uint16(len(data)))
|
||||
}
|
||||
b := c.encryptIfNeeded(packet)*/
|
||||
//c.freeBlock(packet)
|
||||
}
|
||||
|
||||
func (c *NoiseSession) readPacketLocked(data []byte) (int, error) {
|
||||
log.WithField("data_length", len(data)).Debug("Starting readPacketLocked")
|
||||
|
||||
var n int
|
||||
if len(data) == 0 { // special case to answer when everything is ok during handshake
|
||||
if len(data) == 0 { // Handle special case where data length is zero during handshake
|
||||
log.Debug("readPacketLocked: special case - reading 2 bytes during handshake")
|
||||
if _, err := c.Conn.Read(make([]byte, 2)); err != nil {
|
||||
log.WithError(err).Error("readPacketLocked: failed to read 2 bytes during handshake")
|
||||
@ -96,26 +96,18 @@ func (c *NoiseSession) readPacketLocked(data []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
for len(data) > 0 {
|
||||
m, b, err := c.encryptPacket(data)
|
||||
_, b, err := c.decryptPacket(data)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("readPacketLocked: failed to encrypt packet")
|
||||
return 0, err
|
||||
}
|
||||
/*
|
||||
if n, err := c.Conn.Read(b); err != nil {
|
||||
return n, err
|
||||
} else {
|
||||
n += m
|
||||
data = data[m:]
|
||||
}
|
||||
*/
|
||||
n, err := c.Conn.Read(b)
|
||||
bytesRead, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("bytes_read (aka n)", n).Error("readPacketLocked: failed to read from connection")
|
||||
return n, err
|
||||
log.WithError(err).WithField("bytes_read", bytesRead).Error("readPacketLocked: failed to read from connection")
|
||||
return bytesRead, err
|
||||
}
|
||||
n += m
|
||||
data = data[m:]
|
||||
n += bytesRead
|
||||
data = data[bytesRead:]
|
||||
log.WithFields(logrus.Fields{
|
||||
"bytes_read": n,
|
||||
"remaining_data": len(data),
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/go-i2p/go-i2p/lib/common/router_info"
|
||||
"github.com/go-i2p/go-i2p/lib/transport"
|
||||
"github.com/go-i2p/go-i2p/lib/transport/handshake"
|
||||
)
|
||||
|
||||
type NoiseSession struct {
|
||||
@ -20,7 +21,7 @@ type NoiseSession struct {
|
||||
*noise.CipherState
|
||||
*sync.Cond
|
||||
*NoiseTransport // The parent transport, which "Dialed" the connection to the peer with whom we established the session
|
||||
*HandshakeState
|
||||
handshake.HandshakeState
|
||||
RecvQueue *cb.Queue
|
||||
SendQueue *cb.Queue
|
||||
VerifyCallback VerifyCallbackFunc
|
||||
|
@ -31,10 +31,28 @@ type NoiseTransport struct {
|
||||
}
|
||||
|
||||
func (noopt *NoiseTransport) Compatible(routerInfo router_info.RouterInfo) bool {
|
||||
// TODO implement
|
||||
// panic("implement me")
|
||||
log.Warn("func (noopt *NoiseTransport) Compatible(routerInfo router_info.RouterInfo) is not implemented!")
|
||||
return true
|
||||
// Check if we have an existing session with this router
|
||||
_, ok := noopt.peerConnections[routerInfo.IdentHash()]
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check router addresses for Noise protocol support
|
||||
for _, addr := range routerInfo.RouterAddresses() {
|
||||
transportStyle, err := addr.TransportStyle().Data()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for Noise protocol support
|
||||
if transportStyle == NOISE_PROTOCOL_NAME {
|
||||
// A router is compatible if it has a static key
|
||||
if addr.CheckOption("s") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var exampleNoiseTransport transport.Transport = &NoiseTransport{}
|
||||
@ -123,7 +141,7 @@ func (c *NoiseTransport) getSession(routerInfo router_info.RouterInfo) (transpor
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
if session.(*NoiseSession).handshakeComplete {
|
||||
if session.(*NoiseSession).HandshakeComplete() {
|
||||
log.Debug("NoiseTransport: Handshake complete")
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func (c *NoiseSession) Write(b []byte) (int, error) {
|
||||
}
|
||||
log.Debug("NoiseSession: Write - retrying atomic operation")
|
||||
}
|
||||
if !c.handshakeComplete {
|
||||
if !c.HandshakeComplete() {
|
||||
log.Debug("NoiseSession: Write - handshake not complete, running outgoing handshake")
|
||||
if err := c.RunOutgoingHandshake(); err != nil {
|
||||
log.WithError(err).Error("NoiseSession: Write - failed to run outgoing handshake")
|
||||
@ -35,7 +35,7 @@ func (c *NoiseSession) Write(b []byte) (int, error) {
|
||||
}
|
||||
c.Mutex.Lock()
|
||||
defer c.Mutex.Unlock()
|
||||
if !c.handshakeComplete {
|
||||
if !c.HandshakeComplete() {
|
||||
log.Error("NoiseSession: Write - internal error, handshake still not complete")
|
||||
return 0, oops.Errorf("internal error")
|
||||
}
|
||||
@ -53,48 +53,33 @@ func (c *NoiseSession) encryptPacket(data []byte) (int, []byte, error) {
|
||||
|
||||
m := len(data)
|
||||
if c.CipherState == nil {
|
||||
log.Error("NoiseSession: encryptPacket - CipherState is nil")
|
||||
return 0, nil, oops.Errorf("CipherState is nil")
|
||||
log.Error("NoiseSession: encryptPacket - writeState is nil")
|
||||
return 0, nil, oops.Errorf("writeState is nil")
|
||||
}
|
||||
|
||||
// Create length prefix first
|
||||
lengthPrefix := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(lengthPrefix, uint16(m))
|
||||
|
||||
// Encrypt the data
|
||||
encryptedData, err := c.CipherState.Encrypt(nil, nil, data)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("NoiseSession: encryptPacket - failed to encrypt data")
|
||||
return 0, nil, oops.Errorf("failed to encrypt: '%w'", err)
|
||||
return 0, nil, oops.Errorf("failed to encrypt: %w", err)
|
||||
}
|
||||
// m := len(encryptedData)
|
||||
|
||||
lengthPrefix := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(lengthPrefix, uint16(len(encryptedData)))
|
||||
// Combine length prefix and encrypted data
|
||||
packet := make([]byte, 0, len(lengthPrefix)+len(encryptedData))
|
||||
packet = append(packet, lengthPrefix...)
|
||||
packet = append(packet, encryptedData...)
|
||||
|
||||
// Append encr data to prefix
|
||||
packet := append(lengthPrefix, encryptedData...)
|
||||
log.WithFields(logrus.Fields{
|
||||
"original_length": m,
|
||||
"encrypted_length": len(encryptedData),
|
||||
"packet_length": len(packet),
|
||||
}).Debug("NoiseSession: encryptPacket - packet encrypted successfully")
|
||||
|
||||
return m, packet, nil
|
||||
/*packet := c.InitializePacket()
|
||||
maxPayloadSize := c.maxPayloadSizeForWrite(packet)
|
||||
if m > int(maxPayloadSize) {
|
||||
m = int(maxPayloadSize)
|
||||
}
|
||||
if c.CipherState != nil {
|
||||
////fmt.Println("writing encrypted packet:", m)
|
||||
packet.reserve(uint16Size + uint16Size + m + macSize)
|
||||
packet.resize(uint16Size + uint16Size + m)
|
||||
copy(packet.data[uint16Size+uint16Size:], data[:m])
|
||||
binary.BigEndian.PutUint16(packet.data[uint16Size:], uint16(m))
|
||||
//fmt.Println("encrypt size", uint16(m))
|
||||
} else {
|
||||
packet.resize(len(packet.data) + len(data))
|
||||
copy(packet.data[uint16Size:len(packet.data)], data[:m])
|
||||
binary.BigEndian.PutUint16(packet.data, uint16(len(data)))
|
||||
}
|
||||
b := c.encryptIfNeeded(packet)*/
|
||||
//c.freeBlock(packet)
|
||||
}
|
||||
|
||||
func (c *NoiseSession) writePacketLocked(data []byte) (int, error) {
|
||||
|
18
lib/transport/ntcp/session_request.go
Normal file
18
lib/transport/ntcp/session_request.go
Normal file
@ -0,0 +1,18 @@
|
||||
package ntcp
|
||||
|
||||
// SessionRequestMessage represents Message 1 of the NTCP2 handshake
|
||||
type SessionRequestMessage struct {
|
||||
ObfuscatedKey []byte // 32 bytes ephemeral key X
|
||||
Timestamp uint32 // Current time
|
||||
Options [16]byte // Options block
|
||||
Padding []byte // Random padding
|
||||
}
|
||||
|
||||
// SessionRequestBuilder handles creation of NTCP2 Message 1
|
||||
type SessionRequestBuilder interface {
|
||||
// CreateSessionRequest builds Message 1 of handshake
|
||||
CreateSessionRequest() (*SessionRequestMessage, error)
|
||||
|
||||
// ObfuscateEphemeral obfuscates ephemeral key using AES
|
||||
ObfuscateEphemeral(key []byte) ([]byte, error)
|
||||
}
|
Reference in New Issue
Block a user