mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 11:30:15 +00:00
174 lines
4.7 KiB
Go
174 lines
4.7 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestFetchURL(t *testing.T) {
|
|
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
|
|
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
|
|
|
|
// Start a local test server
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "Hello, world")
|
|
}))
|
|
defer ts.Close()
|
|
|
|
svc := NewService(nil, nil)
|
|
ctx := context.Background()
|
|
|
|
// Test successful fetch
|
|
result, err := svc.fetchURL(ctx, ts.URL)
|
|
if err != nil {
|
|
t.Fatalf("fetchURL failed: %v", err)
|
|
}
|
|
if !containsString(result, "Hello, world") {
|
|
t.Errorf("Expected 'Hello, world' in result, got: %s", result)
|
|
}
|
|
|
|
// Test blocked host (localhost)
|
|
_, err = svc.fetchURL(ctx, "http://localhost:8080")
|
|
if err == nil || !containsString(err.Error(), "blocked") {
|
|
t.Errorf("Expected blocked host error, got: %v", err)
|
|
}
|
|
|
|
// Test invalid URL
|
|
_, err = svc.fetchURL(ctx, "not-a-url")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid URL")
|
|
}
|
|
|
|
// Test scheme check
|
|
_, err = svc.fetchURL(ctx, "ftp://example.com")
|
|
if err == nil || !containsString(err.Error(), "only http/https") {
|
|
t.Errorf("Expected scheme error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestParseAndValidateFetchURL(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
url string
|
|
wantErr bool
|
|
errSub string
|
|
}{
|
|
{"http://example.com", false, ""},
|
|
{"https://example.com/path", false, ""},
|
|
{" http://example.com ", false, ""},
|
|
{"", true, "url is required"},
|
|
{"http://localhost", true, "blocked"},
|
|
{"http://localhost.", true, "blocked"},
|
|
{"http://127.0.0.1", true, "blocked"},
|
|
{"http://[::1]", true, "blocked"},
|
|
{"ftp://example.com", true, "only http/https"},
|
|
{"http://user:pass@example.com", true, "credentials"},
|
|
{"http://example.com/#frag", true, "fragments"},
|
|
{"http://", true, "host"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.url, func(t *testing.T) {
|
|
_, err := parseAndValidateFetchURL(ctx, tt.url)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Fatalf("parseAndValidateFetchURL() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if tt.wantErr && tt.errSub != "" && !containsString(err.Error(), tt.errSub) {
|
|
t.Errorf("error %v does not contain %q", err, tt.errSub)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsBlockedFetchIP(t *testing.T) {
|
|
tests := []struct {
|
|
ip string
|
|
blocked bool
|
|
}{
|
|
{"127.0.0.1", true},
|
|
{"::1", true},
|
|
{"0.0.0.0", true},
|
|
{"169.254.1.1", true},
|
|
{"192.168.1.1", true}, // Private IPs are blocked by default for security (SSRF prevention)
|
|
{"10.0.0.1", true}, // Private range 10.x.x.x blocked
|
|
{"172.16.0.1", true}, // Private range 172.16.x.x blocked
|
|
{"8.8.8.8", false}, // Global is allowed
|
|
{"224.0.0.1", true}, // Multicast
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
ip := net.ParseIP(tt.ip)
|
|
if got := isBlockedFetchIP(ip); got != tt.blocked {
|
|
t.Errorf("isBlockedFetchIP(%s) = %v, want %v", tt.ip, got, tt.blocked)
|
|
}
|
|
}
|
|
|
|
if !isBlockedFetchIP(nil) {
|
|
t.Error("nil IP should be blocked")
|
|
}
|
|
|
|
// Test that private IPs can be allowed via environment variable
|
|
os.Setenv("PULSE_AI_ALLOW_PRIVATE_IPS", "true")
|
|
defer os.Unsetenv("PULSE_AI_ALLOW_PRIVATE_IPS")
|
|
|
|
privateIP := net.ParseIP("192.168.1.1")
|
|
if isBlockedFetchIP(privateIP) {
|
|
t.Error("Private IP should be allowed when PULSE_AI_ALLOW_PRIVATE_IPS=true")
|
|
}
|
|
|
|
// Test that loopback IPs can be allowed via environment variable
|
|
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
|
|
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
|
|
|
|
loopbackIP := net.ParseIP("127.0.0.1")
|
|
if isBlockedFetchIP(loopbackIP) {
|
|
t.Error("Loopback IP should be allowed when PULSE_AI_ALLOW_LOOPBACK=true")
|
|
}
|
|
}
|
|
|
|
func TestFetchURL_SizeLimit(t *testing.T) {
|
|
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
|
|
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
|
|
|
|
// Server that returns 100KB of data
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
data := make([]byte, 100*1024)
|
|
for i := range data {
|
|
data[i] = 'a'
|
|
}
|
|
w.Write(data)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
svc := NewService(nil, nil)
|
|
result, err := svc.fetchURL(context.Background(), ts.URL)
|
|
if err != nil {
|
|
t.Fatalf("fetchURL failed: %v", err)
|
|
}
|
|
if !containsString(result, "truncated at 64KB") {
|
|
t.Error("Expected result to be truncated")
|
|
}
|
|
}
|
|
|
|
func TestFetchURL_RedirectLimit(t *testing.T) {
|
|
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
|
|
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
|
|
|
|
var ts *httptest.Server
|
|
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, ts.URL, http.StatusFound)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
svc := NewService(nil, nil)
|
|
_, err := svc.fetchURL(context.Background(), ts.URL)
|
|
if err == nil || !containsString(err.Error(), "too many redirects") {
|
|
t.Errorf("Expected redirect limit error, got: %v", err)
|
|
}
|
|
}
|