package divert import ( "encoding/binary" "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // helloTCPSYN is a minimum well-formed IPv4 + TCP SYN packet: // src=10.0.0.1:54321 dst=1.2.3.4:443 // Captured from a raw socket trace; checksums are correct. var helloTCPSYN = []byte{ // IPv4 header (20 bytes, IHL=5) 0x45, 0x00, 0x00, 0x28, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, // checksum placeholder — we'll fill in below 0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1 0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4 // TCP header (20 bytes) 0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xff, 0xff, 0x00, 0x00, // checksum placeholder 0x00, 0x00, } // fillTestChecksums computes correct IP + TCP checksums for the test // packet so we can compare against the parser's recompute output. func fillTestChecksums(b []byte) { // IP checksum b[10], b[11] = 0, 0 cs := ipChecksum(b[:20]) b[10] = byte(cs >> 8) b[11] = byte(cs & 0xff) // TCP checksum b[36], b[37] = 0, 0 cs = tcpChecksum(b[:20], b[20:40]) b[36] = byte(cs >> 8) b[37] = byte(cs & 0xff) } func TestParseIPv4TCP_Roundtrip(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) p, err := ParseIPv4TCP(pkt) require.NoError(t, err) assert.Equal(t, "10.0.0.1", p.SrcIP.String()) assert.Equal(t, "1.2.3.4", p.DstIP.String()) assert.Equal(t, uint16(54321), p.SrcPort) assert.Equal(t, uint16(443), p.DstPort) } func TestRewriteDst_PreservesSrc(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) require.NoError(t, err) p, err := ParseIPv4TCP(pkt) require.NoError(t, err) assert.Equal(t, "127.0.0.1", p.DstIP.String()) assert.Equal(t, uint16(8080), p.DstPort) assert.Equal(t, "10.0.0.1", p.SrcIP.String()) assert.Equal(t, uint16(54321), p.SrcPort) } func TestRewriteDst_RecomputesChecksums(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) require.NoError(t, err) // Validate IP checksum ipCs := uint16(pkt[10])<<8 | uint16(pkt[11]) pkt[10], pkt[11] = 0, 0 expIP := ipChecksum(pkt[:20]) pkt[10] = byte(ipCs >> 8) pkt[11] = byte(ipCs & 0xff) assert.Equal(t, expIP, ipCs, "IP checksum mismatch") // Validate TCP checksum tcpCs := uint16(pkt[36])<<8 | uint16(pkt[37]) pkt[36], pkt[37] = 0, 0 expTCP := tcpChecksum(pkt[:20], pkt[20:]) pkt[36] = byte(tcpCs >> 8) pkt[37] = byte(tcpCs & 0xff) assert.Equal(t, expTCP, tcpCs, "TCP checksum mismatch") } func TestParseIPv4TCP_Errors(t *testing.T) { cases := []struct { name string b []byte }{ {"too_short", []byte{0x45}}, {"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, {"not_tcp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 17, /* UDP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { _, err := ParseIPv4TCP(c.b) assert.Error(t, err) }) } } // helloUDP is a minimum well-formed IPv4 + UDP datagram: // // src=10.0.0.1:54321 dst=1.2.3.4:443 payload=4 bytes ABCD // // Total length: 20(IP) + 8(UDP) + 4(payload) = 32 bytes. var helloUDP = []byte{ // IPv4 header (20 bytes, IHL=5) 0x45, 0x00, 0x00, 0x20, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x11, // proto=17 (UDP) 0x00, 0x00, // checksum placeholder 0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1 0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4 // UDP header (8 bytes) 0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443 0x00, 0x0c, // length=12 (UDP header + 4 payload) 0x00, 0x00, // checksum placeholder // Payload (4 bytes) 'A', 'B', 'C', 'D', } func fillUDPTestChecksums(b []byte) { // IP checksum b[10], b[11] = 0, 0 cs := ipChecksum(b[:20]) b[10] = byte(cs >> 8) b[11] = byte(cs & 0xff) // UDP checksum (covers UDP header + payload + pseudo-header) udpLen := int(binary.BigEndian.Uint16(b[24:26])) b[26], b[27] = 0, 0 cs = udpChecksum(b[:20], b[20:20+udpLen]) if cs == 0 { cs = 0xFFFF } b[26] = byte(cs >> 8) b[27] = byte(cs & 0xff) } func TestParseIPv4UDP_Roundtrip(t *testing.T) { pkt := make([]byte, len(helloUDP)) copy(pkt, helloUDP) fillUDPTestChecksums(pkt) p, err := ParseIPv4UDP(pkt) require.NoError(t, err) assert.Equal(t, "10.0.0.1", p.SrcIP.String()) assert.Equal(t, "1.2.3.4", p.DstIP.String()) assert.Equal(t, uint16(54321), p.SrcPort) assert.Equal(t, uint16(443), p.DstPort) assert.Equal(t, 20, p.IHL) assert.Equal(t, uint16(12), p.UDPLen) } func TestParseIPv4UDP_Errors(t *testing.T) { cases := []struct { name string b []byte }{ {"too_short", []byte{0x45}}, {"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, {"not_udp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 6, /* TCP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { _, err := ParseIPv4UDP(c.b) assert.Error(t, err) }) } } func TestSwapUDPAndSetDstPort(t *testing.T) { pkt := make([]byte, len(helloUDP)) copy(pkt, helloUDP) fillUDPTestChecksums(pkt) require.NoError(t, SwapUDPAndSetDstPort(pkt, 8080)) p, err := ParseIPv4UDP(pkt) require.NoError(t, err) assert.Equal(t, "1.2.3.4", p.SrcIP.String(), "src should be original dst after swap") assert.Equal(t, "10.0.0.1", p.DstIP.String(), "dst should be original src after swap") assert.Equal(t, uint16(54321), p.SrcPort, "src port unchanged") assert.Equal(t, uint16(8080), p.DstPort, "dst port set to new value") // Validate IP checksum recomputed ipCs := uint16(pkt[10])<<8 | uint16(pkt[11]) pkt[10], pkt[11] = 0, 0 expIP := ipChecksum(pkt[:20]) assert.Equal(t, expIP, ipCs, "IP checksum mismatch") } func TestSwapUDPAndSetSrcPort(t *testing.T) { pkt := make([]byte, len(helloUDP)) copy(pkt, helloUDP) fillUDPTestChecksums(pkt) require.NoError(t, SwapUDPAndSetSrcPort(pkt, 50007)) p, err := ParseIPv4UDP(pkt) require.NoError(t, err) assert.Equal(t, "1.2.3.4", p.SrcIP.String()) assert.Equal(t, "10.0.0.1", p.DstIP.String()) assert.Equal(t, uint16(50007), p.SrcPort, "src port set to new value") assert.Equal(t, uint16(443), p.DstPort, "dst port unchanged") } func TestBuildIPv4UDPInbound(t *testing.T) { src := net.IPv4(140, 82, 121, 4) // GitHub IP, just for variety dst := net.IPv4(192, 168, 1, 50) // local LAN payload := []byte("hello voice") pkt, err := BuildIPv4UDPInbound(src, dst, 50007, 50100, payload) require.NoError(t, err) // Total length: 20+8+11 = 39 assert.Len(t, pkt, 39) // Re-parse and verify fields p, err := ParseIPv4UDP(pkt) require.NoError(t, err) assert.Equal(t, "140.82.121.4", p.SrcIP.String()) assert.Equal(t, "192.168.1.50", p.DstIP.String()) assert.Equal(t, uint16(50007), p.SrcPort) assert.Equal(t, uint16(50100), p.DstPort) assert.Equal(t, uint16(8+len(payload)), p.UDPLen) // Payload after headers assert.Equal(t, payload, pkt[28:]) // IP checksum valid: clearing + recomputing should match ipCs := uint16(pkt[10])<<8 | uint16(pkt[11]) pkt[10], pkt[11] = 0, 0 expIP := ipChecksum(pkt[:20]) assert.Equal(t, expIP, ipCs, "IP checksum should be valid") // UDP checksum valid (and non-zero) udpCs := uint16(pkt[26])<<8 | uint16(pkt[27]) assert.NotEqual(t, uint16(0), udpCs, "UDP checksum should be non-zero (RFC 768 trick)") } func TestBuildIPv4UDPInbound_NotIPv4(t *testing.T) { v6 := net.ParseIP("::1") _, err := BuildIPv4UDPInbound(v6, net.IPv4(1, 2, 3, 4), 1, 2, []byte("x")) assert.Error(t, err) }