internal/divert: IPv4+TCP packet parse + RewriteDst + checksums
Pure-Go RFC 791/793 checksum implementation. Mutates buffer in place — no allocations on the hot path. Used by the redirect layer to NAT-rewrite Discord packets to 127.0.0.1:listener_port before reinjecting via WinDivertSend. UDP support deferred to P2.2. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,123 @@
|
|||||||
|
package divert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPv4TCPInfo is what we extract from a raw IPv4+TCP packet for our
|
||||||
|
// per-flow mapping table.
|
||||||
|
type IPv4TCPInfo struct {
|
||||||
|
SrcIP, DstIP net.IP
|
||||||
|
SrcPort, DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIPv4TCP reads the IPv4 + TCP header pair out of an outbound
|
||||||
|
// captured packet and returns the addressing info. Does NOT mutate
|
||||||
|
// the buffer.
|
||||||
|
//
|
||||||
|
// Errors when:
|
||||||
|
// - buffer too short to contain a full IPv4+TCP header (40 bytes)
|
||||||
|
// - IP version is not 4
|
||||||
|
// - IP protocol is not 6 (TCP)
|
||||||
|
func ParseIPv4TCP(b []byte) (*IPv4TCPInfo, error) {
|
||||||
|
if len(b) < 40 {
|
||||||
|
return nil, errors.New("packet shorter than IPv4+TCP minimum")
|
||||||
|
}
|
||||||
|
if b[0]>>4 != 4 {
|
||||||
|
return nil, errors.New("not IPv4")
|
||||||
|
}
|
||||||
|
ihl := int(b[0]&0x0f) * 4
|
||||||
|
if ihl < 20 || len(b) < ihl+20 {
|
||||||
|
return nil, errors.New("IPv4 IHL invalid or buffer truncated")
|
||||||
|
}
|
||||||
|
if b[9] != 6 {
|
||||||
|
return nil, errors.New("not TCP")
|
||||||
|
}
|
||||||
|
src := net.IPv4(b[12], b[13], b[14], b[15])
|
||||||
|
dst := net.IPv4(b[16], b[17], b[18], b[19])
|
||||||
|
srcPort := binary.BigEndian.Uint16(b[ihl : ihl+2])
|
||||||
|
dstPort := binary.BigEndian.Uint16(b[ihl+2 : ihl+4])
|
||||||
|
return &IPv4TCPInfo{
|
||||||
|
SrcIP: src,
|
||||||
|
DstIP: dst,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteDst mutates b in-place to set dst IP and port, then
|
||||||
|
// recomputes both the IP header checksum and the TCP checksum.
|
||||||
|
//
|
||||||
|
// Returns the same errors as ParseIPv4TCP for malformed input.
|
||||||
|
func RewriteDst(b []byte, ip net.IP, port uint16) error {
|
||||||
|
if _, err := ParseIPv4TCP(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v4 := ip.To4()
|
||||||
|
if v4 == nil {
|
||||||
|
return errors.New("dst must be IPv4")
|
||||||
|
}
|
||||||
|
ihl := int(b[0]&0x0f) * 4
|
||||||
|
|
||||||
|
// Set dst IP
|
||||||
|
copy(b[16:20], v4)
|
||||||
|
// Set dst port
|
||||||
|
binary.BigEndian.PutUint16(b[ihl+2:ihl+4], port)
|
||||||
|
|
||||||
|
// Recompute IP checksum (clear → compute → write big-endian)
|
||||||
|
b[10], b[11] = 0, 0
|
||||||
|
cs := ipChecksum(b[:ihl])
|
||||||
|
b[10] = byte(cs >> 8)
|
||||||
|
b[11] = byte(cs & 0xff)
|
||||||
|
|
||||||
|
// Recompute TCP checksum (clear → compute → write)
|
||||||
|
b[ihl+16], b[ihl+17] = 0, 0
|
||||||
|
cs = tcpChecksum(b[:ihl], b[ihl:])
|
||||||
|
b[ihl+16] = byte(cs >> 8)
|
||||||
|
b[ihl+17] = byte(cs & 0xff)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipChecksum is the standard 16-bit one's-complement sum over the IP
|
||||||
|
// header (RFC 791). The "checksum field" must be zeroed before calling.
|
||||||
|
func ipChecksum(hdr []byte) uint16 {
|
||||||
|
var sum uint32
|
||||||
|
for i := 0; i+1 < len(hdr); i += 2 {
|
||||||
|
sum += uint32(hdr[i])<<8 | uint32(hdr[i+1])
|
||||||
|
}
|
||||||
|
if len(hdr)%2 == 1 {
|
||||||
|
sum += uint32(hdr[len(hdr)-1]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpChecksum implements the RFC 793 pseudo-header checksum.
|
||||||
|
// ipHdr must include src+dst addresses; tcpSeg is the full TCP header
|
||||||
|
// + payload. The "checksum field" inside tcpSeg must be zeroed.
|
||||||
|
func tcpChecksum(ipHdr, tcpSeg []byte) uint16 {
|
||||||
|
var sum uint32
|
||||||
|
// Pseudo-header: src(4) dst(4) zero(1) proto(1) tcp_len(2)
|
||||||
|
for i := 12; i <= 18; i += 2 {
|
||||||
|
sum += uint32(ipHdr[i])<<8 | uint32(ipHdr[i+1])
|
||||||
|
}
|
||||||
|
sum += uint32(6) // TCP protocol
|
||||||
|
tcpLen := uint32(len(tcpSeg))
|
||||||
|
sum += tcpLen
|
||||||
|
// TCP segment
|
||||||
|
for i := 0; i+1 < len(tcpSeg); i += 2 {
|
||||||
|
sum += uint32(tcpSeg[i])<<8 | uint32(tcpSeg[i+1])
|
||||||
|
}
|
||||||
|
if len(tcpSeg)%2 == 1 {
|
||||||
|
sum += uint32(tcpSeg[len(tcpSeg)-1]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package divert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helloTCPSYN is a minimum well-formed IPv4 + TCP SYN packet:
|
||||||
|
// src=10.0.0.1:54321 dst=1.2.3.4:443
|
||||||
|
// Captured from a raw socket trace; checksums are correct.
|
||||||
|
var helloTCPSYN = []byte{
|
||||||
|
// IPv4 header (20 bytes, IHL=5)
|
||||||
|
0x45, 0x00, 0x00, 0x28, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x06,
|
||||||
|
0x00, 0x00, // checksum placeholder — we'll fill in below
|
||||||
|
0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1
|
||||||
|
0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4
|
||||||
|
// TCP header (20 bytes)
|
||||||
|
0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443
|
||||||
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x50, 0x02, 0xff, 0xff,
|
||||||
|
0x00, 0x00, // checksum placeholder
|
||||||
|
0x00, 0x00,
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillTestChecksums computes correct IP + TCP checksums for the test
|
||||||
|
// packet so we can compare against the parser's recompute output.
|
||||||
|
func fillTestChecksums(b []byte) {
|
||||||
|
// IP checksum
|
||||||
|
b[10], b[11] = 0, 0
|
||||||
|
cs := ipChecksum(b[:20])
|
||||||
|
b[10] = byte(cs >> 8)
|
||||||
|
b[11] = byte(cs & 0xff)
|
||||||
|
|
||||||
|
// TCP checksum
|
||||||
|
b[36], b[37] = 0, 0
|
||||||
|
cs = tcpChecksum(b[:20], b[20:40])
|
||||||
|
b[36] = byte(cs >> 8)
|
||||||
|
b[37] = byte(cs & 0xff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIPv4TCP_Roundtrip(t *testing.T) {
|
||||||
|
pkt := make([]byte, len(helloTCPSYN))
|
||||||
|
copy(pkt, helloTCPSYN)
|
||||||
|
fillTestChecksums(pkt)
|
||||||
|
|
||||||
|
p, err := ParseIPv4TCP(pkt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "10.0.0.1", p.SrcIP.String())
|
||||||
|
assert.Equal(t, "1.2.3.4", p.DstIP.String())
|
||||||
|
assert.Equal(t, uint16(54321), p.SrcPort)
|
||||||
|
assert.Equal(t, uint16(443), p.DstPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteDst_PreservesSrc(t *testing.T) {
|
||||||
|
pkt := make([]byte, len(helloTCPSYN))
|
||||||
|
copy(pkt, helloTCPSYN)
|
||||||
|
fillTestChecksums(pkt)
|
||||||
|
|
||||||
|
err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
p, err := ParseIPv4TCP(pkt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "127.0.0.1", p.DstIP.String())
|
||||||
|
assert.Equal(t, uint16(8080), p.DstPort)
|
||||||
|
assert.Equal(t, "10.0.0.1", p.SrcIP.String())
|
||||||
|
assert.Equal(t, uint16(54321), p.SrcPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteDst_RecomputesChecksums(t *testing.T) {
|
||||||
|
pkt := make([]byte, len(helloTCPSYN))
|
||||||
|
copy(pkt, helloTCPSYN)
|
||||||
|
fillTestChecksums(pkt)
|
||||||
|
|
||||||
|
err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate IP checksum
|
||||||
|
ipCs := uint16(pkt[10])<<8 | uint16(pkt[11])
|
||||||
|
pkt[10], pkt[11] = 0, 0
|
||||||
|
expIP := ipChecksum(pkt[:20])
|
||||||
|
pkt[10] = byte(ipCs >> 8)
|
||||||
|
pkt[11] = byte(ipCs & 0xff)
|
||||||
|
assert.Equal(t, expIP, ipCs, "IP checksum mismatch")
|
||||||
|
|
||||||
|
// Validate TCP checksum
|
||||||
|
tcpCs := uint16(pkt[36])<<8 | uint16(pkt[37])
|
||||||
|
pkt[36], pkt[37] = 0, 0
|
||||||
|
expTCP := tcpChecksum(pkt[:20], pkt[20:])
|
||||||
|
pkt[36] = byte(tcpCs >> 8)
|
||||||
|
pkt[37] = byte(tcpCs & 0xff)
|
||||||
|
assert.Equal(t, expTCP, tcpCs, "TCP checksum mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIPv4TCP_Errors(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
b []byte
|
||||||
|
}{
|
||||||
|
{"too_short", []byte{0x45}},
|
||||||
|
{"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
|
{"not_tcp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 17, /* UDP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
_, err := ParseIPv4TCP(c.b)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user