diff --git a/firewall/interception.go b/firewall/interception.go index 6c6f7566..592e6c0a 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -59,7 +59,7 @@ const ( func init() { // TODO: Move interception module to own package (dir). - interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles", "captain") + interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles") network.SetDefaultFirewallHandler(defaultHandler) } @@ -72,12 +72,12 @@ func interceptionPrep() error { configChangeEvent, "firewall config change event", func(ctx context.Context, _ interface{}) error { - resetAllConnections() + resetPersistentVerdicts() return nil }, ) if err != nil { - _ = fmt.Errorf("failed registering event hook: %w", err) + log.Errorf("interception: failed registering event hook: %s", err) } // Reset connections every time profile changes @@ -86,12 +86,12 @@ func interceptionPrep() error { profileConfigChangeEvent, "firewall profile change event", func(ctx context.Context, _ interface{}) error { - resetAllConnections() + resetPersistentVerdicts() return nil }, ) if err != nil { - _ = fmt.Errorf("failed registering event hook: %w", err) + log.Errorf("failed registering event hook: %s", err) } // Reset connections when spn is connected @@ -101,12 +101,12 @@ func interceptionPrep() error { onSPNConnectEvent, "firewall spn connect event", func(ctx context.Context, _ interface{}) error { - resetAllConnections() + resetPersistentVerdicts() return nil }, ) if err != nil { - _ = fmt.Errorf("failed registering event hook: %w", err) + log.Errorf("failed registering event hook: %s", err) } if err := registerConfig(); err != nil { @@ -116,23 +116,18 @@ func interceptionPrep() error { return prepAPIAuth() } -func resetAllConnections() { +func resetPersistentVerdicts() { // Resetting will force all the connection to be evaluated by the firewall again // this will set new verdicts if configuration was update or spn has been disabled or enabled log.Info("interception: resetting all connections") - err := interception.ResetAllConnections() - if err != nil { - log.Errorf("failed to reset all connections: %q", err) - } // reset all connection firewall handlers. This will tell the master to rerun the firewall checks - for _, id := range network.GetAllIDs() { - conn, err := getConnectionByID(id) - if err != nil { - continue - } + for _, conn := range network.GetAllConnections() { + conn.Lock() + isSPNConnection := captain.IsExcepted(conn.Entity.IP) && conn.Process().Pid == ownPID - if !captain.IsExcepted(conn.Entity.IP) { + // mark all non SPN connections to be processed by the firewall + if !isSPNConnection { conn.SetFirewallHandler(initialHandler) // Don't keep the previous tunneled value conn.Tunneled = false @@ -141,6 +136,12 @@ func resetAllConnections() { conn.Entity.ResetLists() } } + conn.Unlock() + } + + err := interception.ResetVerdictOfAllConnections() + if err != nil { + log.Errorf("interception: failed to reset connections verdict: %s", err) } } @@ -494,7 +495,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { // Check if connection should be tunneled. checkTunneling(pkt.Ctx(), conn, pkt) - updateVerdictBasedOnPreviousState(conn) + finalizeVerdict(conn) switch { case conn.Inspecting: @@ -581,7 +582,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V } } -func updateVerdictBasedOnPreviousState(conn *network.Connection) { +func finalizeVerdict(conn *network.Connection) { // previously accepted or tunneled connections may need to be blocked if conn.Verdict.Current == network.VerdictAccept { switch { diff --git a/firewall/interception/interception_linux.go b/firewall/interception/interception_linux.go index 2223890e..6fe38edf 100644 --- a/firewall/interception/interception_linux.go +++ b/firewall/interception/interception_linux.go @@ -15,7 +15,7 @@ func stop() error { return StopNfqueueInterception() } -// ResetAllConnections resets all connections so they are forced to go thought the firewall again. -func ResetAllConnections() error { +// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again. +func ResetVerdictOfAllConnections() error { return nfq.DeleteAllMarkedConnection() } diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 150d24d8..382869be 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -39,7 +39,7 @@ func stop() error { return windowskext.Stop() } -// ResetAllConnections resets all connections so they are forced to go thought the firewall again -func ResetAllConnections() error { +// ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again. +func ResetVerdictOfAllConnections() error { return windowskext.ClearCache() } diff --git a/firewall/interception/nfq/conntrack.go b/firewall/interception/nfq/conntrack.go index 9066eef4..e3b9be02 100644 --- a/firewall/interception/nfq/conntrack.go +++ b/firewall/interception/nfq/conntrack.go @@ -6,6 +6,8 @@ import ( "encoding/binary" ct "github.com/florianl/go-conntrack" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/netenv" ) // DeleteAllMarkedConnection deletes all marked entries from the conntrack table. @@ -17,37 +19,42 @@ func DeleteAllMarkedConnection() error { defer func() { _ = nfct.Close() }() // Delete all ipv4 marked connections - connections := getAllMarkedConnections(nfct, ct.IPv4) - for _, connection := range connections { - _ = nfct.Delete(ct.Conntrack, ct.IPv4, connection) - } + deleteMarkedConnections(nfct, ct.IPv4) - // Delete all ipv6 marked connections - connections = getAllMarkedConnections(nfct, ct.IPv6) - for _, connection := range connections { - _ = nfct.Delete(ct.Conntrack, ct.IPv6, connection) + if netenv.IPv6Enabled() { + // Delete all ipv6 marked connections + deleteMarkedConnections(nfct, ct.IPv6) } return nil } -func getAllMarkedConnections(nfct *ct.Nfct, f ct.Family) []ct.Con { +func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) { // initialize variables permanentFlags := [...]uint32{MarkAccept, MarkBlock, MarkDrop, MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN} filter := ct.FilterAttr{} filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF} filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value - connections := make([]ct.Con, 0) // get all connections from the specified family (ipv4 or ipv6) for _, mark := range permanentFlags { binary.BigEndian.PutUint32(filter.Mark, mark) // Little endian is in reverse not sure why. BigEndian makes it in correct order. currentConnections, err := nfct.Query(ct.Conntrack, f, filter) if err != nil { + log.Warningf("nfq: error on conntrack query: %s", err) continue } - connections = append(connections, currentConnections...) - } - return connections + numberOfErrors := 0 + for _, connection := range currentConnections { + err = nfct.Delete(ct.Conntrack, ct.IPv4, connection) + if err != nil { + numberOfErrors++ + } + } + + if numberOfErrors > 0 { + log.Warningf("nfq: failed to delete %d conntrack entries last error is: %s", numberOfErrors, err) + } + } } diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index 5b7c972c..7e96614c 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -95,8 +95,7 @@ func Init(dllPath, driverPath string) error { new.clearCache, err = new.dll.FindProc("PortmasterClearCache") if err != nil { // the loaded dll is an old version - log.Errorf("could not find proc PortmasterClearCache in dll: %s", err) - log.Warning("are you using the latest kext version?") + log.Errorf("could not find proc PortmasterClearCache (v0.x.x+) in dll: %s", err) } // initialize dll/kext diff --git a/network/connection.go b/network/connection.go index 0624949b..68fdfa18 100644 --- a/network/connection.go +++ b/network/connection.go @@ -446,9 +446,9 @@ func GetConnection(id string) (*Connection, bool) { return conns.get(id) } -// GetAllIDs Get all connection IDs. -func GetAllIDs() []string { - return append(conns.keys(), dnsConns.keys()...) +// GetAllConnections Gets all connection. +func GetAllConnections() []*Connection { + return append(conns.list(), dnsConns.list()...) } // SetLocalIP sets the local IP address together with its network scope. The @@ -524,14 +524,17 @@ func (conn *Connection) Failed(reason, reasonOptionKey string) { func (conn *Connection) SetVerdict(newVerdict Verdict, reason, reasonOptionKey string, reasonCtx interface{}) (ok bool) { conn.SetVerdictDirectly(newVerdict) - conn.Reason.Msg = reason - conn.Reason.Context = reasonCtx + // Only set if it matches the user verdict. For a consistent reason + if newVerdict == conn.Verdict.User { + conn.Reason.Msg = reason + conn.Reason.Context = reasonCtx - conn.Reason.OptionKey = "" - conn.Reason.Profile = "" - if reasonOptionKey != "" && conn.Process() != nil { - conn.Reason.OptionKey = reasonOptionKey - conn.Reason.Profile = conn.Process().Profile().GetProfileSource(conn.Reason.OptionKey) + conn.Reason.OptionKey = "" + conn.Reason.Profile = "" + if reasonOptionKey != "" && conn.Process() != nil { + conn.Reason.OptionKey = reasonOptionKey + conn.Reason.Profile = conn.Process().Profile().GetProfileSource(conn.Reason.OptionKey) + } } return true diff --git a/network/connection_store.go b/network/connection_store.go index 937f97b0..69a43a5a 100644 --- a/network/connection_store.go +++ b/network/connection_store.go @@ -48,18 +48,15 @@ func (cs *connectionStore) clone() map[string]*Connection { return m } -func (cs *connectionStore) keys() []string { +func (cs *connectionStore) list() []*Connection { cs.rw.RLock() defer cs.rw.RUnlock() - keys := make([]string, len(cs.items)) - i := 0 - for key := range cs.items { - keys[i] = key - i++ + l := []*Connection{} + for _, conn := range cs.items { + l = append(l, conn) } - - return keys + return l } func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused.