package socks5 import ( "context" "io" "net" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // fakeProxy is a minimal SOCKS5 server that accepts greet+CONNECT // (and optional auth) and then splices the connection to a target // listener supplied by the test. type fakeProxy struct { addr string target string useAuth bool login string password string } func startFakeProxy(t *testing.T, target string, useAuth bool, login, password string) *fakeProxy { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { ln.Close() }) p := &fakeProxy{ addr: ln.Addr().String(), target: target, useAuth: useAuth, login: login, password: password, } go func() { for { c, err := ln.Accept() if err != nil { return } go p.handle(c) } }() return p } func (p *fakeProxy) handle(c net.Conn) { defer c.Close() buf := make([]byte, 256) // Greeting: 05 N method... io.ReadFull(c, buf[:2]) nmethods := int(buf[1]) io.ReadFull(c, buf[:nmethods]) if p.useAuth { c.Write([]byte{0x05, 0x02}) io.ReadFull(c, buf[:2]) ulen := int(buf[1]) io.ReadFull(c, buf[:ulen]) login := string(buf[:ulen]) io.ReadFull(c, buf[:1]) plen := int(buf[0]) io.ReadFull(c, buf[:plen]) pwd := string(buf[:plen]) if login != p.login || pwd != p.password { c.Write([]byte{0x01, 0x01}) return } c.Write([]byte{0x01, 0x00}) } else { c.Write([]byte{0x05, 0x00}) } // CONNECT request: 05 01 00 ATYP ... io.ReadFull(c, buf[:4]) atyp := buf[3] var host string switch atyp { case 1: io.ReadFull(c, buf[:4]) host = net.IPv4(buf[0], buf[1], buf[2], buf[3]).String() case 3: io.ReadFull(c, buf[:1]) hlen := int(buf[0]) io.ReadFull(c, buf[:hlen]) host = string(buf[:hlen]) } io.ReadFull(c, buf[:2]) port := int(buf[0])<<8 | int(buf[1]) // Reply REP=0 c.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) // Splice to target target, err := net.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(port))) if err != nil { return } defer target.Close() var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done(); io.Copy(target, c) }() go func() { defer wg.Done(); io.Copy(c, target) }() wg.Wait() } func TestDial_NoAuth_HappyPath(t *testing.T) { target, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer target.Close() go func() { c, err := target.Accept() if err != nil { return } defer c.Close() c.Write([]byte("hello")) }() p := startFakeProxy(t, target.Addr().String(), false, "", "") host, port, _ := net.SplitHostPort(target.Addr().String()) portU, _ := atoiU16(port) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() conn, err := Dial(ctx, Config{ ProxyAddr: p.addr, }, host, portU) require.NoError(t, err) defer conn.Close() buf := make([]byte, 5) io.ReadFull(conn, buf) assert.Equal(t, "hello", string(buf)) } func TestDial_WithAuth_HappyPath(t *testing.T) { target, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer target.Close() go func() { c, _ := target.Accept(); if c != nil { c.Write([]byte("auth-ok")); c.Close() } }() p := startFakeProxy(t, target.Addr().String(), true, "user", "pass") host, port, _ := net.SplitHostPort(target.Addr().String()) portU, _ := atoiU16(port) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() conn, err := Dial(ctx, Config{ ProxyAddr: p.addr, UseAuth: true, Login: "user", Password: "pass", }, host, portU) require.NoError(t, err) defer conn.Close() buf := make([]byte, 7) io.ReadFull(conn, buf) assert.Equal(t, "auth-ok", string(buf)) } func TestDial_BadAuth(t *testing.T) { target, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer target.Close() p := startFakeProxy(t, target.Addr().String(), true, "user", "pass") host, port, _ := net.SplitHostPort(target.Addr().String()) portU, _ := atoiU16(port) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, err = Dial(ctx, Config{ ProxyAddr: p.addr, UseAuth: true, Login: "wrong", Password: "wrong", }, host, portU) require.Error(t, err) } func atoiU16(s string) (uint16, error) { var n int for _, c := range s { if c < '0' || c > '9' { return 0, &net.AddrError{Err: "invalid port", Addr: s} } n = n*10 + int(c-'0') } return uint16(n), nil }