Pulse/cmd/pulse-sensor-proxy/main.go
rcourtman 524f42cc28 security: complete Phase 1 sensor proxy hardening
Implements comprehensive security hardening for pulse-sensor-proxy:
- Privilege drop from root to unprivileged user (UID 995)
- Hash-chained tamper-evident audit logging with remote forwarding
- Per-UID rate limiting (0.2 QPS, burst 2) with concurrency caps
- Enhanced command validation with 10+ attack pattern tests
- Fuzz testing (7M+ executions, 0 crashes)
- SSH hardening, AppArmor/seccomp profiles, operational runbooks

All 27 Phase 1 tasks complete. Ready for production deployment.
2025-10-20 15:13:37 +00:00

1087 lines
28 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/ssh/knownhosts"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"golang.org/x/sys/unix"
)
// Version information (set at build time with -ldflags)
var (
Version = "dev"
BuildTime = "unknown"
GitCommit = "unknown"
)
const (
defaultSocketPath = "/run/pulse-sensor-proxy/pulse-sensor-proxy.sock"
defaultSSHKeyPath = "/var/lib/pulse-sensor-proxy/ssh"
defaultConfigPath = "/etc/pulse-sensor-proxy/config.yaml"
defaultAuditLogPath = "/var/log/pulse/sensor-proxy/audit.log"
maxRequestBytes = 16 * 1024 // 16 KiB max request size
defaultRunAsUser = "pulse-sensor"
)
func defaultWorkDir() string {
return "/var/lib/pulse-sensor-proxy"
}
var (
configPath string
)
var rootCmd = &cobra.Command{
Use: "pulse-sensor-proxy",
Short: "Pulse Sensor Proxy - Secure sensor data bridge for containerized Pulse",
Long: `Sensor monitoring proxy that keeps SSH keys on the host and exposes sensor data via unix socket`,
Version: Version,
Run: func(cmd *cobra.Command, args []string) {
runProxy()
},
}
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("pulse-sensor-proxy %s\n", Version)
if BuildTime != "unknown" {
fmt.Printf("Built: %s\n", BuildTime)
}
if GitCommit != "unknown" {
fmt.Printf("Commit: %s\n", GitCommit)
}
},
}
func init() {
rootCmd.AddCommand(versionCmd)
rootCmd.PersistentFlags().StringVar(&configPath, "config", "", "Path to configuration file (default: /etc/pulse-sensor-proxy/config.yaml)")
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
type userSpec struct {
name string
uid int
gid int
groups []int
home string
}
func dropPrivileges(username string) (*userSpec, error) {
if username == "" {
return nil, nil
}
if os.Geteuid() != 0 {
return nil, nil
}
spec, err := resolveUserSpec(username)
if err != nil {
return nil, err
}
if len(spec.groups) == 0 {
spec.groups = []int{spec.gid}
}
if err := unix.Setgroups(spec.groups); err != nil {
return nil, fmt.Errorf("setgroups: %w", err)
}
if err := unix.Setgid(spec.gid); err != nil {
return nil, fmt.Errorf("setgid: %w", err)
}
if err := unix.Setuid(spec.uid); err != nil {
return nil, fmt.Errorf("setuid: %w", err)
}
if spec.home != "" {
_ = os.Setenv("HOME", spec.home)
}
if spec.name != "" {
_ = os.Setenv("USER", spec.name)
_ = os.Setenv("LOGNAME", spec.name)
}
return spec, nil
}
func resolveUserSpec(username string) (*userSpec, error) {
u, err := user.Lookup(username)
if err == nil {
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return nil, fmt.Errorf("parse uid %q: %w", u.Uid, err)
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return nil, fmt.Errorf("parse gid %q: %w", u.Gid, err)
}
var groups []int
if gids, err := u.GroupIds(); err == nil {
for _, g := range gids {
if gidVal, convErr := strconv.Atoi(g); convErr == nil {
groups = append(groups, gidVal)
}
}
}
if len(groups) == 0 {
groups = []int{gid}
}
return &userSpec{
name: u.Username,
uid: uid,
gid: gid,
groups: groups,
home: u.HomeDir,
}, nil
}
fallbackSpec, fallbackErr := lookupUserFromPasswd(username)
if fallbackErr == nil {
return fallbackSpec, nil
}
return nil, fmt.Errorf("lookup user %q failed: %v (fallback: %w)", username, err, fallbackErr)
}
func lookupUserFromPasswd(username string) (*userSpec, error) {
f, err := os.Open("/etc/passwd")
if err != nil {
return nil, fmt.Errorf("open /etc/passwd: %w", err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
fields := strings.Split(line, ":")
if len(fields) < 7 {
continue
}
if fields[0] != username {
continue
}
uid, err := strconv.Atoi(fields[2])
if err != nil {
return nil, fmt.Errorf("parse uid %q: %w", fields[2], err)
}
gid, err := strconv.Atoi(fields[3])
if err != nil {
return nil, fmt.Errorf("parse gid %q: %w", fields[3], err)
}
return &userSpec{
name: fields[0],
uid: uid,
gid: gid,
groups: []int{gid},
home: fields[5],
}, nil
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("scan /etc/passwd: %w", err)
}
return nil, fmt.Errorf("user %q not found in /etc/passwd", username)
}
// Proxy manages the temperature monitoring proxy
type Proxy struct {
socketPath string
sshKeyPath string
workDir string
knownHosts knownhosts.Manager
listener net.Listener
rateLimiter *rateLimiter
nodeGate *nodeGate
router map[string]handlerFunc
config *Config
metrics *ProxyMetrics
audit *auditLogger
allowedPeerUIDs map[uint32]struct{}
allowedPeerGIDs map[uint32]struct{}
idMappedUIDRanges []idRange
idMappedGIDRanges []idRange
}
// RPC request types
const (
RPCEnsureClusterKeys = "ensure_cluster_keys"
RPCRegisterNodes = "register_nodes"
RPCGetTemperature = "get_temperature"
RPCGetStatus = "get_status"
RPCRequestCleanup = "request_cleanup"
)
// Privileged RPC methods that require host-level access (not accessible from containers)
var privilegedMethods = map[string]bool{
RPCEnsureClusterKeys: true, // SSH key distribution
RPCRegisterNodes: true, // Node registration
RPCRequestCleanup: true, // Cleanup operations
}
// RPCRequest represents a request from Pulse
type RPCRequest struct {
CorrelationID string `json:"correlation_id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params"`
}
// RPCResponse represents a response to Pulse
type RPCResponse struct {
CorrelationID string `json:"correlation_id,omitempty"`
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// handlerFunc is the signature for RPC method handlers
type handlerFunc func(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error)
func runProxy() {
// Initialize logger
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
socketPath := os.Getenv("PULSE_SENSOR_PROXY_SOCKET")
if socketPath == "" {
socketPath = defaultSocketPath
}
sshKeyPath := os.Getenv("PULSE_SENSOR_PROXY_SSH_DIR")
if sshKeyPath == "" {
sshKeyPath = defaultSSHKeyPath
}
// Load configuration
// Priority: --config flag > PULSE_SENSOR_PROXY_CONFIG env > default path
cfgPath := configPath // from flag
if cfgPath == "" {
cfgPath = os.Getenv("PULSE_SENSOR_PROXY_CONFIG")
}
if cfgPath == "" {
cfgPath = defaultConfigPath
}
cfg, err := loadConfig(cfgPath)
if err != nil {
log.Fatal().Err(err).Msg("Failed to load configuration")
}
runAsUser := os.Getenv("PULSE_SENSOR_PROXY_USER")
if runAsUser == "" {
runAsUser = defaultRunAsUser
}
if spec, err := dropPrivileges(runAsUser); err != nil {
log.Fatal().Err(err).Str("user", runAsUser).Msg("Failed to drop privileges")
} else if spec != nil {
log.Info().
Str("user", spec.name).
Int("uid", spec.uid).
Int("gid", spec.gid).
Msg("Running as unprivileged user")
}
auditPath := os.Getenv("PULSE_SENSOR_PROXY_AUDIT_LOG")
if auditPath == "" {
auditPath = defaultAuditLogPath
}
auditLogger, err := newAuditLogger(auditPath)
if err != nil {
log.Fatal().Err(err).Str("path", auditPath).Msg("Failed to initialize audit logger")
}
defer auditLogger.Close()
// Initialize metrics
metrics := NewProxyMetrics(Version)
log.Info().
Str("socket", socketPath).
Str("ssh_key_dir", sshKeyPath).
Str("config_path", cfgPath).
Str("audit_log", auditPath).
Str("version", Version).
Msg("Starting pulse-sensor-proxy")
knownHostsManager, err := knownhosts.NewManager(filepath.Join(sshKeyPath, "known_hosts"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to initialize known hosts manager")
}
proxy := &Proxy{
socketPath: socketPath,
sshKeyPath: sshKeyPath,
knownHosts: knownHostsManager,
rateLimiter: newRateLimiter(metrics),
nodeGate: newNodeGate(),
config: cfg,
metrics: metrics,
audit: auditLogger,
}
if wd, err := os.Getwd(); err == nil {
proxy.workDir = wd
} else {
log.Warn().Err(err).Msg("Failed to determine working directory; using default")
proxy.workDir = defaultWorkDir()
}
// Register RPC method handlers
proxy.router = map[string]handlerFunc{
RPCGetStatus: proxy.handleGetStatusV2,
RPCEnsureClusterKeys: proxy.handleEnsureClusterKeysV2,
RPCRegisterNodes: proxy.handleRegisterNodesV2,
RPCGetTemperature: proxy.handleGetTemperatureV2,
RPCRequestCleanup: proxy.handleRequestCleanup,
}
if err := proxy.initAuthRules(); err != nil {
log.Fatal().Err(err).Msg("Failed to initialize authentication rules")
}
if err := proxy.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start proxy")
}
// Start metrics server
if err := metrics.Start(cfg.MetricsAddress); err != nil {
log.Fatal().Err(err).Msg("Failed to start metrics server")
}
// Setup signal handlers
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Info().Msg("Shutting down proxy...")
proxy.Stop()
proxy.rateLimiter.shutdown()
metrics.Shutdown(context.Background())
log.Info().Msg("Proxy stopped")
}
// Start initializes and starts the proxy
func (p *Proxy) Start() error {
// Create SSH key directory if it doesn't exist
if err := os.MkdirAll(p.sshKeyPath, 0700); err != nil {
return fmt.Errorf("failed to create SSH key directory: %w", err)
}
// Ensure SSH keypair exists
if err := p.ensureSSHKeypair(); err != nil {
return fmt.Errorf("failed to ensure SSH keypair: %w", err)
}
// Remove existing socket if it exists
if err := os.RemoveAll(p.socketPath); err != nil {
return fmt.Errorf("failed to remove existing socket: %w", err)
}
// Create socket directory if needed
socketDir := filepath.Dir(p.socketPath)
if err := os.MkdirAll(socketDir, 0755); err != nil {
return fmt.Errorf("failed to create socket directory: %w", err)
}
// Create unix socket listener
listener, err := net.Listen("unix", p.socketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket: %w", err)
}
p.listener = listener
// Set liberal socket permissions; SO_PEERCRED enforces auth
if err := os.Chmod(p.socketPath, 0666); err != nil {
log.Warn().Err(err).Msg("Failed to set socket permissions")
}
log.Info().Str("socket", p.socketPath).Msg("Unix socket ready")
// Start accepting connections
go p.acceptConnections()
return nil
}
// Stop shuts down the proxy
func (p *Proxy) Stop() {
if p.listener != nil {
p.listener.Close()
os.Remove(p.socketPath)
}
}
// acceptConnections handles incoming socket connections
func (p *Proxy) acceptConnections() {
for {
conn, err := p.listener.Accept()
if err != nil {
// Check if listener was closed
if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" {
return
}
log.Error().Err(err).Msg("Failed to accept connection")
continue
}
go p.handleConnection(conn)
}
}
// handleConnection processes a single RPC request with full validation and throttling
func (p *Proxy) handleConnection(conn net.Conn) {
defer conn.Close()
remoteAddr := conn.RemoteAddr().String()
// Track concurrent requests
p.metrics.queueDepth.Inc()
defer p.metrics.queueDepth.Dec()
// Start timing for latency metrics
startTime := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Skip read deadline - it interferes with write operations on unix sockets
// Context timeout provides sufficient protection against hung connections
// Extract and verify peer credentials
cred, err := extractPeerCredentials(conn)
if err != nil {
log.Warn().Err(err).Msg("Peer credentials unavailable")
if p.audit != nil {
p.audit.LogConnectionDenied("", nil, remoteAddr, "peer_credentials_unavailable")
}
p.sendErrorV2(conn, "unauthorized", "")
return
}
if err := p.authorizePeer(cred); err != nil {
log.Warn().
Err(err).
Uint32("uid", cred.uid).
Uint32("gid", cred.gid).
Msg("Peer authorization failed")
if p.audit != nil {
p.audit.LogConnectionDenied("", cred, remoteAddr, err.Error())
}
p.sendErrorV2(conn, "unauthorized", "")
return
}
if p.audit != nil {
p.audit.LogConnectionAccepted("", cred, remoteAddr)
}
// Check rate limit and concurrency
peer := peerID{uid: cred.uid}
releaseLimiter, limitReason, allowed := p.rateLimiter.allow(peer)
if !allowed {
log.Warn().
Uint32("uid", cred.uid).
Uint32("pid", cred.pid).
Str("reason", limitReason).
Msg("Rate limit exceeded")
if p.audit != nil {
p.audit.LogRateLimitHit("", cred, remoteAddr, limitReason)
}
p.sendErrorV2(conn, "rate limit exceeded", "")
return
}
releaseFn := releaseLimiter
defer func() {
if releaseFn != nil {
releaseFn()
}
}()
applyPenalty := func(reason string) {
if releaseFn != nil {
releaseFn()
releaseFn = nil
}
p.rateLimiter.penalize(peer, reason)
}
// Read request using newline-delimited framing
limited := &io.LimitedReader{R: conn, N: maxRequestBytes}
reader := bufio.NewReader(limited)
line, err := reader.ReadBytes('\n')
if err != nil {
if errors.Is(err, bufio.ErrBufferFull) || limited.N <= 0 {
if p.audit != nil {
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "payload_too_large")
}
p.sendErrorV2(conn, "payload too large", "")
applyPenalty("payload_too_large")
return
}
if errors.Is(err, io.EOF) {
if p.audit != nil {
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "empty_request")
}
p.sendErrorV2(conn, "empty request", "")
applyPenalty("empty_request")
return
}
if p.audit != nil {
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "read_error")
}
p.sendErrorV2(conn, "failed to read request", "")
applyPenalty("read_error")
return
}
// Trim whitespace and validate
line = bytes.TrimSpace(line)
if len(line) == 0 {
if p.audit != nil {
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "empty_request")
}
p.sendErrorV2(conn, "empty request", "")
applyPenalty("empty_request")
return
}
// Parse JSON
var req RPCRequest
if err := json.Unmarshal(line, &req); err != nil {
if p.audit != nil {
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "invalid_json")
}
p.sendErrorV2(conn, "invalid request format", "")
applyPenalty("invalid_json")
return
}
// Sanitize correlation ID
req.CorrelationID = sanitizeCorrelationID(req.CorrelationID)
// Create contextual logger
logger := log.With().
Str("corr_id", req.CorrelationID).
Uint32("uid", cred.uid).
Uint32("pid", cred.pid).
Str("method", req.Method).
Logger()
// Prepare response
resp := RPCResponse{
CorrelationID: req.CorrelationID,
Success: false,
}
// Find handler
handler := p.router[req.Method]
if handler == nil {
if p.audit != nil {
p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "unknown_method")
}
resp.Error = "unknown method"
logger.Warn().Msg("Unknown method")
p.sendResponse(conn, resp)
applyPenalty("unknown_method")
return
}
// Check if method requires host-level privileges
if privilegedMethods[req.Method] {
// Privileged methods can only be called from host (not from containers)
if p.isIDMappedRoot(cred) {
resp.Error = "method requires host-level privileges"
log.Warn().
Str("method", req.Method).
Uint32("uid", cred.uid).
Uint32("gid", cred.gid).
Uint32("pid", cred.pid).
Str("corr_id", req.CorrelationID).
Msg("SECURITY: Container attempted to call privileged method - access denied")
if p.audit != nil {
p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "privileged_method_denied")
}
p.sendResponse(conn, resp)
p.metrics.rpcRequests.WithLabelValues(req.Method, "unauthorized").Inc()
applyPenalty("privileged_method_denied")
return
}
}
if p.audit != nil {
p.audit.LogCommandStart(req.CorrelationID, cred, remoteAddr, "", req.Method, nil)
}
// Execute handler
result, err := handler(ctx, &req, logger)
duration := time.Since(startTime)
if err != nil {
if p.audit != nil {
p.audit.LogCommandResult(req.CorrelationID, cred, remoteAddr, "", req.Method, nil, 1, duration, "", "", err)
}
resp.Error = err.Error()
logger.Warn().Err(err).Msg("Handler failed")
// Clear read deadline and set write deadline for error response
conn.SetReadDeadline(time.Time{})
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
p.sendResponse(conn, resp)
// Record failed request
p.metrics.rpcRequests.WithLabelValues(req.Method, "error").Inc()
p.metrics.rpcLatency.WithLabelValues(req.Method).Observe(time.Since(startTime).Seconds())
return
}
// Success
resp.Success = true
resp.Data = result
if p.audit != nil {
p.audit.LogCommandResult(req.CorrelationID, cred, remoteAddr, "", req.Method, nil, 0, duration, "", "", nil)
}
logger.Info().Msg("Request completed")
// Clear read deadline and set write deadline for response
conn.SetReadDeadline(time.Time{})
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
p.sendResponse(conn, resp)
// Record successful request
p.metrics.rpcRequests.WithLabelValues(req.Method, "success").Inc()
p.metrics.rpcLatency.WithLabelValues(req.Method).Observe(time.Since(startTime).Seconds())
}
// sendError sends an error response (legacy function)
func (p *Proxy) sendError(conn net.Conn, message string) {
resp := RPCResponse{
Success: false,
Error: message,
}
encoder := json.NewEncoder(conn)
encoder.Encode(resp)
}
// sendErrorV2 sends an error response with correlation ID
func (p *Proxy) sendErrorV2(conn net.Conn, message, correlationID string) {
resp := RPCResponse{
CorrelationID: correlationID,
Success: false,
Error: message,
}
// Clear read deadline before writing
conn.SetReadDeadline(time.Time{})
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
encoder := json.NewEncoder(conn)
encoder.Encode(resp)
}
// sendResponse sends an RPC response
func (p *Proxy) sendResponse(conn net.Conn, resp RPCResponse) {
// Clear read deadline before writing
if err := conn.SetReadDeadline(time.Time{}); err != nil {
log.Warn().Err(err).Msg("Failed to clear read deadline")
}
if err := conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
log.Warn().Err(err).Msg("Failed to set write deadline")
}
encoder := json.NewEncoder(conn)
if err := encoder.Encode(resp); err != nil {
log.Error().Err(err).Msg("Failed to encode RPC response")
}
}
// handleGetStatus returns proxy status
func (p *Proxy) handleGetStatus(req RPCRequest) RPCResponse {
pubKeyPath := filepath.Join(p.sshKeyPath, "id_ed25519.pub")
pubKey, err := os.ReadFile(pubKeyPath)
if err != nil {
return RPCResponse{
Success: false,
Error: fmt.Sprintf("failed to read public key: %v", err),
}
}
return RPCResponse{
Success: true,
Data: map[string]interface{}{
"version": Version,
"public_key": string(pubKey),
"ssh_dir": p.sshKeyPath,
},
}
}
// ensureSSHKeypair generates SSH keypair if it doesn't exist
func (p *Proxy) ensureSSHKeypair() error {
privKeyPath := filepath.Join(p.sshKeyPath, "id_ed25519")
pubKeyPath := filepath.Join(p.sshKeyPath, "id_ed25519.pub")
// Check if keypair already exists
if _, err := os.Stat(privKeyPath); err == nil {
if _, err := os.Stat(pubKeyPath); err == nil {
log.Info().Msg("SSH keypair already exists")
return nil
}
}
log.Info().Msg("Generating new SSH keypair")
// Generate ed25519 keypair using ssh-keygen
cmd := fmt.Sprintf("ssh-keygen -t ed25519 -f %s -N '' -C 'pulse-sensor-proxy'", privKeyPath)
if output, err := execCommand(cmd); err != nil {
return fmt.Errorf("failed to generate SSH keypair: %w (output: %s)", err, output)
}
log.Info().Str("path", privKeyPath).Msg("SSH keypair generated")
return nil
}
// handleEnsureClusterKeys discovers cluster nodes and pushes SSH keys
func (p *Proxy) handleEnsureClusterKeys(req RPCRequest) RPCResponse {
// Check if we're on a Proxmox host
if !isProxmoxHost() {
return RPCResponse{
Success: false,
Error: "not running on Proxmox host - cannot discover cluster",
}
}
// Discover cluster nodes
nodes, err := discoverClusterNodes()
if err != nil {
return RPCResponse{
Success: false,
Error: fmt.Sprintf("failed to discover cluster: %v", err),
}
}
log.Info().Strs("nodes", nodes).Msg("Discovered cluster nodes")
// Push SSH key to each node
results := make(map[string]interface{})
successCount := 0
for _, node := range nodes {
log.Info().Str("node", node).Msg("Pushing SSH key to node")
if err := p.pushSSHKey(node); err != nil {
log.Error().Err(err).Str("node", node).Msg("Failed to push SSH key")
results[node] = map[string]interface{}{
"success": false,
"error": err.Error(),
}
} else {
log.Info().Str("node", node).Msg("SSH key pushed successfully")
results[node] = map[string]interface{}{
"success": true,
}
successCount++
}
}
return RPCResponse{
Success: true,
Data: map[string]interface{}{
"nodes": nodes,
"results": results,
"success_count": successCount,
"total_count": len(nodes),
},
}
}
// handleRegisterNodes returns discovered nodes
func (p *Proxy) handleRegisterNodes(req RPCRequest) RPCResponse {
// Check if we're on a Proxmox host
if !isProxmoxHost() {
return RPCResponse{
Success: false,
Error: "not running on Proxmox host",
}
}
// Discover cluster nodes
nodes, err := discoverClusterNodes()
if err != nil {
return RPCResponse{
Success: false,
Error: fmt.Sprintf("failed to discover nodes: %v", err),
}
}
// Test SSH connectivity to each node
nodeStatus := make([]map[string]interface{}, 0, len(nodes))
for _, node := range nodes {
// Validate node name to prevent SSH command injection
node = strings.TrimSpace(node)
if err := validateNodeName(node); err != nil {
log.Warn().Str("node", node).Msg("Invalid node name format from cluster discovery")
continue
}
status := map[string]interface{}{
"name": node,
}
if err := p.testSSHConnection(node); err != nil {
status["ssh_ready"] = false
status["error"] = err.Error()
} else {
status["ssh_ready"] = true
}
nodeStatus = append(nodeStatus, status)
}
return RPCResponse{
Success: true,
Data: map[string]interface{}{
"nodes": nodeStatus,
},
}
}
// handleGetTemperature fetches temperature data from a node via SSH
func (p *Proxy) handleGetTemperature(req RPCRequest) RPCResponse {
// Extract node parameter
nodeParam, ok := req.Params["node"]
if !ok {
return RPCResponse{
Success: false,
Error: "missing 'node' parameter",
}
}
node, ok := nodeParam.(string)
if !ok {
return RPCResponse{
Success: false,
Error: "'node' parameter must be a string",
}
}
// Validate node name to prevent SSH command injection
node = strings.TrimSpace(node)
if err := validateNodeName(node); err != nil {
return RPCResponse{
Success: false,
Error: "invalid node name format",
}
}
// Fetch temperature data
tempData, err := p.getTemperatureViaSSH(node)
if err != nil {
return RPCResponse{
Success: false,
Error: fmt.Sprintf("failed to get temperatures: %v", err),
}
}
return RPCResponse{
Success: true,
Data: map[string]interface{}{
"node": node,
"temperature": tempData,
},
}
}
// New V2 handlers with context and structured logging
// handleGetStatusV2 returns proxy status with context support
func (p *Proxy) handleGetStatusV2(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) {
pubKeyPath := filepath.Join(p.sshKeyPath, "id_ed25519.pub")
pubKey, err := os.ReadFile(pubKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read public key: %w", err)
}
logger.Info().Msg("Status request served")
return map[string]interface{}{
"version": Version,
"public_key": string(pubKey),
"ssh_dir": p.sshKeyPath,
}, nil
}
// handleEnsureClusterKeysV2 discovers cluster nodes and pushes SSH keys with validation
func (p *Proxy) handleEnsureClusterKeysV2(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) {
// Check if we're on a Proxmox host
if !isProxmoxHost() {
return nil, fmt.Errorf("not running on Proxmox host - cannot discover cluster")
}
// Check for optional key_dir parameter (for key rotation)
keyDir := p.sshKeyPath // default
if keyDirParam, ok := req.Params["key_dir"]; ok {
if keyDirStr, ok := keyDirParam.(string); ok && keyDirStr != "" {
keyDir = keyDirStr
logger.Info().Str("key_dir", keyDir).Msg("Using custom key directory for rotation")
}
}
// Discover cluster nodes
nodes, err := discoverClusterNodes()
if err != nil {
return nil, fmt.Errorf("failed to discover cluster: %w", err)
}
logger.Info().Strs("nodes", nodes).Msg("Discovered cluster nodes")
// Push SSH key to each node
results := make(map[string]interface{})
successCount := 0
for _, node := range nodes {
// Validate node name
if err := validateNodeName(node); err != nil {
logger.Warn().Str("node", node).Msg("Invalid node name format")
results[node] = map[string]interface{}{
"success": false,
"error": "invalid node name",
}
continue
}
logger.Info().Str("node", node).Str("key_dir", keyDir).Msg("Pushing SSH key to node")
if err := p.pushSSHKeyFrom(node, keyDir); err != nil {
logger.Error().Err(err).Str("node", node).Msg("Failed to push SSH key")
results[node] = map[string]interface{}{
"success": false,
"error": err.Error(),
}
} else {
logger.Info().Str("node", node).Msg("SSH key pushed successfully")
results[node] = map[string]interface{}{
"success": true,
}
successCount++
}
}
return map[string]interface{}{
"nodes": nodes,
"results": results,
"success_count": successCount,
"total_count": len(nodes),
}, nil
}
// handleRegisterNodesV2 returns discovered nodes with validation
func (p *Proxy) handleRegisterNodesV2(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) {
// Check if we're on a Proxmox host
if !isProxmoxHost() {
return nil, fmt.Errorf("not running on Proxmox host")
}
// Discover cluster nodes
nodes, err := discoverClusterNodes()
if err != nil {
return nil, fmt.Errorf("failed to discover nodes: %w", err)
}
// Test SSH connectivity to each node
nodeStatus := make([]map[string]interface{}, 0, len(nodes))
for _, node := range nodes {
status := map[string]interface{}{
"name": node,
}
// Validate node name
if err := validateNodeName(node); err != nil {
status["ssh_ready"] = false
status["error"] = "invalid node name"
nodeStatus = append(nodeStatus, status)
continue
}
if err := p.testSSHConnection(node); err != nil {
status["ssh_ready"] = false
status["error"] = err.Error()
} else {
status["ssh_ready"] = true
}
nodeStatus = append(nodeStatus, status)
}
logger.Info().Int("node_count", len(nodeStatus)).Msg("Node discovery completed")
return map[string]interface{}{
"nodes": nodeStatus,
}, nil
}
// handleGetTemperatureV2 fetches temperature data with concurrency control and validation
func (p *Proxy) handleGetTemperatureV2(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) {
// Extract node parameter
nodeParam, ok := req.Params["node"]
if !ok {
return nil, fmt.Errorf("missing 'node' parameter")
}
node, ok := nodeParam.(string)
if !ok {
return nil, fmt.Errorf("'node' parameter must be a string")
}
// Trim and validate node name
node = strings.TrimSpace(node)
if err := validateNodeName(node); err != nil {
logger.Warn().Str("node", node).Msg("Invalid node name format")
return nil, fmt.Errorf("invalid node name")
}
// Acquire per-node concurrency lock (prevents multiple simultaneous requests to same node)
releaseNode := p.nodeGate.acquire(node)
defer releaseNode()
logger.Debug().Str("node", node).Msg("Fetching temperature via SSH")
// Fetch temperature data
tempData, err := p.getTemperatureViaSSH(node)
if err != nil {
logger.Warn().Err(err).Str("node", node).Msg("Failed to get temperatures")
return nil, fmt.Errorf("failed to get temperatures: %w", err)
}
logger.Info().Str("node", node).Msg("Temperature data fetched successfully")
return map[string]interface{}{
"node": node,
"temperature": tempData,
}, nil
}