mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 11:30:15 +00:00
312 lines
10 KiB
Go
312 lines
10 KiB
Go
package notifications
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// createTestNotificationManager creates a NotificationManager with localhost allowlisted
|
|
// for testing purposes (since httptest servers run on localhost).
|
|
func createTestNotificationManager(t *testing.T) *NotificationManager {
|
|
t.Helper()
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
// Allowlist localhost for test servers
|
|
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.0/8"); err != nil {
|
|
t.Fatalf("failed to allowlist localhost: %v", err)
|
|
}
|
|
return nm
|
|
}
|
|
|
|
// TestSecureWebhookClientRedirectLimit verifies that the client stops following
|
|
// redirects after exceeding WebhookMaxRedirects.
|
|
func TestSecureWebhookClientRedirectLimit(t *testing.T) {
|
|
redirectCount := 0
|
|
// Create a server that always redirects to itself
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
redirectCount++
|
|
http.Redirect(w, r, r.URL.String()+"x", http.StatusFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
nm := createTestNotificationManager(t)
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
_, err := client.Get(server.URL)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error from too many redirects, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "stopped after") {
|
|
t.Errorf("expected 'stopped after' error, got: %v", err)
|
|
}
|
|
// CheckRedirect is called when len(via) >= WebhookMaxRedirects.
|
|
// via includes previous requests, so with WebhookMaxRedirects=3:
|
|
// - Initial request (not in via yet)
|
|
// - Redirect 1: via has 1 entry (initial) - allowed
|
|
// - Redirect 2: via has 2 entries - allowed
|
|
// - Redirect 3: via has 3 entries - blocked (len(via) >= 3)
|
|
// Total server hits: initial + 2 successful redirects = 3
|
|
expectedRequests := WebhookMaxRedirects
|
|
if redirectCount != expectedRequests {
|
|
t.Errorf("expected %d requests, got %d", expectedRequests, redirectCount)
|
|
}
|
|
}
|
|
|
|
// TestSecureWebhookClientBlocksUnsafeRedirect verifies that redirects to
|
|
// localhost/private networks are blocked when not in allowlist.
|
|
func TestSecureWebhookClientBlocksUnsafeRedirect(t *testing.T) {
|
|
// Create a server that redirects to localhost (allowlist 127.0.0.1 only, not 127.0.0.2)
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Redirect to a different localhost IP that's not in our allowlist
|
|
http.Redirect(w, r, "http://127.0.0.2:8080/evil", http.StatusFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create manager with narrow allowlist (only the test server's IP)
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
// Allowlist only 127.0.0.1/32, so 127.0.0.2 will be blocked
|
|
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil {
|
|
t.Fatalf("failed to set allowlist: %v", err)
|
|
}
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
_, err := client.Get(server.URL)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error from unsafe redirect, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "private IP") {
|
|
t.Errorf("expected private IP validation error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestSecureWebhookClientBlocksPrivateNetworkRedirect verifies that redirects to
|
|
// private network IPs (10.x, 192.168.x, 172.16-31.x) are blocked.
|
|
func TestSecureWebhookClientBlocksPrivateNetworkRedirect(t *testing.T) {
|
|
privateIPs := []string{
|
|
"http://10.0.0.1/webhook",
|
|
"http://192.168.1.1/webhook",
|
|
"http://172.16.0.1/webhook",
|
|
}
|
|
|
|
for _, privateURL := range privateIPs {
|
|
t.Run(privateURL, func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, privateURL, http.StatusFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Allowlist localhost for test server, but not private network ranges
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.0/8"); err != nil {
|
|
t.Fatalf("failed to set allowlist: %v", err)
|
|
}
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
_, err := client.Get(server.URL)
|
|
|
|
if err == nil {
|
|
t.Fatalf("expected error from redirect to %s, got nil", privateURL)
|
|
}
|
|
if !strings.Contains(err.Error(), "private IP") {
|
|
t.Errorf("expected private IP validation error for %s, got: %v", privateURL, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSecureWebhookClientAllowsValidRedirects verifies that valid redirects
|
|
// to safe URLs are followed successfully.
|
|
func TestSecureWebhookClientAllowsValidRedirects(t *testing.T) {
|
|
// Create a chain of servers for redirect testing
|
|
// Final server returns 200 OK
|
|
finalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("success"))
|
|
}))
|
|
defer finalServer.Close()
|
|
|
|
// Middle server redirects to final
|
|
middleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, finalServer.URL, http.StatusFound)
|
|
}))
|
|
defer middleServer.Close()
|
|
|
|
// First server redirects to middle
|
|
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, middleServer.URL, http.StatusFound)
|
|
}))
|
|
defer firstServer.Close()
|
|
|
|
nm := createTestNotificationManager(t)
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
resp, err := client.Get(firstServer.URL)
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected successful redirect chain, got error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// TestSecureWebhookClientBlocksLinkLocalRedirect verifies that redirects to
|
|
// link-local addresses (169.254.x.x) are blocked.
|
|
func TestSecureWebhookClientBlocksLinkLocalRedirect(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, "http://169.254.169.254/latest/meta-data/", http.StatusFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
nm := createTestNotificationManager(t)
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
_, err := client.Get(server.URL)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error from redirect to link-local/metadata address, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "link-local") {
|
|
t.Errorf("expected link-local validation error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextBlocksPrivateIP(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
|
|
_, err := client.Get(server.URL)
|
|
if err == nil {
|
|
t.Fatal("expected error from blocked private IP, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "blocked private IP") {
|
|
t.Fatalf("expected blocked private IP error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextBlocksHostnameWithoutAllowlist(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Rewrite URL to use hostname to exercise DNS resolution path.
|
|
hostedURL := strings.Replace(server.URL, "127.0.0.1", "localhost", 1)
|
|
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
|
|
_, err := client.Get(hostedURL)
|
|
if err == nil {
|
|
t.Fatal("expected error from blocked hostname, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "resolves to blocked private IPs") {
|
|
t.Fatalf("expected blocked hostname error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextAllowsHostnameWithAllowlist(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Rewrite URL to use hostname to exercise DNS resolution path.
|
|
hostedURL := strings.Replace(server.URL, "127.0.0.1", "localhost", 1)
|
|
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil {
|
|
t.Fatalf("failed to set allowlist: %v", err)
|
|
}
|
|
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
resp, err := client.Get(hostedURL)
|
|
if err != nil {
|
|
t.Fatalf("expected request to succeed with allowlist, got: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextAllowsPublicIP(t *testing.T) {
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
transport, ok := client.Transport.(*http.Transport)
|
|
if !ok || transport == nil {
|
|
t.Fatalf("expected transport to be *http.Transport")
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
// Use TEST-NET-3 address to exercise non-private IP branch.
|
|
_, err := transport.DialContext(ctx, "tcp", "203.0.113.1:80")
|
|
if err == nil {
|
|
t.Fatalf("expected dial to fail due to canceled context")
|
|
}
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextRejectsBadAddress(t *testing.T) {
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
transport, ok := client.Transport.(*http.Transport)
|
|
if !ok || transport == nil {
|
|
t.Fatalf("expected transport to be *http.Transport")
|
|
}
|
|
|
|
_, err := transport.DialContext(context.Background(), "tcp", "badaddress")
|
|
if err == nil {
|
|
t.Fatalf("expected dial to fail for invalid address")
|
|
}
|
|
}
|
|
|
|
func TestSecureWebhookClientDialContextLookupFailure(t *testing.T) {
|
|
nm := &NotificationManager{
|
|
lastNotified: make(map[string]notificationRecord),
|
|
webhookRateLimits: make(map[string]*webhookRateLimit),
|
|
}
|
|
client := nm.createSecureWebhookClient(WebhookTimeout)
|
|
transport, ok := client.Transport.(*http.Transport)
|
|
if !ok || transport == nil {
|
|
t.Fatalf("expected transport to be *http.Transport")
|
|
}
|
|
|
|
_, err := transport.DialContext(context.Background(), "tcp", "bad host:80")
|
|
if err == nil {
|
|
t.Fatalf("expected lookup failure for invalid hostname")
|
|
}
|
|
}
|