- Properly invode proxy.OnReqest

They apparently do not stack and each consequitive
call overwrites the one that was called before
- Block localhost the right way by utilizing
`goproxy.IsLocalHost` method, thather than regexp
- Block private networks as well
- Add flags deny-localhost and deny-private-networks
- fmt
This commit is contained in:
ugrentquest
2025-02-27 09:37:52 -08:00
parent cb6afc337b
commit bfdb819a11

View File

@ -4,18 +4,18 @@ import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"regexp"
"github.com/elazarl/goproxy"
"github.com/go-i2p/onramp"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/prometheus/client_golang/prometheus"
)
type Config struct {
@ -25,6 +25,8 @@ type Config struct {
BurstLimit int
LogLevel string
MetricsAddr string
DenyPrivateNetworks bool
DenyLocalhost bool
}
type Proxy struct {
@ -55,6 +57,7 @@ func NewProxy(config *Config) (*Proxy, error) {
// Initialize goproxy
proxy := goproxy.NewProxyHttpServer()
proxy.Verbose = logger.Level == logrus.DebugLevel
proxy.Logger = logger
// Initialize I2P connection
garlic, err := onramp.NewGarlic(config.TunnelName, config.SAMAddress, onramp.OPT_DEFAULTS)
@ -82,7 +85,20 @@ func NewProxy(config *Config) (*Proxy, error) {
}
func (p *Proxy) setupMiddleware() {
// Add request logging
var isPrivate goproxy.ReqConditionFunc = func(req *http.Request, _ *goproxy.ProxyCtx) bool {
h := req.URL.Hostname()
if ip := net.ParseIP(h); ip != nil {
return ip.IsPrivate()
}
// In case of IPv6 without a port number Hostname() sometimes returns the invalid value.
if ip := net.ParseIP(req.URL.Host); ip != nil {
return ip.IsPrivate()
}
return false
}
p.proxy.OnRequest().DoFunc(func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
start := time.Now()
@ -109,16 +125,23 @@ func (p *Proxy) setupMiddleware() {
"remote": req.RemoteAddr,
}).Info("incoming request")
return req, nil
})
// Block common sensitive endpoints
p.proxy.OnRequest(goproxy.ReqHostMatches(regexp.MustCompile(`^(localhost|127\.0\.0\.1)`))).DoFunc(
func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
// deny localhost
if p.config.DenyLocalhost && goproxy.IsLocalHost(req, ctx) {
return nil, goproxy.NewResponse(req,
goproxy.ContentTypeText, http.StatusForbidden,
"Access to local addresses is forbidden")
}
//deny private destinations
if p.config.DenyPrivateNetworks && isPrivate(req, ctx) {
return nil, goproxy.NewResponse(req,
goproxy.ContentTypeText, http.StatusForbidden,
"Access to private (RFC 1918) networks is forbidden")
}
return req, nil
})
}
func (p *Proxy) Run(ctx context.Context) error {
@ -162,6 +185,8 @@ func main() {
flag.IntVar(&config.BurstLimit, "burst", 20, "Maximum burst size")
flag.StringVar(&config.LogLevel, "log-level", "info", "Log level (debug, info, warn, error)")
flag.StringVar(&config.MetricsAddr, "metrics", ":2112", "Metrics server address")
flag.BoolVar(&config.DenyLocalhost, "deny-localhost", true, "Deny requests to localhost")
flag.BoolVar(&config.DenyPrivateNetworks, "deny-private-networks", true, "Deny requests to private (RFC 1918) networks")
flag.Parse()
logger := logrus.New()