Pulse/pkg/discovery/discovery_test.go
rcourtman 7c00055047 feat: unify and improve Proxmox discovery/scanning architecture
Replaced inconsistent per-product detection logic with a unified probe
architecture using confidence scoring and product-specific matchers.

Key improvements:
- PBS detection now inspects TLS certs, auth headers (401/403), and
  probes PBS-specific endpoints (/api2/json/status, /config/datastore)
  fixing false negatives for self-signed and auth-protected servers
- PMG detection uses header analysis first, then conditional endpoint
  probing, working consistently regardless of port
- Single unified probeProxmoxService() replaces separate checkPort8006()
  and checkServer() code paths, eliminating duplication
- Confidence scoring (0.0-1.0+) with evidence tracking for debugging
- Consolidated hostname resolution and version handling

Technical changes:
- Added ProxmoxProbeResult with structured evidence and scoring
- Added product matchers: applyPVEHeuristics, applyPMGHeuristics,
  applyPBSHeuristics
- Removed legacy methods: checkPort8006, checkServer, isPMGServer,
  detectProductFromEndpoint, and duplicate hostname helpers
- Updated all tests to use new unified probe architecture
- Added probe_test_helpers.go for test access to internal methods

All tests passing. Fixes PBS detection issues and improves consistency
across PVE/PMG/PBS discovery.
2025-10-21 13:09:41 +00:00

429 lines
11 KiB
Go

