mirror of
https://github.com/hhftechnology/middleware-manager.git
synced 2026-04-28 03:29:42 +00:00
harden HTTP usage, caching, and transactional deletes
add safety across services: whitelist tables for DeleteInTransaction to prevent dynamic-SQL deletion abuse; remove a deprecated UpdateInTransaction helper. centralize HTTP client creation (HTTPClientWithTimeout/GetHTTPClient) and replace ad-hoc http.Clients with it; limit response body reads with io.LimitReader to avoid unbounded memory use; add better error logging when JSON marshal fails. Improve ConfigProxy cache handling to fetch outside locks, return stale cache on fetch errors, and only lock to swap the cache; add locking to SetCacheDuration and cap error body reads. convert ResourceWatcher isRunning to atomic.Bool for safe concurrent start/stop. Replace sync.Map memoization in id_normalizer with a bounded map protected by RWMutex (maxCacheSize), add cache flush behavior and tests/benchmarks to validate boundedness and hits. Miscellaneous test updates to match new behavior.
This commit is contained in:
parent
44a80ab152
commit
90a75b5a93
9 changed files with 153 additions and 184 deletions
|
|
@ -101,8 +101,24 @@ func ExecuteInTransactionWithResult(c *gin.Context, db *sql.DB, operation string
|
|||
c.JSON(http.StatusOK, response.Data)
|
||||
}
|
||||
|
||||
// allowedTables is the whitelist of table names that can be used in dynamic SQL.
|
||||
var allowedTables = map[string]bool{
|
||||
"resources": true,
|
||||
"services": true,
|
||||
"middlewares": true,
|
||||
"resource_services": true,
|
||||
"mtls_clients": true,
|
||||
"mtls_config": true,
|
||||
}
|
||||
|
||||
// DeleteInTransaction is a specialized helper for delete operations
|
||||
func DeleteInTransaction(c *gin.Context, db *sql.DB, table string, id string, additionalDeletes ...func(*sql.Tx) error) {
|
||||
if !allowedTables[table] {
|
||||
log.Printf("Error: attempted delete from disallowed table %q", table)
|
||||
ResponseWithError(c, http.StatusBadRequest, "Invalid table name")
|
||||
return
|
||||
}
|
||||
|
||||
err := WithTransaction(db, func(tx *sql.Tx) error {
|
||||
// Execute any additional deletes first (e.g., related records)
|
||||
for _, deleteFn := range additionalDeletes {
|
||||
|
|
|
|||
|
|
@ -82,43 +82,3 @@ func (db *DB) BatchTransaction(operations []TxFn) error {
|
|||
})
|
||||
}
|
||||
|
||||
// UpdateInTransaction updates a record in a transaction
|
||||
func (db *DB) UpdateInTransaction(table string, id string, updates map[string]interface{}) error {
|
||||
return db.WithTransaction(func(tx *sql.Tx) error {
|
||||
// Build the update statement
|
||||
query := fmt.Sprintf("UPDATE %s SET ", table)
|
||||
var params []interface{}
|
||||
|
||||
i := 0
|
||||
for field, value := range updates {
|
||||
if i > 0 {
|
||||
query += ", "
|
||||
}
|
||||
query += field + " = ?"
|
||||
params = append(params, value)
|
||||
i++
|
||||
}
|
||||
|
||||
// Add the WHERE clause and updated_at
|
||||
query += ", updated_at = ? WHERE id = ?"
|
||||
params = append(params, time.Now(), id)
|
||||
|
||||
// Execute the update
|
||||
result, err := tx.Exec(query, params...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if any rows were affected
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("no rows affected, record with ID %s not found", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
@ -628,7 +628,7 @@ func (cg *ConfigGenerator) processMTLSOptions(config *TraefikConfig) error {
|
|||
// Helper to fetch service names from Traefik API
|
||||
func (cg *ConfigGenerator) fetchTraefikServiceNames() map[string]string {
|
||||
serviceMap := make(map[string]string)
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
client := HTTPClientWithTimeout(5 * time.Second)
|
||||
|
||||
// Get Traefik API URL from data source config
|
||||
dsConfig, err := cg.configManager.GetActiveDataSourceConfig()
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ func (cm *ConfigManager) EnsureDefaultDataSources(pangolinURL, traefikURL string
|
|||
|
||||
// Try to determine if Traefik is available
|
||||
if cm.config.ActiveDataSource == "pangolin" {
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
client := HTTPClientWithTimeout(2 * time.Second)
|
||||
traefikConfig := cm.config.DataSources["traefik"]
|
||||
|
||||
// Try the Traefik URL
|
||||
|
|
@ -282,9 +282,7 @@ func (cm *ConfigManager) UpdateDataSource(name string, config models.DataSourceC
|
|||
|
||||
// testDataSourceConnection tests the connection to a data source
|
||||
func (cm *ConfigManager) testDataSourceConnection(ctx context.Context, config models.DataSourceConfig) error {
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
client := HTTPClientWithTimeout(5 * time.Second)
|
||||
|
||||
var url string
|
||||
switch config.Type {
|
||||
|
|
|
|||
|
|
@ -146,9 +146,7 @@ func NewConfigProxy(db *database.DB, configManager *ConfigManager, pangolinURL s
|
|||
db: db,
|
||||
configManager: configManager,
|
||||
pangolinURL: pangolinURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
httpClient: HTTPClientWithTimeout(10 * time.Second),
|
||||
cacheDuration: 5 * time.Second, // Match typical Traefik poll interval
|
||||
}
|
||||
}
|
||||
|
|
@ -161,29 +159,21 @@ func (cp *ConfigProxy) GetMergedConfig() (*ProxiedTraefikConfig, error) {
|
|||
defer cp.cacheMutex.RUnlock()
|
||||
return cp.cache, nil
|
||||
}
|
||||
staleCache := cp.cache
|
||||
cp.cacheMutex.RUnlock()
|
||||
|
||||
// Acquire write lock for cache update
|
||||
cp.cacheMutex.Lock()
|
||||
defer cp.cacheMutex.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if cp.cache != nil && time.Now().Before(cp.cacheExpiry) {
|
||||
return cp.cache, nil
|
||||
}
|
||||
|
||||
// Fetch fresh config from Pangolin
|
||||
// Fetch fresh config OUTSIDE the lock to avoid blocking readers
|
||||
config, err := cp.fetchPangolinConfig()
|
||||
if err != nil {
|
||||
// Return stale cache on error if available
|
||||
if cp.cache != nil {
|
||||
if staleCache != nil {
|
||||
log.Printf("Warning: Pangolin fetch failed, using stale cache: %v", err)
|
||||
return cp.cache, nil
|
||||
return staleCache, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch Pangolin config: %w", err)
|
||||
}
|
||||
|
||||
// Merge MW-manager additions
|
||||
// Merge MW-manager additions (no lock needed, operates on local config)
|
||||
if err := cp.mergeMiddlewareManagerConfig(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to merge MW-manager config: %w", err)
|
||||
}
|
||||
|
|
@ -197,9 +187,11 @@ func (cp *ConfigProxy) GetMergedConfig() (*ProxiedTraefikConfig, error) {
|
|||
// Normalize middleware field ordering to match Pangolin's JSON format
|
||||
cp.normalizeMiddlewareOrder(config)
|
||||
|
||||
// Update cache
|
||||
// Lock only to swap the cache
|
||||
cp.cacheMutex.Lock()
|
||||
cp.cache = config
|
||||
cp.cacheExpiry = time.Now().Add(cp.cacheDuration)
|
||||
cp.cacheMutex.Unlock()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
|
@ -245,7 +237,7 @@ func (cp *ConfigProxy) fetchPangolinConfig() (*ProxiedTraefikConfig, error) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) // 1MB limit for error body
|
||||
return nil, fmt.Errorf("Pangolin returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
|
|
@ -1164,6 +1156,8 @@ func (cp *ConfigProxy) SetPangolinURL(url string) {
|
|||
|
||||
// SetCacheDuration updates the cache duration
|
||||
func (cp *ConfigProxy) SetCacheDuration(duration time.Duration) {
|
||||
cp.cacheMutex.Lock()
|
||||
defer cp.cacheMutex.Unlock()
|
||||
cp.cacheDuration = duration
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
|
@ -23,7 +24,7 @@ type ResourceWatcher struct {
|
|||
fetcher ResourceFetcher
|
||||
configManager *ConfigManager
|
||||
stopChan chan struct{}
|
||||
isRunning bool
|
||||
isRunning atomic.Bool
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
|
|
@ -49,18 +50,15 @@ func NewResourceWatcher(db *database.DB, configManager *ConfigManager) (*Resourc
|
|||
fetcher: fetcher,
|
||||
configManager: configManager,
|
||||
stopChan: make(chan struct{}),
|
||||
isRunning: false,
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for resources
|
||||
func (rw *ResourceWatcher) Start(interval time.Duration) {
|
||||
if rw.isRunning {
|
||||
if !rw.isRunning.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.isRunning = true
|
||||
log.Printf("Resource watcher started, checking every %v", interval)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
|
|
@ -109,12 +107,11 @@ func (rw *ResourceWatcher) refreshFetcher() error {
|
|||
|
||||
// Stop stops the resource watcher
|
||||
func (rw *ResourceWatcher) Stop() {
|
||||
if !rw.isRunning {
|
||||
if !rw.isRunning.CompareAndSwap(true, false) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
close(rw.stopChan)
|
||||
rw.isRunning = false
|
||||
}
|
||||
|
||||
// checkResources fetches resources from the configured data source and updates the database
|
||||
|
|
|
|||
|
|
@ -39,10 +39,8 @@ type PangolinServiceFetcher struct {
|
|||
// NewPangolinServiceFetcher creates a new Pangolin API fetcher for services
|
||||
func NewPangolinServiceFetcher(config models.DataSourceConfig) *PangolinServiceFetcher {
|
||||
return &PangolinServiceFetcher{
|
||||
config: config,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
config: config,
|
||||
httpClient: GetHTTPClient(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -72,7 +70,7 @@ func (f *PangolinServiceFetcher) FetchServices(ctx context.Context) (*models.Ser
|
|||
}
|
||||
|
||||
// Process response
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
|
@ -122,7 +120,11 @@ func (f *PangolinServiceFetcher) FetchServices(ctx context.Context) (*models.Ser
|
|||
}
|
||||
|
||||
// Create new service
|
||||
configJSON, _ := json.Marshal(serviceConfig)
|
||||
configJSON, err := json.Marshal(serviceConfig)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal service config for %s: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newService := models.Service{
|
||||
ID: id,
|
||||
|
|
@ -188,10 +190,8 @@ type TraefikServiceFetcher struct {
|
|||
// NewTraefikServiceFetcher creates a new Traefik API fetcher for services
|
||||
func NewTraefikServiceFetcher(config models.DataSourceConfig) *TraefikServiceFetcher {
|
||||
return &TraefikServiceFetcher{
|
||||
config: config,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
config: config,
|
||||
httpClient: GetHTTPClient(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -313,7 +313,7 @@ func (f *TraefikServiceFetcher) fetchHTTPServices(ctx context.Context, baseURL s
|
|||
}
|
||||
|
||||
// Read and parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
|
@ -392,7 +392,7 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Read and parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
|
@ -435,7 +435,11 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Create service
|
||||
configJSON, _ := json.Marshal(config)
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal service config for %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
services = append(services, models.Service{
|
||||
ID: name,
|
||||
|
|
@ -476,7 +480,11 @@ func (f *TraefikServiceFetcher) fetchTCPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Create service
|
||||
configJSON, _ := json.Marshal(config)
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal service config for %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
services = append(services, models.Service{
|
||||
ID: name,
|
||||
|
|
@ -522,7 +530,7 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Read and parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
|
@ -563,7 +571,11 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Create service
|
||||
configJSON, _ := json.Marshal(config)
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal service config for %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
services = append(services, models.Service{
|
||||
ID: name,
|
||||
|
|
@ -604,7 +616,11 @@ func (f *TraefikServiceFetcher) fetchUDPServices(ctx context.Context, baseURL st
|
|||
}
|
||||
|
||||
// Create service
|
||||
configJSON, _ := json.Marshal(config)
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal service config for %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
services = append(services, models.Service{
|
||||
ID: name,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
const maxCacheSize = 1000
|
||||
|
||||
var (
|
||||
// Regular expression to match cascading auth suffixes
|
||||
authCascadeRegex = regexp.MustCompile(`(-auth)+$`)
|
||||
|
|
@ -13,24 +15,33 @@ var (
|
|||
// Regular expression for router suffix with auth patterns
|
||||
routerAuthRegex = regexp.MustCompile(`-router(-auth)*$`)
|
||||
|
||||
// Memoization cache for normalized IDs
|
||||
normalizedIDCache sync.Map
|
||||
// Bounded memoization cache for normalized IDs
|
||||
cacheMu sync.RWMutex
|
||||
normalizedIDCache = make(map[string]string, maxCacheSize)
|
||||
)
|
||||
|
||||
// NormalizeID provides a standard way to normalize any ID across the application
|
||||
// It removes provider suffixes and handles special cases like auth cascades
|
||||
// Uses memoization for improved performance on repeated calls
|
||||
func NormalizeID(id string) string {
|
||||
// Check cache first
|
||||
if cached, ok := normalizedIDCache.Load(id); ok {
|
||||
return cached.(string)
|
||||
// Check cache first (read lock)
|
||||
cacheMu.RLock()
|
||||
if cached, ok := normalizedIDCache[id]; ok {
|
||||
cacheMu.RUnlock()
|
||||
return cached
|
||||
}
|
||||
cacheMu.RUnlock()
|
||||
|
||||
// Perform normalization
|
||||
normalized := normalizeIDInternal(id)
|
||||
|
||||
// Store in cache
|
||||
normalizedIDCache.Store(id, normalized)
|
||||
// Store in cache (write lock), flush if full
|
||||
cacheMu.Lock()
|
||||
if len(normalizedIDCache) >= maxCacheSize {
|
||||
normalizedIDCache = make(map[string]string, maxCacheSize)
|
||||
}
|
||||
normalizedIDCache[id] = normalized
|
||||
cacheMu.Unlock()
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
|
@ -66,10 +77,9 @@ func normalizeIDInternal(id string) string {
|
|||
// ClearNormalizationCache clears the ID normalization cache
|
||||
// Useful for testing or when IDs change
|
||||
func ClearNormalizationCache() {
|
||||
normalizedIDCache.Range(func(key, value interface{}) bool {
|
||||
normalizedIDCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
cacheMu.Lock()
|
||||
normalizedIDCache = make(map[string]string, maxCacheSize)
|
||||
cacheMu.Unlock()
|
||||
}
|
||||
|
||||
// GetProviderSuffix extracts the provider suffix from an ID
|
||||
|
|
|
|||
|
|
@ -1,131 +1,109 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"provider suffix removal", "my-svc@docker", "my-svc"},
|
||||
{"no provider suffix", "my-svc", "my-svc"},
|
||||
{"auth cascade", "svc-auth-auth", "svc-auth"},
|
||||
{"triple auth cascade", "svc-auth-auth-auth", "svc-auth"},
|
||||
{"router auth pattern", "my-router-auth-auth", "my-router-auth"},
|
||||
{"router redirect", "my-router-redirect-auth", "my-router-redirect"},
|
||||
{"empty string", "", ""},
|
||||
{"at sign only suffix", "svc@file", "svc"},
|
||||
{"memoization cache hit", "cached-svc@docker", "cached-svc"},
|
||||
{"service@http", "service"},
|
||||
{"service@docker", "service"},
|
||||
{"my-app-auth-auth", "my-app-auth"},
|
||||
{"my-app-router-auth-auth", "my-app-router-auth"},
|
||||
{"my-app-router-redirect-auth", "my-app-router-redirect"},
|
||||
{"simple-id", "simple-id"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Clear cache before each test for isolation
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
ClearNormalizationCache()
|
||||
got := NormalizeID(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("NormalizeID(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
result := NormalizeID(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeID(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Verify memoization: call twice and check cache is used
|
||||
t.Run("memoization", func(t *testing.T) {
|
||||
ClearNormalizationCache()
|
||||
first := NormalizeID("test-svc@docker")
|
||||
second := NormalizeID("test-svc@docker")
|
||||
if first != second {
|
||||
t.Errorf("memoization mismatch: first=%q, second=%q", first, second)
|
||||
}
|
||||
})
|
||||
func TestNormalizeIDCacheHit(t *testing.T) {
|
||||
ClearNormalizationCache()
|
||||
first := NormalizeID("test@http")
|
||||
second := NormalizeID("test@http")
|
||||
if first != second {
|
||||
t.Errorf("cache returned different results: %q vs %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheBoundedness(t *testing.T) {
|
||||
ClearNormalizationCache()
|
||||
|
||||
// Fill cache beyond maxCacheSize
|
||||
for i := 0; i < maxCacheSize+100; i++ {
|
||||
NormalizeID(fmt.Sprintf("id-%d@http", i))
|
||||
}
|
||||
|
||||
// Cache should have been flushed, so size <= maxCacheSize
|
||||
cacheMu.RLock()
|
||||
size := len(normalizedIDCache)
|
||||
cacheMu.RUnlock()
|
||||
|
||||
if size > maxCacheSize {
|
||||
t.Errorf("cache size %d exceeds max %d", size, maxCacheSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearNormalizationCache(t *testing.T) {
|
||||
// Populate cache
|
||||
NormalizeID("a@docker")
|
||||
NormalizeID("b@file")
|
||||
NormalizeID("something@http")
|
||||
|
||||
// Clear it
|
||||
ClearNormalizationCache()
|
||||
|
||||
// Verify by loading directly from sync.Map
|
||||
count := 0
|
||||
normalizedIDCache.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count != 0 {
|
||||
t.Errorf("cache should be empty after ClearNormalizationCache, got %d items", count)
|
||||
cacheMu.RLock()
|
||||
size := len(normalizedIDCache)
|
||||
cacheMu.RUnlock()
|
||||
|
||||
if size != 0 {
|
||||
t.Errorf("cache not empty after clear: %d entries", size)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProviderSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"with suffix", "my-svc@docker", "@docker"},
|
||||
{"without suffix", "my-svc", ""},
|
||||
{"empty string", "", ""},
|
||||
{"file provider", "svc@file", "@file"},
|
||||
{"service@http", "@http"},
|
||||
{"service@docker", "@docker"},
|
||||
{"no-suffix", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetProviderSuffix(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("GetProviderSuffix(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := GetProviderSuffix(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetProviderSuffix(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddProviderSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
suffix string
|
||||
expected string
|
||||
}{
|
||||
{"adds suffix", "my-svc", "docker", "my-svc@docker"},
|
||||
{"adds suffix with @", "my-svc", "@docker", "my-svc@docker"},
|
||||
{"skips if already has @", "my-svc@file", "docker", "my-svc@file"},
|
||||
{"empty suffix returns original", "my-svc", "", "my-svc"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := AddProviderSuffix(tt.id, tt.suffix)
|
||||
if got != tt.expected {
|
||||
t.Errorf("AddProviderSuffix(%q, %q) = %q, want %q", tt.id, tt.suffix, got, tt.expected)
|
||||
}
|
||||
})
|
||||
func BenchmarkNormalizeIDUnique(b *testing.B) {
|
||||
ClearNormalizationCache()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
NormalizeID(fmt.Sprintf("unique-id-%d@http", i))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineProviderSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceType string
|
||||
activeDS string
|
||||
expected string
|
||||
}{
|
||||
{"file source", "file", "pangolin", "@file"},
|
||||
{"traefik+traefik", "traefik", "traefik", "@docker"},
|
||||
{"default http", "pangolin", "pangolin", "@http"},
|
||||
{"other combo", "custom", "traefik", "@http"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DetermineProviderSuffix(tt.sourceType, tt.activeDS)
|
||||
if got != tt.expected {
|
||||
t.Errorf("DetermineProviderSuffix(%q, %q) = %q, want %q", tt.sourceType, tt.activeDS, got, tt.expected)
|
||||
}
|
||||
})
|
||||
func BenchmarkNormalizeIDRepeated(b *testing.B) {
|
||||
ClearNormalizationCache()
|
||||
NormalizeID("repeated-id@http")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
NormalizeID("repeated-id@http")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue