Refactoring

This commit is contained in:
Vladimir Stoilov 2022-09-20 14:58:43 +02:00 committed by Daniel
parent ddfa3722be
commit ecce16ee78

View file

@ -70,7 +70,7 @@ func interceptionPrep() error {
err := interceptionModule.RegisterEventHook( err := interceptionModule.RegisterEventHook(
"config", "config",
configChangeEvent, configChangeEvent,
"firewall config change event", "reset connection verdicts",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetPersistentVerdicts() resetPersistentVerdicts()
return nil return nil
@ -84,7 +84,7 @@ func interceptionPrep() error {
err = interceptionModule.RegisterEventHook( err = interceptionModule.RegisterEventHook(
"profiles", "profiles",
profileConfigChangeEvent, profileConfigChangeEvent,
"firewall profile change event", "reset connection verdicts",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetPersistentVerdicts() resetPersistentVerdicts()
return nil return nil
@ -99,7 +99,7 @@ func interceptionPrep() error {
err = interceptionModule.RegisterEventHook( err = interceptionModule.RegisterEventHook(
"captain", "captain",
onSPNConnectEvent, onSPNConnectEvent,
"firewall spn connect event", "reset connection verdicts",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetPersistentVerdicts() resetPersistentVerdicts()
return nil return nil
@ -118,30 +118,30 @@ func interceptionPrep() error {
func resetPersistentVerdicts() { func resetPersistentVerdicts() {
// Resetting will force all the connection to be evaluated by the firewall again // Resetting will force all the connection to be evaluated by the firewall again
// this will set new verdicts if configuration was update or spn has been disabled or enabled // this will set new verdicts if configuration was update or spn has been disabled or enabled.
log.Info("interception: resetting all connections") log.Info("interception: reevaluating all connection trough the firewall")
// reset all connection firewall handlers. This will tell the master to rerun the firewall checks // reset all connection firewall handlers. This will tell the master to rerun the firewall checks.
for _, conn := range network.GetAllConnections() { for _, conn := range network.GetAllConnections() {
conn.Lock()
isSPNConnection := captain.IsExcepted(conn.Entity.IP) && conn.Process().Pid == ownPID isSPNConnection := captain.IsExcepted(conn.Entity.IP) && conn.Process().Pid == ownPID
// mark all non SPN connections to be processed by the firewall // mark all non SPN connections to be processed by the firewall.
if !isSPNConnection { if !isSPNConnection {
conn.Lock()
conn.SetFirewallHandler(initialHandler) conn.SetFirewallHandler(initialHandler)
// Don't keep the previous tunneled value // Don't keep the previous tunneled value.
conn.Tunneled = false conn.Tunneled = false
// Reset entity if it exists. // Reset entity if it exists.
if conn.Entity != nil { if conn.Entity != nil {
conn.Entity.ResetLists() conn.Entity.ResetLists()
} }
conn.Unlock()
} }
conn.Unlock()
} }
err := interception.ResetVerdictOfAllConnections() err := interception.ResetVerdictOfAllConnections()
if err != nil { if err != nil {
log.Errorf("interception: failed to reset connections verdict: %s", err) log.Errorf("interception: failed to remove persistent verdicts: %s", err)
} }
} }
@ -177,8 +177,6 @@ func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
} }
func handlePacket(ctx context.Context, pkt packet.Packet) { func handlePacket(ctx context.Context, pkt packet.Packet) {
// log.Errorf("DEBUG: firewall: handling packet %s", pkt)
// Record metrics. // Record metrics.
startTime := time.Now() startTime := time.Now()
defer packetHandlingHistogram.UpdateDuration(startTime) defer packetHandlingHistogram.UpdateDuration(startTime)
@ -222,7 +220,9 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) {
// Else create new one from the packet. // Else create new one from the packet.
conn = network.NewConnectionFromFirstPacket(pkt) conn = network.NewConnectionFromFirstPacket(pkt)
conn.Lock()
conn.SetFirewallHandler(initialHandler) conn.SetFirewallHandler(initialHandler)
conn.Unlock()
created = true created = true
return conn, nil return conn, nil
}) })
@ -248,27 +248,6 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) {
return conn, nil return conn, nil
} }
func getConnectionByID(id string) (*network.Connection, error) {
// Create or get connection in single inflight lock in order to prevent duplicates.
connPtr, _, _ := getConnectionSingleInflight.Do(id, func() (interface{}, error) {
// First, check for an existing connection.
conn, ok := network.GetConnection(id)
if ok {
return conn, nil
}
// Else return nil
return nil, nil
})
if connPtr == nil {
return nil, errors.New("connection does not exist")
}
connection := connPtr.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection.
return connection, nil
}
// fastTrackedPermit quickly permits certain network critical or internal connections. // fastTrackedPermit quickly permits certain network critical or internal connections.
func fastTrackedPermit(pkt packet.Packet) (handled bool) { func fastTrackedPermit(pkt packet.Packet) (handled bool) {
meta := pkt.Info() meta := pkt.Info()