Pulse/internal/api/oidc_handlers_test.go
2026-04-01 16:00:04 +01:00

812 lines
16 KiB
Go

package api
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func TestSanitizeOIDCReturnTo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
// Valid paths
{
name: "simple path",
raw: "/dashboard",
want: "/dashboard",
},
{
name: "root path",
raw: "/",
want: "/",
},
{
name: "nested path",
raw: "/settings/alerts",
want: "/settings/alerts",
},
{
name: "path with query params",
raw: "/page?foo=bar",
want: "/page?foo=bar",
},
{
name: "path with fragment",
raw: "/page#section",
want: "/page#section",
},
// Invalid - empty or whitespace
{
name: "empty string",
raw: "",
want: "",
},
{
name: "whitespace only",
raw: " ",
want: "",
},
// Invalid - doesn't start with /
{
name: "no leading slash",
raw: "dashboard",
want: "",
},
{
name: "http URL",
raw: "http://evil.com",
want: "",
},
{
name: "https URL",
raw: "https://evil.com",
want: "",
},
// Invalid - protocol-relative URL (double slash)
{
name: "protocol relative URL",
raw: "//evil.com",
want: "",
},
{
name: "protocol relative with path",
raw: "//evil.com/path",
want: "",
},
{
name: "backslash relative path",
raw: "/\\evil.com/path",
want: "",
},
// Whitespace handling
{
name: "leading whitespace",
raw: " /dashboard",
want: "/dashboard",
},
{
name: "trailing whitespace",
raw: "/dashboard ",
want: "/dashboard",
},
{
name: "both whitespace",
raw: " /dashboard ",
want: "/dashboard",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := sanitizeOIDCReturnTo(tc.raw)
if result != tc.want {
t.Errorf("sanitizeOIDCReturnTo(%q) = %q, want %q", tc.raw, result, tc.want)
}
})
}
}
func TestBuildLocalRedirectTarget(t *testing.T) {
t.Parallel()
tests := []struct {
name string
returnTo string
queryParams map[string]string
want string
}{
{
name: "rejects absolute URL",
returnTo: "https://evil.example.com/pwn",
queryParams: map[string]string{"oidc": "error"},
want: "/?oidc=error",
},
{
name: "preserves query and fragment",
returnTo: "/login?foo=bar#section",
queryParams: map[string]string{"oidc": "success"},
want: "/login?foo=bar&oidc=success#section",
},
{
name: "drops empty key and value",
returnTo: "/",
queryParams: map[string]string{"": "ignored", "oidc": "", "state": "ok"},
want: "/?state=ok",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := buildLocalRedirectTarget(tc.returnTo, tc.queryParams)
if result != tc.want {
t.Fatalf("buildLocalRedirectTarget(%q, %#v) = %q, want %q", tc.returnTo, tc.queryParams, result, tc.want)
}
})
}
}
func TestRedirectOIDCErrorRejectsAbsoluteReturnTo(t *testing.T) {
t.Parallel()
router := &Router{config: &config.Config{}}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback", nil)
rec := httptest.NewRecorder()
router.redirectOIDCError(rec, req, "https://evil.example.com/pwn", "bad")
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
if loc := rec.Header().Get("Location"); loc != "/?oidc=error&oidc_error=bad" {
t.Fatalf("unexpected redirect location %q", loc)
}
}
func TestAddQueryParam(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
key string
value string
want string
}{
// Basic cases
{
name: "add to simple path",
path: "/dashboard",
key: "foo",
value: "bar",
want: "/dashboard?foo=bar",
},
{
name: "add to root",
path: "/",
key: "key",
value: "value",
want: "/?key=value",
},
// Existing query params
{
name: "add to path with existing param",
path: "/page?existing=param",
key: "new",
value: "value",
want: "/page?existing=param&new=value",
},
{
name: "replace existing param",
path: "/page?key=old",
key: "key",
value: "new",
want: "/page?key=new",
},
// Empty path
{
name: "empty path becomes root",
path: "",
key: "foo",
value: "bar",
want: "/?foo=bar",
},
// URL encoding
{
name: "value with spaces",
path: "/page",
key: "message",
value: "hello world",
want: "/page?message=hello+world",
},
{
name: "value with special chars",
path: "/page",
key: "data",
value: "a=b&c=d",
want: "/page?data=a%3Db%26c%3Dd",
},
// Fragment handling
{
name: "path with fragment",
path: "/page#section",
key: "foo",
value: "bar",
want: "/page?foo=bar#section",
},
// Invalid URL (control character causes parse error)
{
name: "path with control character returns unchanged",
path: "/page\x00invalid",
key: "foo",
value: "bar",
want: "/page\x00invalid",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := addQueryParam(tc.path, tc.key, tc.value)
if result != tc.want {
t.Errorf("addQueryParam(%q, %q, %q) = %q, want %q", tc.path, tc.key, tc.value, result, tc.want)
}
})
}
}
func TestExtractStringClaim(t *testing.T) {
t.Parallel()
tests := []struct {
name string
claims map[string]any
key string
want string
}{
// String value
{
name: "string value",
claims: map[string]any{"email": "user@example.com"},
key: "email",
want: "user@example.com",
},
{
name: "string with whitespace",
claims: map[string]any{"name": " John Doe "},
key: "name",
want: "John Doe",
},
// String slice - returns first element
{
name: "string slice returns first",
claims: map[string]any{"groups": []string{"admin", "users"}},
key: "groups",
want: "admin",
},
{
name: "empty string slice",
claims: map[string]any{"groups": []string{}},
key: "groups",
want: "",
},
{
name: "string slice with whitespace",
claims: map[string]any{"groups": []string{" admin "}},
key: "groups",
want: "admin",
},
// Interface slice - returns first string
{
name: "interface slice with strings",
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
key: "roles",
want: "admin",
},
{
name: "interface slice with mixed types",
claims: map[string]any{"data": []interface{}{123, "value", true}},
key: "data",
want: "value",
},
{
name: "interface slice with no strings",
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
key: "nums",
want: "",
},
// Missing or empty key
{
name: "key not in claims",
claims: map[string]any{"other": "value"},
key: "email",
want: "",
},
{
name: "empty key",
claims: map[string]any{"email": "user@example.com"},
key: "",
want: "",
},
{
name: "nil claims",
claims: nil,
key: "email",
want: "",
},
// Unsupported types
{
name: "integer value",
claims: map[string]any{"count": 42},
key: "count",
want: "",
},
{
name: "boolean value",
claims: map[string]any{"active": true},
key: "active",
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := extractStringClaim(tc.claims, tc.key)
if result != tc.want {
t.Errorf("extractStringClaim(%v, %q) = %q, want %q", tc.claims, tc.key, result, tc.want)
}
})
}
}
func TestExtractStringSliceClaim(t *testing.T) {
t.Parallel()
tests := []struct {
name string
claims map[string]any
key string
want []string
}{
// String slice
{
name: "string slice",
claims: map[string]any{"groups": []string{"admin", "users", "devops"}},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "empty string slice",
claims: map[string]any{"groups": []string{}},
key: "groups",
want: []string{},
},
// Interface slice
{
name: "interface slice with strings",
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
key: "roles",
want: []string{"admin", "user"},
},
{
name: "interface slice with mixed types filters non-strings",
claims: map[string]any{"data": []interface{}{"str1", 123, "str2", true}},
key: "data",
want: []string{"str1", "str2"},
},
{
name: "interface slice with no strings",
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
key: "nums",
want: []string{},
},
// String value (comma/space separated)
{
name: "comma separated string",
claims: map[string]any{"groups": "admin,users,devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "space separated string",
claims: map[string]any{"groups": "admin users devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
{
name: "mixed separator string",
claims: map[string]any{"groups": "admin, users devops"},
key: "groups",
want: []string{"admin", "users", "devops"},
},
// Missing or empty key
{
name: "key not in claims",
claims: map[string]any{"other": "value"},
key: "groups",
want: nil,
},
{
name: "empty key",
claims: map[string]any{"groups": []string{"admin"}},
key: "",
want: nil,
},
{
name: "nil claims",
claims: nil,
key: "groups",
want: nil,
},
// Unsupported types
{
name: "integer value",
claims: map[string]any{"count": 42},
key: "count",
want: nil,
},
{
name: "boolean value",
claims: map[string]any{"active": true},
key: "active",
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := extractStringSliceClaim(tc.claims, tc.key)
// Compare slices
if tc.want == nil {
if result != nil {
t.Errorf("extractStringSliceClaim(%v, %q) = %v, want nil", tc.claims, tc.key, result)
}
return
}
if len(result) != len(tc.want) {
t.Errorf("extractStringSliceClaim(%v, %q) = %v (len %d), want %v (len %d)",
tc.claims, tc.key, result, len(result), tc.want, len(tc.want))
return
}
for i, v := range result {
if v != tc.want[i] {
t.Errorf("extractStringSliceClaim(%v, %q)[%d] = %q, want %q",
tc.claims, tc.key, i, v, tc.want[i])
}
}
})
}
}
func TestMatchesValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
candidate string
allowed []string
want bool
}{
// Matches
{
name: "exact match",
candidate: "admin",
allowed: []string{"admin", "user"},
want: true,
},
{
name: "case insensitive match",
candidate: "ADMIN",
allowed: []string{"admin", "user"},
want: true,
},
{
name: "candidate with whitespace",
candidate: " admin ",
allowed: []string{"admin"},
want: true,
},
{
name: "allowed with whitespace",
candidate: "admin",
allowed: []string{" admin "},
want: true,
},
{
name: "mixed case both sides",
candidate: "AdMiN",
allowed: []string{"aDmIn"},
want: true,
},
// No match
{
name: "no match",
candidate: "guest",
allowed: []string{"admin", "user"},
want: false,
},
{
name: "empty candidate",
candidate: "",
allowed: []string{"admin"},
want: false,
},
{
name: "whitespace candidate",
candidate: " ",
allowed: []string{"admin"},
want: false,
},
{
name: "empty allowed list",
candidate: "admin",
allowed: []string{},
want: false,
},
{
name: "nil allowed list",
candidate: "admin",
allowed: nil,
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := matchesValue(tc.candidate, tc.allowed)
if result != tc.want {
t.Errorf("matchesValue(%q, %v) = %v, want %v", tc.candidate, tc.allowed, result, tc.want)
}
})
}
}
func TestMatchesDomain(t *testing.T) {
t.Parallel()
tests := []struct {
name string
email string
allowed []string
want bool
}{
// Matches
{
name: "exact domain match",
email: "user@example.com",
allowed: []string{"example.com"},
want: true,
},
{
name: "domain with @ prefix in allowed",
email: "user@example.com",
allowed: []string{"@example.com"},
want: true,
},
{
name: "case insensitive email",
email: "USER@EXAMPLE.COM",
allowed: []string{"example.com"},
want: true,
},
{
name: "case insensitive allowed",
email: "user@example.com",
allowed: []string{"EXAMPLE.COM"},
want: true,
},
{
name: "multiple allowed domains",
email: "user@company.org",
allowed: []string{"example.com", "company.org", "test.net"},
want: true,
},
{
name: "email with whitespace",
email: " user@example.com ",
allowed: []string{"example.com"},
want: true,
},
{
name: "allowed with whitespace",
email: "user@example.com",
allowed: []string{" example.com "},
want: true,
},
// No match
{
name: "different domain",
email: "user@other.com",
allowed: []string{"example.com"},
want: false,
},
{
name: "subdomain not matched",
email: "user@sub.example.com",
allowed: []string{"example.com"},
want: false,
},
// Invalid emails
{
name: "empty email",
email: "",
allowed: []string{"example.com"},
want: false,
},
{
name: "whitespace email",
email: " ",
allowed: []string{"example.com"},
want: false,
},
{
name: "no @ in email",
email: "userexample.com",
allowed: []string{"example.com"},
want: false,
},
{
name: "@ at end",
email: "user@",
allowed: []string{"example.com"},
want: false,
},
// Empty allowed
{
name: "empty allowed list",
email: "user@example.com",
allowed: []string{},
want: false,
},
{
name: "nil allowed list",
email: "user@example.com",
allowed: nil,
want: false,
},
{
name: "allowed with empty strings",
email: "user@example.com",
allowed: []string{"", " ", "@"},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := matchesDomain(tc.email, tc.allowed)
if result != tc.want {
t.Errorf("matchesDomain(%q, %v) = %v, want %v", tc.email, tc.allowed, result, tc.want)
}
})
}
}
func TestIntersects(t *testing.T) {
t.Parallel()
tests := []struct {
name string
values []string
allowed []string
want bool
}{
// Intersects
{
name: "single common element",
values: []string{"admin", "user"},
allowed: []string{"admin", "guest"},
want: true,
},
{
name: "multiple common elements",
values: []string{"admin", "user", "devops"},
allowed: []string{"admin", "user", "guest"},
want: true,
},
{
name: "case insensitive",
values: []string{"ADMIN"},
allowed: []string{"admin"},
want: true,
},
{
name: "with whitespace",
values: []string{" admin "},
allowed: []string{"admin"},
want: true,
},
// No intersection
{
name: "no common elements",
values: []string{"admin", "user"},
allowed: []string{"guest", "viewer"},
want: false,
},
{
name: "empty values",
values: []string{},
allowed: []string{"admin"},
want: false,
},
{
name: "nil values",
values: nil,
allowed: []string{"admin"},
want: false,
},
{
name: "empty allowed",
values: []string{"admin"},
allowed: []string{},
want: false,
},
{
name: "nil allowed",
values: []string{"admin"},
allowed: nil,
want: false,
},
{
name: "both empty",
values: []string{},
allowed: []string{},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := intersects(tc.values, tc.allowed)
if result != tc.want {
t.Errorf("intersects(%v, %v) = %v, want %v", tc.values, tc.allowed, result, tc.want)
}
})
}
}