package divert import ( "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // helloTCPSYN is a minimum well-formed IPv4 + TCP SYN packet: // src=10.0.0.1:54321 dst=1.2.3.4:443 // Captured from a raw socket trace; checksums are correct. var helloTCPSYN = []byte{ // IPv4 header (20 bytes, IHL=5) 0x45, 0x00, 0x00, 0x28, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, // checksum placeholder — we'll fill in below 0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1 0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4 // TCP header (20 bytes) 0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xff, 0xff, 0x00, 0x00, // checksum placeholder 0x00, 0x00, } // fillTestChecksums computes correct IP + TCP checksums for the test // packet so we can compare against the parser's recompute output. func fillTestChecksums(b []byte) { // IP checksum b[10], b[11] = 0, 0 cs := ipChecksum(b[:20]) b[10] = byte(cs >> 8) b[11] = byte(cs & 0xff) // TCP checksum b[36], b[37] = 0, 0 cs = tcpChecksum(b[:20], b[20:40]) b[36] = byte(cs >> 8) b[37] = byte(cs & 0xff) } func TestParseIPv4TCP_Roundtrip(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) p, err := ParseIPv4TCP(pkt) require.NoError(t, err) assert.Equal(t, "10.0.0.1", p.SrcIP.String()) assert.Equal(t, "1.2.3.4", p.DstIP.String()) assert.Equal(t, uint16(54321), p.SrcPort) assert.Equal(t, uint16(443), p.DstPort) } func TestRewriteDst_PreservesSrc(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) require.NoError(t, err) p, err := ParseIPv4TCP(pkt) require.NoError(t, err) assert.Equal(t, "127.0.0.1", p.DstIP.String()) assert.Equal(t, uint16(8080), p.DstPort) assert.Equal(t, "10.0.0.1", p.SrcIP.String()) assert.Equal(t, uint16(54321), p.SrcPort) } func TestRewriteDst_RecomputesChecksums(t *testing.T) { pkt := make([]byte, len(helloTCPSYN)) copy(pkt, helloTCPSYN) fillTestChecksums(pkt) err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) require.NoError(t, err) // Validate IP checksum ipCs := uint16(pkt[10])<<8 | uint16(pkt[11]) pkt[10], pkt[11] = 0, 0 expIP := ipChecksum(pkt[:20]) pkt[10] = byte(ipCs >> 8) pkt[11] = byte(ipCs & 0xff) assert.Equal(t, expIP, ipCs, "IP checksum mismatch") // Validate TCP checksum tcpCs := uint16(pkt[36])<<8 | uint16(pkt[37]) pkt[36], pkt[37] = 0, 0 expTCP := tcpChecksum(pkt[:20], pkt[20:]) pkt[36] = byte(tcpCs >> 8) pkt[37] = byte(tcpCs & 0xff) assert.Equal(t, expTCP, tcpCs, "TCP checksum mismatch") } func TestParseIPv4TCP_Errors(t *testing.T) { cases := []struct { name string b []byte }{ {"too_short", []byte{0x45}}, {"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, {"not_tcp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 17, /* UDP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { _, err := ParseIPv4TCP(c.b) assert.Error(t, err) }) } }