package main import ( "context" "net" "strings" "testing" "time" ) func TestSanitizeCorrelationID(t *testing.T) { valid := sanitizeCorrelationID("550e8400-e29b-41d4-a716-446655440000") if valid != "550e8400-e29b-41d4-a716-446655440000" { t.Fatalf("expected valid UUID to pass through, got %s", valid) } invalid := sanitizeCorrelationID("not-a-uuid") if invalid == "not-a-uuid" { t.Fatalf("expected invalid UUID to be replaced") } empty := sanitizeCorrelationID("") if empty == "" { t.Fatalf("expected empty string to be replaced") } if invalid == empty { t.Fatalf("expected regenerated UUIDs to differ") } } func TestValidateNodeName(t *testing.T) { cases := []struct { name string wantErr bool desc string }{ {name: "node-1", wantErr: false, desc: "alphanumeric"}, {name: "example.com", wantErr: false, desc: "dns hostname"}, {name: "1.2.3.4", wantErr: false, desc: "ipv4"}, {name: "2001:db8::1", wantErr: false, desc: "ipv6 compressed"}, {name: "[2001:db8::10]", wantErr: false, desc: "ipv6 bracketed"}, {name: "::1", wantErr: false, desc: "ipv6 loopback"}, {name: "::", wantErr: false, desc: "ipv6 unspecified"}, {name: "::ffff:192.0.2.1", wantErr: false, desc: "ipv4-mapped ipv6 dual stack"}, {name: "[::1]", wantErr: false, desc: "ipv6 loopback bracketed"}, {name: "fe80::1%eth0", wantErr: true, desc: "ipv6 zone identifier"}, {name: "[fe80::1%eth0]", wantErr: true, desc: "ipv6 zone identifier bracketed"}, {name: "[2001:db8::1]:22", wantErr: true, desc: "ipv6 with port suffix"}, {name: "[2001:db8::1", wantErr: true, desc: "missing closing bracket"}, {name: "2001:db8::1]", wantErr: true, desc: "missing opening bracket"}, {name: "bad host", wantErr: true, desc: "whitespace disallowed"}, {name: "-leadinghyphen", wantErr: true, desc: "leading hyphen disallowed"}, {name: "example.com:22", wantErr: true, desc: "dns name with port"}, {name: "", wantErr: true, desc: "empty string"}, {name: "example_com", wantErr: false, desc: "underscore"}, {name: "NODE123", wantErr: false, desc: "uppercase"}, {name: strings.Repeat("a", 64), wantErr: false, desc: "64 chars"}, {name: strings.Repeat("a", 65), wantErr: true, desc: "65 chars"}, {name: "senso\u200Brs", wantErr: true, desc: "zero-width space"}, {name: "node\\name", wantErr: true, desc: "backslash"}, {name: "/etc/passwd", wantErr: true, desc: "absolute path"}, {name: "node\x00", wantErr: true, desc: "null byte"}, {name: "example.com;rm", wantErr: true, desc: "semicolon"}, {name: "node$(rm)", wantErr: true, desc: "subshell"}, } for _, tc := range cases { tc := tc name := tc.desc if name == "" { name = tc.name } t.Run(name, func(t *testing.T) { err := validateNodeName(tc.name) if tc.wantErr && err == nil { t.Fatalf("expected error validating %q", tc.name) } if !tc.wantErr && err != nil { t.Fatalf("unexpected error for %q: %v", tc.name, err) } }) } } type stubResolver struct { ips []net.IP err error } func (s stubResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) { if s.err != nil { return nil, s.err } return s.ips, nil } func TestNodeValidatorAllowlistHost(t *testing.T) { v := &nodeValidator{ allowHosts: map[string]struct{}{"node-1": {}}, hasAllowlist: true, resolver: stubResolver{}, } if err := v.Validate(context.Background(), "node-1"); err != nil { t.Fatalf("expected node-1 to be permitted, got error: %v", err) } if err := v.Validate(context.Background(), "node-2"); err == nil { t.Fatalf("expected node-2 to be rejected without allow-list entry") } } func TestNodeValidatorAllowlistCIDRWithLookup(t *testing.T) { _, network, _ := net.ParseCIDR("10.0.0.0/24") v := &nodeValidator{ allowHosts: make(map[string]struct{}), allowCIDRs: []*net.IPNet{network}, hasAllowlist: true, resolver: stubResolver{ ips: []net.IP{net.ParseIP("10.0.0.5")}, }, } if err := v.Validate(context.Background(), "worker.local"); err != nil { t.Fatalf("expected worker.local to resolve into allowed CIDR: %v", err) } } func TestNodeValidatorClusterCaching(t *testing.T) { current := time.Now() fetches := 0 v := &nodeValidator{ clusterEnabled: true, clusterFetcher: func() ([]string, error) { fetches++ return []string{"10.0.0.9"}, nil }, cacheTTL: nodeValidatorCacheTTL, clock: func() time.Time { return current }, } if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { t.Fatalf("expected node to be allowed via cluster membership: %v", err) } if fetches != 1 { t.Fatalf("expected initial cluster fetch, got %d", fetches) } current = current.Add(30 * time.Second) if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { t.Fatalf("expected cached cluster membership to allow node: %v", err) } if fetches != 1 { t.Fatalf("expected cache hit to avoid new fetch, got %d fetches", fetches) } current = current.Add(nodeValidatorCacheTTL + time.Second) if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { t.Fatalf("expected refreshed cluster membership to allow node: %v", err) } if fetches != 2 { t.Fatalf("expected cache expiry to trigger new fetch, got %d", fetches) } } func TestNodeValidatorClusterResolvesHostIPs(t *testing.T) { v := &nodeValidator{ clusterEnabled: true, clusterFetcher: func() ([]string, error) { return []string{"worker.local"}, nil }, resolver: stubResolver{ ips: []net.IP{net.ParseIP("10.0.0.5")}, }, } if err := v.Validate(context.Background(), "10.0.0.5"); err != nil { t.Fatalf("expected cluster hostname resolution to permit node: %v", err) } } func TestNodeValidatorStrictNoSources(t *testing.T) { v := &nodeValidator{ strict: true, } if err := v.Validate(context.Background(), "node-1"); err == nil { t.Fatalf("expected strict mode without sources to reject nodes") } } func TestStripNodeDelimiters(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "ipv6 with brackets", input: "[2001:db8::1]", expected: "2001:db8::1", }, { name: "ipv6 loopback bracketed", input: "[::1]", expected: "::1", }, { name: "ipv4 no brackets", input: "192.168.1.1", expected: "192.168.1.1", }, { name: "hostname no brackets", input: "node-1.example.com", expected: "node-1.example.com", }, { name: "empty string", input: "", expected: "", }, { name: "only opening bracket", input: "[2001:db8::1", expected: "[2001:db8::1", }, { name: "only closing bracket", input: "2001:db8::1]", expected: "2001:db8::1]", }, { name: "brackets with single char", input: "[a]", expected: "a", }, { name: "empty brackets", input: "[]", expected: "[]", }, { name: "single bracket char", input: "[", expected: "[", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result := stripNodeDelimiters(tc.input) if result != tc.expected { t.Errorf("stripNodeDelimiters(%q) = %q, want %q", tc.input, result, tc.expected) } }) } } func TestParseNodeIP(t *testing.T) { tests := []struct { name string input string expectNil bool expectedIP string }{ { name: "ipv4", input: "192.168.1.1", expectNil: false, expectedIP: "192.168.1.1", }, { name: "ipv4 with whitespace", input: " 10.0.0.1 ", expectNil: false, expectedIP: "10.0.0.1", }, { name: "ipv6", input: "2001:db8::1", expectNil: false, expectedIP: "2001:db8::1", }, { name: "ipv6 bracketed", input: "[2001:db8::1]", expectNil: false, expectedIP: "2001:db8::1", }, { name: "ipv6 loopback", input: "::1", expectNil: false, expectedIP: "::1", }, { name: "ipv6 loopback bracketed", input: "[::1]", expectNil: false, expectedIP: "::1", }, { name: "hostname", input: "node-1.example.com", expectNil: true, }, { name: "empty string", input: "", expectNil: true, }, { name: "invalid ip", input: "999.999.999.999", expectNil: true, }, { name: "partial ipv4", input: "192.168.1", expectNil: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result := parseNodeIP(tc.input) if tc.expectNil { if result != nil { t.Errorf("parseNodeIP(%q) = %v, want nil", tc.input, result) } } else { if result == nil { t.Errorf("parseNodeIP(%q) = nil, want %s", tc.input, tc.expectedIP) } else if result.String() != tc.expectedIP { t.Errorf("parseNodeIP(%q) = %v, want %s", tc.input, result, tc.expectedIP) } } }) } } func TestNormalizeAllowlistEntry(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "ipv4", input: "192.168.1.1", expected: "192.168.1.1", }, { name: "ipv4 with whitespace", input: " 10.0.0.1 ", expected: "10.0.0.1", }, { name: "ipv6", input: "2001:db8::1", expected: "2001:db8::1", }, { name: "ipv6 bracketed", input: "[2001:db8::1]", expected: "2001:db8::1", }, { name: "hostname lowercase", input: "node-1.example.com", expected: "node-1.example.com", }, { name: "hostname uppercase normalized", input: "NODE-1.EXAMPLE.COM", expected: "node-1.example.com", }, { name: "hostname mixed case", input: "Node-1.Example.Com", expected: "node-1.example.com", }, { name: "empty string", input: "", expected: "", }, { name: "whitespace only", input: " ", expected: "", }, { name: "ipv6 loopback", input: "::1", expected: "::1", }, { name: "ipv6 full form normalized", input: "2001:0db8:0000:0000:0000:0000:0000:0001", expected: "2001:db8::1", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result := normalizeAllowlistEntry(tc.input) if result != tc.expected { t.Errorf("normalizeAllowlistEntry(%q) = %q, want %q", tc.input, result, tc.expected) } }) } } func TestIPAllowed(t *testing.T) { // Setup test CIDRs _, cidr10, _ := net.ParseCIDR("10.0.0.0/8") _, cidr192, _ := net.ParseCIDR("192.168.1.0/24") _, cidr172, _ := net.ParseCIDR("172.16.0.0/12") tests := []struct { name string ip net.IP hosts map[string]struct{} cidrs []*net.IPNet expected bool }{ { name: "nil ip", ip: nil, hosts: map[string]struct{}{"10.0.0.1": {}}, cidrs: nil, expected: false, }, { name: "ip in hosts map", ip: net.ParseIP("192.168.1.100"), hosts: map[string]struct{}{"192.168.1.100": {}}, cidrs: nil, expected: true, }, { name: "ip not in hosts map", ip: net.ParseIP("192.168.1.100"), hosts: map[string]struct{}{"192.168.1.200": {}}, cidrs: nil, expected: false, }, { name: "ip in cidr range", ip: net.ParseIP("10.1.2.3"), hosts: nil, cidrs: []*net.IPNet{cidr10}, expected: true, }, { name: "ip not in cidr range", ip: net.ParseIP("11.0.0.1"), hosts: nil, cidrs: []*net.IPNet{cidr10}, expected: false, }, { name: "ip in second cidr", ip: net.ParseIP("192.168.1.50"), hosts: nil, cidrs: []*net.IPNet{cidr10, cidr192}, expected: true, }, { name: "ip matches hosts not cidrs", ip: net.ParseIP("8.8.8.8"), hosts: map[string]struct{}{"8.8.8.8": {}}, cidrs: []*net.IPNet{cidr10}, expected: true, }, { name: "ip matches cidrs not hosts", ip: net.ParseIP("172.20.0.1"), hosts: map[string]struct{}{"8.8.8.8": {}}, cidrs: []*net.IPNet{cidr172}, expected: true, }, { name: "empty hosts and cidrs", ip: net.ParseIP("192.168.1.1"), hosts: nil, cidrs: nil, expected: false, }, { name: "ipv6 in hosts", ip: net.ParseIP("2001:db8::1"), hosts: map[string]struct{}{"2001:db8::1": {}}, cidrs: nil, expected: true, }, { name: "nil hosts map with ip match in cidr", ip: net.ParseIP("10.255.255.255"), hosts: nil, cidrs: []*net.IPNet{cidr10}, expected: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result := ipAllowed(tc.ip, tc.hosts, tc.cidrs) if result != tc.expected { t.Errorf("ipAllowed(%v, hosts, cidrs) = %v, want %v", tc.ip, result, tc.expected) } }) } }