Pulse/internal/ai/service_tools_test.go
2026-02-04 13:54:32 +00:00

174 lines
4.7 KiB
Go

package ai
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestFetchURL(t *testing.T) {
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
// Start a local test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello, world")
}))
defer ts.Close()
svc := NewService(nil, nil)
ctx := context.Background()
// Test successful fetch
result, err := svc.fetchURL(ctx, ts.URL)
if err != nil {
t.Fatalf("fetchURL failed: %v", err)
}
if !containsString(result, "Hello, world") {
t.Errorf("Expected 'Hello, world' in result, got: %s", result)
}
// Test blocked host (localhost)
_, err = svc.fetchURL(ctx, "http://localhost:8080")
if err == nil || !containsString(err.Error(), "blocked") {
t.Errorf("Expected blocked host error, got: %v", err)
}
// Test invalid URL
_, err = svc.fetchURL(ctx, "not-a-url")
if err == nil {
t.Error("Expected error for invalid URL")
}
// Test scheme check
_, err = svc.fetchURL(ctx, "ftp://example.com")
if err == nil || !containsString(err.Error(), "only http/https") {
t.Errorf("Expected scheme error, got: %v", err)
}
}
func TestParseAndValidateFetchURL(t *testing.T) {
ctx := context.Background()
tests := []struct {
url string
wantErr bool
errSub string
}{
{"http://example.com", false, ""},
{"https://example.com/path", false, ""},
{" http://example.com ", false, ""},
{"", true, "url is required"},
{"http://localhost", true, "blocked"},
{"http://localhost.", true, "blocked"},
{"http://127.0.0.1", true, "blocked"},
{"http://[::1]", true, "blocked"},
{"ftp://example.com", true, "only http/https"},
{"http://user:pass@example.com", true, "credentials"},
{"http://example.com/#frag", true, "fragments"},
{"http://", true, "host"},
}
for _, tt := range tests {
t.Run(tt.url, func(t *testing.T) {
_, err := parseAndValidateFetchURL(ctx, tt.url)
if (err != nil) != tt.wantErr {
t.Fatalf("parseAndValidateFetchURL() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && tt.errSub != "" && !containsString(err.Error(), tt.errSub) {
t.Errorf("error %v does not contain %q", err, tt.errSub)
}
})
}
}
func TestIsBlockedFetchIP(t *testing.T) {
tests := []struct {
ip string
blocked bool
}{
{"127.0.0.1", true},
{"::1", true},
{"0.0.0.0", true},
{"169.254.1.1", true},
{"192.168.1.1", true}, // Private IPs are blocked by default for security (SSRF prevention)
{"10.0.0.1", true}, // Private range 10.x.x.x blocked
{"172.16.0.1", true}, // Private range 172.16.x.x blocked
{"8.8.8.8", false}, // Global is allowed
{"224.0.0.1", true}, // Multicast
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := isBlockedFetchIP(ip); got != tt.blocked {
t.Errorf("isBlockedFetchIP(%s) = %v, want %v", tt.ip, got, tt.blocked)
}
}
if !isBlockedFetchIP(nil) {
t.Error("nil IP should be blocked")
}
// Test that private IPs can be allowed via environment variable
os.Setenv("PULSE_AI_ALLOW_PRIVATE_IPS", "true")
defer os.Unsetenv("PULSE_AI_ALLOW_PRIVATE_IPS")
privateIP := net.ParseIP("192.168.1.1")
if isBlockedFetchIP(privateIP) {
t.Error("Private IP should be allowed when PULSE_AI_ALLOW_PRIVATE_IPS=true")
}
// Test that loopback IPs can be allowed via environment variable
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
loopbackIP := net.ParseIP("127.0.0.1")
if isBlockedFetchIP(loopbackIP) {
t.Error("Loopback IP should be allowed when PULSE_AI_ALLOW_LOOPBACK=true")
}
}
func TestFetchURL_SizeLimit(t *testing.T) {
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
// Server that returns 100KB of data
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data := make([]byte, 100*1024)
for i := range data {
data[i] = 'a'
}
w.Write(data)
}))
defer ts.Close()
svc := NewService(nil, nil)
result, err := svc.fetchURL(context.Background(), ts.URL)
if err != nil {
t.Fatalf("fetchURL failed: %v", err)
}
if !containsString(result, "truncated at 64KB") {
t.Error("Expected result to be truncated")
}
}
func TestFetchURL_RedirectLimit(t *testing.T) {
os.Setenv("PULSE_AI_ALLOW_LOOPBACK", "true")
defer os.Unsetenv("PULSE_AI_ALLOW_LOOPBACK")
var ts *httptest.Server
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, ts.URL, http.StatusFound)
}))
defer ts.Close()
svc := NewService(nil, nil)
_, err := svc.fetchURL(context.Background(), ts.URL)
if err == nil || !containsString(err.Error(), "too many redirects") {
t.Errorf("Expected redirect limit error, got: %v", err)
}
}