diff --git a/go.mod b/go.mod index a9ed4d2..6d69e42 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( github.com/minio/selfupdate v0.6.0 github.com/spf13/cobra v1.10.2 + github.com/stretchr/testify v1.11.1 github.com/wailsapp/wails/v2 v2.12.0 golang.org/x/mod v0.23.0 golang.org/x/sys v0.30.0 @@ -14,6 +15,7 @@ require ( aead.dev/minisign v0.2.0 // indirect git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 // indirect github.com/bep/debounce v1.2.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -30,6 +32,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/samber/lo v1.49.1 // indirect github.com/spf13/pflag v1.0.9 // indirect @@ -41,4 +44,5 @@ require ( golang.org/x/crypto v0.33.0 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/text v0.22.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index a39c51c..cb59edd 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e h1:Q3+PugElBCf4PFpxhErSzU3/PY5sFL5Z6rfv4AbGAck= github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= @@ -43,6 +45,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -59,8 +63,8 @@ github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tkrajina/go-reflector v0.5.8 h1:yPADHrwmUbMq4RGEyaOUpz2H90sRsETNVpjzo3DLVQQ= github.com/tkrajina/go-reflector v0.5.8/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -106,5 +110,7 @@ golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/checker/socks5.go b/internal/checker/socks5.go new file mode 100644 index 0000000..3319548 --- /dev/null +++ b/internal/checker/socks5.go @@ -0,0 +1,219 @@ +package checker + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Sentinel errors returned by the SOCKS5 primitives. +var ( + ErrSocks5BadVersion = errors.New("socks5: server returned wrong version") + ErrSocks5RejectedAllAuth = errors.New("socks5: server rejected all offered auth methods (0xFF)") + ErrAuthRejected = errors.New("socks5: user/pass authentication rejected") + ErrCredentialTooLong = errors.New("socks5: login or password longer than 255 bytes") + ErrHostTooLong = errors.New("socks5: target hostname longer than 255 bytes") + ErrUnsupportedRelayATYP = errors.New("socks5: udp associate replied with non-IPv4 ATYP") + ErrShortReply = errors.New("socks5: short server reply") +) + +// ErrSocks5Reply wraps a non-zero REP code so callers can react to specific +// SOCKS5 reply codes (e.g. REP=0x07 = command not supported, REP=0x05 = +// connection refused). +type ErrSocks5Reply struct{ Code byte } + +// Error implements the error interface. +func (e ErrSocks5Reply) Error() string { + return fmt.Sprintf("socks5: server replied with non-zero REP code 0x%02X", e.Code) +} + +// Is reports whether target matches this reply error by Code. +func (e ErrSocks5Reply) Is(target error) bool { + t, ok := target.(ErrSocks5Reply) + if !ok { + return false + } + return t.Code == e.Code +} + +// applyDeadline applies the deadline from ctx (if any) to conn. Returns a +// function to clear the deadline. +func applyDeadline(ctx context.Context, conn net.Conn) { + if dl, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(dl) + } else { + _ = conn.SetDeadline(time.Time{}) + } +} + +// joinCtxErr wraps err with ctx.Err() if ctx has been cancelled or expired, +// so that callers see context.Canceled / context.DeadlineExceeded in the +// error chain even when the underlying I/O reported a deadline-based error. +func joinCtxErr(ctx context.Context, err error) error { + if err == nil { + return nil + } + if cerr := ctx.Err(); cerr != nil { + return errors.Join(err, cerr) + } + return err +} + +// socks5Greeting performs the RFC 1928 client greeting on conn. +// useAuth=true sends "05 02 00 02" (offer no-auth and user/pass); +// useAuth=false sends "05 01 00" (offer no-auth only). +func socks5Greeting(ctx context.Context, conn net.Conn, useAuth bool) (method byte, rawReply []byte, err error) { + applyDeadline(ctx, conn) + + var greet []byte + if useAuth { + greet = []byte{0x05, 0x02, 0x00, 0x02} + } else { + greet = []byte{0x05, 0x01, 0x00} + } + + if _, werr := conn.Write(greet); werr != nil { + return 0, nil, joinCtxErr(ctx, fmt.Errorf("socks5 greeting: write: %w", werr)) + } + + reply := make([]byte, 2) + n, rerr := io.ReadFull(conn, reply) + if rerr != nil { + partial := reply[:n] + if errors.Is(rerr, io.ErrUnexpectedEOF) || errors.Is(rerr, io.EOF) { + return 0, partial, joinCtxErr(ctx, fmt.Errorf("socks5 greeting: %w (raw=%x)", ErrShortReply, partial)) + } + return 0, partial, joinCtxErr(ctx, fmt.Errorf("socks5 greeting: read: %w (raw=%x)", rerr, partial)) + } + + if reply[0] != 0x05 { + return 0, reply, fmt.Errorf("socks5 greeting: %w (raw=%x)", ErrSocks5BadVersion, reply) + } + if reply[1] == 0xFF { + return reply[1], reply, fmt.Errorf("socks5 greeting: %w (raw=%x)", ErrSocks5RejectedAllAuth, reply) + } + return reply[1], reply, nil +} + +// socks5Auth performs RFC 1929 user/pass sub-negotiation on conn, +// after greeting selected method 0x02. +func socks5Auth(ctx context.Context, conn net.Conn, login, password string) (rawReply []byte, err error) { + if len(login) > 255 || len(password) > 255 { + return nil, ErrCredentialTooLong + } + + applyDeadline(ctx, conn) + + buf := make([]byte, 0, 3+len(login)+len(password)) + buf = append(buf, 0x01) // VER + buf = append(buf, byte(len(login))) // ULEN + buf = append(buf, []byte(login)...) // UNAME + buf = append(buf, byte(len(password))) + buf = append(buf, []byte(password)...) + + if _, werr := conn.Write(buf); werr != nil { + return nil, joinCtxErr(ctx, fmt.Errorf("socks5 auth: write: %w", werr)) + } + + reply := make([]byte, 2) + n, rerr := io.ReadFull(conn, reply) + if rerr != nil { + partial := reply[:n] + if errors.Is(rerr, io.ErrUnexpectedEOF) || errors.Is(rerr, io.EOF) { + return partial, joinCtxErr(ctx, fmt.Errorf("socks5 auth: %w (raw=%x)", ErrShortReply, partial)) + } + return partial, joinCtxErr(ctx, fmt.Errorf("socks5 auth: read: %w (raw=%x)", rerr, partial)) + } + + if reply[0] != 0x01 { + return reply, fmt.Errorf("socks5 auth: auth subneg version mismatch: got 0x%02X want 0x01 (raw=%x)", reply[0], reply) + } + if reply[1] != 0x00 { + return reply, fmt.Errorf("socks5 auth: %w (raw=%x)", ErrAuthRejected, reply) + } + return reply, nil +} + +// socks5Connect performs SOCKS5 CONNECT (CMD=01) to host:port using +// ATYP=03 (domain name). +func socks5Connect(ctx context.Context, conn net.Conn, host string, port uint16) (rawReply []byte, err error) { + if len(host) > 255 { + return nil, ErrHostTooLong + } + + applyDeadline(ctx, conn) + + // VER=05 CMD=01 RSV=00 ATYP=03 LEN host port + req := make([]byte, 0, 7+len(host)) + req = append(req, 0x05, 0x01, 0x00, 0x03) + req = append(req, byte(len(host))) + req = append(req, []byte(host)...) + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], port) + req = append(req, portBuf[:]...) + + if _, werr := conn.Write(req); werr != nil { + return nil, joinCtxErr(ctx, fmt.Errorf("socks5 connect: write: %w", werr)) + } + + // We always read 10 bytes (assuming ATYP=01 IPv4 reply, the most + // common case from real proxies). Parsing variable-length BND is + // out of scope for the diagnostic. + reply := make([]byte, 10) + n, rerr := io.ReadFull(conn, reply) + if rerr != nil { + partial := reply[:n] + if errors.Is(rerr, io.ErrUnexpectedEOF) || errors.Is(rerr, io.EOF) { + return partial, joinCtxErr(ctx, fmt.Errorf("socks5 connect: %w (raw=%x)", ErrShortReply, partial)) + } + return partial, joinCtxErr(ctx, fmt.Errorf("socks5 connect: read: %w (raw=%x)", rerr, partial)) + } + + if reply[0] != 0x05 { + return reply, fmt.Errorf("socks5 connect: %w (raw=%x)", ErrSocks5BadVersion, reply) + } + if reply[1] != 0x00 { + return reply, fmt.Errorf("socks5 connect: %w (raw=%x)", ErrSocks5Reply{Code: reply[1]}, reply) + } + return reply, nil +} + +// socks5UDPAssociate performs SOCKS5 UDP ASSOCIATE (CMD=03) on conn. +func socks5UDPAssociate(ctx context.Context, conn net.Conn) (relay *net.UDPAddr, rawReply []byte, err error) { + applyDeadline(ctx, conn) + + // VER=05 CMD=03 RSV=00 ATYP=01 DST.ADDR=0.0.0.0 DST.PORT=0 + req := []byte{0x05, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + if _, werr := conn.Write(req); werr != nil { + return nil, nil, joinCtxErr(ctx, fmt.Errorf("socks5 udp-associate: write: %w", werr)) + } + + reply := make([]byte, 10) + n, rerr := io.ReadFull(conn, reply) + if rerr != nil { + partial := reply[:n] + if errors.Is(rerr, io.ErrUnexpectedEOF) || errors.Is(rerr, io.EOF) { + return nil, partial, joinCtxErr(ctx, fmt.Errorf("socks5 udp-associate: %w (raw=%x)", ErrShortReply, partial)) + } + return nil, partial, joinCtxErr(ctx, fmt.Errorf("socks5 udp-associate: read: %w (raw=%x)", rerr, partial)) + } + + if reply[0] != 0x05 { + return nil, reply, fmt.Errorf("socks5 udp-associate: %w (raw=%x)", ErrSocks5BadVersion, reply) + } + if reply[1] != 0x00 { + return nil, reply, fmt.Errorf("socks5 udp-associate: %w (raw=%x)", ErrSocks5Reply{Code: reply[1]}, reply) + } + if reply[3] != 0x01 { + return nil, reply, fmt.Errorf("socks5 udp-associate: %w (atyp=0x%02X raw=%x)", ErrUnsupportedRelayATYP, reply[3], reply) + } + + ip := net.IPv4(reply[4], reply[5], reply[6], reply[7]) + port := binary.BigEndian.Uint16(reply[8:10]) + relay = &net.UDPAddr{IP: ip, Port: int(port)} + return relay, reply, nil +} diff --git a/internal/checker/socks5_test.go b/internal/checker/socks5_test.go new file mode 100644 index 0000000..375cb4c --- /dev/null +++ b/internal/checker/socks5_test.go @@ -0,0 +1,357 @@ +package checker + +import ( + "context" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newFakeSocks5Server starts a TCP listener on 127.0.0.1:0. On the first +// accepted connection it reads up to 1024 bytes (enough for any of our +// primitives' fixed-length frames in a single Write), then writes +// scriptedReply, then closes the connection. The listener is closed by +// t.Cleanup. +func newFakeSocks5Server(t *testing.T, scriptedReply []byte) (addr string) { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "listen") + + done := make(chan struct{}) + t.Cleanup(func() { + _ = ln.Close() + <-done + }) + + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + _, _ = conn.Read(buf) + if len(scriptedReply) > 0 { + _, _ = conn.Write(scriptedReply) + } + }() + + return ln.Addr().String() +} + +// dial connects to addr and registers t.Cleanup to close the conn. +func dial(t *testing.T, addr string) net.Conn { + t.Helper() + conn, err := net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err, "dial") + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +func ctxShort(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + t.Cleanup(cancel) + return ctx +} + +func TestSocks5Greeting(t *testing.T) { + t.Run("happy_no_auth", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x05, 0x00}) + conn := dial(t, addr) + + method, raw, err := socks5Greeting(ctxShort(t), conn, false) + require.NoError(t, err) + assert.Equal(t, byte(0x00), method) + assert.Equal(t, []byte{0x05, 0x00}, raw) + }) + + t.Run("happy_userpass_selected", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x05, 0x02}) + conn := dial(t, addr) + + method, raw, err := socks5Greeting(ctxShort(t), conn, true) + require.NoError(t, err) + assert.Equal(t, byte(0x02), method) + assert.Equal(t, []byte{0x05, 0x02}, raw) + }) + + t.Run("happy_no_auth_when_offered_both", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x05, 0x00}) + conn := dial(t, addr) + + method, raw, err := socks5Greeting(ctxShort(t), conn, true) + require.NoError(t, err) + assert.Equal(t, byte(0x00), method) + assert.Equal(t, []byte{0x05, 0x00}, raw) + }) + + t.Run("rejected_all_auth", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x05, 0xFF}) + conn := dial(t, addr) + + _, raw, err := socks5Greeting(ctxShort(t), conn, true) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5RejectedAllAuth), "expected ErrSocks5RejectedAllAuth in chain, got: %v", err) + assert.Equal(t, []byte{0x05, 0xFF}, raw) + }) + + t.Run("bad_version", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x04, 0x00}) + conn := dial(t, addr) + + _, raw, err := socks5Greeting(ctxShort(t), conn, false) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion in chain, got: %v", err) + assert.Equal(t, []byte{0x04, 0x00}, raw) + }) + + t.Run("short_read", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x05}) + conn := dial(t, addr) + + _, _, err := socks5Greeting(ctxShort(t), conn, false) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply in chain, got: %v", err) + }) + + t.Run("garbage_http_response", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte("HTTP/1.1 200 OK\r\n")) + conn := dial(t, addr) + + _, raw, err := socks5Greeting(ctxShort(t), conn, false) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion, got: %v", err) + // First two bytes "HT" = 0x48 0x54 + assert.Equal(t, []byte{'H', 'T'}, raw) + }) +} + +func TestSocks5Auth(t *testing.T) { + t.Run("happy", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x01, 0x00}) + conn := dial(t, addr) + + raw, err := socks5Auth(ctxShort(t), conn, "user", "pass") + require.NoError(t, err) + assert.Equal(t, []byte{0x01, 0x00}, raw) + }) + + t.Run("rejected", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x01, 0x01}) + conn := dial(t, addr) + + raw, err := socks5Auth(ctxShort(t), conn, "user", "pass") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrAuthRejected), "expected ErrAuthRejected, got: %v", err) + assert.Equal(t, []byte{0x01, 0x01}, raw) + }) + + t.Run("short_read", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x01}) + conn := dial(t, addr) + + _, err := socks5Auth(ctxShort(t), conn, "user", "pass") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) + }) + + t.Run("bad_subneg_version", func(t *testing.T) { + addr := newFakeSocks5Server(t, []byte{0x02, 0x00}) + conn := dial(t, addr) + + _, err := socks5Auth(ctxShort(t), conn, "user", "pass") + require.Error(t, err) + assert.Contains(t, err.Error(), "auth subneg version", "want subneg version mention, got: %v", err) + }) + + t.Run("login_too_long", func(t *testing.T) { + // 300 chars, no I/O should occur + conn := &noopConn{} + long := strings.Repeat("a", 300) + _, err := socks5Auth(context.Background(), conn, long, "pass") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrCredentialTooLong), "expected ErrCredentialTooLong, got: %v", err) + assert.False(t, conn.touched, "no I/O should occur for over-long credential") + }) +} + +func TestSocks5Connect(t *testing.T) { + t.Run("happy", func(t *testing.T) { + // 05 00 00 01 00000000 0000 + reply := []byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) + require.NoError(t, err) + assert.Equal(t, reply, raw) + }) + + t.Run("rep_connection_refused", func(t *testing.T) { + reply := []byte{0x05, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x05}), "expected ErrSocks5Reply{Code:5}, got: %v", err) + assert.Equal(t, reply, raw) + }) + + t.Run("rep_cmd_not_supported", func(t *testing.T) { + reply := []byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x07}), "expected ErrSocks5Reply{Code:7}, got: %v", err) + assert.Equal(t, reply, raw) + // And it should NOT match other codes: + assert.False(t, errors.Is(err, ErrSocks5Reply{Code: 0x05})) + }) + + t.Run("short_read", func(t *testing.T) { + reply := []byte{0x05, 0x00, 0x00, 0x01, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + _, err := socks5Connect(ctxShort(t), conn, "example.com", 443) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) + }) + + t.Run("bad_version", func(t *testing.T) { + reply := []byte{0x04, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + _, err := socks5Connect(ctxShort(t), conn, "example.com", 443) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion, got: %v", err) + }) + + t.Run("host_too_long", func(t *testing.T) { + conn := &noopConn{} + long := strings.Repeat("h", 300) + _, err := socks5Connect(context.Background(), conn, long, 443) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrHostTooLong), "expected ErrHostTooLong, got: %v", err) + assert.False(t, conn.touched, "no I/O should occur for over-long host") + }) +} + +func TestSocks5UDPAssociate(t *testing.T) { + t.Run("happy_ipv4", func(t *testing.T) { + // 05 00 00 01 7F000001 0539 -> 127.0.0.1:1337 + reply := []byte{0x05, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x05, 0x39} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) + require.NoError(t, err) + require.NotNil(t, relay) + assert.True(t, relay.IP.Equal(net.IPv4(127, 0, 0, 1)), "ip=%s", relay.IP) + assert.Equal(t, 1337, relay.Port) + assert.Equal(t, reply, raw) + }) + + t.Run("rep_cmd_not_supported", func(t *testing.T) { + reply := []byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) + require.Error(t, err) + assert.Nil(t, relay) + assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x07}), "expected ErrSocks5Reply{Code:7}, got: %v", err) + assert.Equal(t, reply, raw) + }) + + t.Run("atyp_ipv6_unsupported", func(t *testing.T) { + // REP=0x00 (success), ATYP=0x04 (IPv6) — unsupported by us. We + // only read 10 bytes total so the trailing IPv6 bytes are + // implicitly ignored on the wire. + reply := []byte{0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) + require.Error(t, err) + assert.Nil(t, relay) + assert.True(t, errors.Is(err, ErrUnsupportedRelayATYP), "expected ErrUnsupportedRelayATYP, got: %v", err) + assert.Equal(t, reply, raw) + }) + + t.Run("short_read", func(t *testing.T) { + reply := []byte{0x05, 0x00, 0x00} + addr := newFakeSocks5Server(t, reply) + conn := dial(t, addr) + + _, _, err := socks5UDPAssociate(ctxShort(t), conn) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) + }) +} + +// TestSocks5GreetingCtxCancel verifies that a cancelled ctx surfaces +// context.Canceled in the error chain even if the underlying I/O fails +// with a deadline-style error. +func TestSocks5GreetingCtxCancel(t *testing.T) { + // Server that accepts but never replies — read will hang until ctx + // deadline triggers SetDeadline-induced timeout. + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accepted := make(chan struct{}) + go func() { + defer close(accepted) + conn, err := ln.Accept() + if err != nil { + return + } + // Hold the connection open without writing anything. + t.Cleanup(func() { _ = conn.Close() }) + <-accepted // intentionally blocks; actually we close immediately on test end + }() + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 1*time.Second) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, _, err = socks5Greeting(ctx, conn, false) + require.Error(t, err) + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected ctx error in chain, got: %v", err) +} + +// noopConn is a minimal net.Conn that records whether any I/O was +// attempted. Used to assert that pre-I/O validation rejects oversized +// inputs without ever touching the wire. +type noopConn struct { + touched bool +} + +func (c *noopConn) Read(b []byte) (int, error) { c.touched = true; return 0, io.EOF } +func (c *noopConn) Write(b []byte) (int, error) { c.touched = true; return len(b), nil } +func (c *noopConn) Close() error { return nil } +func (c *noopConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (c *noopConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (c *noopConn) SetDeadline(time.Time) error { return nil } +func (c *noopConn) SetReadDeadline(time.Time) error { return nil } +func (c *noopConn) SetWriteDeadline(time.Time) error { return nil }