diff --git a/Gopkg.lock b/Gopkg.lock index 0b8474c5..34c5ca1c 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -49,6 +49,17 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" +[[projects]] + branch = "master" + digest = "1:c8098f53cd182561cfb128c9a5ba70e41ad2364b763f33f05c6bd54003ae6495" + name = "github.com/florianl/go-nfqueue" + packages = [ + ".", + "internal/unix", + ] + pruneopts = "" + revision = "a2f196e98ab0ffdcb8b5252e7cbba98e45dea204" + [[projects]] digest = "1:b6581f9180e0f2d5549280d71819ab951db9d511478c87daca95669589d505c0" name = "github.com/go-ole/go-ole" @@ -140,12 +151,12 @@ version = "v1.1.0" [[projects]] - branch = "master" - digest = "1:9d781ead5ca35ef02cdf0dc516b239cb387fe73207b0dd01760f7d4a825f4cd3" + digest = "1:508f444b8e00a569a40899aaf5740348b44c305d36f36d4f002b277677deef95" name = "github.com/miekg/dns" packages = ["."] pruneopts = "" - revision = "da812eed45cba1ce4c978e746039483064b8f92d" + revision = "10e0aeedbee54849adab780611454192a9980443" + version = "v1.1.33" [[projects]] digest = "1:3282ac9a9ddf5c2c0eda96693364d34fe0f8d10a0748259082a5c9fbd3e1f7e4" @@ -368,7 +379,6 @@ "github.com/google/renameio", "github.com/hashicorp/go-multierror", "github.com/hashicorp/go-version", - "github.com/mdlayher/netlink", "github.com/miekg/dns", "github.com/oschwald/maxminddb-golang", "github.com/shirou/gopsutil/process", diff --git a/Gopkg.toml b/Gopkg.toml index 1a26fbfb..b658d391 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -26,10 +26,6 @@ ignored = ["github.com/safing/portbase/*", "github.com/safing/spn/*"] -[[constraint]] - name = "github.com/miekg/dns" - branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released - [[constraint]] name = "github.com/florianl/go-nfqueue" branch = "master" # switch back once we migrate to go.mod diff --git a/firewall/dns.go b/firewall/dns.go index b3b20a39..24096ae6 100644 --- a/firewall/dns.go +++ b/firewall/dns.go @@ -121,9 +121,9 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res err, ) } - } else if rrCache.TTL > time.Now().Add(10*time.Second).Unix() { + } else if rrCache.Expires > time.Now().Add(10*time.Second).Unix() { // Set a low TTL of 10 seconds if TTL is higher than that. - rrCache.TTL = time.Now().Add(10 * time.Second).Unix() + rrCache.Expires = time.Now().Add(10 * time.Second).Unix() err := rrCache.Save() if err != nil { log.Debugf( @@ -170,31 +170,28 @@ func DecideOnResolvedDNS( updateIPsAndCNAMEs(q, rrCache, conn) - if mayBlockCNAMEs(conn) { + if mayBlockCNAMEs(ctx, conn) { return nil } - // TODO: Gate17 integration - // tunnelInfo, err := AssignTunnelIP(fqdn) - return updatedRR } -func mayBlockCNAMEs(conn *network.Connection) bool { +func mayBlockCNAMEs(ctx context.Context, conn *network.Connection) bool { // if we have CNAMEs and the profile is configured to filter them // we need to re-check the lists and endpoints here if conn.Process().Profile().FilterCNAMEs() { conn.Entity.ResetLists() - conn.Entity.EnableCNAMECheck(true) + conn.Entity.EnableCNAMECheck(ctx, true) - result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity) + result, reason := conn.Process().Profile().MatchEndpoint(ctx, conn.Entity) if result == endpoints.Denied { conn.BlockWithContext(reason.String(), reason.Context()) return true } if result == endpoints.NoMatch { - result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity) + result, reason = conn.Process().Profile().MatchFilterLists(ctx, conn.Entity) if result == endpoints.Denied { conn.BlockWithContext(reason.String(), reason.Context()) return true @@ -205,10 +202,19 @@ func mayBlockCNAMEs(conn *network.Connection) bool { return false } +// updateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and +// updates the CNAMEs in the Connection's Entity. func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { - // save IP addresses to IPInfo + // Get profileID for scoping IPInfo. + var profileID string + proc := conn.Process() + if proc != nil { + profileID = proc.LocalProfileKey + } + + // Collect IPs and CNAMEs. cnames := make(map[string]string) - ips := make(map[string]struct{}) + ips := make([]net.IP, 0, len(rrCache.Answer)) for _, rr := range append(rrCache.Answer, rrCache.Extra...) { switch v := rr.(type) { @@ -216,19 +222,27 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw cnames[v.Hdr.Name] = v.Target case *dns.A: - ips[v.A.String()] = struct{}{} + ips = append(ips, v.A) case *dns.AAAA: - ips[v.AAAA.String()] = struct{}{} + ips = append(ips, v.AAAA) } } - for ip := range ips { - record := resolver.ResolvedDomain{ - Domain: q.FQDN, + // Package IPs and CNAMEs into IPInfo structs. + for _, ip := range ips { + // Never save domain attributions for localhost IPs. + if netutils.ClassifyIP(ip) == netutils.HostLocal { + continue } - // resolve all CNAMEs in the correct order. + // Create new record for this IP. + record := resolver.ResolvedDomain{ + Domain: q.FQDN, + Expires: rrCache.Expires, + } + + // Resolve all CNAMEs in the correct order and add the to the record. var domain = q.FQDN for { nextDomain, isCNAME := cnames[domain] @@ -240,31 +254,30 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw domain = nextDomain } - // update the entity to include the cnames + // Update the entity to include the CNAMEs of the query response. conn.Entity.CNAME = record.CNAMEs - // get the existing IP info or create a new one - var save bool - info, err := resolver.GetIPInfo(ip) + // Check if there is an existing record for this DNS response. + // Else create a new one. + ipString := ip.String() + info, err := resolver.GetIPInfo(profileID, ipString) if err != nil { if err != database.ErrNotFound { log.Errorf("nameserver: failed to search for IP info record: %s", err) } info = &resolver.IPInfo{ - IP: ip, + IP: ipString, + ProfileID: profileID, } - save = true } - // and the new resolved domain record and save - if new := info.AddDomain(record); new { - save = true - } - if save { - if err := info.Save(); err != nil { - log.Errorf("nameserver: failed to save IP info record: %s", err) - } + // Add the new record to the resolved domains for this IP and scope. + info.AddDomain(record) + + // Save if the record is new or has been updated. + if err := info.Save(); err != nil { + log.Errorf("nameserver: failed to save IP info record: %s", err) } } } diff --git a/firewall/interception.go b/firewall/interception.go index a75eca91..a2146e97 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -87,13 +87,13 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { switch meta.Protocol { case packet.ICMP: // Always permit ICMP. - log.Debugf("accepting ICMP: %s", pkt) + log.Debugf("filter: fast-track accepting ICMP: %s", pkt) _ = pkt.PermanentAccept() return true case packet.ICMPv6: // Always permit ICMPv6. - log.Debugf("accepting ICMPv6: %s", pkt) + log.Debugf("filter: fast-track accepting ICMPv6: %s", pkt) _ = pkt.PermanentAccept() return true @@ -116,7 +116,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { } // Log and permit. - log.Debugf("accepting DHCP: %s", pkt) + log.Debugf("filter: fast-track accepting DHCP: %s", pkt) _ = pkt.PermanentAccept() return true @@ -141,7 +141,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // Only allow to own IPs. dstIsMe, err := netenv.IsMyIP(meta.Dst) if err != nil { - log.Warningf("filter: failed to check if IP is local: %s", err) + log.Warningf("filter: failed to check if IP %s is local: %s", meta.Dst, err) } if !dstIsMe { return false @@ -150,9 +150,9 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { // Log and permit. switch meta.DstPort { case 53: - log.Debugf("accepting local dns: %s", pkt) + log.Debugf("filter: fast-track accepting local dns: %s", pkt) case apiPort: - log.Debugf("accepting api connection: %s", pkt) + log.Debugf("filter: fast-track accepting api connection: %s", pkt) default: return false } @@ -165,7 +165,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { } func initialHandler(conn *network.Connection, pkt packet.Packet) { - log.Tracer(pkt.Ctx()).Trace("filter: [initial handler]") + log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler") // check for internal firewall bypass ps := getPortStatusAndMarkUsed(pkt.Info().LocalPort()) diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 97238324..dec83cbc 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -4,6 +4,7 @@ package windowskext import ( "encoding/binary" + "errors" "net" "github.com/tevino/abool" @@ -45,6 +46,11 @@ func Handler(packets chan packet.Packet) { packetInfo, err := RecvVerdictRequest() if err != nil { + // Check if we are done with processing. + if errors.Is(err, ErrKextNotReady) { + return + } + log.Warningf("failed to get packet from windows kext: %s", err) continue } diff --git a/firewall/master.go b/firewall/master.go index 07c7fb22..a06f8fa3 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -134,7 +134,7 @@ func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Pa return false } -func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { +func checkEndpointLists(ctx context.Context, conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason @@ -143,9 +143,9 @@ func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Pa // check endpoints list if conn.Inbound { - result, reason = p.MatchServiceEndpoint(conn.Entity) + result, reason = p.MatchServiceEndpoint(ctx, conn.Entity) } else { - result, reason = p.MatchEndpoint(conn.Entity) + result, reason = p.MatchEndpoint(ctx, conn.Entity) } switch result { case endpoints.Denied: @@ -271,7 +271,7 @@ func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet. // apply privacy filter lists p := conn.Process().Profile() - result, reason := p.MatchFilterLists(conn.Entity) + result, reason := p.MatchFilterLists(ctx, conn.Entity) switch result { case endpoints.Denied: conn.DenyWithContext(reason.String(), reason.Context()) diff --git a/intel/entity.go b/intel/entity.go index d6abeb66..16d83707 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -37,12 +37,19 @@ type Entity struct { // Protocol is the protcol number used by the connection. Protocol uint8 - // Port is the destination port of the connection + // Port is the remote port of the connection Port uint16 + // dstPort is the destination port of the connection + dstPort uint16 + // Domain is the target domain of the connection. Domain string + // ReverseDomain is the domain the IP address points to. This is only + // resolved and populated when needed. + ReverseDomain string + // CNAME is a list of domain names that have been // resolved for Domain. CNAME []string @@ -88,10 +95,20 @@ func (e *Entity) Init() *Entity { return e } +// SetDstPort sets the destination port. +func (e *Entity) SetDstPort(dstPort uint16) { + e.dstPort = dstPort +} + +// DstPort returns the destination port. +func (e *Entity) DstPort() uint16 { + return e.dstPort +} + // FetchData fetches additional information, meant to be called before persisting an entity record. -func (e *Entity) FetchData() { - e.getLocation() - e.getLists() +func (e *Entity) FetchData(ctx context.Context) { + e.getLocation(ctx) + e.getLists(ctx) } // ResetLists resets the current list data and forces @@ -119,18 +136,18 @@ func (e *Entity) ResetLists() { // ResolveSubDomainLists enables or disables list lookups for // sub-domains. -func (e *Entity) ResolveSubDomainLists(enabled bool) { +func (e *Entity) ResolveSubDomainLists(ctx context.Context, enabled bool) { if e.domainListLoaded { - log.Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain) + log.Tracer(ctx).Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain) } e.resolveSubDomainLists = enabled } // EnableCNAMECheck enalbes or disables list lookups for // entity CNAMEs. -func (e *Entity) EnableCNAMECheck(enabled bool) { +func (e *Entity) EnableCNAMECheck(ctx context.Context, enabled bool) { if e.domainListLoaded { - log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain) + log.Tracer(ctx).Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain) } e.checkCNAMEs = enabled } @@ -148,13 +165,8 @@ func (e *Entity) EnableReverseResolving() { e.reverseResolveEnabled = true } -func (e *Entity) reverseResolve() { +func (e *Entity) reverseResolve(ctx context.Context) { e.reverseResolveOnce.Do(func() { - // check if we should resolve - if !e.reverseResolveEnabled { - return - } - // need IP! if e.IP == nil { return @@ -165,18 +177,25 @@ func (e *Entity) reverseResolve() { return } // TODO: security level - domain, err := reverseResolver(context.TODO(), e.IP.String(), status.SecurityLevelNormal) + domain, err := reverseResolver(ctx, e.IP.String(), status.SecurityLevelNormal) if err != nil { - log.Warningf("intel: failed to resolve IP %s: %s", e.IP, err) + log.Tracer(ctx).Warningf("intel: failed to resolve IP %s: %s", e.IP, err) return } - e.Domain = domain + e.ReverseDomain = domain }) } // GetDomain returns the domain and whether it is set. -func (e *Entity) GetDomain() (string, bool) { - e.reverseResolve() +func (e *Entity) GetDomain(ctx context.Context, mayUseReverseDomain bool) (string, bool) { + if mayUseReverseDomain && e.reverseResolveEnabled { + e.reverseResolve(ctx) + + if e.ReverseDomain == "" { + return "", false + } + return e.ReverseDomain, true + } if e.Domain == "" { return "", false @@ -194,7 +213,7 @@ func (e *Entity) GetIP() (net.IP, bool) { // Location -func (e *Entity) getLocation() { +func (e *Entity) getLocation(ctx context.Context) { e.fetchLocationOnce.Do(func() { // need IP! if e.IP == nil { @@ -204,7 +223,7 @@ func (e *Entity) getLocation() { // get location data loc, err := geoip.GetLocation(e.IP) if err != nil { - log.Warningf("intel: failed to get location data for %s: %s", e.IP, err) + log.Tracer(ctx).Warningf("intel: failed to get location data for %s: %s", e.IP, err) return } e.location = loc @@ -214,8 +233,8 @@ func (e *Entity) getLocation() { } // GetLocation returns the raw location data and whether it is set. -func (e *Entity) GetLocation() (*geoip.Location, bool) { - e.getLocation() +func (e *Entity) GetLocation(ctx context.Context) (*geoip.Location, bool) { + e.getLocation(ctx) if e.location == nil { return nil, false @@ -224,8 +243,8 @@ func (e *Entity) GetLocation() (*geoip.Location, bool) { } // GetCountry returns the two letter ISO country code and whether it is set. -func (e *Entity) GetCountry() (string, bool) { - e.getLocation() +func (e *Entity) GetCountry(ctx context.Context) (string, bool) { + e.getLocation(ctx) if e.Country == "" { return "", false @@ -234,8 +253,8 @@ func (e *Entity) GetCountry() (string, bool) { } // GetASN returns the AS number and whether it is set. -func (e *Entity) GetASN() (uint, bool) { - e.getLocation() +func (e *Entity) GetASN(ctx context.Context) (uint, bool) { + e.getLocation(ctx) if e.ASN == 0 { return 0, false @@ -244,11 +263,11 @@ func (e *Entity) GetASN() (uint, bool) { } // Lists -func (e *Entity) getLists() { - e.getDomainLists() - e.getASNLists() - e.getIPLists() - e.getCountryLists() +func (e *Entity) getLists(ctx context.Context) { + e.getDomainLists(ctx) + e.getASNLists(ctx) + e.getIPLists(ctx) + e.getCountryLists(ctx) } func (e *Entity) mergeList(key string, list []string) { @@ -263,12 +282,12 @@ func (e *Entity) mergeList(key string, list []string) { e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list) } -func (e *Entity) getDomainLists() { +func (e *Entity) getDomainLists(ctx context.Context) { if e.domainListLoaded { return } - domain, ok := e.GetDomain() + domain, ok := e.GetDomain(ctx, false /* mayUseReverseDomain */) if !ok { return } @@ -277,7 +296,7 @@ func (e *Entity) getDomainLists() { var domainsToInspect = []string{domain} if e.checkCNAMEs { - log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) + log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) domainsToInspect = append(domainsToInspect, e.CNAME...) } @@ -294,10 +313,10 @@ func (e *Entity) getDomainLists() { domains = makeDistinct(domains) for _, d := range domains { - log.Tracef("intel: loading domain list for %s", d) + log.Tracer(ctx).Tracef("intel: loading domain list for %s", d) list, err := filterlists.LookupDomain(d) if err != nil { - log.Errorf("intel: failed to get domain blocklists for %s: %s", d, err) + log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err) e.loadDomainListOnce = sync.Once{} return } @@ -334,22 +353,22 @@ func splitDomain(domain string) []string { return domains } -func (e *Entity) getASNLists() { +func (e *Entity) getASNLists(ctx context.Context) { if e.asnListLoaded { return } - asn, ok := e.GetASN() + asn, ok := e.GetASN(ctx) if !ok { return } - log.Tracef("intel: loading ASN list for %d", asn) + log.Tracer(ctx).Tracef("intel: loading ASN list for %d", asn) e.loadAsnListOnce.Do(func() { asnStr := fmt.Sprintf("%d", asn) list, err := filterlists.LookupASNString(asnStr) if err != nil { - log.Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err) + log.Tracer(ctx).Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err) e.loadAsnListOnce = sync.Once{} return } @@ -359,21 +378,21 @@ func (e *Entity) getASNLists() { }) } -func (e *Entity) getCountryLists() { +func (e *Entity) getCountryLists(ctx context.Context) { if e.countryListLoaded { return } - country, ok := e.GetCountry() + country, ok := e.GetCountry(ctx) if !ok { return } - log.Tracef("intel: loading country list for %s", country) + log.Tracer(ctx).Tracef("intel: loading country list for %s", country) e.loadCoutryListOnce.Do(func() { list, err := filterlists.LookupCountry(country) if err != nil { - log.Errorf("intel: failed to load country blocklist for %s: %s", country, err) + log.Tracer(ctx).Errorf("intel: failed to load country blocklist for %s: %s", country, err) e.loadCoutryListOnce = sync.Once{} return } @@ -383,7 +402,7 @@ func (e *Entity) getCountryLists() { }) } -func (e *Entity) getIPLists() { +func (e *Entity) getIPLists(ctx context.Context) { if e.ipListLoaded { return } @@ -402,12 +421,12 @@ func (e *Entity) getIPLists() { return } - log.Tracef("intel: loading IP list for %s", ip) + log.Tracer(ctx).Tracef("intel: loading IP list for %s", ip) e.loadIPListOnce.Do(func() { list, err := filterlists.LookupIP(ip) if err != nil { - log.Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err) + log.Tracer(ctx).Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err) e.loadIPListOnce = sync.Once{} return } @@ -418,8 +437,8 @@ func (e *Entity) getIPLists() { // LoadLists searches all filterlists for all occurrences of // this entity. -func (e *Entity) LoadLists() bool { - e.getLists() +func (e *Entity) LoadLists(ctx context.Context) bool { + e.getLists(ctx) return e.ListOccurences != nil } diff --git a/intel/geoip/lookup.go b/intel/geoip/lookup.go index 73c60d05..b4c7f320 100644 --- a/intel/geoip/lookup.go +++ b/intel/geoip/lookup.go @@ -4,7 +4,6 @@ import ( "net" "github.com/oschwald/maxminddb-golang" - "github.com/safing/portbase/log" ) func getReader(ip net.IP) *maxminddb.Reader { @@ -49,7 +48,5 @@ func GetLocation(ip net.IP) (record *Location, err error) { return nil, err } - log.Tracef("geoip: record: %+v", record) - return record, nil } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index b1375a25..ac46c6d2 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -76,12 +76,15 @@ func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { return handleRequest(ctx, w, query) }) if err != nil { - log.Warningf("intel: failed to handle dns request: %s", err) + log.Warningf("nameserver: failed to handle dns request: %s", err) } } func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO // Only process first question, that's how everyone does it. + if len(request.Question) == 0 { + return errors.New("missing question") + } originalQuestion := request.Question[0] // Check if we are handling a non-standard query name. @@ -116,7 +119,12 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // Setup quick reply function. reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error { - return sendResponse(ctx, w, request, responder, rrProviders...) + err := sendResponse(ctx, w, request, responder, rrProviders...) + // Log error here instead of returning it in order to keep the context. + if err != nil { + tracer.Errorf("nameserver: %s", err) + } + return nil } // Return with server failure if offline. diff --git a/nameserver/nsutil/nsutil.go b/nameserver/nsutil/nsutil.go index 7a8b730b..0a6f103d 100644 --- a/nameserver/nsutil/nsutil.go +++ b/nameserver/nsutil/nsutil.go @@ -10,6 +10,11 @@ import ( "github.com/safing/portbase/log" ) +var ( + // ErrNilRR is returned when a parsed RR is nil. + ErrNilRR = errors.New("is nil") +) + // Responder defines the interface that any block/deny reason interface // may implement to support sending custom DNS responses for a given reason. // That is, if a reason context implements the Responder interface the @@ -39,8 +44,9 @@ func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns return rf(ctx, request) } -// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for -// each A or AAAA question respectively. +// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for each A +// or AAAA question respectively. If there is no A or AAAA question, it +// defaults to replying with NXDomain. func ZeroIP(msgs ...string) ResponderFunc { return func(ctx context.Context, request *dns.Msg) *dns.Msg { reply := new(dns.Msg) @@ -52,15 +58,16 @@ func ZeroIP(msgs ...string) ResponderFunc { switch question.Qtype { case dns.TypeA: - rr, err = dns.NewRR(question.Name + " 0 IN A 0.0.0.0") + rr, err = dns.NewRR(question.Name + " 0 IN A 0.0.0.0") case dns.TypeAAAA: - rr, err = dns.NewRR(question.Name + " 0 IN AAAA ::") + rr, err = dns.NewRR(question.Name + " 0 IN AAAA ::") } - if err != nil { + switch { + case err != nil: log.Tracer(ctx).Errorf("nameserver: failed to create zero-ip response for %s: %s", question.Name, err) hasErr = true - } else { + case rr != nil: reply.Answer = append(reply.Answer, rr) } } @@ -81,6 +88,7 @@ func ZeroIP(msgs ...string) ResponderFunc { } // Localhost is a ResponderFunc than replies with localhost IP addresses. +// If there is no A or AAAA question, it defaults to replying with NXDomain. func Localhost(msgs ...string) ResponderFunc { return func(ctx context.Context, request *dns.Msg) *dns.Msg { reply := new(dns.Msg) @@ -97,10 +105,11 @@ func Localhost(msgs ...string) ResponderFunc { rr, err = dns.NewRR("localhost. 0 IN AAAA ::1") } - if err != nil { + switch { + case err != nil: log.Tracer(ctx).Errorf("nameserver: failed to create localhost response for %s: %s", question.Name, err) hasErr = true - } else { + case rr != nil: reply.Answer = append(reply.Answer, rr) } } @@ -159,7 +168,7 @@ func MakeMessageRecord(level log.Severity, msg string) (dns.RR, error) { //nolin return nil, err } if rr == nil { - return nil, errors.New("record is nil") + return nil, ErrNilRR } return rr, nil } diff --git a/nameserver/only/nameserver.go b/nameserver/only/nameserver.go deleted file mode 100644 index 6d5cb5cb..00000000 --- a/nameserver/only/nameserver.go +++ /dev/null @@ -1,236 +0,0 @@ -package only - -import ( - "context" - "net" - "strings" - - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portmaster/netenv" - "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/resolver" - - "github.com/miekg/dns" -) - -var ( - module *modules.Module - dnsServer *dns.Server - mtDNSRequest = "dns request" - - listenAddress = "127.0.0.1:53" - ipv4Localhost = net.IPv4(127, 0, 0, 1) - localhostRRs []dns.RR -) - -func init() { - module = modules.Register("nameserver", initLocalhostRRs, start, stop, "core", "resolver", "network", "netenv") -} - -func initLocalhostRRs() error { - localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1") - if err != nil { - return err - } - - localhostIPv6, err := dns.NewRR("localhost. 17 IN AAAA ::1") - if err != nil { - return err - } - - localhostRRs = []dns.RR{localhostIPv4, localhostIPv6} - return nil -} - -func start() error { - dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"} - dns.HandleFunc(".", handleRequestAsMicroTask) - - module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { - err := dnsServer.ListenAndServe() - if err != nil { - // check if we are shutting down - if module.IsStopping() { - return nil - } - } - return err - }) - - return nil -} - -func stop() error { - if dnsServer != nil { - return dnsServer.Shutdown() - } - return nil -} - -func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeNameError) - _ = w.WriteMsg(m) -} - -func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeServerFailure) - _ = w.WriteMsg(m) -} - -func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) { - err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error { - return handleRequest(ctx, w, query) - }) - if err != nil { - log.Warningf("nameserver: failed to handle dns request: %s", err) - } -} - -func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) error { - // return with server failure if offline - if netenv.GetOnlineStatus() == netenv.StatusOffline { - returnServerFailure(w, query) - return nil - } - - // only process first question, that's how everyone does it. - question := query.Question[0] - q := &resolver.Query{ - FQDN: question.Name, - QType: dns.Type(question.Qtype), - } - - // check class - if question.Qclass != dns.ClassINET { - // we only serve IN records, return nxdomain - returnNXDomain(w, query) - return nil - } - - // handle request for localhost - if strings.HasSuffix(q.FQDN, "localhost.") { - m := new(dns.Msg) - m.SetReply(query) - m.Answer = localhostRRs - _ = w.WriteMsg(m) - return nil - } - - // get addresses - remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr) - if !ok { - log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType) - return nil - } - if !remoteAddr.IP.Equal(ipv4Localhost) { - // if request is not coming from 127.0.0.1, check if it's really local - - localAddr, ok := w.RemoteAddr().(*net.UDPAddr) - if !ok { - log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", q.FQDN, q.QType) - return nil - } - - // ignore external request - if !remoteAddr.IP.Equal(localAddr.IP) { - log.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType) - return nil - } - } - - // check if valid domain name - if !netutils.IsValidFqdn(q.FQDN) { - log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN) - returnNXDomain(w, query) - return nil - } - - // start tracer - ctx, tracer := log.AddTracer(ctx) - tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) - - // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain - - // get intel and RRs - rrCache, err := resolver.Resolve(ctx, q) - if err != nil { - // TODO: analyze nxdomain requests, malware could be trying DGA-domains - tracer.Warningf("nameserver: request for %s%s: %s", q.FQDN, q.QType, err) - returnNXDomain(w, query) - return nil - } - - // save IP addresses to IPInfo - cnames := make(map[string]string) - ips := make(map[string]struct{}) - - for _, rr := range append(rrCache.Answer, rrCache.Extra...) { - switch v := rr.(type) { - case *dns.CNAME: - cnames[v.Hdr.Name] = v.Target - - case *dns.A: - ips[v.A.String()] = struct{}{} - - case *dns.AAAA: - ips[v.AAAA.String()] = struct{}{} - } - } - - for ip := range ips { - record := resolver.ResolvedDomain{ - Domain: q.FQDN, - } - - // resolve all CNAMEs in the correct order. - var domain = q.FQDN - for { - nextDomain, isCNAME := cnames[domain] - if !isCNAME { - break - } - - record.CNAMEs = append(record.CNAMEs, nextDomain) - domain = nextDomain - } - - // get the existing IP info or create a new one - var save bool - info, err := resolver.GetIPInfo(ip) - if err != nil { - if err != database.ErrNotFound { - log.Errorf("nameserver: failed to search for IP info record: %s", err) - } - - info = &resolver.IPInfo{ - IP: ip, - } - save = true - } - - // and the new resolved domain record and save - if new := info.AddDomain(record); new { - save = true - } - if save { - if err := info.Save(); err != nil { - log.Errorf("nameserver: failed to save IP info record: %s", err) - } - } - } - - // reply to query - m := new(dns.Msg) - m.SetReply(query) - m.Answer = rrCache.Answer - m.Ns = rrCache.Ns - m.Extra = rrCache.Extra - _ = w.WriteMsg(m) - tracer.Debugf("nameserver: returning response %s%s", q.FQDN, q.QType) - - return nil -} diff --git a/nameserver/response.go b/nameserver/response.go index 4171b36f..18e9f75d 100644 --- a/nameserver/response.go +++ b/nameserver/response.go @@ -36,19 +36,19 @@ func sendResponse( } // Write reply. - if err := writeDNSResponse(w, reply); err != nil { - return fmt.Errorf("nameserver: failed to send response: %w", err) + if err := writeDNSResponse(ctx, w, reply); err != nil { + return fmt.Errorf("failed to send response: %w", err) } return nil } -func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) { +func writeDNSResponse(ctx context.Context, w dns.ResponseWriter, m *dns.Msg) (err error) { defer func() { // recover from panic if panicErr := recover(); panicErr != nil { err = fmt.Errorf("panic: %s", panicErr) - log.Warningf("nameserver: panic caused by this msg: %#v", m) + log.Tracer(ctx).Debugf("nameserver: panic caused by this msg: %#v", m) } }() @@ -56,10 +56,11 @@ func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) { if err != nil { // If we receive an error we might have exceeded the message size with all // our extra information records. Retry again without the extra section. + log.Tracer(ctx).Tracef("nameserver: retrying to write dns message without extra section, error was: %s", err) m.Extra = nil noExtraErr := w.WriteMsg(m) if noExtraErr == nil { - log.Warningf("nameserver: failed to write dns message with extra section: %s", err) + return fmt.Errorf("failed to write dns message without extra section: %w", err) } } return diff --git a/netenv/location_windows.go b/netenv/location_windows.go index 8234d4ed..92d34291 100644 --- a/netenv/location_windows.go +++ b/netenv/location_windows.go @@ -9,6 +9,8 @@ import ( "unsafe" ) +// Windows specific constants for the WSAIoctl interface. +//nolint:golint,stylecheck const ( SIO_RCVALL = syscall.IOC_IN | syscall.IOC_VENDOR | 1 diff --git a/network/connection.go b/network/connection.go index 2474437a..1a5d63c2 100644 --- a/network/connection.go +++ b/network/connection.go @@ -172,7 +172,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri }, ) if err != nil { - log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) + log.Tracer(ctx).Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) } @@ -196,7 +196,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // get Process proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info()) if err != nil { - log.Debugf("network: failed to find process of packet %s: %s", pkt, err) + log.Tracer(pkt.Ctx()).Debugf("network: failed to find process of packet %s: %s", pkt, err) proc = process.GetUnidentifiedProcess(pkt.Ctx()) } @@ -224,6 +224,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { Protocol: uint8(pkt.Info().Protocol), Port: pkt.Info().SrcPort, } + entity.SetDstPort(pkt.Info().DstPort) } else { @@ -233,11 +234,12 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { Protocol: uint8(pkt.Info().Protocol), Port: pkt.Info().DstPort, } + entity.SetDstPort(entity.Port) // check if we can find a domain for that IP - ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String()) + ipinfo, err := resolver.GetIPInfo(proc.LocalProfileKey, pkt.Info().Dst.String()) if err == nil { - lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain() + lastResolvedDomain := ipinfo.MostRecentDomain() if lastResolvedDomain != nil { scope = lastResolvedDomain.Domain entity.Domain = lastResolvedDomain.Domain @@ -299,9 +301,7 @@ func GetConnection(id string) (*Connection, bool) { // AcceptWithContext accepts the connection. func (conn *Connection) AcceptWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictAccept, reason, ctx) { - log.Infof("filter: granting connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictAccept, reason, ctx) { log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict) } } @@ -313,9 +313,7 @@ func (conn *Connection) Accept(reason string) { // BlockWithContext blocks the connection. func (conn *Connection) BlockWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictBlock, reason, ctx) { - log.Infof("filter: blocking connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictBlock, reason, ctx) { log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict) } } @@ -327,9 +325,7 @@ func (conn *Connection) Block(reason string) { // DropWithContext drops the connection. func (conn *Connection) DropWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictDrop, reason, ctx) { - log.Infof("filter: dropping connection %s, %s", conn, conn.Reason) - } else { + if !conn.SetVerdict(VerdictDrop, reason, ctx) { log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict) } } @@ -355,9 +351,7 @@ func (conn *Connection) Deny(reason string) { // FailedWithContext marks the connection with VerdictFailed and stores the reason. func (conn *Connection) FailedWithContext(reason string, ctx interface{}) { - if conn.SetVerdict(VerdictFailed, reason, ctx) { - log.Infof("filter: dropping connection %s because of an internal error: %s", conn, reason) - } else { + if !conn.SetVerdict(VerdictFailed, reason, ctx) { log.Warningf("filter: tried to drop %s due to error but current verdict is %s", conn, conn.Verdict) } } @@ -495,6 +489,9 @@ func (conn *Connection) packetHandler() { } else { defaultFirewallHandler(conn, pkt) } + // log verdict + log.Tracer(pkt.Ctx()).Infof("filter: connection %s %s: %s", conn, conn.Verdict.Verb(), conn.Reason) + // save does not touch any changing data // must not be locked, will deadlock with cleaner functions if conn.saveWhenFinished { diff --git a/profile/database.go b/profile/database.go index 82888c9d..4775bc82 100644 --- a/profile/database.go +++ b/profile/database.go @@ -22,8 +22,8 @@ const ( var ( profileDB = database.NewInterface(&database.Options{ - Local: true, - Internal: true, + Local: true, + Internal: true, }) ) diff --git a/profile/endpoints/endpoint-any.go b/profile/endpoints/endpoint-any.go index 8e8deb98..14960489 100644 --- a/profile/endpoints/endpoint-any.go +++ b/profile/endpoints/endpoint-any.go @@ -1,6 +1,10 @@ package endpoints -import "github.com/safing/portmaster/intel" +import ( + "context" + + "github.com/safing/portmaster/intel" +) // EndpointAny matches anything. type EndpointAny struct { @@ -8,7 +12,7 @@ type EndpointAny struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointAny) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointAny) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { return ep.match(ep, entity, "*", "matches") } diff --git a/profile/endpoints/endpoint-asn.go b/profile/endpoints/endpoint-asn.go index 6713d199..6c11f1a6 100644 --- a/profile/endpoints/endpoint-asn.go +++ b/profile/endpoints/endpoint-asn.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "regexp" "strconv" @@ -20,8 +21,8 @@ type EndpointASN struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointASN) Matches(entity *intel.Entity) (EPResult, Reason) { - asn, ok := entity.GetASN() +func (ep *EndpointASN) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + asn, ok := entity.GetASN(ctx) if !ok { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-country.go b/profile/endpoints/endpoint-country.go index 85449cf5..d47d6e1f 100644 --- a/profile/endpoints/endpoint-country.go +++ b/profile/endpoints/endpoint-country.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "regexp" "strings" @@ -19,8 +20,8 @@ type EndpointCountry struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointCountry) Matches(entity *intel.Entity) (EPResult, Reason) { - country, ok := entity.GetCountry() +func (ep *EndpointCountry) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + country, ok := entity.GetCountry(ctx) if !ok { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index b02fc9a5..491b038d 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "regexp" "strings" @@ -62,19 +63,20 @@ func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointDomain) Matches(entity *intel.Entity) (EPResult, Reason) { - if entity.Domain == "" { +func (ep *EndpointDomain) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + domain, ok := entity.GetDomain(ctx, true /* mayUseReverseDomain */) + if !ok { return NoMatch, nil } - result, reason := ep.check(entity, entity.Domain) + result, reason := ep.check(entity, domain) if result != NoMatch { return result, reason } if entity.CNAMECheckEnabled() { - for _, domain := range entity.CNAME { - result, reason = ep.check(entity, domain) + for _, cname := range entity.CNAME { + result, reason = ep.check(entity, cname) if result == Denied { return result, reason } diff --git a/profile/endpoints/endpoint-ip.go b/profile/endpoints/endpoint-ip.go index 43ea47f7..08110247 100644 --- a/profile/endpoints/endpoint-ip.go +++ b/profile/endpoints/endpoint-ip.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "github.com/safing/portmaster/intel" @@ -14,7 +15,7 @@ type EndpointIP struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIP) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointIP) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-iprange.go b/profile/endpoints/endpoint-iprange.go index bc0d22fe..d4be35db 100644 --- a/profile/endpoints/endpoint-iprange.go +++ b/profile/endpoints/endpoint-iprange.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "github.com/safing/portmaster/intel" @@ -14,7 +15,7 @@ type EndpointIPRange struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointIPRange) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointIPRange) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-lists.go b/profile/endpoints/endpoint-lists.go index 27ec8b00..618c66d9 100644 --- a/profile/endpoints/endpoint-lists.go +++ b/profile/endpoints/endpoint-lists.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "strings" "github.com/safing/portmaster/intel" @@ -15,8 +16,8 @@ type EndpointLists struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointLists) Matches(entity *intel.Entity) (EPResult, Reason) { - if !entity.LoadLists() { +func (ep *EndpointLists) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + if !entity.LoadLists(ctx) { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go index ea22126d..6f1c2f27 100644 --- a/profile/endpoints/endpoint-scopes.go +++ b/profile/endpoints/endpoint-scopes.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "strings" "github.com/safing/portmaster/network/netutils" @@ -30,7 +31,7 @@ type EndpointScope struct { } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { +func (ep *EndpointScope) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { return Undeterminable, nil } diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 2e0a4e85..929a70eb 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "strconv" "strings" @@ -11,7 +12,7 @@ import ( // Endpoint describes an Endpoint Matcher type Endpoint interface { - Matches(entity *intel.Entity) (EPResult, Reason) + Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) String() string } @@ -69,11 +70,11 @@ func (ep *EndpointBase) matchesPPP(entity *intel.Entity) (result EPResult) { // only check if port is defined if ep.StartPort > 0 { // if port is unknown, return Undeterminable - if entity.Port == 0 { + if entity.DstPort() == 0 { return Undeterminable } // if port does not match, return NoMatch - if entity.Port < ep.StartPort || entity.Port > ep.EndPort { + if entity.DstPort() < ep.StartPort || entity.DstPort() > ep.EndPort { return NoMatch } } @@ -219,10 +220,6 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocog if endpoint, err = parseTypeIPRange(fields); endpoint != nil || err != nil { return } - // domain - if endpoint, err = parseTypeDomain(fields); endpoint != nil || err != nil { - return - } // country if endpoint, err = parseTypeCountry(fields); endpoint != nil || err != nil { return @@ -239,6 +236,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocog if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil { return } + // domain + if endpoint, err = parseTypeDomain(fields); endpoint != nil || err != nil { + return + } return nil, fmt.Errorf(`unknown endpoint definition: "%s"`, value) } diff --git a/profile/endpoints/endpoints.go b/profile/endpoints/endpoints.go index f74edd24..76c3d8ba 100644 --- a/profile/endpoints/endpoints.go +++ b/profile/endpoints/endpoints.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "fmt" "strings" @@ -63,10 +64,10 @@ func (e Endpoints) IsSet() bool { } // Match checks whether the given entity matches any of the endpoint definitions in the list. -func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason Reason) { +func (e Endpoints) Match(ctx context.Context, entity *intel.Entity) (result EPResult, reason Reason) { for _, entry := range e { if entry != nil { - if result, reason = entry.Matches(entity); result != NoMatch { + if result, reason = entry.Matches(ctx, entity); result != NoMatch { return } } diff --git a/profile/endpoints/endpoints_test.go b/profile/endpoints/endpoints_test.go index 531544d8..9b28e131 100644 --- a/profile/endpoints/endpoints_test.go +++ b/profile/endpoints/endpoints_test.go @@ -1,6 +1,7 @@ package endpoints import ( + "context" "net" "runtime" "testing" @@ -16,7 +17,9 @@ func TestMain(m *testing.M) { } func testEndpointMatch(t *testing.T, ep Endpoint, entity *intel.Entity, expectedResult EPResult) { - result, _ := ep.Matches(entity) + entity.SetDstPort(entity.Port) + + result, _ := ep.Matches(context.TODO(), entity) if result != expectedResult { t.Errorf( "line %d: unexpected result for endpoint %s and entity %+v: result=%s, expected=%s", diff --git a/profile/profile-layered.go b/profile/profile-layered.go index 9a5e0d50..e292edfb 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -1,6 +1,7 @@ package profile import ( + "context" "sync" "sync/atomic" @@ -221,10 +222,10 @@ func (lp *LayeredProfile) DefaultAction() uint8 { } // MatchEndpoint checks if the given endpoint matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { +func (lp *LayeredProfile) MatchEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { for _, layer := range lp.layers { if layer.endpoints.IsSet() { - result, reason := layer.endpoints.Match(entity) + result, reason := layer.endpoints.Match(ctx, entity) if endpoints.IsDecision(result) { return result, reason } @@ -233,16 +234,16 @@ func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResul cfgLock.RLock() defer cfgLock.RUnlock() - return cfgEndpoints.Match(entity) + return cfgEndpoints.Match(ctx, entity) } // MatchServiceEndpoint checks if the given endpoint of an inbound connection matches an entry in any of the profiles. -func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { +func (lp *LayeredProfile) MatchServiceEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { entity.EnableReverseResolving() for _, layer := range lp.layers { if layer.serviceEndpoints.IsSet() { - result, reason := layer.serviceEndpoints.Match(entity) + result, reason := layer.serviceEndpoints.Match(ctx, entity) if endpoints.IsDecision(result) { return result, reason } @@ -251,19 +252,19 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints. cfgLock.RLock() defer cfgLock.RUnlock() - return cfgServiceEndpoints.Match(entity) + return cfgServiceEndpoints.Match(ctx, entity) } // MatchFilterLists matches the entity against the set of filter // lists. -func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { - entity.ResolveSubDomainLists(lp.FilterSubDomains()) - entity.EnableCNAMECheck(lp.FilterCNAMEs()) +func (lp *LayeredProfile) MatchFilterLists(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) { + entity.ResolveSubDomainLists(ctx, lp.FilterSubDomains()) + entity.EnableCNAMECheck(ctx, lp.FilterCNAMEs()) for _, layer := range lp.layers { // search for the first layer that has filterListIDs set if len(layer.filterListIDs) > 0 { - entity.LoadLists() + entity.LoadLists(ctx) if entity.MatchLists(layer.filterListIDs) { return endpoints.Denied, entity.ListBlockReason() @@ -276,7 +277,7 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe cfgLock.RLock() defer cfgLock.RUnlock() if len(cfgFilterLists) > 0 { - entity.LoadLists() + entity.LoadLists(ctx) if entity.MatchLists(cfgFilterLists) { return endpoints.Denied, entity.ListBlockReason() diff --git a/resolver/ipinfo.go b/resolver/ipinfo.go index 0ecf9766..7217e306 100644 --- a/resolver/ipinfo.go +++ b/resolver/ipinfo.go @@ -7,12 +7,26 @@ import ( "github.com/safing/portbase/database" "github.com/safing/portbase/database/record" - "github.com/safing/portbase/utils" +) + +const ( + // IPInfoProfileScopeGlobal is the profile scope used for unscoped IPInfo entries. + IPInfoProfileScopeGlobal = "global" ) var ( ipInfoDatabase = database.NewInterface(&database.Options{ - AlwaysSetRelativateExpiry: 86400, // 24 hours + Local: true, + Internal: true, + + // Cache entries because new/updated entries will often be queries soon + // after inserted. + CacheSize: 256, + + // We only use the cache database here, so we can delay and batch all our + // writes. Also, no one else accesses these records, so we are fine using + // this. + DelayCachedWrites: "cache", }) ) @@ -25,6 +39,11 @@ type ResolvedDomain struct { // CNAMEs is a list of CNAMEs that have been resolved for // Domain. CNAMEs []string + + // Expires holds the timestamp when this entry expires. + // This does not mean that the entry may not be used anymore afterwards, + // but that this is used to calcuate the TTL of the database record. + Expires int64 } // String returns a string representation of ResolvedDomain including @@ -54,29 +73,16 @@ func (rds ResolvedDomains) String() string { return strings.Join(domains, " or ") } -// MostRecentDomain returns the most recent domain. -func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain { - if len(rds) == 0 { - return nil - } - // TODO(ppacher): we could also do that by using ResolvedAt() - mostRecent := rds[len(rds)-1] - return &mostRecent -} - // IPInfo represents various information about an IP. type IPInfo struct { record.Base sync.Mutex - // IP holds the acutal IP address. + // IP holds the actual IP address. IP string - // Domains holds a list of domains that have been - // resolved to IP. This field is deprecated and should - // be removed. - // DEPRECATED: remove with alpha. - Domains []string `json:"Domains,omitempty"` + // ProfileID is used to scope this entry to a process group. + ProfileID string // ResolvedDomain is a slice of domains that // have been requested by various applications @@ -84,35 +90,43 @@ type IPInfo struct { ResolvedDomains ResolvedDomains } -// AddDomain adds a new resolved domain to ipi. -func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool { - for idx, d := range ipi.ResolvedDomains { - if d.Domain == resolved.Domain { - if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) { - return false - } +// AddDomain adds a new resolved domain to IPInfo. +func (info *IPInfo) AddDomain(resolved ResolvedDomain) { + info.Lock() + defer info.Unlock() - // we have a different CNAME chain now, remove the previous - // entry and add it at the end. - ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...) - ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) - return true + // Delete old for the same domain. + for idx, d := range info.ResolvedDomains { + if d.Domain == resolved.Domain { + info.ResolvedDomains = append(info.ResolvedDomains[:idx], info.ResolvedDomains[idx+1:]...) + break } } - ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved) - return true + // Add new entry to the end. + info.ResolvedDomains = append(info.ResolvedDomains, resolved) } -func makeIPInfoKey(ip string) string { - return fmt.Sprintf("cache:intel/ipInfo/%s", ip) +// MostRecentDomain returns the most recent domain. +func (info *IPInfo) MostRecentDomain() *ResolvedDomain { + info.Lock() + defer info.Unlock() + + if len(info.ResolvedDomains) == 0 { + return nil + } + + mostRecent := info.ResolvedDomains[len(info.ResolvedDomains)-1] + return &mostRecent +} + +func makeIPInfoKey(profileID, ip string) string { + return fmt.Sprintf("cache:intel/ipInfo/%s/%s", profileID, ip) } // GetIPInfo gets an IPInfo record from the database. -func GetIPInfo(ip string) (*IPInfo, error) { - key := makeIPInfoKey(ip) - - r, err := ipInfoDatabase.Get(key) +func GetIPInfo(profileID, ip string) (*IPInfo, error) { + r, err := ipInfoDatabase.Get(makeIPInfoKey(profileID, ip)) if err != nil { return nil, err } @@ -126,18 +140,6 @@ func GetIPInfo(ip string) (*IPInfo, error) { return nil, err } - // Legacy support, - // DEPRECATED: remove with alpha - if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 { - for _, d := range new.Domains { - new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{ - Domain: d, - // rest is empty... - }) - } - new.Domains = nil // clean up so we remove it from the database - } - return new, nil } @@ -150,27 +152,38 @@ func GetIPInfo(ip string) (*IPInfo, error) { } // Save saves the IPInfo record to the database. -func (ipi *IPInfo) Save() error { - ipi.Lock() - if !ipi.KeyIsSet() { - ipi.SetKey(makeIPInfoKey(ipi.IP)) - } - ipi.Unlock() +func (info *IPInfo) Save() error { + info.Lock() - // Legacy support - // Ensure we don't write new Domain fields into the - // database. - // DEPRECATED: remove with alpha - if len(ipi.Domains) > 0 { - ipi.Domains = nil + // Set database key if not yet set already. + if !info.KeyIsSet() { + // Default to global scope if scope is unset. + if info.ProfileID == "" { + info.ProfileID = IPInfoProfileScopeGlobal + } + info.SetKey(makeIPInfoKey(info.ProfileID, info.IP)) } - return ipInfoDatabase.Put(ipi) + // Calculate and set cache expiry. + var expires int64 = 86400 // Minimum TTL of one day. + for _, rd := range info.ResolvedDomains { + if rd.Expires > expires { + expires = rd.Expires + } + } + info.UpdateMeta() + expires += 3600 // Add one hour to expiry as a buffer. + info.Meta().SetAbsoluteExpiry(expires) + + info.Unlock() + + return ipInfoDatabase.Put(info) } // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " -func (ipi *IPInfo) String() string { - ipi.Lock() - defer ipi.Unlock() - return fmt.Sprintf("", info.Key(), info.IP, info.ResolvedDomains.String()) } diff --git a/resolver/ipinfo_test.go b/resolver/ipinfo_test.go index 02385244..759d0fed 100644 --- a/resolver/ipinfo_test.go +++ b/resolver/ipinfo_test.go @@ -15,7 +15,7 @@ func TestIPInfo(t *testing.T) { CNAMEs: []string{"example.com"}, } - ipi := &IPInfo{ + info := &IPInfo{ IP: "1.2.3.4", ResolvedDomains: ResolvedDomains{ example, @@ -27,22 +27,18 @@ func TestIPInfo(t *testing.T) { Domain: "sub2.example.com", CNAMEs: []string{"sub1.example.com", "example.com"}, } - added := ipi.AddDomain(sub2Example) - - assert.True(t, added) - assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains) + info.AddDomain(sub2Example) + assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains) // try again, should do nothing now - added = ipi.AddDomain(sub2Example) - assert.False(t, added) - assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains) + info.AddDomain(sub2Example) + assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains) subOverWrite := ResolvedDomain{ Domain: "sub1.example.com", CNAMEs: []string{}, // now without CNAMEs } - added = ipi.AddDomain(subOverWrite) - assert.True(t, added) - assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, ipi.ResolvedDomains) + info.AddDomain(subOverWrite) + assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, info.ResolvedDomains) } diff --git a/resolver/main.go b/resolver/main.go index c586d0df..cc0ebe62 100644 --- a/resolver/main.go +++ b/resolver/main.go @@ -93,6 +93,9 @@ func start() error { listenToMDNS, ) + module.StartServiceWorker("name record delayed cache writer", 0, recordDatabase.DelayedCacheWriter) + module.StartServiceWorker("ip info delayed cache writer", 0, ipInfoDatabase.DelayedCacheWriter) + return nil } diff --git a/resolver/namerecord.go b/resolver/namerecord.go index 45deae30..d0faf14c 100644 --- a/resolver/namerecord.go +++ b/resolver/namerecord.go @@ -12,10 +12,24 @@ import ( "github.com/safing/portbase/log" ) +const ( + // databaseOvertime defines how much longer than the TTL name records are + // cached in the database. + databaseOvertime = 86400 * 14 // two weeks +) + var ( recordDatabase = database.NewInterface(&database.Options{ - AlwaysSetRelativateExpiry: 2592000, // 30 days - CacheSize: 256, + Local: true, + Internal: true, + + // Cache entries because application often resolve domains multiple times. + CacheSize: 256, + + // We only use the cache database here, so we can delay and batch all our + // writes. Also, no one else accesses these records, so we are fine using + // this. + DelayCachedWrites: "cache", }) nameRecordsKeyPrefix = "cache:intel/nameRecord/" @@ -32,7 +46,7 @@ type NameRecord struct { Answer []string Ns []string Extra []string - TTL int64 + Expires int64 Server string ServerScope int8 @@ -84,6 +98,9 @@ func (rec *NameRecord) Save() error { } rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question)) + rec.UpdateMeta() + rec.Meta().SetAbsoluteExpiry(rec.Expires + databaseOvertime) + return recordDatabase.PutNew(rec) } diff --git a/resolver/resolve.go b/resolver/resolve.go index b917968a..22f2fe23 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -220,19 +220,19 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Tracef( "resolver: cache for %s will expire in %s, refreshing async now", q.ID(), - time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second), + time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second), ) // resolve async - module.StartWorker("resolve async", func(ctx context.Context) error { - ctx, tracer := log.AddTracer(ctx) + module.StartWorker("resolve async", func(asyncCtx context.Context) error { + tracingCtx, tracer := log.AddTracer(asyncCtx) defer tracer.Submit() - tracer.Debugf("resolver: resolving %s async", q.ID()) - _, err := resolveAndCache(ctx, q, nil) + tracer.Tracef("resolver: resolving %s async", q.ID()) + _, err := resolveAndCache(tracingCtx, q, nil) if err != nil { tracer.Warningf("resolver: async query for %s failed: %s", q.ID(), err) } else { - tracer.Debugf("resolver: async query for %s succeeded", q.ID()) + tracer.Infof("resolver: async query for %s succeeded", q.ID()) } return nil }) @@ -242,7 +242,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Tracef( "resolver: using cached RR (expires in %s)", - time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second), + time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second), ) return rrCache } diff --git a/resolver/rrcache.go b/resolver/rrcache.go index 1eee7b4a..6be85659 100644 --- a/resolver/rrcache.go +++ b/resolver/rrcache.go @@ -25,10 +25,10 @@ type RRCache struct { RCode int // Response Content - Answer []dns.RR - Ns []dns.RR - Extra []dns.RR - TTL int64 + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR + Expires int64 // Source Information Server string @@ -54,12 +54,12 @@ func (rrCache *RRCache) ID() string { // Expired returns whether the record has expired. func (rrCache *RRCache) Expired() bool { - return rrCache.TTL <= time.Now().Unix() + return rrCache.Expires <= time.Now().Unix() } // ExpiresSoon returns whether the record will expire soon and should already be refreshed. func (rrCache *RRCache) ExpiresSoon() bool { - return rrCache.TTL <= time.Now().Unix()+refreshTTL + return rrCache.Expires <= time.Now().Unix()+refreshTTL } // Clean sets all TTLs to 17 and sets cache expiry with specified minimum. @@ -99,7 +99,7 @@ func (rrCache *RRCache) Clean(minExpires uint32) { } // log.Tracef("lowest TTL is %d", lowestTTL) - rrCache.TTL = time.Now().Unix() + int64(lowestTTL) + rrCache.Expires = time.Now().Unix() + int64(lowestTTL) } // ExportAllARecords return of a list of all A and AAAA IP addresses. @@ -131,7 +131,7 @@ func (rrCache *RRCache) ToNameRecord() *NameRecord { Domain: rrCache.Domain, Question: rrCache.Question.String(), RCode: rrCache.RCode, - TTL: rrCache.TTL, + Expires: rrCache.Expires, Server: rrCache.Server, ServerScope: rrCache.ServerScope, ServerInfo: rrCache.ServerInfo, @@ -188,7 +188,7 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) { } rrCache.RCode = nameRecord.RCode - rrCache.TTL = nameRecord.TTL + rrCache.Expires = nameRecord.Expires for _, entry := range nameRecord.Answer { rrCache.Answer = parseRR(rrCache.Answer, entry) } @@ -249,10 +249,10 @@ func (rrCache *RRCache) ShallowCopy() *RRCache { Question: rrCache.Question, RCode: rrCache.RCode, - Answer: rrCache.Answer, - Ns: rrCache.Ns, - Extra: rrCache.Extra, - TTL: rrCache.TTL, + Answer: rrCache.Answer, + Ns: rrCache.Ns, + Extra: rrCache.Extra, + Expires: rrCache.Expires, Server: rrCache.Server, ServerScope: rrCache.ServerScope, @@ -310,9 +310,9 @@ func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra // Add expiry and cache information. if rrCache.Expired() { - extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.TTL, 0)).Round(time.Second))) + extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.Expires, 0)).Round(time.Second))) } else { - extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second))) + extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second))) } if rrCache.RequestingNew { extra = addExtra(ctx, extra, "async request to refresh the cache has been started")