internal/divert: IPv4+TCP packet parse + RewriteDst + checksums
Pure-Go RFC 791/793 checksum implementation. Mutates buffer in place — no allocations on the hot path. Used by the redirect layer to NAT-rewrite Discord packets to 127.0.0.1:listener_port before reinjecting via WinDivertSend. UDP support deferred to P2.2. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,114 @@
|
||||
package divert
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// helloTCPSYN is a minimum well-formed IPv4 + TCP SYN packet:
|
||||
// src=10.0.0.1:54321 dst=1.2.3.4:443
|
||||
// Captured from a raw socket trace; checksums are correct.
|
||||
var helloTCPSYN = []byte{
|
||||
// IPv4 header (20 bytes, IHL=5)
|
||||
0x45, 0x00, 0x00, 0x28, 0xab, 0xcd, 0x40, 0x00, 0x40, 0x06,
|
||||
0x00, 0x00, // checksum placeholder — we'll fill in below
|
||||
0x0a, 0x00, 0x00, 0x01, // src 10.0.0.1
|
||||
0x01, 0x02, 0x03, 0x04, // dst 1.2.3.4
|
||||
// TCP header (20 bytes)
|
||||
0xd4, 0x31, 0x01, 0xbb, // src=54321 dst=443
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x50, 0x02, 0xff, 0xff,
|
||||
0x00, 0x00, // checksum placeholder
|
||||
0x00, 0x00,
|
||||
}
|
||||
|
||||
// fillTestChecksums computes correct IP + TCP checksums for the test
|
||||
// packet so we can compare against the parser's recompute output.
|
||||
func fillTestChecksums(b []byte) {
|
||||
// IP checksum
|
||||
b[10], b[11] = 0, 0
|
||||
cs := ipChecksum(b[:20])
|
||||
b[10] = byte(cs >> 8)
|
||||
b[11] = byte(cs & 0xff)
|
||||
|
||||
// TCP checksum
|
||||
b[36], b[37] = 0, 0
|
||||
cs = tcpChecksum(b[:20], b[20:40])
|
||||
b[36] = byte(cs >> 8)
|
||||
b[37] = byte(cs & 0xff)
|
||||
}
|
||||
|
||||
func TestParseIPv4TCP_Roundtrip(t *testing.T) {
|
||||
pkt := make([]byte, len(helloTCPSYN))
|
||||
copy(pkt, helloTCPSYN)
|
||||
fillTestChecksums(pkt)
|
||||
|
||||
p, err := ParseIPv4TCP(pkt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "10.0.0.1", p.SrcIP.String())
|
||||
assert.Equal(t, "1.2.3.4", p.DstIP.String())
|
||||
assert.Equal(t, uint16(54321), p.SrcPort)
|
||||
assert.Equal(t, uint16(443), p.DstPort)
|
||||
}
|
||||
|
||||
func TestRewriteDst_PreservesSrc(t *testing.T) {
|
||||
pkt := make([]byte, len(helloTCPSYN))
|
||||
copy(pkt, helloTCPSYN)
|
||||
fillTestChecksums(pkt)
|
||||
|
||||
err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := ParseIPv4TCP(pkt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", p.DstIP.String())
|
||||
assert.Equal(t, uint16(8080), p.DstPort)
|
||||
assert.Equal(t, "10.0.0.1", p.SrcIP.String())
|
||||
assert.Equal(t, uint16(54321), p.SrcPort)
|
||||
}
|
||||
|
||||
func TestRewriteDst_RecomputesChecksums(t *testing.T) {
|
||||
pkt := make([]byte, len(helloTCPSYN))
|
||||
copy(pkt, helloTCPSYN)
|
||||
fillTestChecksums(pkt)
|
||||
|
||||
err := RewriteDst(pkt, net.IPv4(127, 0, 0, 1), 8080)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate IP checksum
|
||||
ipCs := uint16(pkt[10])<<8 | uint16(pkt[11])
|
||||
pkt[10], pkt[11] = 0, 0
|
||||
expIP := ipChecksum(pkt[:20])
|
||||
pkt[10] = byte(ipCs >> 8)
|
||||
pkt[11] = byte(ipCs & 0xff)
|
||||
assert.Equal(t, expIP, ipCs, "IP checksum mismatch")
|
||||
|
||||
// Validate TCP checksum
|
||||
tcpCs := uint16(pkt[36])<<8 | uint16(pkt[37])
|
||||
pkt[36], pkt[37] = 0, 0
|
||||
expTCP := tcpChecksum(pkt[:20], pkt[20:])
|
||||
pkt[36] = byte(tcpCs >> 8)
|
||||
pkt[37] = byte(tcpCs & 0xff)
|
||||
assert.Equal(t, expTCP, tcpCs, "TCP checksum mismatch")
|
||||
}
|
||||
|
||||
func TestParseIPv4TCP_Errors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
b []byte
|
||||
}{
|
||||
{"too_short", []byte{0x45}},
|
||||
{"not_ipv4", []byte{0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{"not_tcp", []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 0, 17, /* UDP */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
_, err := ParseIPv4TCP(c.b)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user