diff --git a/internal/socks5/client.go b/internal/socks5/client.go new file mode 100644 index 0000000..ff7f465 --- /dev/null +++ b/internal/socks5/client.go @@ -0,0 +1,117 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Config carries connection-time SOCKS5 settings. +type Config struct { + ProxyAddr string // "host:port" + UseAuth bool + Login string + Password string +} + +// Dial opens a TCP connection to the SOCKS5 proxy, runs the greeting, +// optionally authenticates with username/password (RFC 1929), and +// issues a CONNECT to host:port (sent as ATYP=03 domain so the proxy +// resolves on its side). Returns the established net.Conn ready for +// bidirectional traffic. +// +// The given ctx bounds dial + handshake; once Dial returns, the conn +// has its own deadline-free I/O state. +func Dial(ctx context.Context, cfg Config, host string, port uint16) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", cfg.ProxyAddr) + if err != nil { + return nil, fmt.Errorf("dial proxy: %w", err) + } + if dl, ok := ctx.Deadline(); ok { + conn.SetDeadline(dl) + } + if err := handshake(conn, cfg, host, port); err != nil { + conn.Close() + return nil, err + } + conn.SetDeadline(time.Time{}) + return conn, nil +} + +func handshake(conn net.Conn, cfg Config, host string, port uint16) error { + // Greeting + if cfg.UseAuth { + if _, err := conn.Write([]byte{0x05, 0x02, 0x00, 0x02}); err != nil { + return fmt.Errorf("greet write: %w", err) + } + } else { + if _, err := conn.Write([]byte{0x05, 0x01, 0x00}); err != nil { + return fmt.Errorf("greet write: %w", err) + } + } + var rep [2]byte + if _, err := io.ReadFull(conn, rep[:]); err != nil { + return fmt.Errorf("greet read: %w", err) + } + if rep[0] != 0x05 { + return fmt.Errorf("greet: server version %#x is not SOCKS5", rep[0]) + } + if rep[1] == 0xff { + return errors.New("greet: proxy rejected all offered auth methods") + } + method := rep[1] + + // Auth subneg + if method == 0x02 { + if !cfg.UseAuth { + return errors.New("proxy requires auth but Config.UseAuth is false") + } + if len(cfg.Login) > 255 || len(cfg.Password) > 255 { + return errors.New("login or password too long") + } + buf := make([]byte, 0, 3+len(cfg.Login)+len(cfg.Password)) + buf = append(buf, 0x01, byte(len(cfg.Login))) + buf = append(buf, []byte(cfg.Login)...) + buf = append(buf, byte(len(cfg.Password))) + buf = append(buf, []byte(cfg.Password)...) + if _, err := conn.Write(buf); err != nil { + return fmt.Errorf("auth write: %w", err) + } + if _, err := io.ReadFull(conn, rep[:]); err != nil { + return fmt.Errorf("auth read: %w", err) + } + if rep[1] != 0x00 { + return errors.New("auth: invalid login or password") + } + } + + // CONNECT + if len(host) > 255 { + return errors.New("host too long") + } + req := make([]byte, 0, 7+len(host)) + req = append(req, 0x05, 0x01, 0x00, 0x03, byte(len(host))) + req = append(req, []byte(host)...) + pBuf := make([]byte, 2) + binary.BigEndian.PutUint16(pBuf, port) + req = append(req, pBuf...) + if _, err := conn.Write(req); err != nil { + return fmt.Errorf("connect write: %w", err) + } + var creply [10]byte + if _, err := io.ReadFull(conn, creply[:]); err != nil { + return fmt.Errorf("connect read: %w", err) + } + if creply[0] != 0x05 { + return fmt.Errorf("connect: server version %#x is not SOCKS5", creply[0]) + } + if creply[1] != 0x00 { + return fmt.Errorf("connect: REP=%#02x", creply[1]) + } + return nil +} diff --git a/internal/socks5/client_test.go b/internal/socks5/client_test.go new file mode 100644 index 0000000..d942695 --- /dev/null +++ b/internal/socks5/client_test.go @@ -0,0 +1,199 @@ +package socks5 + +import ( + "context" + "io" + "net" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeProxy is a minimal SOCKS5 server that accepts greet+CONNECT +// (and optional auth) and then splices the connection to a target +// listener supplied by the test. +type fakeProxy struct { + addr string + target string + useAuth bool + login string + password string +} + +func startFakeProxy(t *testing.T, target string, useAuth bool, login, password string) *fakeProxy { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + p := &fakeProxy{ + addr: ln.Addr().String(), + target: target, + useAuth: useAuth, login: login, password: password, + } + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + go p.handle(c) + } + }() + return p +} + +func (p *fakeProxy) handle(c net.Conn) { + defer c.Close() + buf := make([]byte, 256) + + // Greeting: 05 N method... + io.ReadFull(c, buf[:2]) + nmethods := int(buf[1]) + io.ReadFull(c, buf[:nmethods]) + if p.useAuth { + c.Write([]byte{0x05, 0x02}) + io.ReadFull(c, buf[:2]) + ulen := int(buf[1]) + io.ReadFull(c, buf[:ulen]) + login := string(buf[:ulen]) + io.ReadFull(c, buf[:1]) + plen := int(buf[0]) + io.ReadFull(c, buf[:plen]) + pwd := string(buf[:plen]) + if login != p.login || pwd != p.password { + c.Write([]byte{0x01, 0x01}) + return + } + c.Write([]byte{0x01, 0x00}) + } else { + c.Write([]byte{0x05, 0x00}) + } + + // CONNECT request: 05 01 00 ATYP ... + io.ReadFull(c, buf[:4]) + atyp := buf[3] + var host string + switch atyp { + case 1: + io.ReadFull(c, buf[:4]) + host = net.IPv4(buf[0], buf[1], buf[2], buf[3]).String() + case 3: + io.ReadFull(c, buf[:1]) + hlen := int(buf[0]) + io.ReadFull(c, buf[:hlen]) + host = string(buf[:hlen]) + } + io.ReadFull(c, buf[:2]) + port := int(buf[0])<<8 | int(buf[1]) + + // Reply REP=0 + c.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + + // Splice to target + target, err := net.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return + } + defer target.Close() + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); io.Copy(target, c) }() + go func() { defer wg.Done(); io.Copy(c, target) }() + wg.Wait() +} + +func TestDial_NoAuth_HappyPath(t *testing.T) { + target, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer target.Close() + go func() { + c, err := target.Accept() + if err != nil { + return + } + defer c.Close() + c.Write([]byte("hello")) + }() + + p := startFakeProxy(t, target.Addr().String(), false, "", "") + + host, port, _ := net.SplitHostPort(target.Addr().String()) + portU, _ := atoiU16(port) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := Dial(ctx, Config{ + ProxyAddr: p.addr, + }, host, portU) + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 5) + io.ReadFull(conn, buf) + assert.Equal(t, "hello", string(buf)) +} + +func TestDial_WithAuth_HappyPath(t *testing.T) { + target, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer target.Close() + go func() { c, _ := target.Accept(); if c != nil { c.Write([]byte("auth-ok")); c.Close() } }() + + p := startFakeProxy(t, target.Addr().String(), true, "user", "pass") + host, port, _ := net.SplitHostPort(target.Addr().String()) + portU, _ := atoiU16(port) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := Dial(ctx, Config{ + ProxyAddr: p.addr, + UseAuth: true, + Login: "user", + Password: "pass", + }, host, portU) + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 7) + io.ReadFull(conn, buf) + assert.Equal(t, "auth-ok", string(buf)) +} + +func TestDial_BadAuth(t *testing.T) { + target, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer target.Close() + + p := startFakeProxy(t, target.Addr().String(), true, "user", "pass") + host, port, _ := net.SplitHostPort(target.Addr().String()) + portU, _ := atoiU16(port) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err = Dial(ctx, Config{ + ProxyAddr: p.addr, + UseAuth: true, + Login: "wrong", + Password: "wrong", + }, host, portU) + require.Error(t, err) +} + +func atoiU16(s string) (uint16, error) { + var n int + for _, c := range s { + if c < '0' || c > '9' { + return 0, &net.AddrError{Err: "invalid port", Addr: s} + } + n = n*10 + int(c-'0') + } + return uint16(n), nil +}