mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 17:19:57 +00:00
812 lines
16 KiB
Go
812 lines
16 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
)
|
|
|
|
func TestSanitizeOIDCReturnTo(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
raw string
|
|
want string
|
|
}{
|
|
// Valid paths
|
|
{
|
|
name: "simple path",
|
|
raw: "/dashboard",
|
|
want: "/dashboard",
|
|
},
|
|
{
|
|
name: "root path",
|
|
raw: "/",
|
|
want: "/",
|
|
},
|
|
{
|
|
name: "nested path",
|
|
raw: "/settings/alerts",
|
|
want: "/settings/alerts",
|
|
},
|
|
{
|
|
name: "path with query params",
|
|
raw: "/page?foo=bar",
|
|
want: "/page?foo=bar",
|
|
},
|
|
{
|
|
name: "path with fragment",
|
|
raw: "/page#section",
|
|
want: "/page#section",
|
|
},
|
|
|
|
// Invalid - empty or whitespace
|
|
{
|
|
name: "empty string",
|
|
raw: "",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "whitespace only",
|
|
raw: " ",
|
|
want: "",
|
|
},
|
|
|
|
// Invalid - doesn't start with /
|
|
{
|
|
name: "no leading slash",
|
|
raw: "dashboard",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "http URL",
|
|
raw: "http://evil.com",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "https URL",
|
|
raw: "https://evil.com",
|
|
want: "",
|
|
},
|
|
|
|
// Invalid - protocol-relative URL (double slash)
|
|
{
|
|
name: "protocol relative URL",
|
|
raw: "//evil.com",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "protocol relative with path",
|
|
raw: "//evil.com/path",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "backslash relative path",
|
|
raw: "/\\evil.com/path",
|
|
want: "",
|
|
},
|
|
|
|
// Whitespace handling
|
|
{
|
|
name: "leading whitespace",
|
|
raw: " /dashboard",
|
|
want: "/dashboard",
|
|
},
|
|
{
|
|
name: "trailing whitespace",
|
|
raw: "/dashboard ",
|
|
want: "/dashboard",
|
|
},
|
|
{
|
|
name: "both whitespace",
|
|
raw: " /dashboard ",
|
|
want: "/dashboard",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := sanitizeOIDCReturnTo(tc.raw)
|
|
if result != tc.want {
|
|
t.Errorf("sanitizeOIDCReturnTo(%q) = %q, want %q", tc.raw, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildLocalRedirectTarget(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
returnTo string
|
|
queryParams map[string]string
|
|
want string
|
|
}{
|
|
{
|
|
name: "rejects absolute URL",
|
|
returnTo: "https://evil.example.com/pwn",
|
|
queryParams: map[string]string{"oidc": "error"},
|
|
want: "/?oidc=error",
|
|
},
|
|
{
|
|
name: "preserves query and fragment",
|
|
returnTo: "/login?foo=bar#section",
|
|
queryParams: map[string]string{"oidc": "success"},
|
|
want: "/login?foo=bar&oidc=success#section",
|
|
},
|
|
{
|
|
name: "drops empty key and value",
|
|
returnTo: "/",
|
|
queryParams: map[string]string{"": "ignored", "oidc": "", "state": "ok"},
|
|
want: "/?state=ok",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := buildLocalRedirectTarget(tc.returnTo, tc.queryParams)
|
|
if result != tc.want {
|
|
t.Fatalf("buildLocalRedirectTarget(%q, %#v) = %q, want %q", tc.returnTo, tc.queryParams, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRedirectOIDCErrorRejectsAbsoluteReturnTo(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
router := &Router{config: &config.Config{}}
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.redirectOIDCError(rec, req, "https://evil.example.com/pwn", "bad")
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
if loc := rec.Header().Get("Location"); loc != "/?oidc=error&oidc_error=bad" {
|
|
t.Fatalf("unexpected redirect location %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestAddQueryParam(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
key string
|
|
value string
|
|
want string
|
|
}{
|
|
// Basic cases
|
|
{
|
|
name: "add to simple path",
|
|
path: "/dashboard",
|
|
key: "foo",
|
|
value: "bar",
|
|
want: "/dashboard?foo=bar",
|
|
},
|
|
{
|
|
name: "add to root",
|
|
path: "/",
|
|
key: "key",
|
|
value: "value",
|
|
want: "/?key=value",
|
|
},
|
|
|
|
// Existing query params
|
|
{
|
|
name: "add to path with existing param",
|
|
path: "/page?existing=param",
|
|
key: "new",
|
|
value: "value",
|
|
want: "/page?existing=param&new=value",
|
|
},
|
|
{
|
|
name: "replace existing param",
|
|
path: "/page?key=old",
|
|
key: "key",
|
|
value: "new",
|
|
want: "/page?key=new",
|
|
},
|
|
|
|
// Empty path
|
|
{
|
|
name: "empty path becomes root",
|
|
path: "",
|
|
key: "foo",
|
|
value: "bar",
|
|
want: "/?foo=bar",
|
|
},
|
|
|
|
// URL encoding
|
|
{
|
|
name: "value with spaces",
|
|
path: "/page",
|
|
key: "message",
|
|
value: "hello world",
|
|
want: "/page?message=hello+world",
|
|
},
|
|
{
|
|
name: "value with special chars",
|
|
path: "/page",
|
|
key: "data",
|
|
value: "a=b&c=d",
|
|
want: "/page?data=a%3Db%26c%3Dd",
|
|
},
|
|
|
|
// Fragment handling
|
|
{
|
|
name: "path with fragment",
|
|
path: "/page#section",
|
|
key: "foo",
|
|
value: "bar",
|
|
want: "/page?foo=bar#section",
|
|
},
|
|
|
|
// Invalid URL (control character causes parse error)
|
|
{
|
|
name: "path with control character returns unchanged",
|
|
path: "/page\x00invalid",
|
|
key: "foo",
|
|
value: "bar",
|
|
want: "/page\x00invalid",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := addQueryParam(tc.path, tc.key, tc.value)
|
|
if result != tc.want {
|
|
t.Errorf("addQueryParam(%q, %q, %q) = %q, want %q", tc.path, tc.key, tc.value, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractStringClaim(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claims map[string]any
|
|
key string
|
|
want string
|
|
}{
|
|
// String value
|
|
{
|
|
name: "string value",
|
|
claims: map[string]any{"email": "user@example.com"},
|
|
key: "email",
|
|
want: "user@example.com",
|
|
},
|
|
{
|
|
name: "string with whitespace",
|
|
claims: map[string]any{"name": " John Doe "},
|
|
key: "name",
|
|
want: "John Doe",
|
|
},
|
|
|
|
// String slice - returns first element
|
|
{
|
|
name: "string slice returns first",
|
|
claims: map[string]any{"groups": []string{"admin", "users"}},
|
|
key: "groups",
|
|
want: "admin",
|
|
},
|
|
{
|
|
name: "empty string slice",
|
|
claims: map[string]any{"groups": []string{}},
|
|
key: "groups",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "string slice with whitespace",
|
|
claims: map[string]any{"groups": []string{" admin "}},
|
|
key: "groups",
|
|
want: "admin",
|
|
},
|
|
|
|
// Interface slice - returns first string
|
|
{
|
|
name: "interface slice with strings",
|
|
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
|
|
key: "roles",
|
|
want: "admin",
|
|
},
|
|
{
|
|
name: "interface slice with mixed types",
|
|
claims: map[string]any{"data": []interface{}{123, "value", true}},
|
|
key: "data",
|
|
want: "value",
|
|
},
|
|
{
|
|
name: "interface slice with no strings",
|
|
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
|
|
key: "nums",
|
|
want: "",
|
|
},
|
|
|
|
// Missing or empty key
|
|
{
|
|
name: "key not in claims",
|
|
claims: map[string]any{"other": "value"},
|
|
key: "email",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "empty key",
|
|
claims: map[string]any{"email": "user@example.com"},
|
|
key: "",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "nil claims",
|
|
claims: nil,
|
|
key: "email",
|
|
want: "",
|
|
},
|
|
|
|
// Unsupported types
|
|
{
|
|
name: "integer value",
|
|
claims: map[string]any{"count": 42},
|
|
key: "count",
|
|
want: "",
|
|
},
|
|
{
|
|
name: "boolean value",
|
|
claims: map[string]any{"active": true},
|
|
key: "active",
|
|
want: "",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := extractStringClaim(tc.claims, tc.key)
|
|
if result != tc.want {
|
|
t.Errorf("extractStringClaim(%v, %q) = %q, want %q", tc.claims, tc.key, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractStringSliceClaim(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claims map[string]any
|
|
key string
|
|
want []string
|
|
}{
|
|
// String slice
|
|
{
|
|
name: "string slice",
|
|
claims: map[string]any{"groups": []string{"admin", "users", "devops"}},
|
|
key: "groups",
|
|
want: []string{"admin", "users", "devops"},
|
|
},
|
|
{
|
|
name: "empty string slice",
|
|
claims: map[string]any{"groups": []string{}},
|
|
key: "groups",
|
|
want: []string{},
|
|
},
|
|
|
|
// Interface slice
|
|
{
|
|
name: "interface slice with strings",
|
|
claims: map[string]any{"roles": []interface{}{"admin", "user"}},
|
|
key: "roles",
|
|
want: []string{"admin", "user"},
|
|
},
|
|
{
|
|
name: "interface slice with mixed types filters non-strings",
|
|
claims: map[string]any{"data": []interface{}{"str1", 123, "str2", true}},
|
|
key: "data",
|
|
want: []string{"str1", "str2"},
|
|
},
|
|
{
|
|
name: "interface slice with no strings",
|
|
claims: map[string]any{"nums": []interface{}{1, 2, 3}},
|
|
key: "nums",
|
|
want: []string{},
|
|
},
|
|
|
|
// String value (comma/space separated)
|
|
{
|
|
name: "comma separated string",
|
|
claims: map[string]any{"groups": "admin,users,devops"},
|
|
key: "groups",
|
|
want: []string{"admin", "users", "devops"},
|
|
},
|
|
{
|
|
name: "space separated string",
|
|
claims: map[string]any{"groups": "admin users devops"},
|
|
key: "groups",
|
|
want: []string{"admin", "users", "devops"},
|
|
},
|
|
{
|
|
name: "mixed separator string",
|
|
claims: map[string]any{"groups": "admin, users devops"},
|
|
key: "groups",
|
|
want: []string{"admin", "users", "devops"},
|
|
},
|
|
|
|
// Missing or empty key
|
|
{
|
|
name: "key not in claims",
|
|
claims: map[string]any{"other": "value"},
|
|
key: "groups",
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "empty key",
|
|
claims: map[string]any{"groups": []string{"admin"}},
|
|
key: "",
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "nil claims",
|
|
claims: nil,
|
|
key: "groups",
|
|
want: nil,
|
|
},
|
|
|
|
// Unsupported types
|
|
{
|
|
name: "integer value",
|
|
claims: map[string]any{"count": 42},
|
|
key: "count",
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "boolean value",
|
|
claims: map[string]any{"active": true},
|
|
key: "active",
|
|
want: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := extractStringSliceClaim(tc.claims, tc.key)
|
|
|
|
// Compare slices
|
|
if tc.want == nil {
|
|
if result != nil {
|
|
t.Errorf("extractStringSliceClaim(%v, %q) = %v, want nil", tc.claims, tc.key, result)
|
|
}
|
|
return
|
|
}
|
|
|
|
if len(result) != len(tc.want) {
|
|
t.Errorf("extractStringSliceClaim(%v, %q) = %v (len %d), want %v (len %d)",
|
|
tc.claims, tc.key, result, len(result), tc.want, len(tc.want))
|
|
return
|
|
}
|
|
|
|
for i, v := range result {
|
|
if v != tc.want[i] {
|
|
t.Errorf("extractStringSliceClaim(%v, %q)[%d] = %q, want %q",
|
|
tc.claims, tc.key, i, v, tc.want[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchesValue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
candidate string
|
|
allowed []string
|
|
want bool
|
|
}{
|
|
// Matches
|
|
{
|
|
name: "exact match",
|
|
candidate: "admin",
|
|
allowed: []string{"admin", "user"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "case insensitive match",
|
|
candidate: "ADMIN",
|
|
allowed: []string{"admin", "user"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "candidate with whitespace",
|
|
candidate: " admin ",
|
|
allowed: []string{"admin"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "allowed with whitespace",
|
|
candidate: "admin",
|
|
allowed: []string{" admin "},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "mixed case both sides",
|
|
candidate: "AdMiN",
|
|
allowed: []string{"aDmIn"},
|
|
want: true,
|
|
},
|
|
|
|
// No match
|
|
{
|
|
name: "no match",
|
|
candidate: "guest",
|
|
allowed: []string{"admin", "user"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty candidate",
|
|
candidate: "",
|
|
allowed: []string{"admin"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "whitespace candidate",
|
|
candidate: " ",
|
|
allowed: []string{"admin"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty allowed list",
|
|
candidate: "admin",
|
|
allowed: []string{},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "nil allowed list",
|
|
candidate: "admin",
|
|
allowed: nil,
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := matchesValue(tc.candidate, tc.allowed)
|
|
if result != tc.want {
|
|
t.Errorf("matchesValue(%q, %v) = %v, want %v", tc.candidate, tc.allowed, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchesDomain(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
allowed []string
|
|
want bool
|
|
}{
|
|
// Matches
|
|
{
|
|
name: "exact domain match",
|
|
email: "user@example.com",
|
|
allowed: []string{"example.com"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "domain with @ prefix in allowed",
|
|
email: "user@example.com",
|
|
allowed: []string{"@example.com"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "case insensitive email",
|
|
email: "USER@EXAMPLE.COM",
|
|
allowed: []string{"example.com"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "case insensitive allowed",
|
|
email: "user@example.com",
|
|
allowed: []string{"EXAMPLE.COM"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "multiple allowed domains",
|
|
email: "user@company.org",
|
|
allowed: []string{"example.com", "company.org", "test.net"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "email with whitespace",
|
|
email: " user@example.com ",
|
|
allowed: []string{"example.com"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "allowed with whitespace",
|
|
email: "user@example.com",
|
|
allowed: []string{" example.com "},
|
|
want: true,
|
|
},
|
|
|
|
// No match
|
|
{
|
|
name: "different domain",
|
|
email: "user@other.com",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "subdomain not matched",
|
|
email: "user@sub.example.com",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
|
|
// Invalid emails
|
|
{
|
|
name: "empty email",
|
|
email: "",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "whitespace email",
|
|
email: " ",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "no @ in email",
|
|
email: "userexample.com",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "@ at end",
|
|
email: "user@",
|
|
allowed: []string{"example.com"},
|
|
want: false,
|
|
},
|
|
|
|
// Empty allowed
|
|
{
|
|
name: "empty allowed list",
|
|
email: "user@example.com",
|
|
allowed: []string{},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "nil allowed list",
|
|
email: "user@example.com",
|
|
allowed: nil,
|
|
want: false,
|
|
},
|
|
{
|
|
name: "allowed with empty strings",
|
|
email: "user@example.com",
|
|
allowed: []string{"", " ", "@"},
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := matchesDomain(tc.email, tc.allowed)
|
|
if result != tc.want {
|
|
t.Errorf("matchesDomain(%q, %v) = %v, want %v", tc.email, tc.allowed, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntersects(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
values []string
|
|
allowed []string
|
|
want bool
|
|
}{
|
|
// Intersects
|
|
{
|
|
name: "single common element",
|
|
values: []string{"admin", "user"},
|
|
allowed: []string{"admin", "guest"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "multiple common elements",
|
|
values: []string{"admin", "user", "devops"},
|
|
allowed: []string{"admin", "user", "guest"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "case insensitive",
|
|
values: []string{"ADMIN"},
|
|
allowed: []string{"admin"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "with whitespace",
|
|
values: []string{" admin "},
|
|
allowed: []string{"admin"},
|
|
want: true,
|
|
},
|
|
|
|
// No intersection
|
|
{
|
|
name: "no common elements",
|
|
values: []string{"admin", "user"},
|
|
allowed: []string{"guest", "viewer"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty values",
|
|
values: []string{},
|
|
allowed: []string{"admin"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "nil values",
|
|
values: nil,
|
|
allowed: []string{"admin"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty allowed",
|
|
values: []string{"admin"},
|
|
allowed: []string{},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "nil allowed",
|
|
values: []string{"admin"},
|
|
allowed: nil,
|
|
want: false,
|
|
},
|
|
{
|
|
name: "both empty",
|
|
values: []string{},
|
|
allowed: []string{},
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := intersects(tc.values, tc.allowed)
|
|
if result != tc.want {
|
|
t.Errorf("intersects(%v, %v) = %v, want %v", tc.values, tc.allowed, result, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|