This commit is contained in:
Vladimir Stoilov 2022-09-20 11:23:21 +02:00 committed by Daniel
parent b4e2687884
commit ddfa3722be
7 changed files with 64 additions and 57 deletions

View file

@ -59,7 +59,7 @@ const (
func init() { func init() {
// TODO: Move interception module to own package (dir). // TODO: Move interception module to own package (dir).
interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles", "captain") interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles")
network.SetDefaultFirewallHandler(defaultHandler) network.SetDefaultFirewallHandler(defaultHandler)
} }
@ -72,12 +72,12 @@ func interceptionPrep() error {
configChangeEvent, configChangeEvent,
"firewall config change event", "firewall config change event",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetAllConnections() resetPersistentVerdicts()
return nil return nil
}, },
) )
if err != nil { if err != nil {
_ = fmt.Errorf("failed registering event hook: %w", err) log.Errorf("interception: failed registering event hook: %s", err)
} }
// Reset connections every time profile changes // Reset connections every time profile changes
@ -86,12 +86,12 @@ func interceptionPrep() error {
profileConfigChangeEvent, profileConfigChangeEvent,
"firewall profile change event", "firewall profile change event",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetAllConnections() resetPersistentVerdicts()
return nil return nil
}, },
) )
if err != nil { if err != nil {
_ = fmt.Errorf("failed registering event hook: %w", err) log.Errorf("failed registering event hook: %s", err)
} }
// Reset connections when spn is connected // Reset connections when spn is connected
@ -101,12 +101,12 @@ func interceptionPrep() error {
onSPNConnectEvent, onSPNConnectEvent,
"firewall spn connect event", "firewall spn connect event",
func(ctx context.Context, _ interface{}) error { func(ctx context.Context, _ interface{}) error {
resetAllConnections() resetPersistentVerdicts()
return nil return nil
}, },
) )
if err != nil { if err != nil {
_ = fmt.Errorf("failed registering event hook: %w", err) log.Errorf("failed registering event hook: %s", err)
} }
if err := registerConfig(); err != nil { if err := registerConfig(); err != nil {
@ -116,23 +116,18 @@ func interceptionPrep() error {
return prepAPIAuth() return prepAPIAuth()
} }
func resetAllConnections() { func resetPersistentVerdicts() {
// Resetting will force all the connection to be evaluated by the firewall again // Resetting will force all the connection to be evaluated by the firewall again
// this will set new verdicts if configuration was update or spn has been disabled or enabled // this will set new verdicts if configuration was update or spn has been disabled or enabled
log.Info("interception: resetting all connections") log.Info("interception: resetting all connections")
err := interception.ResetAllConnections()
if err != nil {
log.Errorf("failed to reset all connections: %q", err)
}
// reset all connection firewall handlers. This will tell the master to rerun the firewall checks // reset all connection firewall handlers. This will tell the master to rerun the firewall checks
for _, id := range network.GetAllIDs() { for _, conn := range network.GetAllConnections() {
conn, err := getConnectionByID(id) conn.Lock()
if err != nil { isSPNConnection := captain.IsExcepted(conn.Entity.IP) && conn.Process().Pid == ownPID
continue
}
if !captain.IsExcepted(conn.Entity.IP) { // mark all non SPN connections to be processed by the firewall
if !isSPNConnection {
conn.SetFirewallHandler(initialHandler) conn.SetFirewallHandler(initialHandler)
// Don't keep the previous tunneled value // Don't keep the previous tunneled value
conn.Tunneled = false conn.Tunneled = false
@ -141,6 +136,12 @@ func resetAllConnections() {
conn.Entity.ResetLists() conn.Entity.ResetLists()
} }
} }
conn.Unlock()
}
err := interception.ResetVerdictOfAllConnections()
if err != nil {
log.Errorf("interception: failed to reset connections verdict: %s", err)
} }
} }
@ -494,7 +495,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
// Check if connection should be tunneled. // Check if connection should be tunneled.
checkTunneling(pkt.Ctx(), conn, pkt) checkTunneling(pkt.Ctx(), conn, pkt)
updateVerdictBasedOnPreviousState(conn) finalizeVerdict(conn)
switch { switch {
case conn.Inspecting: case conn.Inspecting:
@ -581,7 +582,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
} }
} }
func updateVerdictBasedOnPreviousState(conn *network.Connection) { func finalizeVerdict(conn *network.Connection) {
// previously accepted or tunneled connections may need to be blocked // previously accepted or tunneled connections may need to be blocked
if conn.Verdict.Current == network.VerdictAccept { if conn.Verdict.Current == network.VerdictAccept {
switch { switch {

View file

@ -15,7 +15,7 @@ func stop() error {
return StopNfqueueInterception() return StopNfqueueInterception()
} }
// ResetAllConnections resets all connections so they are forced to go thought the firewall again. // ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
func ResetAllConnections() error { func ResetVerdictOfAllConnections() error {
return nfq.DeleteAllMarkedConnection() return nfq.DeleteAllMarkedConnection()
} }

View file

@ -39,7 +39,7 @@ func stop() error {
return windowskext.Stop() return windowskext.Stop()
} }
// ResetAllConnections resets all connections so they are forced to go thought the firewall again // ResetVerdictOfAllConnections resets all connections so they are forced to go thought the firewall again.
func ResetAllConnections() error { func ResetVerdictOfAllConnections() error {
return windowskext.ClearCache() return windowskext.ClearCache()
} }

View file

@ -6,6 +6,8 @@ import (
"encoding/binary" "encoding/binary"
ct "github.com/florianl/go-conntrack" ct "github.com/florianl/go-conntrack"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netenv"
) )
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table. // DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
@ -17,37 +19,42 @@ func DeleteAllMarkedConnection() error {
defer func() { _ = nfct.Close() }() defer func() { _ = nfct.Close() }()
// Delete all ipv4 marked connections // Delete all ipv4 marked connections
connections := getAllMarkedConnections(nfct, ct.IPv4) deleteMarkedConnections(nfct, ct.IPv4)
for _, connection := range connections {
_ = nfct.Delete(ct.Conntrack, ct.IPv4, connection)
}
if netenv.IPv6Enabled() {
// Delete all ipv6 marked connections // Delete all ipv6 marked connections
connections = getAllMarkedConnections(nfct, ct.IPv6) deleteMarkedConnections(nfct, ct.IPv6)
for _, connection := range connections {
_ = nfct.Delete(ct.Conntrack, ct.IPv6, connection)
} }
return nil return nil
} }
func getAllMarkedConnections(nfct *ct.Nfct, f ct.Family) []ct.Con { func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) {
// initialize variables // initialize variables
permanentFlags := [...]uint32{MarkAccept, MarkBlock, MarkDrop, MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN} permanentFlags := [...]uint32{MarkAccept, MarkBlock, MarkDrop, MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN}
filter := ct.FilterAttr{} filter := ct.FilterAttr{}
filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF} filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF}
filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value
connections := make([]ct.Con, 0)
// get all connections from the specified family (ipv4 or ipv6) // get all connections from the specified family (ipv4 or ipv6)
for _, mark := range permanentFlags { for _, mark := range permanentFlags {
binary.BigEndian.PutUint32(filter.Mark, mark) // Little endian is in reverse not sure why. BigEndian makes it in correct order. binary.BigEndian.PutUint32(filter.Mark, mark) // Little endian is in reverse not sure why. BigEndian makes it in correct order.
currentConnections, err := nfct.Query(ct.Conntrack, f, filter) currentConnections, err := nfct.Query(ct.Conntrack, f, filter)
if err != nil { if err != nil {
log.Warningf("nfq: error on conntrack query: %s", err)
continue continue
} }
connections = append(connections, currentConnections...)
numberOfErrors := 0
for _, connection := range currentConnections {
err = nfct.Delete(ct.Conntrack, ct.IPv4, connection)
if err != nil {
numberOfErrors++
}
} }
return connections if numberOfErrors > 0 {
log.Warningf("nfq: failed to delete %d conntrack entries last error is: %s", numberOfErrors, err)
}
}
} }

View file

@ -95,8 +95,7 @@ func Init(dllPath, driverPath string) error {
new.clearCache, err = new.dll.FindProc("PortmasterClearCache") new.clearCache, err = new.dll.FindProc("PortmasterClearCache")
if err != nil { if err != nil {
// the loaded dll is an old version // the loaded dll is an old version
log.Errorf("could not find proc PortmasterClearCache in dll: %s", err) log.Errorf("could not find proc PortmasterClearCache (v0.x.x+) in dll: %s", err)
log.Warning("are you using the latest kext version?")
} }
// initialize dll/kext // initialize dll/kext

View file

@ -446,9 +446,9 @@ func GetConnection(id string) (*Connection, bool) {
return conns.get(id) return conns.get(id)
} }
// GetAllIDs Get all connection IDs. // GetAllConnections Gets all connection.
func GetAllIDs() []string { func GetAllConnections() []*Connection {
return append(conns.keys(), dnsConns.keys()...) return append(conns.list(), dnsConns.list()...)
} }
// SetLocalIP sets the local IP address together with its network scope. The // SetLocalIP sets the local IP address together with its network scope. The
@ -524,6 +524,8 @@ func (conn *Connection) Failed(reason, reasonOptionKey string) {
func (conn *Connection) SetVerdict(newVerdict Verdict, reason, reasonOptionKey string, reasonCtx interface{}) (ok bool) { func (conn *Connection) SetVerdict(newVerdict Verdict, reason, reasonOptionKey string, reasonCtx interface{}) (ok bool) {
conn.SetVerdictDirectly(newVerdict) conn.SetVerdictDirectly(newVerdict)
// Only set if it matches the user verdict. For a consistent reason
if newVerdict == conn.Verdict.User {
conn.Reason.Msg = reason conn.Reason.Msg = reason
conn.Reason.Context = reasonCtx conn.Reason.Context = reasonCtx
@ -533,6 +535,7 @@ func (conn *Connection) SetVerdict(newVerdict Verdict, reason, reasonOptionKey s
conn.Reason.OptionKey = reasonOptionKey conn.Reason.OptionKey = reasonOptionKey
conn.Reason.Profile = conn.Process().Profile().GetProfileSource(conn.Reason.OptionKey) conn.Reason.Profile = conn.Process().Profile().GetProfileSource(conn.Reason.OptionKey)
} }
}
return true return true
} }

View file

@ -48,18 +48,15 @@ func (cs *connectionStore) clone() map[string]*Connection {
return m return m
} }
func (cs *connectionStore) keys() []string { func (cs *connectionStore) list() []*Connection {
cs.rw.RLock() cs.rw.RLock()
defer cs.rw.RUnlock() defer cs.rw.RUnlock()
keys := make([]string, len(cs.items)) l := []*Connection{}
i := 0 for _, conn := range cs.items {
for key := range cs.items { l = append(l, conn)
keys[i] = key
i++
} }
return l
return keys
} }
func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused. func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused.