From cfafbfca4ebcdca1e1f740aaf95222f5b0788967 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 15 Oct 2020 11:29:47 +0200 Subject: [PATCH] Improve trace logging --- firewall/dns.go | 13 ++--- firewall/interception.go | 14 ++--- firewall/master.go | 8 +-- intel/entity.go | 80 +++++++++++++-------------- intel/geoip/lookup.go | 3 - nameserver/nameserver.go | 12 +++- nameserver/response.go | 11 ++-- network/connection.go | 22 +++----- profile/endpoints/endpoint-any.go | 8 ++- profile/endpoints/endpoint-asn.go | 5 +- profile/endpoints/endpoint-country.go | 5 +- profile/endpoints/endpoint-domain.go | 1 + profile/endpoints/endpoint-ip.go | 3 +- profile/endpoints/endpoint-iprange.go | 3 +- profile/endpoints/endpoint-lists.go | 5 +- profile/endpoints/endpoint-scopes.go | 3 +- profile/endpoints/endpoint.go | 3 +- profile/endpoints/endpoints.go | 5 +- profile/endpoints/endpoints_test.go | 3 +- profile/profile-layered.go | 23 ++++---- 20 files changed, 121 insertions(+), 109 deletions(-) diff --git a/firewall/dns.go b/firewall/dns.go index acdf5ca6..24096ae6 100644 --- a/firewall/dns.go +++ b/firewall/dns.go @@ -170,31 +170,28 @@ func DecideOnResolvedDNS( updateIPsAndCNAMEs(q, rrCache, conn) - if mayBlockCNAMEs(conn) { + if mayBlockCNAMEs(ctx, conn) { return nil } - // TODO: Gate17 integration - // tunnelInfo, err := AssignTunnelIP(fqdn) - return updatedRR } -func mayBlockCNAMEs(conn *network.Connection) bool { +func mayBlockCNAMEs(ctx context.Context, conn *network.Connection) bool { // if we have CNAMEs and the profile is configured to filter them // we need to re-check the lists and endpoints here if conn.Process().Profile().FilterCNAMEs() { conn.Entity.ResetLists() - conn.Entity.EnableCNAMECheck(true) + conn.Entity.EnableCNAMECheck(ctx, true) - result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + result, reason := conn.Process().Profile().MatchEndpoint(ctx, conn.Entity) if result == endpoints.Denied { conn.BlockWithContext(reason.String(), reason.Context()) return true } if result == endpoints.NoMatch { - result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + result, reason = conn.Process().Profile().MatchFilterLists(ctx, conn.Entity) if result == endpoints.Denied { conn.BlockWithContext(reason.String(), reason.Context()) return true diff --git a/firewall/interception.go b/firewall/interception.go index a75eca91..a2146e97 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -87,13 +87,13 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { switch meta.Protocol { case packet.ICMP: // Always permit ICMP. - log.Debugf("accepting ICMP: %s", pkt) + log.Debugf("filter: fast-track accepting ICMP: %s", pkt) _ = pkt.PermanentAccept() return true case packet.ICMPv6: // Always permit ICMPv6. - log.Debugf("accepting ICMPv6: %s", pkt) + log.Debugf("filter: fast-track accepting ICMPv6: %s", pkt) _ = pkt.PermanentAccept() return true @@ -116,7 +116,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { } // Log and permit. - log.Debugf("accepting DHCP: %s", pkt) + log.Debugf("filter: fast-track accepting DHCP: %s", pkt) _ = pkt.PermanentAccept() return true @@ -141,7 +141,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // Only allow to own IPs. dstIsMe, err := netenv.IsMyIP(meta.Dst) if err != nil { - log.Warningf("filter: failed to check if IP is local: %s", err) + log.Warningf("filter: failed to check if IP %s is local: %s", meta.Dst, err) } if !dstIsMe { return false @@ -150,9 +150,9 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // Log and permit. switch meta.DstPort { case 53: - log.Debugf("accepting local dns: %s", pkt) + log.Debugf("filter: fast-track accepting local dns: %s", pkt) case apiPort: - log.Debugf("accepting api connection: %s", pkt) + log.Debugf("filter: fast-track accepting api connection: %s", pkt) default: return false } @@ -165,7 +165,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { } func initialHandler(conn *network.Connection, pkt packet.Packet) { - log.Tracer(pkt.Ctx()).Trace("filter: [initial handler]") + log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler") // check for internal firewall bypass ps := getPortStatusAndMarkUsed(pkt.Info().LocalPort()) diff --git a/firewall/master.go b/firewall/master.go index 07c7fb22..a06f8fa3 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -134,7 +134,7 @@ func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Pa return false } -func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { +func checkEndpointLists(ctx context.Context, conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason @@ -143,9 +143,9 @@ func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Pa // check endpoints list if conn.Inbound { - result, reason = p.MatchServiceEndpoint(conn.Entity) + result, reason = p.MatchServiceEndpoint(ctx, conn.Entity) } else { - result, reason = p.MatchEndpoint(conn.Entity) + result, reason = p.MatchEndpoint(ctx, conn.Entity) } switch result { case endpoints.Denied: @@ -271,7 +271,7 @@ func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet. // apply privacy filter lists p := conn.Process().Profile() - result, reason := p.MatchFilterLists(conn.Entity) + result, reason := p.MatchFilterLists(ctx, conn.Entity) switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) diff --git a/intel/entity.go b/intel/entity.go index d6abeb66..4ca3765e 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -89,9 +89,9 @@ func (e *Entity) Init() *Entity { } // FetchData fetches additional information, meant to be called before persisting an entity record. -func (e *Entity) FetchData() { - e.getLocation() - e.getLists() +func (e *Entity) FetchData(ctx context.Context) { + e.getLocation(ctx) + e.getLists(ctx) } // ResetLists resets the current list data and forces @@ -119,18 +119,18 @@ func (e *Entity) ResetLists() { // ResolveSubDomainLists enables or disables list lookups for // sub-domains. -func (e *Entity) ResolveSubDomainLists(enabled bool) { +func (e *Entity) ResolveSubDomainLists(ctx context.Context, enabled bool) { if e.domainListLoaded { - log.Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain) + log.Tracer(ctx).Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain) } e.resolveSubDomainLists = enabled } // EnableCNAMECheck enalbes or disables list lookups for // entity CNAMEs. -func (e *Entity) EnableCNAMECheck(enabled bool) { +func (e *Entity) EnableCNAMECheck(ctx context.Context, enabled bool) { if e.domainListLoaded { - log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain) + log.Tracer(ctx).Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain) } e.checkCNAMEs = enabled } @@ -148,7 +148,7 @@ func (e *Entity) EnableReverseResolving() { e.reverseResolveEnabled = true } -func (e *Entity) reverseResolve() { +func (e *Entity) reverseResolve(ctx context.Context) { e.reverseResolveOnce.Do(func() { // check if we should resolve if !e.reverseResolveEnabled { @@ -165,9 +165,9 @@ func (e *Entity) reverseResolve() { return } // TODO: security level - domain, err := reverseResolver(context.TODO(), e.IP.String(), status.SecurityLevelNormal) + domain, err := reverseResolver(ctx, e.IP.String(), status.SecurityLevelNormal) if err != nil { - log.Warningf("intel: failed to resolve IP %s: %s", e.IP, err) + log.Tracer(ctx).Warningf("intel: failed to resolve IP %s: %s", e.IP, err) return } e.Domain = domain @@ -194,7 +194,7 @@ func (e *Entity) GetIP() (net.IP, bool) { // Location -func (e *Entity) getLocation() { +func (e *Entity) getLocation(ctx context.Context) { e.fetchLocationOnce.Do(func() { // need IP! if e.IP == nil { @@ -204,7 +204,7 @@ func (e *Entity) getLocation() { // get location data loc, err := geoip.GetLocation(e.IP) if err != nil { - log.Warningf("intel: failed to get location data for %s: %s", e.IP, err) + log.Tracer(ctx).Warningf("intel: failed to get location data for %s: %s", e.IP, err) return } e.location = loc @@ -214,8 +214,8 @@ func (e *Entity) getLocation() { } // GetLocation returns the raw location data and whether it is set. -func (e *Entity) GetLocation() (*geoip.Location, bool) { - e.getLocation() +func (e *Entity) GetLocation(ctx context.Context) (*geoip.Location, bool) { + e.getLocation(ctx) if e.location == nil { return nil, false @@ -224,8 +224,8 @@ func (e *Entity) GetLocation() (*geoip.Location, bool) { } // GetCountry returns the two letter ISO country code and whether it is set. -func (e *Entity) GetCountry() (string, bool) { - e.getLocation() +func (e *Entity) GetCountry(ctx context.Context) (string, bool) { + e.getLocation(ctx) if e.Country == "" { return "", false @@ -234,8 +234,8 @@ func (e *Entity) GetCountry() (string, bool) { } // GetASN returns the AS number and whether it is set. -func (e *Entity) GetASN() (uint, bool) { - e.getLocation() +func (e *Entity) GetASN(ctx context.Context) (uint, bool) { + e.getLocation(ctx) if e.ASN == 0 { return 0, false @@ -244,11 +244,11 @@ func (e *Entity) GetASN() (uint, bool) { } // Lists -func (e *Entity) getLists() { - e.getDomainLists() - e.getASNLists() - e.getIPLists() - e.getCountryLists() +func (e *Entity) getLists(ctx context.Context) { + e.getDomainLists(ctx) + e.getASNLists(ctx) + e.getIPLists(ctx) + e.getCountryLists(ctx) } func (e *Entity) mergeList(key string, list []string) { @@ -263,7 +263,7 @@ func (e *Entity) mergeList(key string, list []string) { e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list) } -func (e *Entity) getDomainLists() { +func (e *Entity) getDomainLists(ctx context.Context) { if e.domainListLoaded { return } @@ -277,7 +277,7 @@ func (e *Entity) getDomainLists() { var domainsToInspect = []string{domain} if e.checkCNAMEs { - log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) + log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) domainsToInspect = append(domainsToInspect, e.CNAME...) } @@ -294,10 +294,10 @@ func (e *Entity) getDomainLists() { domains = makeDistinct(domains) for _, d := range domains { - log.Tracef("intel: loading domain list for %s", d) + log.Tracer(ctx).Tracef("intel: loading domain list for %s", d) list, err := filterlists.LookupDomain(d) if err != nil { - log.Errorf("intel: failed to get domain blocklists for %s: %s", d, err) + log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err) e.loadDomainListOnce = sync.Once{} return } @@ -334,22 +334,22 @@ func splitDomain(domain string) []string { return domains } -func (e *Entity) getASNLists() { +func (e *Entity) getASNLists(ctx context.Context) { if e.asnListLoaded { return } - asn, ok := e.GetASN() + asn, ok := e.GetASN(ctx) if !ok { return } - log.Tracef("intel: loading ASN list for %d", asn) + log.Tracer(ctx).Tracef("intel: loading ASN list for %d", asn) e.loadAsnListOnce.Do(func() { asnStr := fmt.Sprintf("%d", asn) list, err := filterlists.LookupASNString(asnStr) if err != nil { - log.Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err) + log.Tracer(ctx).Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err) e.loadAsnListOnce = sync.Once{} return } @@ -359,21 +359,21 @@ func (e *Entity) getASNLists() { }) } -func (e *Entity) getCountryLists() { +func (e *Entity) getCountryLists(ctx context.Context) { if e.countryListLoaded { return } - country, ok := e.GetCountry() + country, ok := e.GetCountry(ctx) if !ok { return } - log.Tracef("intel: loading country list for %s", country) + log.Tracer(ctx).Tracef("intel: loading country list for %s", country) e.loadCoutryListOnce.Do(func() { list, err := filterlists.LookupCountry(country) if err != nil { - log.Errorf("intel: failed to load country blocklist for %s: %s", country, err) + log.Tracer(ctx).Errorf("intel: failed to load country blocklist for %s: %s", country, err) e.loadCoutryListOnce = sync.Once{} return } @@ -383,7 +383,7 @@ func (e *Entity) getCountryLists() { }) } -func (e *Entity) getIPLists() { +func (e *Entity) getIPLists(ctx context.Context) { if e.ipListLoaded { return } @@ -402,12 +402,12 @@ func (e *Entity) getIPLists() { return } - log.Tracef("intel: loading IP list for %s", ip) + log.Tracer(ctx).Tracef("intel: loading IP list for %s", ip) e.loadIPListOnce.Do(func() { list, err := filterlists.LookupIP(ip) if err != nil { - log.Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err) + log.Tracer(ctx).Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err) e.loadIPListOnce = sync.Once{} return } @@ -418,8 +418,8 @@ func (e *Entity) getIPLists() { // LoadLists searches all filterlists for all occurrences of // this entity. -func (e *Entity) LoadLists() bool { - e.getLists() +func (e *Entity) LoadLists(ctx context.Context) bool { + e.getLists(ctx) return e.ListOccurences != nil } diff --git a/intel/geoip/lookup.go b/intel/geoip/lookup.go index 73c60d05..b4c7f320 100644 --- a/intel/geoip/lookup.go +++ b/intel/geoip/lookup.go @@ -4,7 +4,6 @@ import ( "net" "github.com/oschwald/maxminddb-golang" - "github.com/safing/portbase/log" ) func getReader(ip net.IP) *maxminddb.Reader { @@ -49,7 +48,5 @@ func GetLocation(ip net.IP) (record *Location, err error) { return nil, err } - log.Tracef("geoip: record: %+v", record) - return record, nil } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 0d780d87..293e7ef6 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -76,12 +76,15 @@ func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { return handleRequest(ctx, w, query) }) if err != nil { - log.Warningf("intel: failed to handle dns request: %s", err) + log.Warningf("nameserver: failed to handle dns request: %s", err) } } func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO // Only process first question, that's how everyone does it. + if len(request.Question) == 0 { + return errors.New("missing question") + } originalQuestion := request.Question[0] // Check if we are handling a non-standard query name. @@ -116,7 +119,12 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // Setup quick reply function. reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error { - return sendResponse(ctx, w, request, responder, rrProviders...) + err := sendResponse(ctx, w, request, responder, rrProviders...) + // Log error here instead of returning it in order to keep the context. + if err != nil { + tracer.Errorf("nameserver: %s", err) + } + return nil } // Return with server failure if offline. diff --git a/nameserver/response.go b/nameserver/response.go index 4171b36f..18e9f75d 100644 --- a/nameserver/response.go +++ b/nameserver/response.go @@ -36,19 +36,19 @@ func sendResponse( } // Write reply. - if err := writeDNSResponse(w, reply); err != nil { - return fmt.Errorf("nameserver: failed to send response: %w", err) + if err := writeDNSResponse(ctx, w, reply); err != nil { + return fmt.Errorf("failed to send response: %w", err) } return nil } -func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) { +func writeDNSResponse(ctx context.Context, w dns.ResponseWriter, m *dns.Msg) (err error) { defer func() { // recover from panic if panicErr := recover(); panicErr != nil { err = fmt.Errorf("panic: %s", panicErr) - log.Warningf("nameserver: panic caused by this msg: %#v", m) + log.Tracer(ctx).Debugf("nameserver: panic caused by this msg: %#v", m) } }() @@ -56,10 +56,11 @@ func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) { if err != nil { // If we receive an error we might have exceeded the message size with all // our extra information records. Retry again without the extra section. + log.Tracer(ctx).Tracef("nameserver: retrying to write dns message without extra section, error was: %s", err) m.Extra = nil noExtraErr := w.WriteMsg(m) if noExtraErr == nil { - log.Warningf("nameserver: failed to write dns message with extra section: %s", err) + return fmt.Errorf("failed to write dns message without extra section: %w", err) } } return diff --git a/network/connection.go b/network/connection.go index a43f4478..18f32dc2 100644 --- a/network/connection.go +++ b/network/connection.go @@ -80,7 +80,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri }, ) if err != nil { - log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) + log.Tracer(ctx).Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) } @@ -103,7 +103,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // get Process proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info()) if err != nil { - log.Debugf("network: failed to find process of packet %s: %s", pkt, err) + log.Tracer(pkt.Ctx()).Debugf("network: failed to find process of packet %s: %s", pkt, err) proc = process.GetUnidentifiedProcess(pkt.Ctx()) } @@ -203,9 +203,7 @@ func GetConnection(id string) (*Connection, bool) { // AcceptWithContext accepts the connection. func (conn *Connection) AcceptWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictAccept, reason, ctx) { - log.Infof("filter: granting connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictAccept, reason, ctx) { log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict) } } @@ -217,9 +215,7 @@ func (conn *Connection) Accept(reason string) { // BlockWithContext blocks the connection. func (conn *Connection) BlockWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictBlock, reason, ctx) { - log.Infof("filter: blocking connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictBlock, reason, ctx) { log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict) } } @@ -231,9 +227,7 @@ func (conn *Connection) Block(reason string) { // DropWithContext drops the connection. func (conn *Connection) DropWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictDrop, reason, ctx) { - log.Infof("filter: dropping connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictDrop, reason, ctx) { log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict) } } @@ -259,9 +253,7 @@ func (conn *Connection) Deny(reason string) { // FailedWithContext marks the connection with VerdictFailed and stores the reason. func (conn *Connection) FailedWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictFailed, reason, ctx) { - log.Infof("filter: dropping connection %s because of an internal error: %s", conn, reason) - } else { + if !conn.SetVerdict(VerdictFailed, reason, ctx) { log.Warningf("filter: tried to drop %s due to error but current verdict is %s", conn, conn.Verdict) } } @@ -401,6 +393,8 @@ func (conn *Connection) packetHandler() { } else { defaultFirewallHandler(conn, pkt) } + // log verdict + log.Tracer(pkt.Ctx()).Infof("filter: connection %s %s: %s", conn, conn.Verdict.Verb(), conn.Reason) conn.Unlock() // save does not touch any changing data // must not be locked, will deadlock with cleaner functions diff --git a/profile/endpoints/endpoint-any.go b/profile/endpoints/endpoint-any.go index 8e8deb98..14960489 100644 --- a/profile/endpoints/endpoint-any.go +++ b/profile/endpoints/endpoint-any.go @@ -1,6 +1,10 @@ package endpoints -import "github.com/safing/portmaster/intel" +import ( + "context" + + "github.com/safing/portmaster/intel" +) // EndpointAny matches anything. type EndpointAny struct { @@ -8,7 +12,7 @@ type EndpointAny struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointAny) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointAny) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { return ep.match(ep, entity, "*", "matches") } diff --git a/profile/endpoints/endpoint-asn.go b/profile/endpoints/endpoint-asn.go index 6713d199..6c11f1a6 100644 --- a/profile/endpoints/endpoint-asn.go +++ b/profile/endpoints/endpoint-asn.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "regexp" "strconv" @@ -20,8 +21,8 @@ type EndpointASN struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointASN) Matches(entity *intel.Entity) (EPResult, Reason) { - asn, ok := entity.GetASN() +func (ep *EndpointASN) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + asn, ok := entity.GetASN(ctx) if !ok { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-country.go b/profile/endpoints/endpoint-country.go index 85449cf5..d47d6e1f 100644 --- a/profile/endpoints/endpoint-country.go +++ b/profile/endpoints/endpoint-country.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "regexp" "strings" @@ -19,8 +20,8 @@ type EndpointCountry struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointCountry) Matches(entity *intel.Entity) (EPResult, Reason) { - country, ok := entity.GetCountry() +func (ep *EndpointCountry) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + country, ok := entity.GetCountry(ctx) if !ok { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index b02fc9a5..6c43a7de 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "regexp" "strings" diff --git a/profile/endpoints/endpoint-ip.go b/profile/endpoints/endpoint-ip.go index 43ea47f7..08110247 100644 --- a/profile/endpoints/endpoint-ip.go +++ b/profile/endpoints/endpoint-ip.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "github.com/safing/portmaster/intel" @@ -14,7 +15,7 @@ type EndpointIP struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIP) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointIP) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-iprange.go b/profile/endpoints/endpoint-iprange.go index bc0d22fe..d4be35db 100644 --- a/profile/endpoints/endpoint-iprange.go +++ b/profile/endpoints/endpoint-iprange.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "github.com/safing/portmaster/intel" @@ -14,7 +15,7 @@ type EndpointIPRange struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIPRange) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointIPRange) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-lists.go b/profile/endpoints/endpoint-lists.go index 27ec8b00..618c66d9 100644 --- a/profile/endpoints/endpoint-lists.go +++ b/profile/endpoints/endpoint-lists.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "strings" "github.com/safing/portmaster/intel" @@ -15,8 +16,8 @@ type EndpointLists struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointLists) Matches(entity *intel.Entity) (EPResult, Reason) { - if !entity.LoadLists() { +func (ep *EndpointLists) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + if !entity.LoadLists(ctx) { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go index ea22126d..6f1c2f27 100644 --- a/profile/endpoints/endpoint-scopes.go +++ b/profile/endpoints/endpoint-scopes.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "strings" "github.com/safing/portmaster/network/netutils" @@ -30,7 +31,7 @@ type EndpointScope struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointScope) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 2e0a4e85..260d0c8e 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "strconv" "strings" @@ -11,7 +12,7 @@ import ( // Endpoint describes an Endpoint Matcher type Endpoint interface { - Matches(entity *intel.Entity) (EPResult, Reason) + Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) String() string } diff --git a/profile/endpoints/endpoints.go b/profile/endpoints/endpoints.go index f74edd24..76c3d8ba 100644 --- a/profile/endpoints/endpoints.go +++ b/profile/endpoints/endpoints.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "strings" @@ -63,10 +64,10 @@ func (e Endpoints) IsSet() bool { } // Match checks whether the given entity matches any of the endpoint definitions in the list. -func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason Reason) { +func (e Endpoints) Match(ctx context.Context, entity *intel.Entity) (result EPResult, reason Reason) { for _, entry := range e { if entry != nil { - if result, reason = entry.Matches(entity); result != NoMatch { + if result, reason = entry.Matches(ctx, entity); result != NoMatch { return } } diff --git a/profile/endpoints/endpoints_test.go b/profile/endpoints/endpoints_test.go index 531544d8..c50aca90 100644 --- a/profile/endpoints/endpoints_test.go +++ b/profile/endpoints/endpoints_test.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "runtime" "testing" @@ -16,7 +17,7 @@ func TestMain(m *testing.M) { } func testEndpointMatch(t *testing.T, ep Endpoint, entity *intel.Entity, expectedResult EPResult) { - result, _ := ep.Matches(entity) + result, _ := ep.Matches(context.TODO(), entity) if result != expectedResult { t.Errorf( "line %d: unexpected result for endpoint %s and entity %+v: result=%s, expected=%s", diff --git a/profile/profile-layered.go b/profile/profile-layered.go index 9a5e0d50..e292edfb 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -1,6 +1,7 @@ package profile import ( + "context" "sync" "sync/atomic" @@ -221,10 +222,10 @@ func (lp *LayeredProfile) DefaultAction() uint8 { } // MatchEndpoint checks if the given endpoint matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { +func (lp *LayeredProfile) MatchEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { for _, layer := range lp.layers { if layer.endpoints.IsSet() { - result, reason := layer.endpoints.Match(entity) + result, reason := layer.endpoints.Match(ctx, entity) if endpoints.IsDecision(result) { return result, reason } @@ -233,16 +234,16 @@ func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResul cfgLock.RLock() defer cfgLock.RUnlock() - return cfgEndpoints.Match(entity) + return cfgEndpoints.Match(ctx, entity) } // MatchServiceEndpoint checks if the given endpoint of an inbound connection matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { +func (lp *LayeredProfile) MatchServiceEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { entity.EnableReverseResolving() for _, layer := range lp.layers { if layer.serviceEndpoints.IsSet() { - result, reason := layer.serviceEndpoints.Match(entity) + result, reason := layer.serviceEndpoints.Match(ctx, entity) if endpoints.IsDecision(result) { return result, reason } @@ -251,19 +252,19 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints. cfgLock.RLock() defer cfgLock.RUnlock() - return cfgServiceEndpoints.Match(entity) + return cfgServiceEndpoints.Match(ctx, entity) } // MatchFilterLists matches the entity against the set of filter // lists. -func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { - entity.ResolveSubDomainLists(lp.FilterSubDomains()) - entity.EnableCNAMECheck(lp.FilterCNAMEs()) +func (lp *LayeredProfile) MatchFilterLists(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { + entity.ResolveSubDomainLists(ctx, lp.FilterSubDomains()) + entity.EnableCNAMECheck(ctx, lp.FilterCNAMEs()) for _, layer := range lp.layers { // search for the first layer that has filterListIDs set if len(layer.filterListIDs) > 0 { - entity.LoadLists() + entity.LoadLists(ctx) if entity.MatchLists(layer.filterListIDs) { return endpoints.Denied, entity.ListBlockReason() @@ -276,7 +277,7 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe cfgLock.RLock() defer cfgLock.RUnlock() if len(cfgFilterLists) > 0 { - entity.LoadLists() + entity.LoadLists(ctx) if entity.MatchLists(cfgFilterLists) { return endpoints.Denied, entity.ListBlockReason()