harden HTTP usage, caching, and transactional deletes

add safety across services: whitelist tables for DeleteInTransaction to prevent dynamic-SQL deletion abuse; remove a deprecated UpdateInTransaction helper. centralize HTTP client creation (HTTPClientWithTimeout/GetHTTPClient) and replace ad-hoc http.Clients with it; limit response body reads with io.LimitReader to avoid unbounded memory use; add better error logging when JSON marshal fails. Improve ConfigProxy cache handling to fetch outside locks, return stale cache on fetch errors, and only lock to swap the cache; add locking to SetCacheDuration and cap error body reads. convert ResourceWatcher isRunning to atomic.Bool for safe concurrent start/stop. Replace sync.Map memoization in id_normalizer with a bounded map protected by RWMutex (maxCacheSize), add cache flush behavior and tests/benchmarks to validate boundedness and hits. Miscellaneous test updates to match new behavior.
This commit is contained in:
hhftechnologies 2026-03-01 14:22:02 +05:30
parent 44a80ab152
commit 90a75b5a93
9 changed files with 153 additions and 184 deletions

View file

@ -101,8 +101,24 @@ func ExecuteInTransactionWithResult(c *gin.Context, db *sql.DB, operation string
c.JSON(http.StatusOK, response.Data)
}
// allowedTables is the whitelist of table names that can be used in dynamic SQL.
var allowedTables = map[string]bool{
"resources": true,
"services": true,
"middlewares": true,
"resource_services": true,
"mtls_clients": true,
"mtls_config": true,
}
// DeleteInTransaction is a specialized helper for delete operations
func DeleteInTransaction(c *gin.Context, db *sql.DB, table string, id string, additionalDeletes ...func(*sql.Tx) error) {
if !allowedTables[table] {
log.Printf("Error: attempted delete from disallowed table %q", table)
ResponseWithError(c, http.StatusBadRequest, "Invalid table name")
return
}
err := WithTransaction(db, func(tx *sql.Tx) error {
// Execute any additional deletes first (e.g., related records)
for _, deleteFn := range additionalDeletes {

View file

@ -82,43 +82,3 @@ func (db *DB) BatchTransaction(operations []TxFn) error {
})
}
// UpdateInTransaction updates a record in a transaction
func (db *DB) UpdateInTransaction(table string, id string, updates map[string]interface{}) error {
return db.WithTransaction(func(tx *sql.Tx) error {
// Build the update statement
query := fmt.Sprintf("UPDATE %s SET ", table)
var params []interface{}
i := 0
for field, value := range updates {
if i > 0 {
query += ", "
}
query += field + " = ?"
params = append(params, value)
i++
}
// Add the WHERE clause and updated_at
query += ", updated_at = ? WHERE id = ?"
params = append(params, time.Now(), id)
// Execute the update
result, err := tx.Exec(query, params...)
if err != nil {
return fmt.Errorf("update failed: %w", err)
}
// Check if any rows were affected
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no rows affected, record with ID %s not found", id)
}
return nil
})
}

View file

@ -628,7 +628,7 @@ func (cg *ConfigGenerator) processMTLSOptions(config *TraefikConfig) error {
// Helper to fetch service names from Traefik API
func (cg *ConfigGenerator) fetchTraefikServiceNames() map[string]string {
serviceMap := make(map[string]string)
client := &http.Client{Timeout: 5 * time.Second}
client := HTTPClientWithTimeout(5 * time.Second)
// Get Traefik API URL from data source config
dsConfig, err := cg.configManager.GetActiveDataSourceConfig()

View file

@ -132,7 +132,7 @@ func (cm *ConfigManager) EnsureDefaultDataSources(pangolinURL, traefikURL string
// Try to determine if Traefik is available
if cm.config.ActiveDataSource == "pangolin" {
client := &http.Client{Timeout: 2 * time.Second}
client := HTTPClientWithTimeout(2 * time.Second)
traefikConfig := cm.config.DataSources["traefik"]
// Try the Traefik URL
@ -282,9 +282,7 @@ func (cm *ConfigManager) UpdateDataSource(name string, config models.DataSourceC
// testDataSourceConnection tests the connection to a data source
func (cm *ConfigManager) testDataSourceConnection(ctx context.Context, config models.DataSourceConfig) error {
client := &http.Client{
Timeout: 5 * time.Second,
}
client := HTTPClientWithTimeout(5 * time.Second)
var url string
switch config.Type {

View file

@ -146,9 +146,7 @@ func NewConfigProxy(db *database.DB, configManager *ConfigManager, pangolinURL s
db: db,
configManager: configManager,
pangolinURL: pangolinURL,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
httpClient: HTTPClientWithTimeout(10 * time.Second),
cacheDuration: 5 * time.Second, // Match typical Traefik poll interval
}
}
@ -161,29 +159,21 @@ func (cp *ConfigProxy) GetMergedConfig() (*ProxiedTraefikConfig, error) {
defer cp.cacheMutex.RUnlock()
return cp.cache, nil
}
staleCache := cp.cache
cp.cacheMutex.RUnlock()
// Acquire write lock for cache update
cp.cacheMutex.Lock()
defer cp.cacheMutex.Unlock()
// Double-check after acquiring write lock
if cp.cache != nil && time.Now().Before(cp.cacheExpiry) {
return cp.cache, nil
}
// Fetch fresh config from Pangolin
// Fetch fresh config OUTSIDE the lock to avoid blocking readers
config, err := cp.fetchPangolinConfig()
if err != nil {
// Return stale cache on error if available
if cp.cache != nil {
if staleCache != nil {
log.Printf("Warning: Pangolin fetch failed, using stale cache: %v", err)
return cp.cache, nil
return staleCache, nil
}
return nil, fmt.Errorf("failed to fetch Pangolin config: %w", err)
}
// Merge MW-manager additions
// Merge MW-manager additions (no lock needed, operates on local config)
if err := cp.mergeMiddlewareManagerConfig(config); err != nil {
return nil, fmt.Errorf("failed to merge MW-manager config: %w", err)
}
@ -197,9 +187,11 @@ func (cp *ConfigProxy) GetMergedConfig() (*ProxiedTraefikConfig, error) {
// Normalize middleware field ordering to match Pangolin's JSON format
cp.normalizeMiddlewareOrder(config)
// Update cache
// Lock only to swap the cache
cp.cacheMutex.Lock()
cp.cache = config
cp.cacheExpiry = time.Now().Add(cp.cacheDuration)
cp.cacheMutex.Unlock()
return config, nil
}
@ -245,7 +237,7 @@ func (cp *ConfigProxy) fetchPangolinConfig() (*ProxiedTraefikConfig, error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) // 1MB limit for error body
return nil, fmt.Errorf("Pangolin returned status %d: %s", resp.StatusCode, string(body))
}
@ -1164,6 +1156,8 @@ func (cp *ConfigProxy) SetPangolinURL(url string) {
// SetCacheDuration updates the cache duration
func (cp *ConfigProxy) SetCacheDuration(duration time.Duration) {
cp.cacheMutex.Lock()
defer cp.cacheMutex.Unlock()
cp.cacheDuration = duration
}

View file

@ -9,6 +9,7 @@ import (
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
@ -23,7 +24,7 @@ type ResourceWatcher struct {
fetcher ResourceFetcher
configManager *ConfigManager
stopChan chan struct{}
isRunning bool
isRunning atomic.Bool
httpClient *http.Client
}
@ -49,18 +50,15 @@ func NewResourceWatcher(db *database.DB, configManager *ConfigManager) (*Resourc
fetcher: fetcher,
configManager: configManager,
stopChan: make(chan struct{}),
isRunning: false,
httpClient: httpClient,
}, nil
}
// Start begins watching for resources
func (rw *ResourceWatcher) Start(interval time.Duration) {
if rw.isRunning {
if !rw.isRunning.CompareAndSwap(false, true) {
return
}
rw.isRunning = true
log.Printf("Resource watcher started, checking every %v", interval)
ticker := time.NewTicker(interval)
@ -109,12 +107,11 @@ func (rw *ResourceWatcher) refreshFetcher() error {
// Stop stops the resource watcher
func (rw *ResourceWatcher) Stop() {
if !rw.isRunning {
if !rw.isRunning.CompareAndSwap(true, false) {
return
}
close(rw.stopChan)
rw.isRunning = false
}
// checkResources fetches resources from the configured data source and updates the database

View file

@ -39,10 +39,8 @@ type PangolinServiceFetcher struct {
// NewPangolinServiceFetcher creates a new Pangolin API fetcher for services
func NewPangolinServiceFetcher(config models.DataSourceConfig) *PangolinServiceFetcher {
return &PangolinServiceFetcher{
config: config,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
config: config,
httpClient: GetHTTPClient(),
}
}
@ -72,7 +70,7 @@ func (f *PangolinServiceFetcher) FetchServices(ctx context.Context) (*models.Ser
}
// Process response
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
@ -122,7 +120,11 @@ func (f *PangolinServiceFetcher) FetchServices(ctx context.Context) (*models.Ser
}
// Create new service
configJSON, _ := json.Marshal(serviceConfig)
configJSON, err := json.Marshal(serviceConfig)
if err != nil {
log.Printf("Failed to marshal service config for %s: %v", id, err)
continue
}
newService := models.Service{
ID: id,
@ -188,10 +190,8 @@ type TraefikServiceFetcher struct {
// NewTraefikServiceFetcher creates a new Traefik API fetcher for services
func NewTraefikServiceFetcher(config models.DataSourceConfig) *TraefikServiceFetcher {
return &TraefikServiceFetcher{
config: config,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
config: config,
httpClient: GetHTTPClient(),
}
}
@ -313,7 +313,7 @@ func (f *TraefikServiceFetcher) fetchHTTPServices(ctx context.Context, baseURL s
}
// Read and parse response body
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
@ -392,7 +392,7 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
}
// Read and parse response body
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
@ -435,7 +435,11 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
}
// Create service
configJSON, _ := json.Marshal(config)
configJSON, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal service config for %s: %v", name, err)
continue
}
services = append(services, models.Service{
ID: name,
@ -476,7 +480,11 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
}
// Create service
configJSON, _ := json.Marshal(config)
configJSON, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal service config for %s: %v", name, err)
continue
}
services = append(services, models.Service{
ID: name,
@ -522,7 +530,7 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
}
// Read and parse response body
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
@ -563,7 +571,11 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
}
// Create service
configJSON, _ := json.Marshal(config)
configJSON, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal service config for %s: %v", name, err)
continue
}
services = append(services, models.Service{
ID: name,
@ -604,7 +616,11 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
}
// Create service
configJSON, _ := json.Marshal(config)
configJSON, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal service config for %s: %v", name, err)
continue
}
services = append(services, models.Service{
ID: name,

View file

@ -6,6 +6,8 @@ import (
"sync"
)
const maxCacheSize = 1000
var (
// Regular expression to match cascading auth suffixes
authCascadeRegex = regexp.MustCompile(`(-auth)+$`)
@ -13,24 +15,33 @@ var (
// Regular expression for router suffix with auth patterns
routerAuthRegex = regexp.MustCompile(`-router(-auth)*$`)
// Memoization cache for normalized IDs
normalizedIDCache sync.Map
// Bounded memoization cache for normalized IDs
cacheMu sync.RWMutex
normalizedIDCache = make(map[string]string, maxCacheSize)
)
// NormalizeID provides a standard way to normalize any ID across the application
// It removes provider suffixes and handles special cases like auth cascades
// Uses memoization for improved performance on repeated calls
func NormalizeID(id string) string {
// Check cache first
if cached, ok := normalizedIDCache.Load(id); ok {
return cached.(string)
// Check cache first (read lock)
cacheMu.RLock()
if cached, ok := normalizedIDCache[id]; ok {
cacheMu.RUnlock()
return cached
}
cacheMu.RUnlock()
// Perform normalization
normalized := normalizeIDInternal(id)
// Store in cache
normalizedIDCache.Store(id, normalized)
// Store in cache (write lock), flush if full
cacheMu.Lock()
if len(normalizedIDCache) >= maxCacheSize {
normalizedIDCache = make(map[string]string, maxCacheSize)
}
normalizedIDCache[id] = normalized
cacheMu.Unlock()
return normalized
}
@ -66,10 +77,9 @@ func normalizeIDInternal(id string) string {
// ClearNormalizationCache clears the ID normalization cache
// Useful for testing or when IDs change
func ClearNormalizationCache() {
normalizedIDCache.Range(func(key, value interface{}) bool {
normalizedIDCache.Delete(key)
return true
})
cacheMu.Lock()
normalizedIDCache = make(map[string]string, maxCacheSize)
cacheMu.Unlock()
}
// GetProviderSuffix extracts the provider suffix from an ID

View file

@ -1,131 +1,109 @@
package util
import (
"fmt"
"testing"
)
func TestNormalizeID(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"provider suffix removal", "my-svc@docker", "my-svc"},
{"no provider suffix", "my-svc", "my-svc"},
{"auth cascade", "svc-auth-auth", "svc-auth"},
{"triple auth cascade", "svc-auth-auth-auth", "svc-auth"},
{"router auth pattern", "my-router-auth-auth", "my-router-auth"},
{"router redirect", "my-router-redirect-auth", "my-router-redirect"},
{"empty string", "", ""},
{"at sign only suffix", "svc@file", "svc"},
{"memoization cache hit", "cached-svc@docker", "cached-svc"},
{"service@http", "service"},
{"service@docker", "service"},
{"my-app-auth-auth", "my-app-auth"},
{"my-app-router-auth-auth", "my-app-router-auth"},
{"my-app-router-redirect-auth", "my-app-router-redirect"},
{"simple-id", "simple-id"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear cache before each test for isolation
t.Run(tt.input, func(t *testing.T) {
ClearNormalizationCache()
got := NormalizeID(tt.input)
if got != tt.expected {
t.Errorf("NormalizeID(%q) = %q, want %q", tt.input, got, tt.expected)
result := NormalizeID(tt.input)
if result != tt.expected {
t.Errorf("NormalizeID(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// Verify memoization: call twice and check cache is used
t.Run("memoization", func(t *testing.T) {
ClearNormalizationCache()
first := NormalizeID("test-svc@docker")
second := NormalizeID("test-svc@docker")
if first != second {
t.Errorf("memoization mismatch: first=%q, second=%q", first, second)
}
})
func TestNormalizeIDCacheHit(t *testing.T) {
ClearNormalizationCache()
first := NormalizeID("test@http")
second := NormalizeID("test@http")
if first != second {
t.Errorf("cache returned different results: %q vs %q", first, second)
}
}
func TestCacheBoundedness(t *testing.T) {
ClearNormalizationCache()
// Fill cache beyond maxCacheSize
for i := 0; i < maxCacheSize+100; i++ {
NormalizeID(fmt.Sprintf("id-%d@http", i))
}
// Cache should have been flushed, so size <= maxCacheSize
cacheMu.RLock()
size := len(normalizedIDCache)
cacheMu.RUnlock()
if size > maxCacheSize {
t.Errorf("cache size %d exceeds max %d", size, maxCacheSize)
}
}
func TestClearNormalizationCache(t *testing.T) {
// Populate cache
NormalizeID("a@docker")
NormalizeID("b@file")
NormalizeID("something@http")
// Clear it
ClearNormalizationCache()
// Verify by loading directly from sync.Map
count := 0
normalizedIDCache.Range(func(_, _ interface{}) bool {
count++
return true
})
if count != 0 {
t.Errorf("cache should be empty after ClearNormalizationCache, got %d items", count)
cacheMu.RLock()
size := len(normalizedIDCache)
cacheMu.RUnlock()
if size != 0 {
t.Errorf("cache not empty after clear: %d entries", size)
}
}
func TestGetProviderSuffix(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"with suffix", "my-svc@docker", "@docker"},
{"without suffix", "my-svc", ""},
{"empty string", "", ""},
{"file provider", "svc@file", "@file"},
{"service@http", "@http"},
{"service@docker", "@docker"},
{"no-suffix", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetProviderSuffix(tt.input)
if got != tt.expected {
t.Errorf("GetProviderSuffix(%q) = %q, want %q", tt.input, got, tt.expected)
t.Run(tt.input, func(t *testing.T) {
result := GetProviderSuffix(tt.input)
if result != tt.expected {
t.Errorf("GetProviderSuffix(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestAddProviderSuffix(t *testing.T) {
tests := []struct {
name string
id string
suffix string
expected string
}{
{"adds suffix", "my-svc", "docker", "my-svc@docker"},
{"adds suffix with @", "my-svc", "@docker", "my-svc@docker"},
{"skips if already has @", "my-svc@file", "docker", "my-svc@file"},
{"empty suffix returns original", "my-svc", "", "my-svc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := AddProviderSuffix(tt.id, tt.suffix)
if got != tt.expected {
t.Errorf("AddProviderSuffix(%q, %q) = %q, want %q", tt.id, tt.suffix, got, tt.expected)
}
})
func BenchmarkNormalizeIDUnique(b *testing.B) {
ClearNormalizationCache()
b.ResetTimer()
for i := 0; i < b.N; i++ {
NormalizeID(fmt.Sprintf("unique-id-%d@http", i))
}
}
func TestDetermineProviderSuffix(t *testing.T) {
tests := []struct {
name string
sourceType string
activeDS string
expected string
}{
{"file source", "file", "pangolin", "@file"},
{"traefik+traefik", "traefik", "traefik", "@docker"},
{"default http", "pangolin", "pangolin", "@http"},
{"other combo", "custom", "traefik", "@http"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DetermineProviderSuffix(tt.sourceType, tt.activeDS)
if got != tt.expected {
t.Errorf("DetermineProviderSuffix(%q, %q) = %q, want %q", tt.sourceType, tt.activeDS, got, tt.expected)
}
})
func BenchmarkNormalizeIDRepeated(b *testing.B) {
ClearNormalizationCache()
NormalizeID("repeated-id@http")
b.ResetTimer()
for i := 0; i < b.N; i++ {
NormalizeID("repeated-id@http")
}
}