package divert import ( "encoding/binary" "errors" "net" ) // IPv4TCPInfo is what we extract from a raw IPv4+TCP packet for our // per-flow mapping table. type IPv4TCPInfo struct { SrcIP, DstIP net.IP SrcPort, DstPort uint16 } // ParseIPv4TCP reads the IPv4 + TCP header pair out of an outbound // captured packet and returns the addressing info. Does NOT mutate // the buffer. // // Errors when: // - buffer too short to contain a full IPv4+TCP header (40 bytes) // - IP version is not 4 // - IP protocol is not 6 (TCP) func ParseIPv4TCP(b []byte) (*IPv4TCPInfo, error) { if len(b) < 40 { return nil, errors.New("packet shorter than IPv4+TCP minimum") } if b[0]>>4 != 4 { return nil, errors.New("not IPv4") } ihl := int(b[0]&0x0f) * 4 if ihl < 20 || len(b) < ihl+20 { return nil, errors.New("IPv4 IHL invalid or buffer truncated") } if b[9] != 6 { return nil, errors.New("not TCP") } src := net.IPv4(b[12], b[13], b[14], b[15]) dst := net.IPv4(b[16], b[17], b[18], b[19]) srcPort := binary.BigEndian.Uint16(b[ihl : ihl+2]) dstPort := binary.BigEndian.Uint16(b[ihl+2 : ihl+4]) return &IPv4TCPInfo{ SrcIP: src, DstIP: dst, SrcPort: srcPort, DstPort: dstPort, }, nil } // RewriteDst mutates b in-place to set dst IP and port, then // recomputes both the IP header checksum and the TCP checksum. // // Returns the same errors as ParseIPv4TCP for malformed input. func RewriteDst(b []byte, ip net.IP, port uint16) error { if _, err := ParseIPv4TCP(b); err != nil { return err } v4 := ip.To4() if v4 == nil { return errors.New("dst must be IPv4") } ihl := int(b[0]&0x0f) * 4 // Set dst IP copy(b[16:20], v4) // Set dst port binary.BigEndian.PutUint16(b[ihl+2:ihl+4], port) // Recompute IP checksum (clear → compute → write big-endian) b[10], b[11] = 0, 0 cs := ipChecksum(b[:ihl]) b[10] = byte(cs >> 8) b[11] = byte(cs & 0xff) // Recompute TCP checksum (clear → compute → write) b[ihl+16], b[ihl+17] = 0, 0 cs = tcpChecksum(b[:ihl], b[ihl:]) b[ihl+16] = byte(cs >> 8) b[ihl+17] = byte(cs & 0xff) return nil } // ipChecksum is the standard 16-bit one's-complement sum over the IP // header (RFC 791). The "checksum field" must be zeroed before calling. func ipChecksum(hdr []byte) uint16 { var sum uint32 for i := 0; i+1 < len(hdr); i += 2 { sum += uint32(hdr[i])<<8 | uint32(hdr[i+1]) } if len(hdr)%2 == 1 { sum += uint32(hdr[len(hdr)-1]) << 8 } for sum>>16 != 0 { sum = (sum & 0xffff) + (sum >> 16) } return ^uint16(sum) } // tcpChecksum implements the RFC 793 pseudo-header checksum. // ipHdr must include src+dst addresses; tcpSeg is the full TCP header // + payload. The "checksum field" inside tcpSeg must be zeroed. func tcpChecksum(ipHdr, tcpSeg []byte) uint16 { var sum uint32 // Pseudo-header: src(4) dst(4) zero(1) proto(1) tcp_len(2) for i := 12; i <= 18; i += 2 { sum += uint32(ipHdr[i])<<8 | uint32(ipHdr[i+1]) } sum += uint32(6) // TCP protocol tcpLen := uint32(len(tcpSeg)) sum += tcpLen // TCP segment for i := 0; i+1 < len(tcpSeg); i += 2 { sum += uint32(tcpSeg[i])<<8 | uint32(tcpSeg[i+1]) } if len(tcpSeg)%2 == 1 { sum += uint32(tcpSeg[len(tcpSeg)-1]) << 8 } for sum>>16 != 0 { sum = (sum & 0xffff) + (sum >> 16) } return ^uint16(sum) }