Pulse/internal/api/oidc_handlers_test.go

871 lines
18 KiB
Go

package api
import (
"crypto/tls"
"net/http"
"testing"
)
func TestSanitizeOIDCReturnTo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
// Valid paths
{
name: "simple path",
raw: "/dashboard",
want: "/dashboard",
},
{
name: "root path",
raw: "/",
want: "/",
},
{
name: "nested path",
raw: "/settings/alerts",
want: "/settings/alerts",
},
{
name: "path with query params",
raw: "/page?foo=bar",
want: "/page?foo=bar",
},
{
name: "path with fragment",
raw: "/page#section",
want: "/page#section",
},
// Invalid - empty or whitespace
{
name: "empty string",
raw: "",
want: "",
},
{
name: "whitespace only",
raw: " ",
want: "",
},
// Invalid - doesn't start with /
{
name: "no leading slash",
raw: "dashboard",
want: "",
},
{
name: "http URL",
raw: "http://evil.com",
want: "",
},
{
name: "https URL",
raw: "https://evil.com",
want: "",
},
// Invalid - protocol-relative URL (double slash)
{
name: "protocol relative URL",
raw: "//evil.com",
want: "",
},
{
name: "protocol relative with path",
raw: "//evil.com/path",
want: "",
},
// Whitespace handling
{
name: "leading whitespace",
raw: " /dashboard",
want: "/dashboard",
},
{
name: "trailing whitespace",
raw: "/dashboard ",
want: "/dashboard",
},
{
name: "both whitespace",
raw: " /dashboard ",
want: "/dashboard",
},
{
name: "invalid slash backslash",
raw: "/\\evil",
want: "",
},
{
name: "invalid triple slash",
raw: "///evil",
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := sanitizeOIDCReturnTo(tc.raw)
if result != tc.want {
t.Errorf("sanitizeOIDCReturnTo(%q) = %q, want %q", tc.raw, result, tc.want)
}
})
}
}
func TestAddQueryParam(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
key string
value string
want string
}{
// Basic cases
{
name: "add to simple path",
path: "/dashboard",
key: "foo",
value: "bar",
want: "/dashboard?foo=bar",
},
{
name: "add to root",
path: "/",
key: "key",
value: "value",
want: "/?key=value",
},
// Existing query params
{
name: "add to path with existing param",
path: "/page?existing=param",
key: "new",
value: "value",
want: "/page?existing=param&new=value",
},
{
name: "replace existing param",
path: "/page?key=old",
key: "key",
value: "new",
want: "/page?key=new",
},
// Empty path
{
name: "empty path becomes root",
path: "",
key: "foo",
value: "bar",
want: "/?foo=bar",
},
// URL encoding
{
name: "value with spaces",
path: "/page",
key: "message",
value: "hello world",
want: "/page?message=hello+world",
},
{
name: "value with special chars",
path: "/page",
key: "data",
value: "a=b&c=d",
want: "/page?data=a%3Db%26c%3Dd",
},
// Fragment handling
{
name: "path with fragment",
path: "/page#section",
key: "foo",
value: "bar",
want: "/page?foo=bar#section",
},
// Invalid URL (control character causes parse error)
{
name: "path with control character returns unchanged",
path: "/page\x00invalid",
key: "foo",
value: "bar",
want: "/page\x00invalid",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := addQueryParam(tc.path, tc.key, tc.value)
if result != tc.want {
t.Errorf("addQueryParam(%q, %q, %q) = %q, want %q", tc.path, tc.key, tc.value, result, tc.want)
}
})
}
}
func TestExtractStringClaim(t *testing.T) {
t.Parallel()
tests := []struct {
name string
claims map[string]any
key string
want string
}{
// String value
{
name: "string value",
claims: map[string]any{"email": "user@example.com"},
key: "email",
want: "user@example.com",
},
{
name: "string with whitespace",
claims: map[string]any{"name": " John Doe "},
key: "name",
want: "John Doe",
},
// String slice - returns first element
{
name: "string slice returns first",
claims: map[string]any{"groups": []string{"admin", "users"}},
key: "groups",
want: "admin",
},
{
name: "empty string slice",
claims: map[string]any{"groups": []string{}},
key: "groups",
want: "",
},
{
name: "string slice with whitespace",
claims: map[string]any{"groups": []string{" admin "}},
key: "groups",
want: "admin",
},
// Interface slice - returns first string
{
name: "interface slice with strings",
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
key: "roles",
want: "admin",
},
{
name: "interface slice with mixed types",
claims: map[string]any{"data": []interface{}{123, "value", true}},
key: "data",
want: "value",
},
{
name: "interface slice with no strings",
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
key: "nums",
want: "",
},
// Missing or empty key
{
name: "key not in claims",
claims: map[string]any{"other": "value"},
key: "email",
want: "",
},
{
name: "empty key",
claims: map[string]any{"email": "user@example.com"},
key: "",
want: "",
},
{
name: "nil claims",
claims: nil,
key: "email",
want: "",
},
// Unsupported types
{
name: "integer value",
claims: map[string]any{"count": 42},
key: "count",
want: "",
},
{
name: "boolean value",
claims: map[string]any{"active": true},
key: "active",
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := extractStringClaim(tc.claims, tc.key)
if result != tc.want {
t.Errorf("extractStringClaim(%v, %q) = %q, want %q", tc.claims, tc.key, result, tc.want)
}
})
}
}
func TestExtractStringSliceClaim(t *testing.T) {
t.Parallel()
tests := []struct {
name string
claims map[string]any
key string
want []string
}{
// String slice
{
name: "string slice",
claims: map[string]any{"groups": []string{"admin", "users", "devops"}},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "empty string slice",
claims: map[string]any{"groups": []string{}},
key: "groups",
want: []string{},
},
// Interface slice
{
name: "interface slice with strings",
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
key: "roles",
want: []string{"admin", "user"},
},
{
name: "interface slice with mixed types filters non-strings",
claims: map[string]any{"data": []interface{}{"str1", 123, "str2", true}},
key: "data",
want: []string{"str1", "str2"},
},
{
name: "interface slice with no strings",
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
key: "nums",
want: []string{},
},
// String value (comma/space separated)
{
name: "comma separated string",
claims: map[string]any{"groups": "admin,users,devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "space separated string",
claims: map[string]any{"groups": "admin users devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "mixed separator string",
claims: map[string]any{"groups": "admin, users devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
// Missing or empty key
{
name: "key not in claims",
claims: map[string]any{"other": "value"},
key: "groups",
want: nil,
},
{
name: "empty key",
claims: map[string]any{"groups": []string{"admin"}},
key: "",
want: nil,
},
{
name: "nil claims",
claims: nil,
key: "groups",
want: nil,
},
// Unsupported types
{
name: "integer value",
claims: map[string]any{"count": 42},
key: "count",
want: nil,
},
{
name: "boolean value",
claims: map[string]any{"active": true},
key: "active",
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := extractStringSliceClaim(tc.claims, tc.key)
// Compare slices
if tc.want == nil {
if result != nil {
t.Errorf("extractStringSliceClaim(%v, %q) = %v, want nil", tc.claims, tc.key, result)
}
return
}
if len(result) != len(tc.want) {
t.Errorf("extractStringSliceClaim(%v, %q) = %v (len %d), want %v (len %d)",
tc.claims, tc.key, result, len(result), tc.want, len(tc.want))
return
}
for i, v := range result {
if v != tc.want[i] {
t.Errorf("extractStringSliceClaim(%v, %q)[%d] = %q, want %q",
tc.claims, tc.key, i, v, tc.want[i])
}
}
})
}
}
func TestMatchesValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
candidate string
allowed []string
want bool
}{
// Matches
{
name: "exact match",
candidate: "admin",
allowed: []string{"admin", "user"},
want: true,
},
{
name: "case insensitive match",
candidate: "ADMIN",
allowed: []string{"admin", "user"},
want: true,
},
{
name: "candidate with whitespace",
candidate: " admin ",
allowed: []string{"admin"},
want: true,
},
{
name: "allowed with whitespace",
candidate: "admin",
allowed: []string{" admin "},
want: true,
},
{
name: "mixed case both sides",
candidate: "AdMiN",
allowed: []string{"aDmIn"},
want: true,
},
// No match
{
name: "no match",
candidate: "guest",
allowed: []string{"admin", "user"},
want: false,
},
{
name: "empty candidate",
candidate: "",
allowed: []string{"admin"},
want: false,
},
{
name: "whitespace candidate",
candidate: " ",
allowed: []string{"admin"},
want: false,
},
{
name: "empty allowed list",
candidate: "admin",
allowed: []string{},
want: false,
},
{
name: "nil allowed list",
candidate: "admin",
allowed: nil,
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := matchesValue(tc.candidate, tc.allowed)
if result != tc.want {
t.Errorf("matchesValue(%q, %v) = %v, want %v", tc.candidate, tc.allowed, result, tc.want)
}
})
}
}
func TestMatchesDomain(t *testing.T) {
t.Parallel()
tests := []struct {
name string
email string
allowed []string
want bool
}{
// Matches
{
name: "exact domain match",
email: "user@example.com",
allowed: []string{"example.com"},
want: true,
},
{
name: "domain with @ prefix in allowed",
email: "user@example.com",
allowed: []string{"@example.com"},
want: true,
},
{
name: "case insensitive email",
email: "USER@EXAMPLE.COM",
allowed: []string{"example.com"},
want: true,
},
{
name: "case insensitive allowed",
email: "user@example.com",
allowed: []string{"EXAMPLE.COM"},
want: true,
},
{
name: "multiple allowed domains",
email: "user@company.org",
allowed: []string{"example.com", "company.org", "test.net"},
want: true,
},
{
name: "email with whitespace",
email: " user@example.com ",
allowed: []string{"example.com"},
want: true,
},
{
name: "allowed with whitespace",
email: "user@example.com",
allowed: []string{" example.com "},
want: true,
},
// No match
{
name: "different domain",
email: "user@other.com",
allowed: []string{"example.com"},
want: false,
},
{
name: "subdomain not matched",
email: "user@sub.example.com",
allowed: []string{"example.com"},
want: false,
},
// Invalid emails
{
name: "empty email",
email: "",
allowed: []string{"example.com"},
want: false,
},
{
name: "whitespace email",
email: " ",
allowed: []string{"example.com"},
want: false,
},
{
name: "no @ in email",
email: "userexample.com",
allowed: []string{"example.com"},
want: false,
},
{
name: "@ at end",
email: "user@",
allowed: []string{"example.com"},
want: false,
},
// Empty allowed
{
name: "empty allowed list",
email: "user@example.com",
allowed: []string{},
want: false,
},
{
name: "nil allowed list",
email: "user@example.com",
allowed: nil,
want: false,
},
{
name: "allowed with empty strings",
email: "user@example.com",
allowed: []string{"", " ", "@"},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := matchesDomain(tc.email, tc.allowed)
if result != tc.want {
t.Errorf("matchesDomain(%q, %v) = %v, want %v", tc.email, tc.allowed, result, tc.want)
}
})
}
}
func TestIntersects(t *testing.T) {
t.Parallel()
tests := []struct {
name string
values []string
allowed []string
want bool
}{
// Intersects
{
name: "single common element",
values: []string{"admin", "user"},
allowed: []string{"admin", "guest"},
want: true,
},
{
name: "multiple common elements",
values: []string{"admin", "user", "devops"},
allowed: []string{"admin", "user", "guest"},
want: true,
},
{
name: "case insensitive",
values: []string{"ADMIN"},
allowed: []string{"admin"},
want: true,
},
{
name: "with whitespace",
values: []string{" admin "},
allowed: []string{"admin"},
want: true,
},
// No intersection
{
name: "no common elements",
values: []string{"admin", "user"},
allowed: []string{"guest", "viewer"},
want: false,
},
{
name: "empty values",
values: []string{},
allowed: []string{"admin"},
want: false,
},
{
name: "nil values",
values: nil,
allowed: []string{"admin"},
want: false,
},
{
name: "empty allowed",
values: []string{"admin"},
allowed: []string{},
want: false,
},
{
name: "nil allowed",
values: []string{"admin"},
allowed: nil,
want: false,
},
{
name: "both empty",
values: []string{},
allowed: []string{},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := intersects(tc.values, tc.allowed)
if result != tc.want {
t.Errorf("intersects(%v, %v) = %v, want %v", tc.values, tc.allowed, result, tc.want)
}
})
}
}
func TestBuildRedirectURL(t *testing.T) {
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "127.0.0.1/32")
resetTrustedProxyConfig()
tests := []struct {
name string
configuredURL string
host string
tls bool
xForwardProto string
xForwardHost string
wantContains string
}{
// Configured URL takes precedence
{
name: "configured URL used",
configuredURL: "https://configured.example.com/auth/callback",
host: "request.example.com",
tls: false,
wantContains: "https://configured.example.com/auth/callback",
},
{
name: "configured URL with whitespace",
configuredURL: " https://configured.example.com/auth/callback ",
host: "request.example.com",
tls: false,
wantContains: "https://configured.example.com/auth/callback",
},
// Build from request - no proxies
{
name: "http request",
configuredURL: "",
host: "example.com:7655",
tls: false,
wantContains: "http://example.com:7655",
},
{
name: "https request with TLS",
configuredURL: "",
host: "example.com:7655",
tls: true,
wantContains: "https://example.com:7655",
},
// Build from request - with proxies
{
name: "X-Forwarded-Proto overrides scheme",
configuredURL: "",
host: "example.com",
tls: false,
xForwardProto: "https",
wantContains: "https://example.com",
},
{
name: "X-Forwarded-Host overrides host",
configuredURL: "",
host: "internal.local",
tls: false,
xForwardHost: "public.example.com",
wantContains: "http://public.example.com",
},
{
name: "both X-Forwarded headers",
configuredURL: "",
host: "internal.local",
tls: false,
xForwardProto: "https",
xForwardHost: "public.example.com",
wantContains: "https://public.example.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "http://"+tc.host+"/", nil)
req.Host = tc.host
req.RemoteAddr = "127.0.0.1:12345"
if tc.tls {
req.TLS = &tls.ConnectionState{}
}
if tc.xForwardProto != "" {
req.Header.Set("X-Forwarded-Proto", tc.xForwardProto)
}
if tc.xForwardHost != "" {
req.Header.Set("X-Forwarded-Host", tc.xForwardHost)
}
result := buildRedirectURL(req, tc.configuredURL)
if tc.wantContains != "" {
if result != tc.wantContains && !stringContains(result, tc.wantContains) {
t.Errorf("buildRedirectURL() = %q, want to contain %q", result, tc.wantContains)
}
}
})
}
}
// stringContains checks if s contains substr (helper for tests)
func stringContains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && findSubstrInStr(s, substr)))
}
func findSubstrInStr(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}