package checker import ( "context" "errors" "io" "net" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // newFakeSocks5Server starts a TCP listener on 127.0.0.1:0. On the first // accepted connection it reads up to 1024 bytes (enough for any of our // primitives' fixed-length frames in a single Write), then writes // scriptedReply, then closes the connection. The listener is closed by // t.Cleanup. func newFakeSocks5Server(t *testing.T, scriptedReply []byte) (addr string) { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "listen") done := make(chan struct{}) t.Cleanup(func() { _ = ln.Close() <-done }) go func() { defer close(done) conn, err := ln.Accept() if err != nil { return } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) buf := make([]byte, 1024) _, _ = conn.Read(buf) if len(scriptedReply) > 0 { _, _ = conn.Write(scriptedReply) } }() return ln.Addr().String() } // dial connects to addr and registers t.Cleanup to close the conn. func dial(t *testing.T, addr string) net.Conn { t.Helper() conn, err := net.DialTimeout("tcp", addr, 1*time.Second) require.NoError(t, err, "dial") t.Cleanup(func() { _ = conn.Close() }) return conn } func ctxShort(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) t.Cleanup(cancel) return ctx } func TestSocks5Greeting(t *testing.T) { t.Run("happy_no_auth", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x05, 0x00}) conn := dial(t, addr) method, raw, err := socks5Greeting(ctxShort(t), conn, false) require.NoError(t, err) assert.Equal(t, byte(0x00), method) assert.Equal(t, []byte{0x05, 0x00}, raw) }) t.Run("happy_userpass_selected", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x05, 0x02}) conn := dial(t, addr) method, raw, err := socks5Greeting(ctxShort(t), conn, true) require.NoError(t, err) assert.Equal(t, byte(0x02), method) assert.Equal(t, []byte{0x05, 0x02}, raw) }) t.Run("happy_no_auth_when_offered_both", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x05, 0x00}) conn := dial(t, addr) method, raw, err := socks5Greeting(ctxShort(t), conn, true) require.NoError(t, err) assert.Equal(t, byte(0x00), method) assert.Equal(t, []byte{0x05, 0x00}, raw) }) t.Run("rejected_all_auth", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x05, 0xFF}) conn := dial(t, addr) _, raw, err := socks5Greeting(ctxShort(t), conn, true) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5RejectedAllAuth), "expected ErrSocks5RejectedAllAuth in chain, got: %v", err) assert.Equal(t, []byte{0x05, 0xFF}, raw) }) t.Run("bad_version", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x04, 0x00}) conn := dial(t, addr) _, raw, err := socks5Greeting(ctxShort(t), conn, false) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion in chain, got: %v", err) assert.Equal(t, []byte{0x04, 0x00}, raw) }) t.Run("short_read", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x05}) conn := dial(t, addr) _, _, err := socks5Greeting(ctxShort(t), conn, false) require.Error(t, err) assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply in chain, got: %v", err) }) t.Run("garbage_http_response", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte("HTTP/1.1 200 OK\r\n")) conn := dial(t, addr) _, raw, err := socks5Greeting(ctxShort(t), conn, false) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion, got: %v", err) // First two bytes "HT" = 0x48 0x54 assert.Equal(t, []byte{'H', 'T'}, raw) }) } func TestSocks5Auth(t *testing.T) { t.Run("happy", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x01, 0x00}) conn := dial(t, addr) raw, err := socks5Auth(ctxShort(t), conn, "user", "pass") require.NoError(t, err) assert.Equal(t, []byte{0x01, 0x00}, raw) }) t.Run("rejected", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x01, 0x01}) conn := dial(t, addr) raw, err := socks5Auth(ctxShort(t), conn, "user", "pass") require.Error(t, err) assert.True(t, errors.Is(err, ErrAuthRejected), "expected ErrAuthRejected, got: %v", err) assert.Equal(t, []byte{0x01, 0x01}, raw) }) t.Run("short_read", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x01}) conn := dial(t, addr) _, err := socks5Auth(ctxShort(t), conn, "user", "pass") require.Error(t, err) assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) }) t.Run("bad_subneg_version", func(t *testing.T) { addr := newFakeSocks5Server(t, []byte{0x02, 0x00}) conn := dial(t, addr) _, err := socks5Auth(ctxShort(t), conn, "user", "pass") require.Error(t, err) assert.Contains(t, err.Error(), "auth subneg version", "want subneg version mention, got: %v", err) }) t.Run("login_too_long", func(t *testing.T) { // 300 chars, no I/O should occur conn := &noopConn{} long := strings.Repeat("a", 300) _, err := socks5Auth(context.Background(), conn, long, "pass") require.Error(t, err) assert.True(t, errors.Is(err, ErrCredentialTooLong), "expected ErrCredentialTooLong, got: %v", err) assert.False(t, conn.touched, "no I/O should occur for over-long credential") }) } func TestSocks5Connect(t *testing.T) { t.Run("happy", func(t *testing.T) { // 05 00 00 01 00000000 0000 reply := []byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) require.NoError(t, err) assert.Equal(t, reply, raw) }) t.Run("rep_connection_refused", func(t *testing.T) { reply := []byte{0x05, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x05}), "expected ErrSocks5Reply{Code:5}, got: %v", err) assert.Equal(t, reply, raw) }) t.Run("rep_cmd_not_supported", func(t *testing.T) { reply := []byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) raw, err := socks5Connect(ctxShort(t), conn, "example.com", 443) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x07}), "expected ErrSocks5Reply{Code:7}, got: %v", err) assert.Equal(t, reply, raw) // And it should NOT match other codes: assert.False(t, errors.Is(err, ErrSocks5Reply{Code: 0x05})) }) t.Run("short_read", func(t *testing.T) { reply := []byte{0x05, 0x00, 0x00, 0x01, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) _, err := socks5Connect(ctxShort(t), conn, "example.com", 443) require.Error(t, err) assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) }) t.Run("bad_version", func(t *testing.T) { reply := []byte{0x04, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) _, err := socks5Connect(ctxShort(t), conn, "example.com", 443) require.Error(t, err) assert.True(t, errors.Is(err, ErrSocks5BadVersion), "expected ErrSocks5BadVersion, got: %v", err) }) t.Run("host_too_long", func(t *testing.T) { conn := &noopConn{} long := strings.Repeat("h", 300) _, err := socks5Connect(context.Background(), conn, long, 443) require.Error(t, err) assert.True(t, errors.Is(err, ErrHostTooLong), "expected ErrHostTooLong, got: %v", err) assert.False(t, conn.touched, "no I/O should occur for over-long host") }) } func TestSocks5UDPAssociate(t *testing.T) { t.Run("happy_ipv4", func(t *testing.T) { // 05 00 00 01 7F000001 0539 -> 127.0.0.1:1337 reply := []byte{0x05, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x05, 0x39} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) require.NoError(t, err) require.NotNil(t, relay) assert.True(t, relay.IP.Equal(net.IPv4(127, 0, 0, 1)), "ip=%s", relay.IP) assert.Equal(t, 1337, relay.Port) assert.Equal(t, reply, raw) }) t.Run("rep_cmd_not_supported", func(t *testing.T) { reply := []byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) require.Error(t, err) assert.Nil(t, relay) assert.True(t, errors.Is(err, ErrSocks5Reply{Code: 0x07}), "expected ErrSocks5Reply{Code:7}, got: %v", err) assert.Equal(t, reply, raw) }) t.Run("atyp_ipv6_unsupported", func(t *testing.T) { // REP=0x00 (success), ATYP=0x04 (IPv6) — unsupported by us. We // only read 10 bytes total so the trailing IPv6 bytes are // implicitly ignored on the wire. reply := []byte{0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) relay, raw, err := socks5UDPAssociate(ctxShort(t), conn) require.Error(t, err) assert.Nil(t, relay) assert.True(t, errors.Is(err, ErrUnsupportedRelayATYP), "expected ErrUnsupportedRelayATYP, got: %v", err) assert.Equal(t, reply, raw) }) t.Run("short_read", func(t *testing.T) { reply := []byte{0x05, 0x00, 0x00} addr := newFakeSocks5Server(t, reply) conn := dial(t, addr) _, _, err := socks5UDPAssociate(ctxShort(t), conn) require.Error(t, err) assert.True(t, errors.Is(err, ErrShortReply), "expected ErrShortReply, got: %v", err) }) } // TestSocks5GreetingCtxCancel verifies that a cancelled ctx surfaces // context.Canceled in the error chain even if the underlying I/O fails // with a deadline-style error. func TestSocks5GreetingCtxCancel(t *testing.T) { // Server that accepts but never replies — read will hang until ctx // deadline triggers SetDeadline-induced timeout. ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { _ = ln.Close() }) accepted := make(chan struct{}) go func() { defer close(accepted) conn, err := ln.Accept() if err != nil { return } // Hold the connection open without writing anything. t.Cleanup(func() { _ = conn.Close() }) <-accepted // intentionally blocks; actually we close immediately on test end }() conn, err := net.DialTimeout("tcp", ln.Addr().String(), 1*time.Second) require.NoError(t, err) t.Cleanup(func() { _ = conn.Close() }) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() _, _, err = socks5Greeting(ctx, conn, false) require.Error(t, err) assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), "expected ctx error in chain, got: %v", err) } // noopConn is a minimal net.Conn that records whether any I/O was // attempted. Used to assert that pre-I/O validation rejects oversized // inputs without ever touching the wire. type noopConn struct { touched bool } func (c *noopConn) Read(b []byte) (int, error) { c.touched = true; return 0, io.EOF } func (c *noopConn) Write(b []byte) (int, error) { c.touched = true; return len(b), nil } func (c *noopConn) Close() error { return nil } func (c *noopConn) LocalAddr() net.Addr { return &net.TCPAddr{} } func (c *noopConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } func (c *noopConn) SetDeadline(time.Time) error { return nil } func (c *noopConn) SetReadDeadline(time.Time) error { return nil } func (c *noopConn) SetWriteDeadline(time.Time) error { return nil }