Pin discovery HTTP probes to captured TLS peers

This commit is contained in:
rcourtman 2026-03-29 19:32:56 +01:00
parent ba1f9ac9f3
commit 74b78ebd2f
3 changed files with 112 additions and 16 deletions

View file

@ -2,7 +2,10 @@ package discovery
import (
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -142,7 +145,6 @@ func NewScannerWithProfile(profile *envdetect.EnvironmentProfile) *Scanner {
clonedProfile.Policy = policy
transport := &http.Transport{
TLSClientConfig: tlsutil.PeerCertificateCaptureTLSConfig(),
MaxIdleConns: 100,
MaxConnsPerHost: max(policy.MaxConcurrent, 10),
}
@ -159,6 +161,28 @@ func NewScannerWithProfile(profile *envdetect.EnvironmentProfile) *Scanner {
}
}
func fingerprintFromCertificates(certs []*x509.Certificate) string {
if len(certs) == 0 || certs[0] == nil || len(certs[0].Raw) == 0 {
return ""
}
fingerprint := sha256.Sum256(certs[0].Raw)
return hex.EncodeToString(fingerprint[:])
}
func (s *Scanner) httpClientForTLSState(state *tls.ConnectionState) *http.Client {
if state == nil {
return s.httpClient
}
fingerprint := fingerprintFromCertificates(state.PeerCertificates)
if fingerprint == "" {
return s.httpClient
}
return tlsutil.CreateHTTPClientWithTimeout(true, fingerprint, s.policy.HTTPTimeout)
}
// ServerCallback is called when a server is discovered
type ServerCallback func(server DiscoveredServer, phase string)
@ -927,7 +951,9 @@ func (s *Scanner) probeProxmoxService(ctx context.Context, ip string, port int)
return result
}
versionFinding, version, release := s.probeVersionEndpoint(ctx, address)
httpClient := s.httpClientForTLSState(tlsState)
versionFinding, version, release := s.probeVersionEndpoint(ctx, httpClient, address)
result.recordEndpoint(versionFinding)
result.VersionStatus = versionFinding.Status
result.VersionError = versionFinding.Error
@ -935,7 +961,7 @@ func (s *Scanner) probeProxmoxService(ctx context.Context, ip string, port int)
result.Release = release
result.Headers = cloneHeader(versionFinding.Headers)
s.applyProductMatchers(ctx, address, result)
s.applyProductMatchers(ctx, httpClient, address, result)
if strings.TrimSpace(result.Version) == "" {
result.Version = "Unknown"
@ -955,11 +981,11 @@ func (s *Scanner) probeProxmoxService(ctx context.Context, ip string, port int)
return result
}
func (s *Scanner) applyProductMatchers(ctx context.Context, address string, result *ProxmoxProbeResult) {
func (s *Scanner) applyProductMatchers(ctx context.Context, httpClient *http.Client, address string, result *ProxmoxProbeResult) {
applySharedHeuristics(result)
applyPVEHeuristics(result)
s.applyPMGHeuristics(ctx, address, result)
s.applyPBSHeuristics(ctx, address, result)
s.applyPMGHeuristics(ctx, httpClient, address, result)
s.applyPBSHeuristics(ctx, httpClient, address, result)
}
func applySharedHeuristics(result *ProxmoxProbeResult) {
@ -1042,7 +1068,7 @@ func applyPVEHeuristics(result *ProxmoxProbeResult) {
}
}
func (s *Scanner) applyPMGHeuristics(ctx context.Context, address string, result *ProxmoxProbeResult) {
func (s *Scanner) applyPMGHeuristics(ctx context.Context, httpClient *http.Client, address string, result *ProxmoxProbeResult) {
versionFinding, _ := result.endpointFinding("api2/json/version")
hasPMGSignal := false
@ -1081,7 +1107,7 @@ func (s *Scanner) applyPMGHeuristics(ctx context.Context, address string, result
for _, endpoint := range pmgEndpoints {
if _, ok := result.EndpointFindings[endpoint.Path]; !ok {
finding := s.probeAPIEndpoint(ctx, address, endpoint.Path)
finding := s.probeAPIEndpoint(ctx, httpClient, address, endpoint.Path)
result.recordEndpoint(finding)
}
@ -1101,7 +1127,7 @@ func (s *Scanner) applyPMGHeuristics(ctx context.Context, address string, result
}
}
func (s *Scanner) applyPBSHeuristics(ctx context.Context, address string, result *ProxmoxProbeResult) {
func (s *Scanner) applyPBSHeuristics(ctx context.Context, httpClient *http.Client, address string, result *ProxmoxProbeResult) {
if result.Port != 8007 {
return
}
@ -1156,7 +1182,7 @@ func (s *Scanner) applyPBSHeuristics(ctx context.Context, address string, result
for _, endpoint := range pbsEndpoints {
if _, ok := result.EndpointFindings[endpoint.Path]; !ok {
finding := s.probeAPIEndpoint(ctx, address, endpoint.Path)
finding := s.probeAPIEndpoint(ctx, httpClient, address, endpoint.Path)
result.recordEndpoint(finding)
}
@ -1193,12 +1219,19 @@ func defaultProductsForPort(port int) []string {
func (s *Scanner) fetchNodeHostname(ctx context.Context, ip string, port int) string {
address := net.JoinHostPort(ip, strconv.Itoa(port))
tlsState, reachable, _ := s.performTLSProbe(ctx, address)
if !reachable {
return ""
}
httpClient := s.httpClientForTLSState(tlsState)
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://%s/api2/json/nodes", address), nil)
if err != nil {
return ""
}
resp, err := s.httpClient.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
return ""
}
@ -1247,7 +1280,7 @@ func (s *Scanner) performTLSProbe(ctx context.Context, address string) (*tls.Con
return nil, true, err
}
func (s *Scanner) probeVersionEndpoint(ctx context.Context, address string) (EndpointProbeFinding, string, string) {
func (s *Scanner) probeVersionEndpoint(ctx context.Context, httpClient *http.Client, address string) (EndpointProbeFinding, string, string) {
const endpoint = "api2/json/version"
finding := EndpointProbeFinding{Endpoint: endpoint}
@ -1257,7 +1290,7 @@ func (s *Scanner) probeVersionEndpoint(ctx context.Context, address string) (End
return finding, "", ""
}
resp, err := s.httpClient.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
finding.Error = err
return finding, "", ""
@ -1292,7 +1325,7 @@ func (s *Scanner) probeVersionEndpoint(ctx context.Context, address string) (End
return finding, payload.Data.Version, payload.Data.Release
}
func (s *Scanner) probeAPIEndpoint(ctx context.Context, address, endpoint string) EndpointProbeFinding {
func (s *Scanner) probeAPIEndpoint(ctx context.Context, httpClient *http.Client, address, endpoint string) EndpointProbeFinding {
finding := EndpointProbeFinding{Endpoint: endpoint}
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://%s/%s", address, endpoint), nil)
@ -1301,7 +1334,7 @@ func (s *Scanner) probeAPIEndpoint(ctx context.Context, address, endpoint string
return finding
}
resp, err := s.httpClient.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
finding.Error = err
return finding

View file

@ -309,6 +309,69 @@ func TestNewScannerWithProfileAcceptsSelfSignedProxmoxProbe(t *testing.T) {
}
}
func TestNewScannerWithProfile_UsesSecureSharedHTTPClient(t *testing.T) {
t.Parallel()
profile := &envdetect.EnvironmentProfile{
Policy: envdetect.DefaultScanPolicy(),
Metadata: map[string]string{},
}
scanner := NewScannerWithProfile(profile)
transport, ok := scanner.httpClient.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected *http.Transport, got %T", scanner.httpClient.Transport)
}
if transport.TLSClientConfig != nil {
t.Fatal("expected shared scanner client to use default secure TLS verification")
}
}
func TestScannerHTTPClientForTLSState_UsesFingerprintPinning(t *testing.T) {
t.Parallel()
ts := httptest.NewTLSServer(http.NotFoundHandler())
defer ts.Close()
leaf, err := x509.ParseCertificate(ts.TLS.Certificates[0].Certificate[0])
if err != nil {
t.Fatalf("ParseCertificate: %v", err)
}
scanner := NewScannerWithProfile(&envdetect.EnvironmentProfile{
Policy: envdetect.DefaultScanPolicy(),
Metadata: map[string]string{},
})
if got := scanner.httpClientForTLSState(nil); got != scanner.httpClient {
t.Fatal("expected nil TLS state to reuse shared scanner client")
}
client := scanner.httpClientForTLSState(&tls.ConnectionState{
PeerCertificates: []*x509.Certificate{leaf},
})
if client == scanner.httpClient {
t.Fatal("expected pinned TLS state to allocate a dedicated client")
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected *http.Transport, got %T", client.Transport)
}
if transport.TLSClientConfig == nil || transport.TLSClientConfig.VerifyPeerCertificate == nil {
t.Fatal("expected pinned client to install fingerprint verification")
}
if err := transport.TLSClientConfig.VerifyPeerCertificate([][]byte{leaf.Raw}, nil); err != nil {
t.Fatalf("expected fingerprint verifier to accept captured leaf certificate: %v", err)
}
mismatchedRaw := append([]byte(nil), leaf.Raw...)
mismatchedRaw[len(mismatchedRaw)-1] ^= 0xff
if err := transport.TLSClientConfig.VerifyPeerCertificate([][]byte{mismatchedRaw}, nil); err == nil {
t.Fatal("expected fingerprint verifier to reject a different certificate")
}
}
func startTLSServerOn(t *testing.T, addr string, handler http.Handler) *httptest.Server {
t.Helper()

View file

@ -13,5 +13,5 @@ func (s *Scanner) ProbeProxmoxService(ctx context.Context, ip string, port int)
}
func (s *Scanner) ProbeAPIEndpoint(ctx context.Context, address, endpoint string) EndpointProbeFinding {
return s.probeAPIEndpoint(ctx, address, endpoint)
return s.probeAPIEndpoint(ctx, s.httpClient, address, endpoint)
}