diff --git a/.gitignore b/.gitignore index 6f72f89..0c120b6 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ go.work.sum # env file .env +./*keys +./data-dir* diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..1f8a68b --- /dev/null +++ b/handler.go @@ -0,0 +1,68 @@ +package meta + +import ( + "net" + "time" +) + +// handleListener runs in a separate goroutine for each added listener +// and forwards accepted connections to the connCh channel. +func (ml *MetaListener) handleListener(id string, listener net.Listener) { + defer func() { + log.Printf("Listener goroutine for %s exiting", id) + ml.listenerWg.Done() + }() + + for { + // First check if the MetaListener is closed + select { + case <-ml.closeCh: + log.Printf("MetaListener closed, stopping %s listener", id) + return + default: + } + + // Set a deadline for Accept to prevent blocking indefinitely + if deadline, ok := listener.(interface{ SetDeadline(time.Time) error }); ok { + deadline.SetDeadline(time.Now().Add(1 * time.Second)) + } + + conn, err := listener.Accept() + if err != nil { + // Check if this is a timeout error (which we expect due to our deadline) + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + + // Check if this is any other temporary error + if netErr, ok := err.(net.Error); ok && netErr.Temporary() { + log.Printf("Temporary error in %s listener: %v, retrying in 100ms", id, err) + time.Sleep(100 * time.Millisecond) + continue + } + + log.Printf("Permanent error in %s listener: %v, stopping", id, err) + ml.mu.Lock() + delete(ml.listeners, id) + ml.mu.Unlock() + return + } + + // If we reach here, we have a valid connection + log.Printf("Listener %s accepted connection from %s", id, conn.RemoteAddr()) + + // Try to forward the connection, but don't block indefinitely + select { + case ml.connCh <- ConnResult{Conn: conn, src: id}: + log.Printf("Connection from %s successfully forwarded via %s", conn.RemoteAddr(), id) + case <-ml.closeCh: + log.Printf("MetaListener closing while forwarding connection, closing connection") + conn.Close() + return + case <-time.After(5 * time.Second): + // If we can't forward within 5 seconds, something is seriously wrong + log.Printf("WARNING: Connection forwarding timed out, closing connection from %s", conn.RemoteAddr()) + conn.Close() + } + } +} diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..ca06a63 --- /dev/null +++ b/listener.go @@ -0,0 +1,89 @@ +package meta + +import ( + "fmt" + "net" +) + +// Accept implements the net.Listener Accept method. +// It returns the next connection from any of the managed listeners. +func (ml *MetaListener) Accept() (net.Conn, error) { + // Check if already closed before entering the select loop + ml.mu.RLock() + if ml.isClosed { + ml.mu.RUnlock() + return nil, ErrListenerClosed + } + ml.mu.RUnlock() + + for { + select { + case result, ok := <-ml.connCh: + if !ok { + return nil, ErrListenerClosed + } + // Access RemoteAddr() directly on the connection + return result, nil + case <-ml.closeCh: + // Double-check the closed state under lock to ensure consistency + closed := ml.isClosed + if closed { + return nil, ErrListenerClosed + } + continue + } + } +} + +// Close implements the net.Listener Close method. +// It closes all managed listeners and releases resources. +func (ml *MetaListener) Close() error { + ml.mu.Lock() + + if ml.isClosed { + ml.mu.Unlock() + return nil + } + + log.Printf("Closing MetaListener with %d listeners", len(ml.listeners)) + ml.isClosed = true + + // Signal all goroutines to stop + close(ml.closeCh) + + // Close all listeners + var errs []error + for id, listener := range ml.listeners { + if err := listener.Close(); err != nil { + log.Printf("Error closing %s listener: %v", id, err) + errs = append(errs, err) + } + } + + ml.mu.Unlock() + + // Wait for all listener goroutines to exit + ml.listenerWg.Wait() + log.Printf("All listener goroutines have exited") + + // Return combined errors if any + if len(errs) > 0 { + return fmt.Errorf("errors closing listeners: %v", errs) + } + + return nil +} + +// Addr implements the net.Listener Addr method. +// It returns a MetaAddr representing all managed listeners. +func (ml *MetaListener) Addr() net.Addr { + ml.mu.RLock() + defer ml.mu.RUnlock() + + addresses := make([]net.Addr, 0, len(ml.listeners)) + for _, listener := range ml.listeners { + addresses = append(addresses, listener.Addr()) + } + + return &MetaAddr{addresses: addresses} +} diff --git a/metaaddr.go b/metaaddr.go new file mode 100644 index 0000000..7536464 --- /dev/null +++ b/metaaddr.go @@ -0,0 +1,31 @@ +package meta + +import "net" + +// MetaAddr implements the net.Addr interface for a meta listener. +type MetaAddr struct { + addresses []net.Addr +} + +// Network returns the name of the network. +func (ma *MetaAddr) Network() string { + return "meta" +} + +// String returns a string representation of all managed addresses. +func (ma *MetaAddr) String() string { + if len(ma.addresses) == 0 { + return "meta(empty)" + } + + result := "meta(" + for i, addr := range ma.addresses { + if i > 0 { + result += ", " + } + result += addr.String() + } + result += ")" + + return result +} diff --git a/metalistener.go b/metalistener.go index 477c358..f81753b 100644 --- a/metalistener.go +++ b/metalistener.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "sync" - "time" "github.com/samber/oops" ) @@ -96,152 +95,6 @@ func (ml *MetaListener) RemoveListener(id string) error { return err } -// handleListener runs in a separate goroutine for each added listener -// and forwards accepted connections to the connCh channel. -func (ml *MetaListener) handleListener(id string, listener net.Listener) { - defer func() { - log.Printf("Listener goroutine for %s exiting", id) - ml.listenerWg.Done() - }() - - for { - // First check if the MetaListener is closed - select { - case <-ml.closeCh: - log.Printf("MetaListener closed, stopping %s listener", id) - return - default: - } - - // Set a deadline for Accept to prevent blocking indefinitely - if deadline, ok := listener.(interface{ SetDeadline(time.Time) error }); ok { - deadline.SetDeadline(time.Now().Add(1 * time.Second)) - } - - conn, err := listener.Accept() - if err != nil { - // Check if this is a timeout error (which we expect due to our deadline) - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - continue - } - - // Check if this is any other temporary error - if netErr, ok := err.(net.Error); ok && netErr.Temporary() { - log.Printf("Temporary error in %s listener: %v, retrying in 100ms", id, err) - time.Sleep(100 * time.Millisecond) - continue - } - - log.Printf("Permanent error in %s listener: %v, stopping", id, err) - ml.mu.Lock() - delete(ml.listeners, id) - ml.mu.Unlock() - return - } - - // If we reach here, we have a valid connection - log.Printf("Listener %s accepted connection from %s", id, conn.RemoteAddr()) - - // Try to forward the connection, but don't block indefinitely - select { - case ml.connCh <- ConnResult{Conn: conn, src: id}: - log.Printf("Connection from %s successfully forwarded via %s", conn.RemoteAddr(), id) - case <-ml.closeCh: - log.Printf("MetaListener closing while forwarding connection, closing connection") - conn.Close() - return - case <-time.After(5 * time.Second): - // If we can't forward within 5 seconds, something is seriously wrong - log.Printf("WARNING: Connection forwarding timed out, closing connection from %s", conn.RemoteAddr()) - conn.Close() - } - } -} - -// Accept implements the net.Listener Accept method. -// It waits for and returns the next connection from any of the managed listeners. -func (ml *MetaListener) Accept() (net.Conn, error) { - for { - ml.mu.RLock() - if ml.isClosed { - ml.mu.RUnlock() - return nil, ErrListenerClosed - } - - if len(ml.listeners) == 0 { - ml.mu.RUnlock() - return nil, ErrNoListeners - } - ml.mu.RUnlock() - - // Wait for either a connection or close signal - select { - case result, ok := <-ml.connCh: - if !ok { - return nil, ErrListenerClosed - } - log.Printf("Accept returning connection from %s via %s", - result.RemoteAddr(), result.src) - return result.Conn, nil - case <-ml.closeCh: - return nil, ErrListenerClosed - } - } -} - -// Close implements the net.Listener Close method. -// It closes all managed listeners and releases resources. -func (ml *MetaListener) Close() error { - ml.mu.Lock() - - if ml.isClosed { - ml.mu.Unlock() - return nil - } - - log.Printf("Closing MetaListener with %d listeners", len(ml.listeners)) - ml.isClosed = true - - // Signal all goroutines to stop - close(ml.closeCh) - - // Close all listeners - var errs []error - for id, listener := range ml.listeners { - if err := listener.Close(); err != nil { - log.Printf("Error closing %s listener: %v", id, err) - errs = append(errs, err) - } - } - - ml.mu.Unlock() - - // Wait for all listener goroutines to exit - ml.listenerWg.Wait() - log.Printf("All listener goroutines have exited") - - // Return combined errors if any - if len(errs) > 0 { - return fmt.Errorf("errors closing listeners: %v", errs) - } - - return nil -} - -// Addr implements the net.Listener Addr method. -// It returns a MetaAddr representing all managed listeners. -func (ml *MetaListener) Addr() net.Addr { - ml.mu.RLock() - defer ml.mu.RUnlock() - - addresses := make([]net.Addr, 0, len(ml.listeners)) - for _, listener := range ml.listeners { - addresses = append(addresses, listener.Addr()) - } - - return &MetaAddr{addresses: addresses} -} - // ListenerIDs returns the IDs of all active listeners. func (ml *MetaListener) ListenerIDs() []string { ml.mu.RLock() @@ -280,31 +133,3 @@ func (ml *MetaListener) WaitForShutdown(ctx context.Context) error { return ctx.Err() } } - -// MetaAddr implements the net.Addr interface for a meta listener. -type MetaAddr struct { - addresses []net.Addr -} - -// Network returns the name of the network. -func (ma *MetaAddr) Network() string { - return "meta" -} - -// String returns a string representation of all managed addresses. -func (ma *MetaAddr) String() string { - if len(ma.addresses) == 0 { - return "meta(empty)" - } - - result := "meta(" - for i, addr := range ma.addresses { - if i > 0 { - result += ", " - } - result += addr.String() - } - result += ")" - - return result -}