Add Conn() getter to BaseSession

This commit is contained in:
eyedeekay
2025-06-01 17:56:23 -04:00
parent 12559c0335
commit 27312d3e94
6 changed files with 167 additions and 53 deletions

View File

@ -103,6 +103,10 @@ type BaseSession struct {
SAM SAM
}
func (bs *BaseSession) Conn() net.Conn {
return bs.conn
}
func (bs *BaseSession) ID() string { return bs.id }
func (bs *BaseSession) Keys() i2pkeys.I2PKeys { return bs.keys }
func (bs *BaseSession) Read(b []byte) (int, error) { return bs.conn.Read(b) }

View File

@ -82,6 +82,7 @@ func (r *DatagramReader) safeCloseChannel() {
close(r.recvChan)
close(r.errorChan)
}
func (r *DatagramReader) receiveLoop() {
logger := log.WithField("session_id", r.session.ID())
logger.Debug("Starting receive loop")

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/go-i2p/i2pkeys"
"github.com/samber/oops"
"github.com/sirupsen/logrus"
)
@ -24,6 +25,19 @@ func (rs *RawSession) DialTimeout(destination string, timeout time.Duration) (ne
// DialContext establishes a raw connection with context support
func (rs *RawSession) DialContext(ctx context.Context, destination string) (net.PacketConn, error) {
// Validate session state first
rs.mu.RLock()
if rs.closed {
rs.mu.RUnlock()
return nil, oops.Errorf("session is closed")
}
rs.mu.RUnlock()
// Validate destination
if destination == "" {
return nil, oops.Errorf("destination cannot be empty")
}
logger := log.WithFields(logrus.Fields{
"destination": destination,
})
@ -36,7 +50,7 @@ func (rs *RawSession) DialContext(ctx context.Context, destination string) (net.
writer: rs.NewWriter(),
}
// Start the reader loop
// Start the reader loop only if session is valid
go conn.reader.receiveLoop()
logger.WithField("session_id", rs.ID()).Debug("Successfully created raw connection")
@ -57,6 +71,14 @@ func (rs *RawSession) DialI2PTimeout(addr i2pkeys.I2PAddr, timeout time.Duration
// DialI2PContext establishes a raw connection to an I2P address with context support
func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr) (net.PacketConn, error) {
// Validate session state first
rs.mu.RLock()
if rs.closed {
rs.mu.RUnlock()
return nil, oops.Errorf("session is closed")
}
rs.mu.RUnlock()
logger := log.WithFields(logrus.Fields{
"destination": addr.Base32(),
})
@ -69,7 +91,7 @@ func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr)
writer: rs.NewWriter(),
}
// Start the reader loop
// Start the reader loop only if session is valid
go conn.reader.receiveLoop()
logger.WithField("session_id", rs.ID()).Debug("Successfully created I2P raw connection")

View File

@ -10,66 +10,63 @@ import (
"github.com/go-i2p/i2pkeys"
)
func setupTestSession(t *testing.T) *RawSession {
t.Helper()
// Skip actual I2P connection for unit tests
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sam, err := common.NewSAM(testSAMAddr)
if err != nil {
t.Fatalf("Failed to create SAM connection: %v", err)
}
keys, err := sam.NewKeys()
if err != nil {
sam.Close()
t.Fatalf("Failed to generate keys: %v", err)
}
session, err := NewRawSession(sam, "test_dial_session", keys, nil)
if err != nil {
sam.Close()
t.Fatalf("Failed to create session: %v", err)
}
return session
}
// Update the test to use proper session setup
func TestRawSession_Dial(t *testing.T) {
tests := []struct {
name string
destination string
setupSession func() *RawSession
wantErr bool
errContains string
name string
destination string
wantErr bool
errContains string
}{
{
name: "valid_b32_destination",
destination: "example.b32.i2p",
setupSession: func() *RawSession {
sam := &common.SAM{}
baseSession := &common.BaseSession{}
return &RawSession{
BaseSession: baseSession,
sam: sam,
options: []string{},
closed: false,
}
},
wantErr: false,
wantErr: false,
},
{
name: "empty_destination",
destination: "",
setupSession: func() *RawSession {
sam := &common.SAM{}
baseSession := &common.BaseSession{}
return &RawSession{
BaseSession: baseSession,
sam: sam,
options: []string{},
closed: false,
}
},
wantErr: true,
errContains: "destination",
},
{
name: "dial_on_closed_session",
destination: "example.b32.i2p",
setupSession: func() *RawSession {
sam := &common.SAM{}
baseSession := &common.BaseSession{}
return &RawSession{
BaseSession: baseSession,
sam: sam,
options: []string{},
closed: true,
}
},
wantErr: true,
errContains: "closed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
session := setupTestSession(t)
defer session.Close()
conn, err := session.Dial(tt.destination)
@ -94,11 +91,6 @@ func TestRawSession_Dial(t *testing.T) {
return
}
// Verify conn implements net.PacketConn
if _, ok := conn.(net.PacketConn); !ok {
t.Error("Dial() returned connection that doesn't implement net.PacketConn")
}
// Clean up
if conn != nil {
_ = conn.Close()

View File

@ -56,7 +56,10 @@ func (r *RawReader) Close() error {
logger.Warn("Timeout waiting for receive loop to stop")
}
// Now safe to close the receiver channels since receiveLoop has stopped
// Fix: Close doneChan here to prevent multiple closes
close(r.doneChan)
// Fix: Close receiver channels here under mutex protection
close(r.recvChan)
close(r.errorChan)
@ -71,11 +74,23 @@ func (r *RawReader) receiveLoop() {
// Signal completion when this loop exits
defer func() {
if r.doneChan != nil {
close(r.doneChan)
select {
case r.doneChan <- struct{}{}:
// Successfully signaled completion
default:
// Channel may be closed or blocked - that's okay
}
}()
// Check session state before starting loop
r.session.mu.RLock()
if r.session.closed || r.session.BaseSession == nil {
r.session.mu.RUnlock()
logger.Debug("Raw receive loop terminated - session invalid")
return
}
r.session.mu.RUnlock()
for {
// Check for closure in a non-blocking way first
select {
@ -114,6 +129,25 @@ func (r *RawReader) receiveLoop() {
func (r *RawReader) receiveDatagram() (*RawDatagram, error) {
logger := log.WithField("session_id", r.session.ID())
// Check if session is valid and not closed
r.session.mu.RLock()
if r.session.closed {
r.session.mu.RUnlock()
return nil, oops.Errorf("session is closed")
}
// Check if BaseSession is properly initialized
if r.session.BaseSession == nil {
r.session.mu.RUnlock()
return nil, oops.Errorf("session is not properly initialized")
}
if r.session.BaseSession.Conn() == nil {
r.session.mu.RUnlock()
return nil, oops.Errorf("session connection is not available")
}
r.session.mu.RUnlock()
// Read from the session connection for incoming raw datagrams
buf := make([]byte, 4096)
n, err := r.session.Read(buf)

61
raw/read_test.go Normal file
View File

@ -0,0 +1,61 @@
package raw
import (
"sync"
"testing"
"github.com/go-i2p/go-sam-go/common"
)
func TestRawReader_ConcurrentClose(t *testing.T) {
// Test concurrent Close() calls don't panic
session := &RawSession{
BaseSession: &common.BaseSession{},
closed: false,
}
reader := &RawReader{
session: session,
recvChan: make(chan *RawDatagram, 10),
errorChan: make(chan error, 1),
closeChan: make(chan struct{}),
doneChan: make(chan struct{}),
closed: false,
mu: sync.RWMutex{},
}
// Start receive loop
go reader.receiveLoop()
// Simulate concurrent close attempts
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = reader.Close() // Should not panic
}()
}
wg.Wait()
// Verify reader is properly closed
if !reader.closed {
t.Error("Reader should be marked as closed")
}
}
func TestRawReader_CloseRaceCondition(t *testing.T) {
// Test that rapid close after start doesn't cause channel panic
for i := 0; i < 100; i++ {
session := &RawSession{closed: false}
reader := session.NewReader()
go reader.receiveLoop()
// Close immediately to trigger race condition
if err := reader.Close(); err != nil {
t.Errorf("Close() failed: %v", err)
}
}
}