package discovery
import (
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/pkg/discovery/envdetect"
)
func newTestScanner(client *http.Client) *Scanner {
policy := envdetect.DefaultScanPolicy()
policy.DialTimeout = time.Second
profile := &envdetect.EnvironmentProfile{
Policy: policy,
Metadata: map[string]string{},
}
return &Scanner{
policy: policy,
profile: profile,
httpClient: client,
}
}
func TestInferTypeFromMetadata(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
parts []string
want string
}{
{
name: "detects PMG from auth header",
parts: []string{`PMGAuth realm="Proxmox Mail Gateway"`, "pmgproxy/4.0"},
want: "pmg",
},
{
name: "detects PVE from realm string",
parts: []string{`PVEAuth realm="Proxmox Virtual Environment"`, "pve-api-daemon/3.0"},
want: "pve",
},
{
name: "detects PBS from cookie",
parts: []string{"PBS", "PBSCookie=abc123", `PBSAuth realm="Proxmox Backup Server"`},
want: "pbs",
},
{
name: "returns empty when no markers",
parts: []string{"Custom Certificate", "Example Corp"},
want: "",
},
{
name: "tolerates compact strings",
parts: []string{"ProxmoxMailGateway"},
want: "pmg",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if got := inferTypeFromMetadata(tc.parts...); got != tc.want {
t.Fatalf("inferTypeFromMetadata(%v) = %q, want %q", tc.parts, got, tc.want)
}
})
}
}
func TestInferTypeFromCertificate(t *testing.T) {
t.Parallel()
state := tls.ConnectionState{
PeerCertificates: []*x509.Certificate{
{
Subject: pkix.Name{
CommonName: "Proxmox Mail Gateway",
Organization: []string{"Proxmox"},
OrganizationalUnit: []string{"PMG"},
},
},
},
}
if got := inferTypeFromCertificate(state); got != "pmg" {
t.Fatalf("inferTypeFromCertificate returned %q, want %q", got, "pmg")
}
if got := inferTypeFromCertificate(tls.ConnectionState{}); got != "" {
t.Fatalf("expected empty result for missing certificates, got %q", got)
}
}
func TestDetectProductFromEndpoint(t *testing.T) {
t.Parallel()
var requestPaths []string
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestPaths = append(requestPaths, r.URL.Path)
switch {
case strings.Contains(r.URL.Path, "statistics/mail"):
w.Header().Set("Proxmox-Product", "Proxmox Mail Gateway")
w.WriteHeader(http.StatusOK)
case strings.Contains(r.URL.Path, "api2/json/version"):
w.Header().Set("Proxmox-Product", "Proxmox Backup Server")
w.WriteHeader(http.StatusOK)
case strings.Contains(r.URL.Path, "mail/queue"):
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
}))
defer ts.Close()
scanner := newTestScanner(ts.Client())
address := strings.TrimPrefix(ts.URL, "https://")
finding := scanner.ProbeAPIEndpoint(context.Background(), address, "api2/json/statistics/mail")
if finding.ProductGuess != ProductPMG {
t.Fatalf("ProbeAPIEndpoint returned %q, want %q", finding.ProductGuess, ProductPMG)
}
versionFinding := scanner.ProbeAPIEndpoint(context.Background(), address, "api2/json/version")
if versionFinding.ProductGuess != ProductPBS {
t.Fatalf("ProbeAPIEndpoint returned %q, want %q", versionFinding.ProductGuess, ProductPBS)
}
unknownFinding := scanner.ProbeAPIEndpoint(context.Background(), address, "api2/json/unknown/path")
if unknownFinding.ProductGuess != "" || unknownFinding.Status != http.StatusNotFound {
t.Fatalf("expected empty result for unknown endpoint, got %+v", unknownFinding)
}
if len(requestPaths) == 0 {
t.Fatalf("expected ProbeAPIEndpoint to perform requests")
}
}
func TestIsPMGServer(t *testing.T) {
t.Parallel()
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Proxmox-Product", "Proxmox Mail Gateway")
w.Header().Set("WWW-Authenticate", `PMGAuth realm="Proxmox Mail Gateway"`)
if strings.Contains(r.URL.Path, "statistics/mail") ||
strings.Contains(r.URL.Path, "api2/json/version") {
w.WriteHeader(http.StatusOK)
return
}
http.NotFound(w, r)
}))
defer ts.Close()
scanner := newTestScanner(ts.Client())
host, portStr, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil {
t.Fatalf("SplitHostPort: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Fatalf("strconv.Atoi: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
probe := scanner.ProbeProxmoxService(ctx, host, port)
if probe == nil || !probe.Positive || probe.PrimaryProduct != ProductPMG {
t.Fatalf("expected PMG detection to succeed, got %+v", probe)
}
tsNoMatch := httptest.NewTLSServer(http.NotFoundHandler())
defer tsNoMatch.Close()
scanner.httpClient = tsNoMatch.Client()
host, portStr, err = net.SplitHostPort(tsNoMatch.Listener.Addr().String())
if err != nil {
t.Fatalf("SplitHostPort: %v", err)
}
port, err = strconv.Atoi(portStr)
if err != nil {
t.Fatalf("strconv.Atoi: %v", err)
}
probe = scanner.ProbeProxmoxService(ctx, host, port)
if probe != nil && probe.Positive {
t.Fatalf("expected PMG detection to fail for endpoints without markers")
}
}
func TestCheckServerRetrievesVersion(t *testing.T) {
t.Parallel()
const responseVersion = `{"data":{"version":"2.4.1","release":"1"}}`
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api2/json/version" {
w.Header().Set("Content-Type", "application/json")
http.SetCookie(w, &http.Cookie{Name: "PBSCookie", Value: "abc"})
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseVersion))
return
}
http.NotFound(w, r)
}))
defer ts.Close()
host, portStr, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil {
t.Fatalf("SplitHostPort: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Fatalf("strconv.Atoi: %v", err)
}
scanner := newTestScanner(ts.Client())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
probe := scanner.ProbeProxmoxService(ctx, host, port)
if probe == nil || !probe.Positive {
t.Fatalf("ProbeProxmoxService returned nil")
}
if probe.PrimaryProduct != ProductPBS {
t.Fatalf("expected product pbs, got %q", probe.PrimaryProduct)
}
if probe.Version != "2.4.1" {
t.Fatalf("expected version 2.4.1, got %q", probe.Version)
}
if probe.Release != "1" {
t.Fatalf("expected release 1, got %q", probe.Release)
}
}
func startTLSServerOn(t *testing.T, addr string, handler http.Handler) *httptest.Server {
t.Helper()
srv := httptest.NewUnstartedServer(handler)
ln, err := net.Listen("tcp", addr)
if err != nil {
t.Skipf("port %s unavailable: %v", addr, err)
}
srv.Listener = ln
srv.StartTLS()
t.Cleanup(func() { srv.Close() })
return srv
}
func TestCheckServerHandlesUnauthorized(t *testing.T) {
t.Parallel()
unauthorizedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", "PVEAuth realm=\"Proxmox Virtual Environment\"")
w.WriteHeader(http.StatusUnauthorized)
})
srv := startTLSServerOn(t, "127.0.0.1:9008", unauthorizedHandler)
_ = srv
scanner := newTestScanner(&http.Client{
Timeout: 500 * time.Millisecond,
Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
probe := scanner.ProbeProxmoxService(ctx, "127.0.0.1", 9008)
if probe == nil || !probe.Positive {
t.Fatalf("expected server discovery despite unauthorized response: %+v", probe)
}
if probe.PrimaryProduct != ProductPVE {
t.Fatalf("expected product pve, got %q", probe.PrimaryProduct)
}
if probe.Version != "Unknown" {
t.Fatalf("expected version Unknown, got %q", probe.Version)
}
}
func TestDiscoverServersWithCallback(t *testing.T) {
t.Parallel()
const subnet = "127.0.0.0/29"
noTLSListener, err := net.Listen("tcp", "127.0.0.1:9009")
if err != nil {
t.Fatalf("failed to listen on 9009: %v", err)
}
go func() {
for {
conn, err := noTLSListener.Accept()
if err != nil {
return
}
conn.Close()
}
}()
t.Cleanup(func() { noTLSListener.Close() })
pveHandler := http.NewServeMux()
pveHandler.HandleFunc("/api2/json/version", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Proxmox-Product", "Proxmox Virtual Environment")
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]string{
"version": "8.1",
"release": "1",
},
})
})
pbsHandler := http.NewServeMux()
pbsHandler.HandleFunc("/api2/json/version", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Proxmox-Product", "Proxmox Backup Server")
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]string{
"version": "2.4.1",
"release": "2",
},
})
})
pveServer := startTLSServerOn(t, "127.0.0.1:8006", pveHandler)
_ = pveServer
pbsServer := startTLSServerOn(t, "127.0.0.1:8007", pbsHandler)
_ = pbsServer
scanner := newTestScanner(&http.Client{
Timeout: 500 * time.Millisecond,
Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
})
scanner.policy.MaxConcurrent = 4
scanner.policy.DialTimeout = 200 * time.Millisecond
scanner.policy.HTTPTimeout = 500 * time.Millisecond
if scanner.profile != nil {
scanner.profile.Policy = scanner.policy
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var mu sync.Mutex
var callbacks []DiscoveredServer
// Add a manual check for the TCP-only port.
if probe := scanner.ProbeProxmoxService(ctx, "127.0.0.1", 9009); probe != nil && probe.Positive {
t.Fatalf("expected ProbeProxmoxService to ignore TCP-only host, got %+v", probe)
}
result, err := scanner.DiscoverServersWithCallback(ctx, subnet, func(server DiscoveredServer, phase string) {
mu.Lock()
callbacks = append(callbacks, server)
mu.Unlock()
})
if err != nil {
t.Fatalf("DiscoverServersWithCallback returned error: %v", err)
}
if len(result.Servers) != 2 {
t.Fatalf("expected 2 servers, got %d: %+v", len(result.Servers), result.Servers)
}
seen := make(map[string]DiscoveredServer, len(result.Servers))
for _, server := range result.Servers {
seen[server.Type] = server
}
pve, ok := seen["pve"]
if !ok {
t.Fatalf("expected to discover pve server")
}
if pve.Version != "8.1" {
t.Fatalf("expected pve version 8.1, got %q", pve.Version)
}
pbs, ok := seen["pbs"]
if !ok {
t.Fatalf("expected to discover pbs server")
}
if pbs.Version != "2.4.1" {
t.Fatalf("expected pbs version 2.4.1, got %q", pbs.Version)
}
mu.Lock()
callbackCount := len(callbacks)
mu.Unlock()
if callbackCount < 2 {
t.Fatalf("expected callbacks for both servers, got %d", callbackCount)
}
}
func TestDiscoverServersCancelledContext(t *testing.T) {
t.Parallel()
scanner := NewScanner()
ctx, cancel := context.WithCancel(context.Background())
cancel()
result, err := scanner.DiscoverServersWithCallback(ctx, "127.0.0.1/32", nil)
if err == nil {
t.Fatalf("expected context error, got nil")
}
if result == nil {
t.Fatalf("expected result object even on cancellation")
}
if len(result.Servers) != 0 {
t.Fatalf("expected no servers on cancelled context")
}
}