Pulse/internal/notifications/webhook_client_test.go

312 lines
10 KiB
Go

package notifications
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// createTestNotificationManager creates a NotificationManager with localhost allowlisted
// for testing purposes (since httptest servers run on localhost).
func createTestNotificationManager(t *testing.T) *NotificationManager {
t.Helper()
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
// Allowlist localhost for test servers
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.0/8"); err != nil {
t.Fatalf("failed to allowlist localhost: %v", err)
}
return nm
}
// TestSecureWebhookClientRedirectLimit verifies that the client stops following
// redirects after exceeding WebhookMaxRedirects.
func TestSecureWebhookClientRedirectLimit(t *testing.T) {
redirectCount := 0
// Create a server that always redirects to itself
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
http.Redirect(w, r, r.URL.String()+"x", http.StatusFound)
}))
defer server.Close()
nm := createTestNotificationManager(t)
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(server.URL)
if err == nil {
t.Fatal("expected error from too many redirects, got nil")
}
if !strings.Contains(err.Error(), "stopped after") {
t.Errorf("expected 'stopped after' error, got: %v", err)
}
// CheckRedirect is called when len(via) >= WebhookMaxRedirects.
// via includes previous requests, so with WebhookMaxRedirects=3:
// - Initial request (not in via yet)
// - Redirect 1: via has 1 entry (initial) - allowed
// - Redirect 2: via has 2 entries - allowed
// - Redirect 3: via has 3 entries - blocked (len(via) >= 3)
// Total server hits: initial + 2 successful redirects = 3
expectedRequests := WebhookMaxRedirects
if redirectCount != expectedRequests {
t.Errorf("expected %d requests, got %d", expectedRequests, redirectCount)
}
}
// TestSecureWebhookClientBlocksUnsafeRedirect verifies that redirects to
// localhost/private networks are blocked when not in allowlist.
func TestSecureWebhookClientBlocksUnsafeRedirect(t *testing.T) {
// Create a server that redirects to localhost (allowlist 127.0.0.1 only, not 127.0.0.2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Redirect to a different localhost IP that's not in our allowlist
http.Redirect(w, r, "http://127.0.0.2:8080/evil", http.StatusFound)
}))
defer server.Close()
// Create manager with narrow allowlist (only the test server's IP)
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
// Allowlist only 127.0.0.1/32, so 127.0.0.2 will be blocked
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil {
t.Fatalf("failed to set allowlist: %v", err)
}
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(server.URL)
if err == nil {
t.Fatal("expected error from unsafe redirect, got nil")
}
if !strings.Contains(err.Error(), "private IP") {
t.Errorf("expected private IP validation error, got: %v", err)
}
}
// TestSecureWebhookClientBlocksPrivateNetworkRedirect verifies that redirects to
// private network IPs (10.x, 192.168.x, 172.16-31.x) are blocked.
func TestSecureWebhookClientBlocksPrivateNetworkRedirect(t *testing.T) {
privateIPs := []string{
"http://10.0.0.1/webhook",
"http://192.168.1.1/webhook",
"http://172.16.0.1/webhook",
}
for _, privateURL := range privateIPs {
t.Run(privateURL, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, privateURL, http.StatusFound)
}))
defer server.Close()
// Allowlist localhost for test server, but not private network ranges
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.0/8"); err != nil {
t.Fatalf("failed to set allowlist: %v", err)
}
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(server.URL)
if err == nil {
t.Fatalf("expected error from redirect to %s, got nil", privateURL)
}
if !strings.Contains(err.Error(), "private IP") {
t.Errorf("expected private IP validation error for %s, got: %v", privateURL, err)
}
})
}
}
// TestSecureWebhookClientAllowsValidRedirects verifies that valid redirects
// to safe URLs are followed successfully.
func TestSecureWebhookClientAllowsValidRedirects(t *testing.T) {
// Create a chain of servers for redirect testing
// Final server returns 200 OK
finalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer finalServer.Close()
// Middle server redirects to final
middleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, finalServer.URL, http.StatusFound)
}))
defer middleServer.Close()
// First server redirects to middle
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, middleServer.URL, http.StatusFound)
}))
defer firstServer.Close()
nm := createTestNotificationManager(t)
client := nm.createSecureWebhookClient(WebhookTimeout)
resp, err := client.Get(firstServer.URL)
if err != nil {
t.Fatalf("expected successful redirect chain, got error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestSecureWebhookClientBlocksLinkLocalRedirect verifies that redirects to
// link-local addresses (169.254.x.x) are blocked.
func TestSecureWebhookClientBlocksLinkLocalRedirect(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://169.254.169.254/latest/meta-data/", http.StatusFound)
}))
defer server.Close()
nm := createTestNotificationManager(t)
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(server.URL)
if err == nil {
t.Fatal("expected error from redirect to link-local/metadata address, got nil")
}
if !strings.Contains(err.Error(), "link-local") {
t.Errorf("expected link-local validation error, got: %v", err)
}
}
func TestSecureWebhookClientDialContextBlocksPrivateIP(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(server.URL)
if err == nil {
t.Fatal("expected error from blocked private IP, got nil")
}
if !strings.Contains(err.Error(), "blocked private IP") {
t.Fatalf("expected blocked private IP error, got: %v", err)
}
}
func TestSecureWebhookClientDialContextBlocksHostnameWithoutAllowlist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Rewrite URL to use hostname to exercise DNS resolution path.
hostedURL := strings.Replace(server.URL, "127.0.0.1", "localhost", 1)
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
client := nm.createSecureWebhookClient(WebhookTimeout)
_, err := client.Get(hostedURL)
if err == nil {
t.Fatal("expected error from blocked hostname, got nil")
}
if !strings.Contains(err.Error(), "resolves to blocked private IPs") {
t.Fatalf("expected blocked hostname error, got: %v", err)
}
}
func TestSecureWebhookClientDialContextAllowsHostnameWithAllowlist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Rewrite URL to use hostname to exercise DNS resolution path.
hostedURL := strings.Replace(server.URL, "127.0.0.1", "localhost", 1)
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
if err := nm.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil {
t.Fatalf("failed to set allowlist: %v", err)
}
client := nm.createSecureWebhookClient(WebhookTimeout)
resp, err := client.Get(hostedURL)
if err != nil {
t.Fatalf("expected request to succeed with allowlist, got: %v", err)
}
resp.Body.Close()
}
func TestSecureWebhookClientDialContextAllowsPublicIP(t *testing.T) {
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
client := nm.createSecureWebhookClient(WebhookTimeout)
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected transport to be *http.Transport")
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
// Use TEST-NET-3 address to exercise non-private IP branch.
_, err := transport.DialContext(ctx, "tcp", "203.0.113.1:80")
if err == nil {
t.Fatalf("expected dial to fail due to canceled context")
}
}
func TestSecureWebhookClientDialContextRejectsBadAddress(t *testing.T) {
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
client := nm.createSecureWebhookClient(WebhookTimeout)
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected transport to be *http.Transport")
}
_, err := transport.DialContext(context.Background(), "tcp", "badaddress")
if err == nil {
t.Fatalf("expected dial to fail for invalid address")
}
}
func TestSecureWebhookClientDialContextLookupFailure(t *testing.T) {
nm := &NotificationManager{
lastNotified: make(map[string]notificationRecord),
webhookRateLimits: make(map[string]*webhookRateLimit),
}
client := nm.createSecureWebhookClient(WebhookTimeout)
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected transport to be *http.Transport")
}
_, err := transport.DialContext(context.Background(), "tcp", "bad host:80")
if err == nil {
t.Fatalf("expected lookup failure for invalid hostname")
}
}