diff --git a/internal/divert/packet.go b/internal/divert/packet.go new file mode 100644 index 0000000..dddd4cb --- /dev/null +++ b/internal/divert/packet.go @@ -0,0 +1,123 @@ +package divert + +import ( + "encoding/binary" + "errors" + "net" +) + +// IPv4TCPInfo is what we extract from a raw IPv4+TCP packet for our +// per-flow mapping table. +type IPv4TCPInfo struct { + SrcIP, DstIP net.IP + SrcPort, DstPort uint16 +} + +// ParseIPv4TCP reads the IPv4 + TCP header pair out of an outbound +// captured packet and returns the addressing info. Does NOT mutate +// the buffer. +// +// Errors when: +// - buffer too short to contain a full IPv4+TCP header (40 bytes) +// - IP version is not 4 +// - IP protocol is not 6 (TCP) +func ParseIPv4TCP(b []byte) (*IPv4TCPInfo, error) { + if len(b) < 40 { + return nil, errors.New("packet shorter than IPv4+TCP minimum") + } + if b[0]>>4 != 4 { + return nil, errors.New("not IPv4") + } + ihl := int(b[0]&0x0f) * 4 + if ihl < 20 || len(b) < ihl+20 { + return nil, errors.New("IPv4 IHL invalid or buffer truncated") + } + if b[9] != 6 { + return nil, errors.New("not TCP") + } + src := net.IPv4(b[12], b[13], b[14], b[15]) + dst := net.IPv4(b[16], b[17], b[18], b[19]) + srcPort := binary.BigEndian.Uint16(b[ihl : ihl+2]) + dstPort := binary.BigEndian.Uint16(b[ihl+2 : ihl+4]) + return &IPv4TCPInfo{ + SrcIP: src, + DstIP: dst, + SrcPort: srcPort, + DstPort: dstPort, + }, nil +} + +// RewriteDst mutates b in-place to set dst IP and port, then +// recomputes both the IP header checksum and the TCP checksum. +// +// Returns the same errors as ParseIPv4TCP for malformed input. +func RewriteDst(b []byte, ip net.IP, port uint16) error { + if _, err := ParseIPv4TCP(b); err != nil { + return err + } + v4 := ip.To4() + if v4 == nil { + return errors.New("dst must be IPv4") + } + ihl := int(b[0]&0x0f) * 4 + + // Set dst IP + copy(b[16:20], v4) + // Set dst port + binary.BigEndian.PutUint16(b[ihl+2:ihl+4], port) + + // Recompute IP checksum (clear → compute → write big-endian) + b[10], b[11] = 0, 0 + cs := ipChecksum(b[:ihl]) + b[10] = byte(cs >> 8) + b[11] = byte(cs & 0xff) + + // Recompute TCP checksum (clear → compute → write) + b[ihl+16], b[ihl+17] = 0, 0 + cs = tcpChecksum(b[:ihl], b[ihl:]) + b[ihl+16] = byte(cs >> 8) + b[ihl+17] = byte(cs & 0xff) + + return nil +} + +// ipChecksum is the standard 16-bit one's-complement sum over the IP +// header (RFC 791). The "checksum field" must be zeroed before calling. +func ipChecksum(hdr []byte) uint16 { + var sum uint32 + for i := 0; i+1 < len(hdr); i += 2 { + sum += uint32(hdr[i])<<8 | uint32(hdr[i+1]) + } + if len(hdr)%2 == 1 { + sum += uint32(hdr[len(hdr)-1]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +// tcpChecksum implements the RFC 793 pseudo-header checksum. +// ipHdr must include src+dst addresses; tcpSeg is the full TCP header +// + payload. The "checksum field" inside tcpSeg must be zeroed. +func tcpChecksum(ipHdr, tcpSeg []byte) uint16 { + var sum uint32 + // Pseudo-header: src(4) dst(4) zero(1) proto(1) tcp_len(2) + for i := 12; i <= 18; i += 2 { + sum += uint32(ipHdr[i])<<8 | uint32(ipHdr[i+1]) + } + sum += uint32(6) // TCP protocol + tcpLen := uint32(len(tcpSeg)) + sum += tcpLen + // TCP segment + for i := 0; i+1 < len(tcpSeg); i += 2 { + sum += uint32(tcpSeg[i])<<8 | uint32(tcpSeg[i+1]) + } + if len(tcpSeg)%2 == 1 { + sum += uint32(tcpSeg[len(tcpSeg)-1]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} diff --git a/internal/divert/packet_test.go b/internal/divert/packet_test.go new file mode 100644 index 0000000..89773a6 --- /dev/null +++ b/internal/divert/packet_test.go @@ -0,0 +1,114 @@ +package divert + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helloTCPSYN is a minimum well-formed IPv4 + TCP SYN packet: +// src=10.0.0.1:54321 dst=1.2.3.4:443 +// Captured from a raw socket trace; checksums are correct. +var helloTCPSYN = []byte{ + // IPv4 header (20 bytes, IHL=5) + 0x45, 0x00, 0x00, 0x28, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x06, + 0x00, 0x00, // checksum placeholder — we'll fill in below + 0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1 + 0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4 + // TCP header (20 bytes) + 0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443 + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, + 0x00, 0x00, // checksum placeholder + 0x00, 0x00, +} + +// fillTestChecksums computes correct IP + TCP checksums for the test +// packet so we can compare against the parser's recompute output. +func fillTestChecksums(b []byte) { + // IP checksum + b[10], b[11] = 0, 0 + cs := ipChecksum(b[:20]) + b[10] = byte(cs >> 8) + b[11] = byte(cs & 0xff) + + // TCP checksum + b[36], b[37] = 0, 0 + cs = tcpChecksum(b[:20], b[20:40]) + b[36] = byte(cs >> 8) + b[37] = byte(cs & 0xff) +} + +func TestParseIPv4TCP_Roundtrip(t *testing.T) { + pkt := make([]byte, len(helloTCPSYN)) + copy(pkt, helloTCPSYN) + fillTestChecksums(pkt) + + p, err := ParseIPv4TCP(pkt) + require.NoError(t, err) + + assert.Equal(t, "10.0.0.1", p.SrcIP.String()) + assert.Equal(t, "1.2.3.4", p.DstIP.String()) + assert.Equal(t, uint16(54321), p.SrcPort) + assert.Equal(t, uint16(443), p.DstPort) +} + +func TestRewriteDst_PreservesSrc(t *testing.T) { + pkt := make([]byte, len(helloTCPSYN)) + copy(pkt, helloTCPSYN) + fillTestChecksums(pkt) + + err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) + require.NoError(t, err) + + p, err := ParseIPv4TCP(pkt) + require.NoError(t, err) + assert.Equal(t, "127.0.0.1", p.DstIP.String()) + assert.Equal(t, uint16(8080), p.DstPort) + assert.Equal(t, "10.0.0.1", p.SrcIP.String()) + assert.Equal(t, uint16(54321), p.SrcPort) +} + +func TestRewriteDst_RecomputesChecksums(t *testing.T) { + pkt := make([]byte, len(helloTCPSYN)) + copy(pkt, helloTCPSYN) + fillTestChecksums(pkt) + + err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080) + require.NoError(t, err) + + // Validate IP checksum + ipCs := uint16(pkt[10])<<8 | uint16(pkt[11]) + pkt[10], pkt[11] = 0, 0 + expIP := ipChecksum(pkt[:20]) + pkt[10] = byte(ipCs >> 8) + pkt[11] = byte(ipCs & 0xff) + assert.Equal(t, expIP, ipCs, "IP checksum mismatch") + + // Validate TCP checksum + tcpCs := uint16(pkt[36])<<8 | uint16(pkt[37]) + pkt[36], pkt[37] = 0, 0 + expTCP := tcpChecksum(pkt[:20], pkt[20:]) + pkt[36] = byte(tcpCs >> 8) + pkt[37] = byte(tcpCs & 0xff) + assert.Equal(t, expTCP, tcpCs, "TCP checksum mismatch") +} + +func TestParseIPv4TCP_Errors(t *testing.T) { + cases := []struct { + name string + b []byte + }{ + {"too_short", []byte{0x45}}, + {"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + {"not_tcp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 17, /* UDP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + _, err := ParseIPv4TCP(c.b) + assert.Error(t, err) + }) + } +}