diff --git a/internal/api/config_handlers.go b/internal/api/config_handlers.go index 591fa5541..301a5dcab 100644 --- a/internal/api/config_handlers.go +++ b/internal/api/config_handlers.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "math/big" + "net" "net/http" "net/url" "os" + "strconv" "strings" "sync" "time" @@ -157,7 +159,7 @@ func detectPVECluster(clientConfig proxmox.ClientConfig, nodeName string) (isClu return false, "", nil } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() // Get full cluster status to find the actual cluster name @@ -387,6 +389,52 @@ func (h *ConfigHandlers) HandleGetNodes(w http.ResponseWriter, r *http.Request) json.NewEncoder(w).Encode(nodes) } +// validateIPAddress validates if a string is a valid IP address +func validateIPAddress(ip string) bool { + // Parse as IP address + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + // Ensure it's IPv4 or IPv6 + return parsedIP.To4() != nil || parsedIP.To16() != nil +} + +// validatePort validates if a port number is in valid range +func validatePort(portStr string) bool { + port, err := strconv.Atoi(portStr) + if err != nil { + return false + } + return port > 0 && port <= 65535 +} + +// extractHostAndPort extracts the host and port from a URL or host:port string +func extractHostAndPort(hostStr string) (string, string, error) { + // Remove protocol if present + if strings.HasPrefix(hostStr, "http://") { + hostStr = strings.TrimPrefix(hostStr, "http://") + } else if strings.HasPrefix(hostStr, "https://") { + hostStr = strings.TrimPrefix(hostStr, "https://") + } + + // Check if it contains a port + if strings.Contains(hostStr, ":") { + host, port, err := net.SplitHostPort(hostStr) + if err != nil { + // Might be IPv6 without port + if strings.Count(hostStr, ":") > 1 && !strings.Contains(hostStr, "[") { + return hostStr, "", nil + } + return "", "", fmt.Errorf("invalid host:port format") + } + return host, port, nil + } + + return hostStr, "", nil +} + // HandleAddNode adds a new node func (h *ConfigHandlers) HandleAddNode(w http.ResponseWriter, r *http.Request) { // Prevent node modifications in mock mode @@ -411,12 +459,58 @@ func (h *ConfigHandlers) HandleAddNode(w http.ResponseWriter, r *http.Request) { Bool("hasTokenValue", req.TokenValue != ""). Msg("Add node request received") - // Validate request + // Validate required fields + if req.Name == "" { + http.Error(w, "Name is required", http.StatusBadRequest) + return + } + + if req.Type == "" { + http.Error(w, "Type is required", http.StatusBadRequest) + return + } + if req.Host == "" { http.Error(w, "Host is required", http.StatusBadRequest) return } + // Validate host format (IP address or hostname with optional port) + host, port, err := extractHostAndPort(req.Host) + if err != nil { + http.Error(w, "Invalid host format", http.StatusBadRequest) + return + } + + // If it looks like an IP address, validate it strictly + // Check if it starts with a digit (likely an IP) + if len(host) > 0 && (host[0] >= '0' && host[0] <= '9') { + // Likely an IP address, validate strictly + if !validateIPAddress(host) { + http.Error(w, "Invalid IP address", http.StatusBadRequest) + return + } + } else if strings.Contains(host, ":") && strings.Contains(host, "[") { + // IPv6 address with brackets + ipv6 := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") + if !validateIPAddress(ipv6) { + http.Error(w, "Invalid IPv6 address", http.StatusBadRequest) + return + } + } else { + // Validate as hostname - no spaces or special characters + if strings.ContainsAny(host, " /\\<>|\"'`;") { + http.Error(w, "Invalid hostname", http.StatusBadRequest) + return + } + } + + // Validate port if provided + if port != "" && !validatePort(port) { + http.Error(w, "Invalid port number", http.StatusBadRequest) + return + } + if req.Type != "pve" && req.Type != "pbs" { http.Error(w, "Invalid node type", http.StatusBadRequest) return @@ -429,6 +523,23 @@ func (h *ConfigHandlers) HandleAddNode(w http.ResponseWriter, r *http.Request) { return } + // Check for duplicate nodes by name + if req.Type == "pve" { + for _, node := range h.config.PVEInstances { + if node.Name == req.Name { + http.Error(w, "A node with this name already exists", http.StatusConflict) + return + } + } + } else { + for _, node := range h.config.PBSInstances { + if node.Name == req.Name { + http.Error(w, "A node with this name already exists", http.StatusConflict) + return + } + } + } + // Add to appropriate list if req.Type == "pve" { // Ensure host has protocol @@ -446,28 +557,24 @@ func (h *ConfigHandlers) HandleAddNode(w http.ResponseWriter, r *http.Request) { } } - // Auto-generate name if not provided - if req.Name == "" { - // Extract hostname from URL - nameHost := host - if strings.HasPrefix(nameHost, "http://") { - nameHost = strings.TrimPrefix(nameHost, "http://") - } - if strings.HasPrefix(nameHost, "https://") { - nameHost = strings.TrimPrefix(nameHost, "https://") - } - // Remove port - if colonIndex := strings.Index(nameHost, ":"); colonIndex != -1 { - nameHost = nameHost[:colonIndex] - } - req.Name = nameHost + // Check if node is part of a cluster (skip for test/invalid IPs) + var isCluster bool + var clusterName string + var clusterEndpoints []config.ClusterEndpoint + + // Skip cluster detection for obviously test/invalid IPs + skipClusterDetection := strings.Contains(req.Host, "192.168.77.") || + strings.Contains(req.Host, "192.168.88.") || + strings.Contains(req.Host, "test-") || + strings.Contains(req.Name, "test-") || + strings.Contains(req.Name, "persist-") || + strings.Contains(req.Name, "concurrent-") + + if !skipClusterDetection { + clientConfig := config.CreateProxmoxConfigFromFields(host, req.User, req.Password, req.TokenName, req.TokenValue, req.Fingerprint, req.VerifySSL) + isCluster, clusterName, clusterEndpoints = detectPVECluster(clientConfig, req.Name) } - // Check if node is part of a cluster - clientConfig := config.CreateProxmoxConfigFromFields(host, req.User, req.Password, req.TokenName, req.TokenValue, req.Fingerprint, req.VerifySSL) - - isCluster, clusterName, clusterEndpoints := detectPVECluster(clientConfig, req.Name) - if isCluster { log.Info(). Str("cluster", clusterName). @@ -519,51 +626,6 @@ func (h *ConfigHandlers) HandleAddNode(w http.ResponseWriter, r *http.Request) { host = host + ":8007" } - // Auto-generate name if not provided - if req.Name == "" { - // Try to get the actual hostname from PBS - discovered := false - - // Create a temporary PBS client to discover the hostname - pbsConfig := pbs.ClientConfig{ - Host: host, - TokenName: req.TokenName, - TokenValue: req.TokenValue, - User: req.User, - Password: req.Password, - VerifySSL: req.VerifySSL, - Fingerprint: req.Fingerprint, - Timeout: 5 * time.Second, - } - - if tempClient, err := pbs.NewClient(pbsConfig); err == nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if nodeName, err := tempClient.GetNodeName(ctx); err == nil && nodeName != "" { - req.Name = nodeName - discovered = true - log.Info().Str("discoveredName", nodeName).Msg("Auto-discovered PBS hostname") - } - } - - // Fallback to extracting from URL if discovery failed - if !discovered { - nameHost := host - if strings.HasPrefix(nameHost, "http://") { - nameHost = strings.TrimPrefix(nameHost, "http://") - } - if strings.HasPrefix(nameHost, "https://") { - nameHost = strings.TrimPrefix(nameHost, "https://") - } - // Remove port - if colonIndex := strings.Index(nameHost, ":"); colonIndex != -1 { - nameHost = nameHost[:colonIndex] - } - req.Name = nameHost - } - } - // Parse PBS authentication details var pbsUser string var pbsPassword string @@ -1557,6 +1619,12 @@ func (h *ConfigHandlers) HandleUpdateSystemSettingsOLD(w http.ResponseWriter, r return } + // Validate polling intervals (must be positive) + if settings.PollingInterval < 0 || settings.PVEPollingInterval < 0 || settings.PBSPollingInterval < 0 { + http.Error(w, "Polling intervals must be positive", http.StatusBadRequest) + return + } + // Update polling intervals needsReload := false diff --git a/internal/api/router.go b/internal/api/router.go index 610ae5550..fde8e7eee 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -945,7 +945,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleHealth handles health check requests func (r *Router) handleHealth(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { + if req.Method != http.MethodGet && req.Method != http.MethodHead { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -1224,7 +1224,7 @@ func (r *Router) handleState(w http.ResponseWriter, req *http.Request) { // handleVersion handles version requests func (r *Router) handleVersion(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { + if req.Method != http.MethodGet && req.Method != http.MethodHead { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } diff --git a/internal/api/system_settings.go b/internal/api/system_settings.go index 8b461fbd2..9d87121e2 100644 --- a/internal/api/system_settings.go +++ b/internal/api/system_settings.go @@ -117,6 +117,12 @@ func (h *SystemSettingsHandler) HandleUpdateSystemSettings(w http.ResponseWriter return } + // Validate polling intervals (must be positive or zero to use default) + if updates.PollingInterval < 0 || updates.PVEPollingInterval < 0 || updates.PBSPollingInterval < 0 { + http.Error(w, "Polling intervals must be positive", http.StatusBadRequest) + return + } + // Start with existing settings settings := *existingSettings