This commit is contained in:
Daniel Hååvi 2025-04-07 14:26:56 +02:00 committed by GitHub
commit ebd913d357
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 88 additions and 14 deletions

View file

@ -0,0 +1,74 @@
package utils
import (
"sync"
"sync/atomic"
"time"
)
// CallLimiter2 bundles concurrent calls and optionally limits how fast a function is called.
type CallLimiter2 struct {
pause time.Duration
slot atomic.Int64
slotWait sync.RWMutex
executing atomic.Bool
lastExec time.Time
}
// NewCallLimiter2 returns a new call limiter.
// Set minPause to zero to disable the minimum pause between calls.
func NewCallLimiter2(minPause time.Duration) *CallLimiter2 {
return &CallLimiter2{
pause: minPause,
}
}
// Do executes the given function.
// All concurrent calls to Do are bundled and return when f() finishes.
// Waits until the minimum pause is over before executing f() again.
func (l *CallLimiter2) Do(f func()) {
// Get ticket number.
slot := l.slot.Load()
// Check if we can execute.
if l.executing.CompareAndSwap(false, true) {
// Make others wait.
l.slotWait.Lock()
defer l.slotWait.Unlock()
// Execute and return.
l.waitAndExec(f)
return
}
// Wait for slot to end and check if slot is done.
for l.slot.Load() == slot {
time.Sleep(100 * time.Microsecond)
l.slotWait.RLock()
l.slotWait.RUnlock() //nolint:staticcheck
}
}
func (l *CallLimiter2) waitAndExec(f func()) {
defer func() {
// Update last exec time.
l.lastExec = time.Now().UTC()
// Enable next execution first.
l.executing.Store(false)
// Move to next slot aftewards to prevent wait loops.
l.slot.Add(1)
}()
// Wait for the minimum duration between executions.
if l.pause > 0 {
sinceLastExec := time.Since(l.lastExec)
if sinceLastExec < l.pause {
time.Sleep(l.pause - sinceLastExec)
}
}
// Execute.
f()
}

View file

@ -13,7 +13,7 @@ func TestCallLimiter(t *testing.T) {
t.Parallel()
pause := 10 * time.Millisecond
oa := NewCallLimiter(pause)
oa := NewCallLimiter2(pause)
executed := abool.New()
var testWg sync.WaitGroup
@ -41,14 +41,14 @@ func TestCallLimiter(t *testing.T) {
executed.UnSet() // reset check
}
// Wait for pause to reset.
time.Sleep(pause)
// Wait for 2x pause to reset.
time.Sleep(2 * pause)
// Continuous use with re-execution.
// Choose values so that about 10 executions are expected
var execs uint32
testWg.Add(200)
for range 200 {
testWg.Add(100)
for range 100 {
go func() {
oa.Do(func() {
atomic.AddUint32(&execs, 1)
@ -69,8 +69,8 @@ func TestCallLimiter(t *testing.T) {
t.Errorf("unexpected high exec count: %d", execs)
}
// Wait for pause to reset.
time.Sleep(pause)
// Wait for 2x pause to reset.
time.Sleep(2 * pause)
// Check if the limiter correctly handles panics.
testWg.Add(100)

View file

@ -19,7 +19,7 @@ var (
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
pidsByUser = make(map[int][]int)
pidsByUserLock sync.RWMutex
fetchPidsByUser = utils.NewCallLimiter(10 * time.Millisecond)
fetchPidsByUser = utils.NewCallLimiter2(10 * time.Millisecond)
)
// getPidsByUser returns the cached PIDs for the given UID.

View file

@ -25,7 +25,7 @@ type tcpTable struct {
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
lastUpdateAt atomic.Int64
fetchLimiter *utils.CallLimiter
fetchLimiter *utils.CallLimiter2
fetchTable func() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error)
dualStack *tcpTable
@ -34,13 +34,13 @@ type tcpTable struct {
var (
tcp6Table = &tcpTable{
version: 6,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
fetchTable: getTCP6Table,
}
tcp4Table = &tcpTable{
version: 4,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
fetchTable: getTCP4Table,
}
)

View file

@ -24,7 +24,7 @@ type udpTable struct {
// lastUpdateAt stores the time when the tables where last updated as unix nanoseconds.
lastUpdateAt atomic.Int64
fetchLimiter *utils.CallLimiter
fetchLimiter *utils.CallLimiter2
fetchTable func() (binds []*socket.BindInfo, err error)
states map[string]map[string]*udpState
@ -52,14 +52,14 @@ const (
var (
udp6Table = &udpTable{
version: 6,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
fetchTable: getUDP6Table,
states: make(map[string]map[string]*udpState),
}
udp4Table = &udpTable{
version: 4,
fetchLimiter: utils.NewCallLimiter(minDurationBetweenTableUpdates),
fetchLimiter: utils.NewCallLimiter2(minDurationBetweenTableUpdates),
fetchTable: getUDP4Table,
states: make(map[string]map[string]*udpState),
}