mirror of
https://github.com/go-i2p/go-sam-go.git
synced 2025-06-07 09:03:18 -04:00
Add Conn() getter to BaseSession
This commit is contained in:
@ -103,6 +103,10 @@ type BaseSession struct {
|
||||
SAM SAM
|
||||
}
|
||||
|
||||
func (bs *BaseSession) Conn() net.Conn {
|
||||
return bs.conn
|
||||
}
|
||||
|
||||
func (bs *BaseSession) ID() string { return bs.id }
|
||||
func (bs *BaseSession) Keys() i2pkeys.I2PKeys { return bs.keys }
|
||||
func (bs *BaseSession) Read(b []byte) (int, error) { return bs.conn.Read(b) }
|
||||
|
@ -82,6 +82,7 @@ func (r *DatagramReader) safeCloseChannel() {
|
||||
close(r.recvChan)
|
||||
close(r.errorChan)
|
||||
}
|
||||
|
||||
func (r *DatagramReader) receiveLoop() {
|
||||
logger := log.WithField("session_id", r.session.ID())
|
||||
logger.Debug("Starting receive loop")
|
||||
|
26
raw/dial.go
26
raw/dial.go
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-i2p/i2pkeys"
|
||||
"github.com/samber/oops"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -24,6 +25,19 @@ func (rs *RawSession) DialTimeout(destination string, timeout time.Duration) (ne
|
||||
|
||||
// DialContext establishes a raw connection with context support
|
||||
func (rs *RawSession) DialContext(ctx context.Context, destination string) (net.PacketConn, error) {
|
||||
// Validate session state first
|
||||
rs.mu.RLock()
|
||||
if rs.closed {
|
||||
rs.mu.RUnlock()
|
||||
return nil, oops.Errorf("session is closed")
|
||||
}
|
||||
rs.mu.RUnlock()
|
||||
|
||||
// Validate destination
|
||||
if destination == "" {
|
||||
return nil, oops.Errorf("destination cannot be empty")
|
||||
}
|
||||
|
||||
logger := log.WithFields(logrus.Fields{
|
||||
"destination": destination,
|
||||
})
|
||||
@ -36,7 +50,7 @@ func (rs *RawSession) DialContext(ctx context.Context, destination string) (net.
|
||||
writer: rs.NewWriter(),
|
||||
}
|
||||
|
||||
// Start the reader loop
|
||||
// Start the reader loop only if session is valid
|
||||
go conn.reader.receiveLoop()
|
||||
|
||||
logger.WithField("session_id", rs.ID()).Debug("Successfully created raw connection")
|
||||
@ -57,6 +71,14 @@ func (rs *RawSession) DialI2PTimeout(addr i2pkeys.I2PAddr, timeout time.Duration
|
||||
|
||||
// DialI2PContext establishes a raw connection to an I2P address with context support
|
||||
func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr) (net.PacketConn, error) {
|
||||
// Validate session state first
|
||||
rs.mu.RLock()
|
||||
if rs.closed {
|
||||
rs.mu.RUnlock()
|
||||
return nil, oops.Errorf("session is closed")
|
||||
}
|
||||
rs.mu.RUnlock()
|
||||
|
||||
logger := log.WithFields(logrus.Fields{
|
||||
"destination": addr.Base32(),
|
||||
})
|
||||
@ -69,7 +91,7 @@ func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr)
|
||||
writer: rs.NewWriter(),
|
||||
}
|
||||
|
||||
// Start the reader loop
|
||||
// Start the reader loop only if session is valid
|
||||
go conn.reader.receiveLoop()
|
||||
|
||||
logger.WithField("session_id", rs.ID()).Debug("Successfully created I2P raw connection")
|
||||
|
@ -10,66 +10,63 @@ import (
|
||||
"github.com/go-i2p/i2pkeys"
|
||||
)
|
||||
|
||||
func setupTestSession(t *testing.T) *RawSession {
|
||||
t.Helper()
|
||||
|
||||
// Skip actual I2P connection for unit tests
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sam, err := common.NewSAM(testSAMAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SAM connection: %v", err)
|
||||
}
|
||||
|
||||
keys, err := sam.NewKeys()
|
||||
if err != nil {
|
||||
sam.Close()
|
||||
t.Fatalf("Failed to generate keys: %v", err)
|
||||
}
|
||||
|
||||
session, err := NewRawSession(sam, "test_dial_session", keys, nil)
|
||||
if err != nil {
|
||||
sam.Close()
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// Update the test to use proper session setup
|
||||
func TestRawSession_Dial(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
destination string
|
||||
setupSession func() *RawSession
|
||||
wantErr bool
|
||||
errContains string
|
||||
name string
|
||||
destination string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid_b32_destination",
|
||||
destination: "example.b32.i2p",
|
||||
setupSession: func() *RawSession {
|
||||
sam := &common.SAM{}
|
||||
baseSession := &common.BaseSession{}
|
||||
return &RawSession{
|
||||
BaseSession: baseSession,
|
||||
sam: sam,
|
||||
options: []string{},
|
||||
closed: false,
|
||||
}
|
||||
},
|
||||
wantErr: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty_destination",
|
||||
destination: "",
|
||||
setupSession: func() *RawSession {
|
||||
sam := &common.SAM{}
|
||||
baseSession := &common.BaseSession{}
|
||||
return &RawSession{
|
||||
BaseSession: baseSession,
|
||||
sam: sam,
|
||||
options: []string{},
|
||||
closed: false,
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "destination",
|
||||
},
|
||||
{
|
||||
name: "dial_on_closed_session",
|
||||
destination: "example.b32.i2p",
|
||||
setupSession: func() *RawSession {
|
||||
sam := &common.SAM{}
|
||||
baseSession := &common.BaseSession{}
|
||||
return &RawSession{
|
||||
BaseSession: baseSession,
|
||||
sam: sam,
|
||||
options: []string{},
|
||||
closed: true,
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "closed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
session := setupTestSession(t)
|
||||
defer session.Close()
|
||||
|
||||
conn, err := session.Dial(tt.destination)
|
||||
|
||||
@ -94,11 +91,6 @@ func TestRawSession_Dial(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify conn implements net.PacketConn
|
||||
if _, ok := conn.(net.PacketConn); !ok {
|
||||
t.Error("Dial() returned connection that doesn't implement net.PacketConn")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
|
40
raw/read.go
40
raw/read.go
@ -56,7 +56,10 @@ func (r *RawReader) Close() error {
|
||||
logger.Warn("Timeout waiting for receive loop to stop")
|
||||
}
|
||||
|
||||
// Now safe to close the receiver channels since receiveLoop has stopped
|
||||
// Fix: Close doneChan here to prevent multiple closes
|
||||
close(r.doneChan)
|
||||
|
||||
// Fix: Close receiver channels here under mutex protection
|
||||
close(r.recvChan)
|
||||
close(r.errorChan)
|
||||
|
||||
@ -71,11 +74,23 @@ func (r *RawReader) receiveLoop() {
|
||||
|
||||
// Signal completion when this loop exits
|
||||
defer func() {
|
||||
if r.doneChan != nil {
|
||||
close(r.doneChan)
|
||||
select {
|
||||
case r.doneChan <- struct{}{}:
|
||||
// Successfully signaled completion
|
||||
default:
|
||||
// Channel may be closed or blocked - that's okay
|
||||
}
|
||||
}()
|
||||
|
||||
// Check session state before starting loop
|
||||
r.session.mu.RLock()
|
||||
if r.session.closed || r.session.BaseSession == nil {
|
||||
r.session.mu.RUnlock()
|
||||
logger.Debug("Raw receive loop terminated - session invalid")
|
||||
return
|
||||
}
|
||||
r.session.mu.RUnlock()
|
||||
|
||||
for {
|
||||
// Check for closure in a non-blocking way first
|
||||
select {
|
||||
@ -114,6 +129,25 @@ func (r *RawReader) receiveLoop() {
|
||||
func (r *RawReader) receiveDatagram() (*RawDatagram, error) {
|
||||
logger := log.WithField("session_id", r.session.ID())
|
||||
|
||||
// Check if session is valid and not closed
|
||||
r.session.mu.RLock()
|
||||
if r.session.closed {
|
||||
r.session.mu.RUnlock()
|
||||
return nil, oops.Errorf("session is closed")
|
||||
}
|
||||
|
||||
// Check if BaseSession is properly initialized
|
||||
if r.session.BaseSession == nil {
|
||||
r.session.mu.RUnlock()
|
||||
return nil, oops.Errorf("session is not properly initialized")
|
||||
}
|
||||
|
||||
if r.session.BaseSession.Conn() == nil {
|
||||
r.session.mu.RUnlock()
|
||||
return nil, oops.Errorf("session connection is not available")
|
||||
}
|
||||
r.session.mu.RUnlock()
|
||||
|
||||
// Read from the session connection for incoming raw datagrams
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.session.Read(buf)
|
||||
|
61
raw/read_test.go
Normal file
61
raw/read_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package raw
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/go-i2p/go-sam-go/common"
|
||||
)
|
||||
|
||||
func TestRawReader_ConcurrentClose(t *testing.T) {
|
||||
// Test concurrent Close() calls don't panic
|
||||
session := &RawSession{
|
||||
BaseSession: &common.BaseSession{},
|
||||
closed: false,
|
||||
}
|
||||
|
||||
reader := &RawReader{
|
||||
session: session,
|
||||
recvChan: make(chan *RawDatagram, 10),
|
||||
errorChan: make(chan error, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
doneChan: make(chan struct{}),
|
||||
closed: false,
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
// Start receive loop
|
||||
go reader.receiveLoop()
|
||||
|
||||
// Simulate concurrent close attempts
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = reader.Close() // Should not panic
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify reader is properly closed
|
||||
if !reader.closed {
|
||||
t.Error("Reader should be marked as closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawReader_CloseRaceCondition(t *testing.T) {
|
||||
// Test that rapid close after start doesn't cause channel panic
|
||||
for i := 0; i < 100; i++ {
|
||||
session := &RawSession{closed: false}
|
||||
reader := session.NewReader()
|
||||
|
||||
go reader.receiveLoop()
|
||||
|
||||
// Close immediately to trigger race condition
|
||||
if err := reader.Close(); err != nil {
|
||||
t.Errorf("Close() failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user