diff --git a/firewall/dns.go b/firewall/dns.go index 32f63b70..953cfebe 100644 --- a/firewall/dns.go +++ b/firewall/dns.go @@ -16,7 +16,7 @@ import ( "github.com/safing/portmaster/resolver" ) -func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ([]dns.RR, []string, int, string) { +func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, resolverScope netutils.IPScope, sysResolver bool) ([]dns.RR, []string, int, string) { goodEntries := make([]dns.RR, 0, len(entries)) filteredRecords := make([]string, 0, len(entries)) @@ -38,16 +38,16 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ( goodEntries = append(goodEntries, rr) continue } - classification := netutils.ClassifyIP(ip) + ipScope := netutils.GetIPScope(ip) if p.RemoveOutOfScopeDNS() { switch { - case classification == netutils.HostLocal: + case ipScope.IsLocalhost(): // No DNS should return localhost addresses filteredRecords = append(filteredRecords, rr.String()) interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey continue - case scope == netutils.Global && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + case resolverScope.IsGlobal() && ipScope.IsLAN() && !sysResolver: // No global DNS should return LAN addresses filteredRecords = append(filteredRecords, rr.String()) interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey @@ -55,18 +55,18 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ( } } - if p.RemoveBlockedDNS() { + if p.RemoveBlockedDNS() && !sysResolver { // filter by flags switch { - case p.BlockScopeInternet() && classification == netutils.Global: + case p.BlockScopeInternet() && ipScope.IsGlobal(): filteredRecords = append(filteredRecords, rr.String()) interveningOptionKey = profile.CfgOptionBlockScopeInternetKey continue - case p.BlockScopeLAN() && (classification == netutils.SiteLocal || classification == netutils.LinkLocal): + case p.BlockScopeLAN() && ipScope.IsLAN(): filteredRecords = append(filteredRecords, rr.String()) interveningOptionKey = profile.CfgOptionBlockScopeLANKey continue - case p.BlockScopeLocal() && classification == netutils.HostLocal: + case p.BlockScopeLocal() && ipScope.IsLocalhost(): filteredRecords = append(filteredRecords, rr.String()) interveningOptionKey = profile.CfgOptionBlockScopeLocalKey continue @@ -83,7 +83,7 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, scope int8) ( return goodEntries, filteredRecords, allowedAddressRecords, interveningOptionKey } -func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *resolver.RRCache { +func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache, sysResolver bool) *resolver.RRCache { p := conn.Process().Profile() // do not modify own queries @@ -104,11 +104,11 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res var validIPs int var interveningOptionKey string - rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(rrCache.Answer, p, rrCache.ServerScope) + rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(rrCache.Answer, p, rrCache.Resolver.IPScope, sysResolver) rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) // we don't count the valid IPs in the extra section - rrCache.Extra, filteredRecords, _, _ = filterDNSSection(rrCache.Extra, p, rrCache.ServerScope) + rrCache.Extra, filteredRecords, _, _ = filterDNSSection(rrCache.Extra, p, rrCache.Resolver.IPScope, sysResolver) rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...) if len(rrCache.FilteredEntries) > 0 { @@ -160,8 +160,9 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res return rrCache } -// DecideOnResolvedDNS filters a dns response according to the application profile and settings. -func DecideOnResolvedDNS( +// FilterResolvedDNS filters a dns response according to the application +// profile and settings. +func FilterResolvedDNS( ctx context.Context, conn *network.Connection, q *resolver.Query, @@ -174,14 +175,15 @@ func DecideOnResolvedDNS( return rrCache } - updatedRR := filterDNSResponse(conn, rrCache) + // Only filter criticial things if request comes from the system resolver. + sysResolver := conn.Process().IsSystemResolver() + + updatedRR := filterDNSResponse(conn, rrCache, sysResolver) if updatedRR == nil { return nil } - updateIPsAndCNAMEs(q, rrCache, conn) - - if mayBlockCNAMEs(ctx, conn) { + if !sysResolver && mayBlockCNAMEs(ctx, conn) { return nil } @@ -213,14 +215,23 @@ func mayBlockCNAMEs(ctx context.Context, conn *network.Connection) bool { return false } -// updateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and +// UpdateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and // updates the CNAMEs in the Connection's Entity. -func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { +func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) { + // Sanity check input, as this is called from defer. + if q == nil || rrCache == nil { + return + } + // Get profileID for scoping IPInfo. var profileID string - proc := conn.Process() - if proc != nil { - profileID = proc.LocalProfileKey + localProfile := conn.Process().Profile().LocalProfile() + switch localProfile.ID { + case profile.UnidentifiedProfileID, + profile.SystemResolverProfileID: + profileID = resolver.IPInfoProfileScopeGlobal + default: + profileID = localProfile.ID } // Collect IPs and CNAMEs. @@ -249,8 +260,9 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw // Create new record for this IP. record := resolver.ResolvedDomain{ - Domain: q.FQDN, - Expires: rrCache.Expires, + Domain: q.FQDN, + Expires: rrCache.Expires, + Resolver: rrCache.Resolver, } // Resolve all CNAMEs in the correct order and add the to the record. diff --git a/firewall/interception.go b/firewall/interception.go index a342b13a..30e16b7e 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -286,7 +286,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { // TODO: add implementation for forced tunneling if pkt.IsOutbound() && captain.ClientReady() && - netutils.IPIsGlobal(conn.Entity.IP) && + conn.Entity.IPScope.IsGlobal() && conn.Verdict == network.VerdictAccept { // try to tunnel err := sluice.AwaitRequest(pkt.Info(), conn.Entity.Domain) diff --git a/firewall/interception/interception_default.go b/firewall/interception/interception_default.go index 35e07f17..0ba5e1c3 100644 --- a/firewall/interception/interception_default.go +++ b/firewall/interception/interception_default.go @@ -9,7 +9,7 @@ import ( // start starts the interception. func start(_ chan packet.Packet) error { - log.Info("interception: this platform has no support for packet interception - a lot of functionality will be broken") + log.Critical("interception: this platform has no support for packet interception - a lot of functionality will be broken") return nil } diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 62f61930..a123d2f6 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -33,7 +33,7 @@ func start(ch chan packet.Packet) error { } go windowskext.Handler(ch) - go handleWindowsDNSCache() + go checkWindowsDNSCache() return nil } @@ -43,37 +43,21 @@ func stop() error { return windowskext.Stop() } -func handleWindowsDNSCache() { - - err := osdetail.StopService("dnscache") +func checkWindowsDNSCache() { + status, err := osdetail.GetServiceStatus("dnscache") if err != nil { - // cannot stop dnscache, try disabling - if err == osdetail.ErrServiceNotStoppable { - err := osdetail.DisableDNSCache() - if err != nil { - log.Warningf("firewall/interception: failed to disable Windows Service \"DNS Client\" (dnscache) for better interception: %s", err) - notifyDisableDNSCache() - } - notifyRebootRequired() - return - } - - // error while stopping service - log.Warningf("firewall/interception: failed to stop Windows Service \"DNS Client\" (dnscache) for better interception: %s", err) - notifyDisableDNSCache() + log.Warningf("firewall/interception: failed to check status of Windows DNS-Client: %s", err) } - // log that service is stopped - log.Info("firewall/interception: Windows Service \"DNS Client\" (dnscache) is stopped for better interception") - -} - -func notifyDisableDNSCache() { - (¬ifications.Notification{ - EventID: "interception:windows-disable-dns-cache", - Message: "The Portmaster needs the Windows Service \"DNS Client\" (dnscache) to be disabled for best effectiveness.", - Type: notifications.Warning, - }).Save() + if status == osdetail.StatusStopped { + err := osdetail.EnableDNSCache() + if err != nil { + log.Warningf("firewall/interception: failed to enable Windows Service \"DNS Client\" (dnscache): %s", err) + } else { + log.Warningf("firewall/interception: successfully enabled the dnscache") + notifyRebootRequired() + } + } } func notifyRebootRequired() { diff --git a/firewall/master.go b/firewall/master.go index 6e92b7af..efe515d9 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -39,7 +39,7 @@ const noReasonOptionKey = "" type deciderFn func(context.Context, *network.Connection, packet.Packet) bool -var deciders = []deciderFn{ +var defaultDeciders = []deciderFn{ checkPortmasterConnection, checkSelfCommunication, checkConnectionType, @@ -53,6 +53,11 @@ var deciders = []deciderFn{ checkAutoPermitRelated, } +var dnsFromSystemResolverDeciders = []deciderFn{ + checkConnectivityDomain, + checkBypassPrevention, +} + // DecideOnConnection makes a decision about a connection. // When called, the connection and profile is already locked. func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) { @@ -79,8 +84,21 @@ func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packe } } + // DNS request from the system resolver require a special decision process, + // because the original requesting process is not known. Here, we only check + // global-only and the most important per-app aspects. The resulting + // connection is then blocked when the original requesting process is known. + if conn.Type == network.DNSRequest && conn.Process().IsSystemResolver() { + // Run all deciders and return if they came to a conclusion. + done, _ := runDeciders(ctx, dnsFromSystemResolverDeciders, conn, pkt) + if !done { + conn.Accept("permitting system resolver dns request", noReasonOptionKey) + } + return + } + // Run all deciders and return if they came to a conclusion. - done, defaultAction := runDeciders(ctx, conn, pkt) + done, defaultAction := runDeciders(ctx, defaultDeciders, conn, pkt) if done { return } @@ -96,7 +114,7 @@ func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packe } } -func runDeciders(ctx context.Context, conn *network.Connection, pkt packet.Packet) (done bool, defaultAction uint8) { +func runDeciders(ctx context.Context, selectedDeciders []deciderFn, conn *network.Connection, pkt packet.Packet) (done bool, defaultAction uint8) { layeredProfile := conn.Process().Profile() // Read-lock the all the profiles. @@ -104,7 +122,7 @@ func runDeciders(ctx context.Context, conn *network.Connection, pkt packet.Packe defer layeredProfile.UnlockForUsage() // Go though all deciders, return if one sets an action. - for _, decider := range deciders { + for _, decider := range selectedDeciders { if decider(ctx, conn, pkt) { return true, profile.DefaultActionNotSet } @@ -248,39 +266,58 @@ func checkConnectivityDomain(_ context.Context, conn *network.Connection, _ pack func checkConnectionScope(_ context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() - // check scopes - if conn.Entity.IP != nil { - classification := netutils.ClassifyIP(conn.Entity.IP) - - switch classification { - case netutils.Global, netutils.GlobalMulticast: - if p.BlockScopeInternet() { - conn.Deny("Internet access blocked", profile.CfgOptionBlockScopeInternetKey) // Block Outbound / Drop Inbound - return true - } - case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: - if p.BlockScopeLAN() { - conn.Block("LAN access blocked", profile.CfgOptionBlockScopeLANKey) // Block Outbound / Drop Inbound - return true - } - case netutils.HostLocal: - if p.BlockScopeLocal() { - conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound - return true - } - default: // netutils.Invalid - conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound - return true - } - } else if conn.Entity.Domain != "" { - // This is a DNS Request. + // If we are handling a DNS request, check if we can immediately block it. + if conn.Type == network.DNSRequest { // DNS is expected to resolve to LAN or Internet addresses. // Localhost queries are immediately responded to by the nameserver. if p.BlockScopeInternet() && p.BlockScopeLAN() { conn.Block("Internet and LAN access blocked", profile.CfgOptionBlockScopeInternetKey) return true } + + return false } + + // Check if the network scope is permitted. + switch conn.Entity.IPScope { + case netutils.Global, netutils.GlobalMulticast: + if p.BlockScopeInternet() { + conn.Deny("Internet access blocked", profile.CfgOptionBlockScopeInternetKey) // Block Outbound / Drop Inbound + return true + } + case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: + if p.BlockScopeLAN() { + conn.Block("LAN access blocked", profile.CfgOptionBlockScopeLANKey) // Block Outbound / Drop Inbound + return true + } + case netutils.HostLocal: + if p.BlockScopeLocal() { + conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound + return true + } + default: // netutils.Unknown and netutils.Invalid + conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound + return true + } + + // If the IP address was resolved, check the scope of the resolver. + switch { + case p.RemoveOutOfScopeDNS(): + // Out of scope checking is not active. + case conn.Resolver == nil: + // IP address of connection was not resolved. + case conn.Resolver.IPScope.IsGlobal() && + (conn.Entity.IPScope.IsLAN() || conn.Entity.IPScope.IsLocalhost()): + // Block global resolvers from returning LAN/Localhost IPs. + conn.Block("DNS server horizon violation: global DNS server returned local IP address", profile.CfgOptionRemoveOutOfScopeDNSKey) + return true + case conn.Resolver.IPScope.IsLAN() && + conn.Entity.IPScope.IsLocalhost(): + // Block LAN resolvers from returning Localhost IPs. + conn.Block("DNS server horizon violation: LAN DNS server returned localhost IP address", profile.CfgOptionRemoveOutOfScopeDNSKey) + return true + } + return false } diff --git a/intel/entity.go b/intel/entity.go index c98126f0..d03cc3b6 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -58,6 +58,9 @@ type Entity struct { // set, IP has been resolved by following all CNAMEs. IP net.IP + // IPScope holds the network scope of the IP. + IPScope netutils.IPScope + // Country holds the country the IP address (ASN) is // located in. Country string @@ -65,6 +68,9 @@ type Entity struct { // ASN holds the autonomous system number of the IP. ASN uint + // ASOrg holds the owner's name of the autonomous system. + ASOrg string + location *geoip.Location // BlockedByLists holds list source IDs that @@ -95,6 +101,12 @@ func (e *Entity) Init() *Entity { return e } +// SetIP sets the IP address together with its network scope. +func (e *Entity) SetIP(ip net.IP) { + e.IP = ip + e.IPScope = netutils.GetIPScope(ip) +} + // SetDstPort sets the destination port. func (e *Entity) SetDstPort(dstPort uint16) { e.dstPort = dstPort @@ -229,6 +241,7 @@ func (e *Entity) getLocation(ctx context.Context) { e.location = loc e.Country = loc.Country.ISOCode e.ASN = loc.AutonomousSystemNumber + e.ASOrg = loc.AutonomousSystemOrganization }) } @@ -422,7 +435,7 @@ func (e *Entity) getIPLists(ctx context.Context) { } // only load lists for IP addresses that are classified as global. - if netutils.ClassifyIP(ip) != netutils.Global { + if !e.IPScope.IsGlobal() { return } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 3a476f8a..6abf26bf 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -3,6 +3,7 @@ package nameserver import ( "context" "errors" + "fmt" "net" "strings" "time" @@ -57,6 +58,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) log.Warningf("nameserver: failed to get remote address of request for %s%s, ignoring", q.FQDN, q.QType) return nil } + // log.Errorf("DEBUG: nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port) // Start context tracer for context-aware logging. ctx, tracer := log.AddTracer(ctx) @@ -100,18 +102,27 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // Authenticate request - only requests from the local host, but with any of its IPs, are allowed. local, err := netenv.IsMyIP(remoteAddr.IP) if err != nil { - tracer.Warningf("nameserver: failed to check if request for %s%s is local: %s", q.FQDN, q.QType, err) + tracer.Warningf("nameserver: failed to check if request for %s is local: %s", q.ID(), err) return nil // Do no reply, drop request immediately. } + // Create connection ID for dns request. + connID := fmt.Sprintf( + "%s-%d-#%d-%s", + remoteAddr.IP, + remoteAddr.Port, + request.Id, + q.ID(), + ) + // Get connection for this request. This identifies the process behind the request. var conn *network.Connection switch { case local: - conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port)) + conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP, uint16(remoteAddr.Port)) case networkServiceMode(): - conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP) + conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP) if err != nil { tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err) return nil // Do no reply, drop request immediately. @@ -124,20 +135,24 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) conn.Lock() defer conn.Unlock() + // Create reference for the rrCache. + var rrCache *resolver.RRCache + // Once we decided on the connection we might need to save it to the database, // so we defer that check for now. defer func() { switch conn.Verdict { // We immediately save blocked, dropped or failed verdicts so // they pop up in the UI. - case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed: + case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel: conn.Save() // For undecided or accepted connections we don't save them yet, because // that will happen later anyway. - case network.VerdictUndecided, network.VerdictAccept, - network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel: - return + case network.VerdictUndecided, network.VerdictAccept: + // Save the request as open, as we don't know if there will be a connection or not. + network.SaveOpenDNSRequest(conn) + firewall.UpdateIPsAndCNAMEs(q, rrCache, conn) default: tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) @@ -153,17 +168,19 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) // IP address in which case we "accept" it, but let the firewall handle // the resolving as it wishes. if responder, ok := conn.Reason.Context.(nsutil.Responder); ok { - // Save the request as open, as we don't know if there will be a connection or not. - network.SaveOpenDNSRequest(conn) - tracer.Infof("nameserver: handing over request for %s to special filter responder: %s", q.ID(), conn.Reason.Msg) return reply(responder) } - // Check if there is Verdict to act upon. + // Check if there is a Verdict to act upon. switch conn.Verdict { case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed: - tracer.Infof("nameserver: request for %s from %s %s", q.ID(), conn.Process(), conn.Verdict.Verb()) + tracer.Infof( + "nameserver: returning %s response for %s to %s", + conn.Verdict.Verb(), + q.ID(), + conn.Process(), + ) return reply(conn, conn) } @@ -171,7 +188,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) q.SecurityLevel = conn.Process().Profile().SecurityLevel() // Resolve request. - rrCache, err := resolver.Resolve(ctx, q) + rrCache, err = resolver.Resolve(ctx, q) if err != nil { // React to special errors. switch { @@ -203,13 +220,10 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) } tracer.Trace("nameserver: deciding on resolved dns") - rrCache = firewall.DecideOnResolvedDNS(ctx, conn, q, rrCache) + rrCache = firewall.FilterResolvedDNS(ctx, conn, q, rrCache) if rrCache == nil { // Check again if there is a responder from the firewall. if responder, ok := conn.Reason.Context.(nsutil.Responder); ok { - // Save the request as open, as we don't know if there will be a connection or not. - network.SaveOpenDNSRequest(conn) - tracer.Infof("nameserver: handing over request for %s to filter responder: %s", q.ID(), conn.Reason.Msg) return reply(responder) } @@ -227,9 +241,6 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) } } - // Save dns request as open. - defer network.SaveOpenDNSRequest(conn) - // Revert back to non-standard question format, if we had to convert. if nonStandardQuestionFormat { rrCache.ReplaceAnswerNames(originalQuestion.Name) diff --git a/netenv/addresses.go b/netenv/addresses.go index 48e9d380..f0edb84b 100644 --- a/netenv/addresses.go +++ b/netenv/addresses.go @@ -38,12 +38,12 @@ func GetAssignedGlobalAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { return nil, nil, err } for _, ip4 := range allv4 { - if netutils.IPIsGlobal(ip4) { + if netutils.GetIPScope(ip4).IsGlobal() { ipv4 = append(ipv4, ip4) } } for _, ip6 := range allv6 { - if netutils.IPIsGlobal(ip6) { + if netutils.GetIPScope(ip6).IsGlobal() { ipv6 = append(ipv6, ip6) } } @@ -59,7 +59,7 @@ var ( // Broadcast or multicast addresses will never match, even if valid in in use. func IsMyIP(ip net.IP) (yes bool, err error) { // Check for IPs that don't need extra checks. - switch netutils.ClassifyIP(ip) { + switch netutils.GetIPScope(ip) { case netutils.HostLocal: return true, nil case netutils.LocalMulticast, netutils.GlobalMulticast: diff --git a/netenv/location.go b/netenv/location.go index 8fc67339..5d38bbee 100644 --- a/netenv/location.go +++ b/netenv/location.go @@ -130,7 +130,7 @@ next: } // If we received something from a global IP address, we have succeeded and can return immediately. - if netutils.IPIsGlobal(addr.IP) { + if netutils.GetIPScope(addr.IP).IsGlobal() { return addr.IP, nil } diff --git a/netenv/online-status.go b/netenv/online-status.go index 44d5a0ff..cbc5f26f 100644 --- a/netenv/online-status.go +++ b/netenv/online-status.go @@ -356,7 +356,7 @@ func checkOnlineStatus(ctx context.Context) { } else { var lan bool for _, ip := range ipv4 { - switch netutils.ClassifyIP(ip) { + switch netutils.GetIPScope(ip) { case netutils.SiteLocal: lan = true case netutils.Global: @@ -366,7 +366,7 @@ func checkOnlineStatus(ctx context.Context) { } } for _, ip := range ipv6 { - switch netutils.ClassifyIP(ip) { + switch netutils.GetIPScope(ip) { case netutils.SiteLocal, netutils.Global: // IPv6 global addresses are also used in local networks lan = true diff --git a/network/api.go b/network/api.go index f9b2c66c..659d77b3 100644 --- a/network/api.go +++ b/network/api.go @@ -11,6 +11,8 @@ import ( "github.com/safing/portbase/api" "github.com/safing/portbase/database/query" "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/network/state" + "github.com/safing/portmaster/process" "github.com/safing/portmaster/status" ) @@ -45,6 +47,18 @@ func registerAPIEndpoints() error { return err } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "debug/network/state", + Read: api.PermitUser, + StructFunc: func(ar *api.Request) (i interface{}, err error) { + return state.GetInfo(), nil + }, + Name: "Get Network State Table Data", + Description: "Returns the current network state tables from the OS.", + }); err != nil { + return err + } + return nil } @@ -156,28 +170,30 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) { func buildNetworkDebugInfoData(debugConns []*Connection) string { // Sort - sort.Sort(connectionsByStarted(debugConns)) + sort.Sort(connectionsByGroup(debugConns)) // Format lines var buf strings.Builder - currentBinaryPath := "__" + currentPID := process.UndefinedProcessID for _, conn := range debugConns { conn.Lock() // Add process infomration if it differs from previous connection. - if currentBinaryPath != conn.ProcessContext.BinaryPath { - if currentBinaryPath != "__" { + if currentPID != conn.ProcessContext.PID { + if currentPID != process.UndefinedProcessID { buf.WriteString("\n\n\n") } - buf.WriteString("ProcessName: " + conn.ProcessContext.ProcessName) - buf.WriteString("\nProfileName: " + conn.ProcessContext.ProfileName) - buf.WriteString("\nBinaryPath: " + conn.ProcessContext.BinaryPath) + buf.WriteString("ProfileName: " + conn.ProcessContext.ProfileName) buf.WriteString("\nProfile: " + conn.ProcessContext.Profile) buf.WriteString("\nSource: " + conn.ProcessContext.Source) + buf.WriteString("\nProcessName: " + conn.ProcessContext.ProcessName) + buf.WriteString("\nBinaryPath: " + conn.ProcessContext.BinaryPath) + buf.WriteString("\nCmdLine: " + conn.ProcessContext.CmdLine) + buf.WriteString("\nPID: " + strconv.Itoa(conn.ProcessContext.PID)) buf.WriteString("\n") - // Set current path in order to not print the process information again. - currentBinaryPath = conn.ProcessContext.BinaryPath + // Set current PID in order to not print the process information again. + currentPID = conn.ProcessContext.PID } // Add connection. @@ -192,7 +208,7 @@ func buildNetworkDebugInfoData(debugConns []*Connection) string { func (conn *Connection) debugInfoLine() string { var connectionData string - if conn.ID != "" { + if conn.Type == IPConnection { // Format IP/Port pair for connections. connectionData = fmt.Sprintf( "% 15s:%- 5s %s % 15s:%- 5s", @@ -272,13 +288,28 @@ func (conn *Connection) fmtReasonProfileComponent() string { return conn.Reason.Profile } -type connectionsByStarted []*Connection +type connectionsByGroup []*Connection -func (a connectionsByStarted) Len() int { return len(a) } -func (a connectionsByStarted) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a connectionsByStarted) Less(i, j int) bool { +func (a connectionsByGroup) Len() int { return len(a) } +func (a connectionsByGroup) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a connectionsByGroup) Less(i, j int) bool { + // Sort by: + + // 1. Profile ID + if a[i].ProcessContext.Profile != a[j].ProcessContext.Profile { + return a[i].ProcessContext.Profile < a[j].ProcessContext.Profile + } + + // 2. Process Binary if a[i].ProcessContext.BinaryPath != a[j].ProcessContext.BinaryPath { return a[i].ProcessContext.BinaryPath < a[j].ProcessContext.BinaryPath } + + // 3. Process ID + if a[i].ProcessContext.PID != a[j].ProcessContext.PID { + return a[i].ProcessContext.PID < a[j].ProcessContext.PID + } + + // 4. Started return a[i].Started < a[j].Started } diff --git a/network/connection.go b/network/connection.go index 750de2a6..d9172150 100644 --- a/network/connection.go +++ b/network/connection.go @@ -29,10 +29,12 @@ type FirewallHandler func(conn *Connection, pkt packet.Packet) type ProcessContext struct { // ProcessName is the name of the process. ProcessName string - //ProfileName is the name of the profile. + // ProfileName is the name of the profile. ProfileName string // BinaryPath is the path to the process binary. BinaryPath string + // CmdLine holds the execution parameters. + CmdLine string // PID is the process identifier. PID int // Profile is the ID of the main profile that @@ -42,21 +44,37 @@ type ProcessContext struct { Source string } +type ConnectionType int8 + +const ( + Undefined ConnectionType = iota + IPConnection + DNSRequest + // ProxyRequest +) + // Connection describes a distinct physical network connection // identified by the IP/Port pair. type Connection struct { //nolint:maligned // TODO: fix alignment record.Base sync.Mutex - // ID may hold unique connection id. It is only set for non-DNS - // request connections and is considered immutable after a - // connection object has been created. + // ID holds a unique request/connection id and is considered immutable after + // creation. ID string + // Type defines the connection type. + Type ConnectionType + // External defines if the connection represents an external request or + // connection. + External bool // Scope defines the scope of a connection. For DNS requests, the // scope is always set to the domain name. For direct packet // connections the scope consists of the involved network environment // and the packet direction. Once a connection object is created, // Scope is considered immutable. + // Deprecated: This field holds duplicate information, which is accessible + // clearer through other attributes. Please use conn.Type, conn.Inbound + // and conn.Entity.Domain instead. Scope string // IPVersion is set to the packet IP version. It is not set (0) for // connections created from a DNS request. @@ -74,6 +92,8 @@ type Connection struct { //nolint:maligned // TODO: fix alignment // set for connections created from DNS requests. LocalIP is // considered immutable once a connection object has been created. LocalIP net.IP + // LocalIPScope holds the network scope of the local IP. + LocalIPScope netutils.IPScope // LocalPort holds the local port of the connection. It is not // set for connections created from DNS requests. LocalPort is // considered immutable once a connection object has been created. @@ -83,6 +103,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment // be added to it during the livetime of a connection. Access to // entity must be guarded by the connection lock. Entity *intel.Entity + // Resolver holds information about the resolver used to resolve + // Entity.Domain. + Resolver *resolver.ResolverInfo // Verdict is the final decision that has been made for a connection. // The verdict may change so any access to it must be guarded by the // connection lock. @@ -171,8 +194,9 @@ type Reason struct { func getProcessContext(ctx context.Context, proc *process.Process) ProcessContext { // Gather process information. pCtx := ProcessContext{ - BinaryPath: proc.Path, ProcessName: proc.Name, + BinaryPath: proc.Path, + CmdLine: proc.CmdLine, PID: proc.Pid, } @@ -191,7 +215,7 @@ func getProcessContext(ctx context.Context, proc *process.Process) ProcessContex } // NewConnectionFromDNSRequest returns a new connection based on the given dns request. -func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection { +func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, connID string, localIP net.IP, localPort uint16) *Connection { // Determine IP version. ipVersion := packet.IPv6 if localIP.To4() != nil { @@ -218,6 +242,8 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri timestamp := time.Now().Unix() dnsConn := &Connection{ + ID: connID, + Type: DNSRequest, Scope: fqdn, Entity: &intel.Entity{ Domain: fqdn, @@ -234,10 +260,15 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri dnsConn.Internal = localProfile.Internal } + // Always mark dns queries from the system resolver as internal. + if proc.IsSystemResolver() { + dnsConn.Internal = true + } + return dnsConn } -func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, remoteIP net.IP) (*Connection, error) { +func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, connID string, remoteIP net.IP) (*Connection, error) { remoteHost, err := process.GetNetworkHost(ctx, remoteIP) if err != nil { return nil, err @@ -245,7 +276,10 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname timestamp := time.Now().Unix() dnsConn := &Connection{ - Scope: fqdn, + ID: connID, + Type: DNSRequest, + External: true, + Scope: fqdn, Entity: &intel.Entity{ Domain: fqdn, CNAME: cnames, @@ -275,11 +309,19 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { var scope string var entity *intel.Entity + var resolverInfo *resolver.ResolverInfo if inbound { // inbound connection - switch netutils.ClassifyIP(pkt.Info().Src) { + entity = &intel.Entity{ + Protocol: uint8(pkt.Info().Protocol), + Port: pkt.Info().SrcPort, + } + entity.SetIP(pkt.Info().Src) + entity.SetDstPort(pkt.Info().DstPort) + + switch entity.IPScope { case netutils.HostLocal: scope = IncomingHost case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: @@ -292,31 +334,30 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { default: scope = IncomingInvalid } - entity = &intel.Entity{ - IP: pkt.Info().Src, - Protocol: uint8(pkt.Info().Protocol), - Port: pkt.Info().SrcPort, - } - entity.SetDstPort(pkt.Info().DstPort) } else { // outbound connection entity = &intel.Entity{ - IP: pkt.Info().Dst, Protocol: uint8(pkt.Info().Protocol), Port: pkt.Info().DstPort, } + entity.SetIP(pkt.Info().Dst) entity.SetDstPort(entity.Port) // check if we can find a domain for that IP - ipinfo, err := resolver.GetIPInfo(proc.LocalProfileKey, pkt.Info().Dst.String()) + ipinfo, err := resolver.GetIPInfo(proc.Profile().LocalProfile().ID, pkt.Info().Dst.String()) + if err != nil { + // Try again with the global scope, in case DNS went through the system resolver. + ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().Dst.String()) + } if err == nil { lastResolvedDomain := ipinfo.MostRecentDomain() if lastResolvedDomain != nil { scope = lastResolvedDomain.Domain entity.Domain = lastResolvedDomain.Domain entity.CNAME = lastResolvedDomain.CNAMEs + resolverInfo = lastResolvedDomain.Resolver removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain) } } @@ -331,7 +372,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { if scope == "" { // outbound direct (possibly P2P) connection - switch netutils.ClassifyIP(pkt.Info().Dst) { + switch entity.IPScope { case netutils.HostLocal: scope = PeerHost case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: @@ -351,21 +392,24 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // Create new connection object. newConn := &Connection{ ID: pkt.GetConnectionID(), + Type: IPConnection, Scope: scope, IPVersion: pkt.Info().Version, Inbound: inbound, // local endpoint IPProtocol: pkt.Info().Protocol, - LocalIP: pkt.Info().LocalIP(), LocalPort: pkt.Info().LocalPort(), ProcessContext: getProcessContext(pkt.Ctx(), proc), process: proc, // remote endpoint Entity: entity, + // resolver used to resolve dns request + Resolver: resolverInfo, // meta Started: time.Now().Unix(), ProfileRevisionCounter: proc.Profile().RevisionCnt(), } + newConn.SetLocalIP(pkt.Info().LocalIP()) // Inherit internal status of profile. if localProfile := proc.Profile().LocalProfile(); localProfile != nil { @@ -380,6 +424,13 @@ func GetConnection(id string) (*Connection, bool) { return conns.get(id) } +// SetLocalIP sets the local IP address together with its network scope. The +// connection is not locked for this. +func (conn *Connection) SetLocalIP(ip net.IP) { + conn.LocalIP = ip + conn.LocalIPScope = netutils.GetIPScope(ip) +} + // AcceptWithContext accepts the connection. func (conn *Connection) AcceptWithContext(reason, reasonOptionKey string, ctx interface{}) { if !conn.SetVerdict(VerdictAccept, reason, reasonOptionKey, ctx) { @@ -477,14 +528,11 @@ func (conn *Connection) Save() { conn.UpdateMeta() if !conn.KeyIsSet() { - // A connection without an ID has been created from - // a DNS request rather than a packet. Choose the correct - // connection store here. - if conn.ID == "" { - conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Scope)) + if conn.Type == DNSRequest { + conn.SetKey(makeKey(conn.process.Pid, "dns", conn.ID)) dnsConns.add(conn) } else { - conn.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", conn.process.Pid, conn.Scope, conn.ID)) + conn.SetKey(makeKey(conn.process.Pid, "ip", conn.ID)) conns.add(conn) } } @@ -500,10 +548,10 @@ func (conn *Connection) delete() { // A connection without an ID has been created from // a DNS request rather than a packet. Choose the correct // connection store here. - if conn.ID == "" { - dnsConns.delete(conn) - } else { + if conn.Type == IPConnection { conns.delete(conn) + } else { + dnsConns.delete(conn) } conn.Meta().Delete() diff --git a/network/connection_store.go b/network/connection_store.go index 18b753ea..a3c633cb 100644 --- a/network/connection_store.go +++ b/network/connection_store.go @@ -1,7 +1,6 @@ package network import ( - "strconv" "sync" ) @@ -16,25 +15,18 @@ func newConnectionStore() *connectionStore { } } -func (cs *connectionStore) getID(conn *Connection) string { - if conn.ID != "" { - return conn.ID - } - return strconv.Itoa(conn.process.Pid) + "/" + conn.Scope -} - func (cs *connectionStore) add(conn *Connection) { cs.rw.Lock() defer cs.rw.Unlock() - cs.items[cs.getID(conn)] = conn + cs.items[conn.ID] = conn } func (cs *connectionStore) delete(conn *Connection) { cs.rw.Lock() defer cs.rw.Unlock() - delete(cs.items, cs.getID(conn)) + delete(cs.items, conn.ID) } func (cs *connectionStore) get(id string) (*Connection, bool) { diff --git a/network/database.go b/network/database.go index 3a13163b..6d9a337d 100644 --- a/network/database.go +++ b/network/database.go @@ -1,11 +1,10 @@ package network import ( + "fmt" "strconv" "strings" - "github.com/safing/portmaster/network/state" - "github.com/safing/portbase/database" "github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/query" @@ -27,37 +26,86 @@ type StorageInterface struct { storage.InjectBase } -// Get returns a database record. -func (s *StorageInterface) Get(key string) (record.Record, error) { +// Database prefixes: +// Processes: network:tree/ +// DNS Requests: network:tree//dns/ +// IP Connections: network:tree//ip/ - splitted := strings.Split(key, "/") - switch splitted[0] { //nolint:gocritic // TODO: implement full key space - case "tree": - switch len(splitted) { - case 2: - pid, err := strconv.Atoi(splitted[1]) - if err == nil { - proc, ok := process.GetProcessFromStorage(pid) - if ok { - return proc, nil - } - } - case 3: - if r, ok := dnsConns.get(splitted[1] + "/" + splitted[2]); ok { - return r, nil - } - case 4: - if r, ok := conns.get(splitted[3]); ok { - return r, nil +func makeKey(pid int, scope, id string) string { + if scope == "" { + return "network:tree/" + strconv.Itoa(pid) + } + return fmt.Sprintf("network:tree/%d/%s/%s", pid, scope, id) +} + +func parseDBKey(key string) (pid int, scope, id string, ok bool) { + // Split into segments. + segments := strings.Split(key, "/") + // Check for valid prefix. + if !strings.HasPrefix("tree", segments[0]) { + return 0, "", "", false + } + + // Keys have 2 or 4 segments. + switch len(segments) { + case 4: + id = segments[3] + + fallthrough + case 3: + scope = segments[2] + // Sanity check. + switch scope { + case "dns", "ip", "": + // Parsed id matches possible values. + // The empty string is for matching a trailing slash for in query prefix. + // TODO: For queries, also prefixes of these values are valid. + default: + // Unknown scope. + return 0, "", "", false + } + + fallthrough + case 2: + var err error + if segments[1] == "" { + pid = process.UndefinedProcessID + } else { + pid, err = strconv.Atoi(segments[1]) + if err != nil { + return 0, "", "", false } } - case "system": - if len(splitted) >= 2 { - switch splitted[1] { - case "state": - return state.GetInfo(), nil - default: - } + + return pid, scope, id, true + case 1: + // This is a valid query prefix, but not process ID was given. + return process.UndefinedProcessID, "", "", true + default: + return 0, "", "", false + } +} + +// Get returns a database record. +func (s *StorageInterface) Get(key string) (record.Record, error) { + // Parse key and check if valid. + pid, scope, id, ok := parseDBKey(strings.TrimPrefix(key, "network:")) + if !ok || pid == process.UndefinedProcessID { + return nil, storage.ErrNotFound + } + + switch scope { + case "dns": + if r, ok := dnsConns.get(id); ok { + return r, nil + } + case "ip": + if r, ok := conns.get(id); ok { + return r, nil + } + case "": + if proc, ok := process.GetProcessFromStorage(pid); ok { + return proc, nil } } @@ -74,9 +122,13 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato } func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { - slashes := strings.Count(q.DatabaseKeyPrefix(), "/") + pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix()) + if !ok { + it.Finish(nil) + return + } - if slashes <= 1 { + if pid == process.UndefinedProcessID { // processes for _, proc := range process.All() { proc.Lock() @@ -87,7 +139,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { } } - if slashes <= 2 { + if scope == "" || scope == "dns" { // dns scopes only for _, dnsConn := range dnsConns.clone() { dnsConn.Lock() @@ -98,7 +150,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { } } - if slashes <= 3 { + if scope == "" || scope == "ip" { // connections for _, conn := range conns.clone() { conn.Lock() diff --git a/network/metrics.go b/network/metrics.go index d8972bda..047ad19c 100644 --- a/network/metrics.go +++ b/network/metrics.go @@ -132,7 +132,7 @@ func (conn *Connection) addToMetrics() { } // Only count successful connections, not DNS requests. - if conn.ID == "" { + if conn.Type == DNSRequest { return } diff --git a/network/netutils/ip.go b/network/netutils/ip.go index 8e40447c..0b058087 100644 --- a/network/netutils/ip.go +++ b/network/netutils/ip.go @@ -2,19 +2,29 @@ package netutils import "net" -// IP classifications +// IPScope is the scope of the IP address. +type IPScope int8 + +// Defined IP Scopes. const ( - HostLocal int8 = iota + Invalid IPScope = iota - 1 + Undefined + HostLocal LinkLocal SiteLocal Global LocalMulticast GlobalMulticast - Invalid int8 = -1 ) -// ClassifyIP returns the classification for the given IP address. -func ClassifyIP(ip net.IP) int8 { //nolint:gocognit +// ClassifyIP returns the network scope of the given IP address. +// Deprecated: Please use the new GetIPScope instead. +func ClassifyIP(ip net.IP) IPScope { + return GetIPScope(ip) +} + +// GetIPScope returns the network scope of the given IP address. +func GetIPScope(ip net.IP) IPScope { //nolint:gocognit if ip4 := ip.To4(); ip4 != nil { // IPv4 switch { @@ -76,32 +86,27 @@ func ClassifyIP(ip net.IP) int8 { //nolint:gocognit return Invalid } -// IPIsLocalhost returns whether the IP refers to the host itself. -func IPIsLocalhost(ip net.IP) bool { - return ClassifyIP(ip) == HostLocal +// IsLocalhost returns whether the IP refers to the host itself. +func (scope IPScope) IsLocalhost() bool { + return scope == HostLocal } -// IPIsLAN returns true if the given IP is a site-local or link-local address. -func IPIsLAN(ip net.IP) bool { - switch ClassifyIP(ip) { - case SiteLocal, LinkLocal: +// IsLAN returns true if the scope is site-local or link-local. +func (scope IPScope) IsLAN() bool { + switch scope { + case SiteLocal, LinkLocal, LocalMulticast: return true default: return false } } -// IPIsGlobal returns true if the given IP is a global address. -func IPIsGlobal(ip net.IP) bool { - return ClassifyIP(ip) == Global -} - -// IPIsLinkLocal returns true if the given IP is a link-local address. -func IPIsLinkLocal(ip net.IP) bool { - return ClassifyIP(ip) == LinkLocal -} - -// IPIsSiteLocal returns true if the given IP is a site-local address. -func IPIsSiteLocal(ip net.IP) bool { - return ClassifyIP(ip) == SiteLocal +// IsGlobal returns true if the scope is global. +func (scope IPScope) IsGlobal() bool { + switch scope { + case Global, GlobalMulticast: + return true + default: + return false + } } diff --git a/network/netutils/ip_test.go b/network/netutils/ip_test.go index 1d40f301..02ef2051 100644 --- a/network/netutils/ip_test.go +++ b/network/netutils/ip_test.go @@ -5,26 +5,30 @@ import ( "testing" ) -func TestIPClassification(t *testing.T) { - testClassification(t, net.IPv4(71, 87, 113, 211), Global) - testClassification(t, net.IPv4(127, 0, 0, 1), HostLocal) - testClassification(t, net.IPv4(127, 255, 255, 1), HostLocal) - testClassification(t, net.IPv4(192, 168, 172, 24), SiteLocal) - testClassification(t, net.IPv4(172, 15, 1, 1), Global) - testClassification(t, net.IPv4(172, 16, 1, 1), SiteLocal) - testClassification(t, net.IPv4(172, 31, 1, 1), SiteLocal) - testClassification(t, net.IPv4(172, 32, 1, 1), Global) +func TestIPScope(t *testing.T) { + testScope(t, net.IPv4(71, 87, 113, 211), Global) + testScope(t, net.IPv4(127, 0, 0, 1), HostLocal) + testScope(t, net.IPv4(127, 255, 255, 1), HostLocal) + testScope(t, net.IPv4(192, 168, 172, 24), SiteLocal) + testScope(t, net.IPv4(172, 15, 1, 1), Global) + testScope(t, net.IPv4(172, 16, 1, 1), SiteLocal) + testScope(t, net.IPv4(172, 31, 1, 1), SiteLocal) + testScope(t, net.IPv4(172, 32, 1, 1), Global) } -func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { - c := ClassifyIP(ip) - if c != expectedClassification { - t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification)) +func testScope(t *testing.T, ip net.IP, expectedScope IPScope) { + c := GetIPScope(ip) + if c != expectedScope { + t.Errorf("%s is %s, expected %s", ip, scopeName(c), scopeName(expectedScope)) } } -func classificationString(c int8) string { +func scopeName(c IPScope) string { switch c { + case Invalid: + return "invalid" + case Undefined: + return "undefined" case HostLocal: return "hostLocal" case LinkLocal: @@ -37,9 +41,7 @@ func classificationString(c int8) string { return "localMulticast" case GlobalMulticast: return "globalMulticast" - case Invalid: - return "invalid" default: - return "unknown" + return "undefined" } } diff --git a/network/state/lookup.go b/network/state/lookup.go index 1d2f11ad..fdd8c3d1 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -151,7 +151,7 @@ func (table *udpTable) lookup(pktInfo *packet.Info, fast bool) ( // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but // binding to different addresses. This highly unusual for clients. - isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(pktInfo.LocalIP()) == netutils.LocalMulticast + isInboundMulticast := pktInfo.Inbound && netutils.GetIPScope(pktInfo.LocalIP()) == netutils.LocalMulticast // Search for the socket until found. for i := 1; i <= lookupRetries; i++ { diff --git a/process/config.go b/process/config.go index 19f239ce..3e91aa55 100644 --- a/process/config.go +++ b/process/config.go @@ -17,7 +17,7 @@ func registerConfiguration() error { err := config.Register(&config.Option{ Name: "Process Detection", Key: CfgOptionEnableProcessDetectionKey, - Description: "This option enables the attribution of network traffic to processes. This should always be enabled, and effectively disables app settings if disabled.", + Description: "This option enables the attribution of network traffic to processes. Without it, app settings are effectively disabled.", OptType: config.OptTypeBool, ExpertiseLevel: config.ExpertiseLevelDeveloper, DefaultValue: true, diff --git a/process/process.go b/process/process.go index 619e832a..72b2c2b3 100644 --- a/process/process.go +++ b/process/process.go @@ -44,6 +44,10 @@ type Process struct { CmdLine string FirstArg string + // SpecialDetail holds special information, the meaning of which can change + // based on any of the previous attributes. + SpecialDetail string + LocalProfileKey string profile *profile.LayeredProfile @@ -65,6 +69,24 @@ func (p *Process) Profile() *profile.LayeredProfile { return p.profile } +// IsSystemResolver is a shortcut to check if the process is or belongs to the +// system resolver and needs special handling. +func (p *Process) IsSystemResolver() bool { + // Check if process exists. + if p == nil { + return false + } + + // Check if local profile exists. + localProfile := p.profile.LocalProfile() + if localProfile == nil { + return false + } + + // Check ID. + return localProfile.ID == profile.SystemResolverProfileID +} + // GetLastSeen returns the unix timestamp when the process was last seen. func (p *Process) GetLastSeen() int64 { p.Lock() diff --git a/process/process_windows.go b/process/process_windows.go index c202bcb9..c0f722c2 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -18,6 +18,7 @@ func (p *Process) specialOSInit() { switch err { case nil: p.Name += fmt.Sprintf(" (%s)", svcNames) + p.SpecialDetail = svcNames case osdetail.ErrServiceNotFound: log.Tracef("process: failed to get service name for svchost.exe (pid %d): %s", p.Pid, err) default: diff --git a/process/profile.go b/process/profile.go index 792c11f8..8f64b91e 100644 --- a/process/profile.go +++ b/process/profile.go @@ -3,6 +3,7 @@ package process import ( "context" "os" + "runtime" "strings" "github.com/safing/portbase/log" @@ -54,6 +55,22 @@ func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) { // sure that we won't kill any of our own things. } } + // Check if this is the system resolver. + switch runtime.GOOS { + case "windows": + if (p.Path == `C:\Windows\System32\svchost.exe` || p.Path == `C:\Windows\system32\svchost.exe`) && + (strings.Contains(p.SpecialDetail, "Dnscache") || strings.Contains(p.CmdLine, "-k NetworkService")) { + profileID = profile.SystemResolverProfileID + } + case "linux": + switch p.Path { + case "/lib/systemd/systemd-resolved", + "/usr/lib/systemd/systemd-resolved", + "/lib64/systemd/systemd-resolved", + "/usr/lib64/systemd/systemd-resolved": + profileID = profile.SystemResolverProfileID + } + } } // Get the (linked) local profile. diff --git a/process/special.go b/process/special.go index 24e3aa8b..8ec47b4c 100644 --- a/process/special.go +++ b/process/special.go @@ -14,6 +14,10 @@ const ( // attributed to a PID for any reason. UnidentifiedProcessID = -1 + // UndefinedProcessID is not used by any (virtual) process and signifies that + // the PID is unset. + UndefinedProcessID = -2 + // NetworkHostProcessID is the PID used for requests served to the network. NetworkHostProcessID = -255 ) diff --git a/profile/config.go b/profile/config.go index 5a0d37bb..9b36062a 100644 --- a/profile/config.go +++ b/profile/config.go @@ -434,7 +434,7 @@ The lists are automatically updated every hour using incremental updates. err = config.Register(&config.Option{ Name: "Enforce Global/Private Split-View", Key: CfgOptionRemoveOutOfScopeDNSKey, - Description: "Reject private IP addresses (RFC1918 et al.) from public DNS responses.", + Description: "Reject private IP addresses (RFC1918 et al.) from public DNS responses. If the system resolver is in use, the resulting connection will be blocked instead of the DNS request.", OptType: config.OptTypeInt, ExpertiseLevel: config.ExpertiseLevelDeveloper, DefaultValue: status.SecurityLevelsAll, @@ -455,7 +455,7 @@ The lists are automatically updated every hour using incremental updates. err = config.Register(&config.Option{ Name: "Reject Blocked IPs", Key: CfgOptionRemoveBlockedDNSKey, - Description: "Reject blocked IP addresses directly from the DNS response instead of handing them over to the app and blocking a resulting connection.", + Description: "Reject blocked IP addresses directly from the DNS response instead of handing them over to the app and blocking a resulting connection. This settings does not affect privacy and only takes effect when the system resolver is not in use.", OptType: config.OptTypeInt, ExpertiseLevel: config.ExpertiseLevelDeveloper, DefaultValue: status.SecurityLevelsAll, @@ -491,6 +491,7 @@ The lists are automatically updated every hour using incremental updates. return err } cfgOptionDomainHeuristics = config.Concurrent.GetAsInt(CfgOptionDomainHeuristicsKey, int64(status.SecurityLevelsAll)) + cfgIntOptions[CfgOptionDomainHeuristicsKey] = cfgOptionDomainHeuristics // Bypass prevention err = config.Register(&config.Option{ @@ -499,7 +500,9 @@ The lists are automatically updated every hour using incremental updates. Description: `Prevent apps from bypassing the privacy filter. Current Features: - Disable Firefox' internal DNS-over-HTTPs resolver -- Block direct access to public DNS resolvers`, +- Block direct access to public DNS resolvers + +Please note that if you are using the system resolver, bypass attempts might be additionally blocked there too.`, OptType: config.OptTypeInt, ExpertiseLevel: config.ExpertiseLevelUser, ReleaseLevel: config.ReleaseLevelBeta, diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go index 6f1c2f27..c6f05529 100644 --- a/profile/endpoints/endpoint-scopes.go +++ b/profile/endpoints/endpoint-scopes.go @@ -36,9 +36,8 @@ func (ep *EndpointScope) Matches(_ context.Context, entity *intel.Entity) (EPRes return Undeterminable, nil } - classification := netutils.ClassifyIP(entity.IP) var scope uint8 - switch classification { + switch entity.IPScope { case netutils.HostLocal: scope = scopeLocalhost case netutils.LinkLocal: diff --git a/profile/special.go b/profile/special.go index 754cdc67..252cd1c4 100644 --- a/profile/special.go +++ b/profile/special.go @@ -11,6 +11,11 @@ const ( // SystemProfileName is the name used for the system/kernel. SystemProfileName = "Operating System" + // SystemResolverProfileID is the profile ID used for the system's DNS resolver. + SystemResolverProfileID = "_system-resolver" + // SystemResolverProfileName is the name used for the system's DNS resolver. + SystemResolverProfileName = "System DNS Client" + // PortmasterProfileID is the profile ID used for the Portmaster Core itself. PortmasterProfileID = "_portmaster" // PortmasterProfileName is the name used for the Portmaster Core itself. @@ -35,6 +40,8 @@ func updateSpecialProfileMetadata(profile *Profile, binaryPath string) (ok, chan newProfileName = UnidentifiedProfileName case SystemProfileID: newProfileName = SystemProfileName + case SystemResolverProfileID: + newProfileName = SystemResolverProfileName case PortmasterProfileID: newProfileName = PortmasterProfileName case PortmasterAppProfileID: @@ -68,6 +75,9 @@ func getSpecialProfile(profileID, linkedPath string) *Profile { case SystemProfileID: return New(SourceLocal, SystemProfileID, linkedPath, nil) + case SystemResolverProfileID: + return New(SourceLocal, SystemResolverProfileID, linkedPath, nil) + case PortmasterProfileID: profile := New(SourceLocal, PortmasterProfileID, linkedPath, nil) profile.Internal = true diff --git a/resolver/ipinfo.go b/resolver/ipinfo.go index 7217e306..7c5f3d5f 100644 --- a/resolver/ipinfo.go +++ b/resolver/ipinfo.go @@ -40,6 +40,10 @@ type ResolvedDomain struct { // Domain. CNAMEs []string + // Resolver holds basic information about the resolver that provided this + // information. + Resolver *ResolverInfo + // Expires holds the timestamp when this entry expires. // This does not mean that the entry may not be used anymore afterwards, // but that this is used to calcuate the TTL of the database record. diff --git a/resolver/namerecord.go b/resolver/namerecord.go index 42a33323..68d6055f 100644 --- a/resolver/namerecord.go +++ b/resolver/namerecord.go @@ -49,9 +49,20 @@ type NameRecord struct { Extra []string Expires int64 - Server string - ServerScope int8 - ServerInfo string + Resolver *ResolverInfo +} + +// IsValid returns whether the NameRecord is valid and may be used. Otherwise, +// it should be disregarded. +func (nameRecord *NameRecord) IsValid() bool { + switch { + case nameRecord.Resolver == nil || nameRecord.Resolver.Type == "": + // Changed in v0.6.7: Introduced Resolver *ResolverInfo + return false + default: + // Up to date! + return true + } } func makeNameRecordKey(domain string, question string) string { @@ -67,7 +78,7 @@ func GetNameRecord(domain, question string) (*NameRecord, error) { return nil, err } - // unwrap + // Unwrap record if it's wrapped. if r.IsWrapped() { // only allocate a new struct, if we need it new := &NameRecord{} @@ -75,14 +86,24 @@ func GetNameRecord(domain, question string) (*NameRecord, error) { if err != nil { return nil, err } + // Check if the record is valid. + if !new.IsValid() { + return nil, errors.New("record is invalid (outdated format)") + } + return new, nil } - // or adjust type + // Or just adjust the type. new, ok := r.(*NameRecord) if !ok { return nil, fmt.Errorf("record not of type *NameRecord, but %T", r) } + // Check if the record is valid. + if !new.IsValid() { + return nil, errors.New("record is invalid (outdated format)") + } + return new, nil } diff --git a/resolver/resolve.go b/resolver/resolve.go index 11ed1590..31ceeb74 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -175,9 +175,9 @@ func checkCache(ctx context.Context, q *Query) *RRCache { } // Get the resolver that the rrCache was resolved with. - resolver := getActiveResolverByIDWithLocking(rrCache.Server) + resolver := getActiveResolverByIDWithLocking(rrCache.Resolver.ID()) if resolver == nil { - log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server) + log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %q has been removed", q.FQDN, q.QType.String(), rrCache.Resolver.ID()) return nil } @@ -361,11 +361,11 @@ resolveLoop: continue case errors.Is(err, ErrTimeout): resolver.Conn.ReportFailure() - log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.GetName()) + log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.Info.ID()) continue default: resolver.Conn.ReportFailure() - log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.GetName(), err) + log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.Info.ID(), err) continue } } diff --git a/resolver/resolver-env.go b/resolver/resolver-env.go index 2d4ad230..30786063 100644 --- a/resolver/resolver-env.go +++ b/resolver/resolver-env.go @@ -20,12 +20,13 @@ const ( var ( envResolver = &Resolver{ - Server: ServerSourceEnv, - ServerType: ServerTypeEnv, - ServerIPScope: netutils.SiteLocal, - ServerInfo: "Portmaster environment", - Source: ServerSourceEnv, - Conn: &envResolverConn{}, + ConfigURL: ServerSourceEnv, + Info: &ResolverInfo{ + Type: ServerTypeEnv, + Source: ServerSourceEnv, + IPScope: netutils.SiteLocal, + }, + Conn: &envResolverConn{}, } envResolvers = []*Resolver{envResolver} @@ -109,14 +110,12 @@ func (er *envResolverConn) makeRRCache(q *Query, answers []dns.RR) *RRCache { q.NoCaching = true return &RRCache{ - Domain: q.FQDN, - Question: q.QType, - RCode: dns.RcodeSuccess, - Answer: answers, - Extra: []dns.RR{internalSpecialUseComment}, // Always add comment about this TLD. - Server: envResolver.Server, - ServerScope: envResolver.ServerIPScope, - ServerInfo: envResolver.ServerInfo, + Domain: q.FQDN, + Question: q.QType, + RCode: dns.RcodeSuccess, + Answer: answers, + Extra: []dns.RR{internalSpecialUseComment}, // Always add comment about this TLD. + Resolver: envResolver.Info.Copy(), } } diff --git a/resolver/resolver-mdns.go b/resolver/resolver-mdns.go index 166cebfb..cd28eb5f 100644 --- a/resolver/resolver-mdns.go +++ b/resolver/resolver-mdns.go @@ -31,12 +31,13 @@ var ( questionsLock sync.Mutex mDNSResolver = &Resolver{ - Server: ServerSourceMDNS, - ServerType: ServerTypeDNS, - ServerIPScope: netutils.SiteLocal, - ServerInfo: "mDNS resolver", - Source: ServerSourceMDNS, - Conn: &mDNSResolverConn{}, + ConfigURL: ServerSourceMDNS, + Info: &ResolverInfo{ + Type: ServerTypeMDNS, + Source: ServerSourceMDNS, + IPScope: netutils.SiteLocal, + }, + Conn: &mDNSResolverConn{}, } mDNSResolvers = []*Resolver{mDNSResolver} ) @@ -200,12 +201,10 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { // create new and do not append if err != nil || rrCache.Modified < time.Now().Add(-2*time.Second).Unix() || rrCache.Expired() { rrCache = &RRCache{ - Domain: question.Name, - Question: dns.Type(question.Qtype), - RCode: dns.RcodeSuccess, - Server: mDNSResolver.Server, - ServerScope: mDNSResolver.ServerIPScope, - ServerInfo: mDNSResolver.ServerInfo, + Domain: question.Name, + Question: dns.Type(question.Qtype), + RCode: dns.RcodeSuccess, + Resolver: mDNSResolver.Info.Copy(), } } } @@ -302,13 +301,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { continue } rrCache = &RRCache{ - Domain: v.Header().Name, - Question: dns.Type(v.Header().Class), - RCode: dns.RcodeSuccess, - Answer: []dns.RR{v}, - Server: mDNSResolver.Server, - ServerScope: mDNSResolver.ServerIPScope, - ServerInfo: mDNSResolver.ServerInfo, + Domain: v.Header().Name, + Question: dns.Type(v.Header().Class), + RCode: dns.RcodeSuccess, + Answer: []dns.RR{v}, + Resolver: mDNSResolver.Info.Copy(), } rrCache.Clean(minMDnsTTL) err := rrCache.Save() @@ -423,12 +420,10 @@ func queryMulticastDNS(ctx context.Context, q *Query) (*RRCache, error) { // Respond with NXDomain. return &RRCache{ - Domain: q.FQDN, - Question: q.QType, - RCode: dns.RcodeNameError, - Server: mDNSResolver.Server, - ServerScope: mDNSResolver.ServerIPScope, - ServerInfo: mDNSResolver.ServerInfo, + Domain: q.FQDN, + Question: q.QType, + RCode: dns.RcodeNameError, + Resolver: mDNSResolver.Info.Copy(), }, nil } diff --git a/resolver/resolver-plain.go b/resolver/resolver-plain.go index 3892ab91..01862a7e 100644 --- a/resolver/resolver-plain.go +++ b/resolver/resolver-plain.go @@ -72,22 +72,20 @@ func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error) // check if blocked if pr.resolver.IsBlockedUpstream(reply) { - return nil, &BlockedUpstreamError{pr.resolver.GetName()} + return nil, &BlockedUpstreamError{pr.resolver.Info.DescriptiveName()} } // hint network environment at successful connection netenv.ReportSuccessfulConnection() newRecord := &RRCache{ - Domain: q.FQDN, - Question: q.QType, - RCode: reply.Rcode, - Answer: reply.Answer, - Ns: reply.Ns, - Extra: reply.Extra, - Server: pr.resolver.Server, - ServerScope: pr.resolver.ServerIPScope, - ServerInfo: pr.resolver.ServerInfo, + Domain: q.FQDN, + Question: q.QType, + RCode: reply.Rcode, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Resolver: pr.resolver.Info.Copy(), } // TODO: check if reply.Answer is valid diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index a5e1ed30..49680516 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -49,15 +49,13 @@ type InFlightQuery struct { // MakeCacheRecord creates an RCache record from a reply. func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache { return &RRCache{ - Domain: ifq.Query.FQDN, - Question: ifq.Query.QType, - RCode: reply.Rcode, - Answer: reply.Answer, - Ns: reply.Ns, - Extra: reply.Extra, - Server: ifq.Resolver.Server, - ServerScope: ifq.Resolver.ServerIPScope, - ServerInfo: ifq.Resolver.ServerInfo, + Domain: ifq.Query.FQDN, + Question: ifq.Query.QType, + RCode: reply.Rcode, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Resolver: ifq.Resolver.Info.Copy(), } } @@ -172,7 +170,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { } if tr.resolver.IsBlockedUpstream(reply) { - return nil, &BlockedUpstreamError{tr.resolver.GetName()} + return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()} } return inFlight.MakeCacheRecord(reply), nil @@ -189,7 +187,7 @@ func (tr *TCPResolver) checkClientStatus() { select { case tr.clientHeartbeat <- struct{}{}: case <-time.After(heartbeatTimeout): - log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName()) + log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.Info.DescriptiveName()) stopClient() } } @@ -299,7 +297,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed b select { case mgr.tr.queries <- inFlight.Msg: default: - log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.GetName()) + log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Info.DescriptiveName()) } } // in-flight queries that match the connection instance ID are not changed. They are already in the queue. @@ -317,7 +315,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed b select { case mgr.tr.queries <- msg: case <-time.After(2 * time.Second): - log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.GetName()) + log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Info.DescriptiveName()) } return nil }) @@ -343,7 +341,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() ( var err error conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress) if err != nil { - log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress) + log.Debugf("resolver: failed to connect to %s", mgr.tr.resolver.Info.DescriptiveName()) return nil, nil, nil, nil } connCtx, cancelConnCtx = context.WithCancel(context.Background()) @@ -356,9 +354,8 @@ func (mgr *tcpResolverConnMgr) establishConnection() ( // Log that a connection to the resolver was established. log.Debugf( - "resolver: connected to %s (%s) with %d queries waiting", - mgr.tr.resolver.GetName(), - conn.RemoteAddr(), + "resolver: connected to %s with %d queries waiting", + mgr.tr.resolver.Info.DescriptiveName(), waitingQueries, ) @@ -434,7 +431,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context activeQueries := len(mgr.tr.inFlightQueries) mgr.tr.Unlock() if activeQueries == 0 { - log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr()) + log.Debugf("resolver: recycling conn to %s", mgr.tr.resolver.Info.DescriptiveName()) return true } } @@ -454,9 +451,8 @@ func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) if !ok { log.Debugf( - "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", - mgr.tr.resolver.GetName(), - conn.RemoteAddr(), + "resolver: received possibly unsolicited reply from %s: txid=%d q=%+v", + mgr.tr.resolver.Info.DescriptiveName(), msg.Id, msg.Question, ) @@ -519,24 +515,21 @@ func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, con switch { case errors.Is(err, io.EOF): log.Debugf( - "resolver: connection to %s (%s) was closed with %d in-flight queries", - mgr.tr.resolver.GetName(), - conn.RemoteAddr(), + "resolver: connection to %s was closed with %d in-flight queries", + mgr.tr.resolver.Info.DescriptiveName(), inFlightQueries, ) case reading: log.Warningf( - "resolver: read error from %s (%s) with %d in-flight queries: %s", - mgr.tr.resolver.GetName(), - conn.RemoteAddr(), + "resolver: read error from %s with %d in-flight queries: %s", + mgr.tr.resolver.Info.DescriptiveName(), inFlightQueries, err, ) default: log.Warningf( - "resolver: write error to %s (%s) with %d in-flight queries: %s", - mgr.tr.resolver.GetName(), - conn.RemoteAddr(), + "resolver: write error to %s with %d in-flight queries: %s", + mgr.tr.resolver.Info.DescriptiveName(), inFlightQueries, err, ) diff --git a/resolver/resolver.go b/resolver/resolver.go index ad7f5741..8b523148 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -2,21 +2,24 @@ package resolver import ( "context" + "fmt" "net" "sync" "time" "github.com/miekg/dns" "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/network/netutils" ) // DNS Resolver Attributes const ( - ServerTypeDNS = "dns" - ServerTypeTCP = "tcp" - ServerTypeDoT = "dot" - ServerTypeDoH = "doh" - ServerTypeEnv = "env" + ServerTypeDNS = "dns" + ServerTypeTCP = "tcp" + ServerTypeDoT = "dot" + ServerTypeDoH = "doh" + ServerTypeMDNS = "mdns" + ServerTypeEnv = "env" ServerSourceConfigured = "config" ServerSourceOperatingSystem = "system" @@ -39,14 +42,13 @@ type Resolver struct { // - `empty`: NXDomain result, but without any other record in any section // - `refused`: Request was refused // - `zeroip`: Answer only contains zeroip - Server string + ConfigURL string - // Source describes from where the resolver configuration originated. - Source string + // Info holds the parsed configuration. + Info *ResolverInfo - // Name is the name of the resolver as passed via - // ?name=. - Name string + // ServerAddress holds the resolver address for easier use. + ServerAddress string // UpstreamBlockDetection defines the detection type // to identifier upstream DNS query blocking. @@ -57,14 +59,6 @@ type Resolver struct { // - disabled UpstreamBlockDetection string - // Parsed config - ServerType string - ServerAddress string - ServerIP net.IP - ServerIPScope int8 - ServerPort uint16 - ServerInfo string - // Special Options VerifyDomain string Search []string @@ -73,25 +67,111 @@ type Resolver struct { Conn ResolverConn `json:"-"` } +// ResolverInfo is a subset of resolver attributes that is attached to answers +// from that server in order to use it later for decision making. It must not +// be changed by anyone after creation and initialization is complete. +type ResolverInfo struct { + // Name describes the name given to the resolver. The name is configured in the config URL using the name parameter. + Name string + + // Type describes the type of the resolver. + // Possible values include dns, tcp, dot, doh, mdns, env. + Type string + + // Source describes where the resolver configuration came from. + // Possible values include config, system, mdns, env. + Source string + + // IP is the IP address of the resolver + IP net.IP + + // IPScope is the network scope of the IP address. + IPScope netutils.IPScope + + // Port is the udp/tcp port of the resolver. + Port uint16 + + // id holds a unique ID for this resolver. + id string + idGen sync.Once +} + +// ID returns the unique ID of the resolver. +func (info *ResolverInfo) ID() string { + // Generate the ID the first time. + info.idGen.Do(func() { + switch info.Type { + case ServerTypeMDNS: + info.id = ServerTypeMDNS + case ServerTypeEnv: + info.id = ServerTypeEnv + default: + info.id = fmt.Sprintf( + "%s://%s:%d#%s", + info.Type, + info.IP, + info.Port, + info.Source, + ) + } + }) + + return info.id +} + +// DescriptiveName returns a human readable, but also detailed representation +// of the resolver. +func (info *ResolverInfo) DescriptiveName() string { + switch { + case info.Type == ServerTypeMDNS: + return "MDNS" + case info.Type == ServerTypeEnv: + return "Portmaster Environment" + case info.Name != "": + return fmt.Sprintf( + "%s (%s)", + info.Name, + info.ID(), + ) + default: + return fmt.Sprintf( + "%s (%s)", + info.IP.String(), + info.ID(), + ) + } +} + +// Copy returns a full copy of the ResolverInfo. +func (info *ResolverInfo) Copy() *ResolverInfo { + // Force idGen to run before we copy. + _ = info.ID() + + // Copy manually in order to not copy the mutex. + cp := &ResolverInfo{ + Name: info.Name, + Type: info.Type, + Source: info.Source, + IP: info.IP, + IPScope: info.IPScope, + Port: info.Port, + id: info.id, + } + // Trigger idGen.Do(), as the ID is already generated. + cp.idGen.Do(func() {}) + + return cp +} + // IsBlockedUpstream returns true if the request has been blocked // upstream. func (resolver *Resolver) IsBlockedUpstream(answer *dns.Msg) bool { return isBlockedUpstream(resolver, answer) } -// GetName returns the name of the server. If no name -// is configured the server address is returned. -func (resolver *Resolver) GetName() string { - if resolver.Name != "" { - return resolver.Name - } - - return resolver.Server -} - // String returns the URL representation of the resolver. func (resolver *Resolver) String() string { - return resolver.GetName() + return resolver.Info.DescriptiveName() } // ResolverConn is an interface to implement different types of query backends. diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index d8ab14e4..5070a248 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -52,7 +52,7 @@ func TestSingleResolving(t *testing.T) { if err != nil { t.Fatal(err) } - t.Logf("running bulk query test with resolver %s", resolver.Server) + t.Logf("running bulk query test with resolver %s", resolver.Info.DescriptiveName()) started := time.Now() @@ -83,7 +83,7 @@ func TestBulkResolving(t *testing.T) { if err != nil { t.Fatal(err) } - t.Logf("running bulk query test with resolver %s", resolver.Server) + t.Logf("running bulk query test with resolver %s", resolver.Info.DescriptiveName()) started := time.Now() diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 2baeb2de..40608faa 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -5,6 +5,7 @@ import ( "net" "net/url" "sort" + "strconv" "strings" "sync" @@ -61,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string { } func resolverConnFactory(resolver *Resolver) ResolverConn { - switch resolver.ServerType { + switch resolver.Info.Type { case ServerTypeTCP: return NewTCPResolver(resolver) case ServerTypeDoT: @@ -82,26 +83,36 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { switch u.Scheme { case ServerTypeDNS, ServerTypeDoT, ServerTypeTCP: default: - return nil, false, fmt.Errorf("invalid DNS resolver scheme %q", u.Scheme) + return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme) } ip := net.ParseIP(u.Hostname()) if ip == nil { - return nil, false, fmt.Errorf("invalid resolver IP") + return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname()) } // Add default port for scheme if it is missing. - if u.Port() == "" { - switch u.Scheme { - case ServerTypeDNS, ServerTypeTCP: - u.Host += ":53" - case ServerTypeDoT: - u.Host += ":853" + var port uint16 + hostPort := u.Port() + switch { + case hostPort != "": + parsedPort, err := strconv.ParseUint(hostPort, 10, 16) + if err != nil { + return nil, false, fmt.Errorf("resolver port %q invalid", u.Port()) } + port = uint16(parsedPort) + case u.Scheme == ServerTypeDNS, u.Scheme == ServerTypeTCP: + port = 53 + case u.Scheme == ServerTypeDoH: + port = 443 + case u.Scheme == ServerTypeDoT: + port = 853 + default: + return nil, false, fmt.Errorf("missing port in %q", u.Host) } - scope := netutils.ClassifyIP(ip) - if scope == netutils.HostLocal { + scope := netutils.GetIPScope(ip) + if scope.IsLocalhost() { return nil, true, nil // skip } @@ -127,24 +138,20 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { } new := &Resolver{ - Server: resolverURL, - ServerType: u.Scheme, - ServerAddress: u.Host, - ServerIP: ip, - ServerIPScope: scope, - Source: source, + ConfigURL: resolverURL, + Info: &ResolverInfo{ + Name: query.Get("name"), + Type: u.Scheme, + Source: source, + IP: ip, + IPScope: scope, + Port: port, + }, + ServerAddress: net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), VerifyDomain: verifyDomain, - Name: query.Get("name"), UpstreamBlockDetection: blockType, } - u.RawQuery = "" // Remove options from parsed URL - if new.Name != "" { - new.ServerInfo = fmt.Sprintf("%s (%s, from %s)", new.Name, u, source) - } else { - new.ServerInfo = fmt.Sprintf("%s (from %s)", u, source) - } - new.Conn = resolverConnFactory(new) return new, false, nil } @@ -195,7 +202,7 @@ func getSystemResolvers() (resolvers []*Resolver) { continue } - if netutils.IPIsLAN(nameserver.IP) { + if resolver.Info.IPScope.IsLAN() { configureSearchDomains(resolver, nameserver.Search) } @@ -244,16 +251,16 @@ func loadResolvers() { activeResolvers = make(map[string]*Resolver) // add for _, resolver := range newResolvers { - activeResolvers[resolver.Server] = resolver + activeResolvers[resolver.Info.ID()] = resolver } - activeResolvers[mDNSResolver.Server] = mDNSResolver - activeResolvers[envResolver.Server] = envResolver + activeResolvers[mDNSResolver.Info.ID()] = mDNSResolver + activeResolvers[envResolver.Info.ID()] = envResolver // log global resolvers if len(globalResolvers) > 0 { log.Trace("resolver: loaded global resolvers:") for _, resolver := range globalResolvers { - log.Tracef("resolver: %s", resolver.Server) + log.Tracef("resolver: %s", resolver.ConfigURL) } } else { log.Warning("resolver: no global resolvers loaded") @@ -263,7 +270,7 @@ func loadResolvers() { if len(localResolvers) > 0 { log.Trace("resolver: loaded local resolvers:") for _, resolver := range localResolvers { - log.Tracef("resolver: %s", resolver.Server) + log.Tracef("resolver: %s", resolver.ConfigURL) } } else { log.Info("resolver: no local resolvers loaded") @@ -273,7 +280,7 @@ func loadResolvers() { if len(systemResolvers) > 0 { log.Trace("resolver: loaded system/network-assigned resolvers:") for _, resolver := range systemResolvers { - log.Tracef("resolver: %s", resolver.Server) + log.Tracef("resolver: %s", resolver.ConfigURL) } } else { log.Info("resolver: no system/network-assigned resolvers loaded") @@ -285,7 +292,7 @@ func loadResolvers() { for _, scope := range localScopes { var scopeServers []string for _, resolver := range scope.Resolvers { - scopeServers = append(scopeServers, resolver.Server) + scopeServers = append(scopeServers, resolver.ConfigURL) } log.Tracef("resolver: %s: %s", scope.Domain, strings.Join(scopeServers, ", ")) } @@ -306,11 +313,11 @@ func setScopedResolvers(resolvers []*Resolver) { localScopes = make([]*Scope, 0) for _, resolver := range resolvers { - if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) { + if resolver.Info.IPScope.IsLAN() { localResolvers = append(localResolvers, resolver) } - if resolver.Source == ServerSourceOperatingSystem { + if resolver.Info.Source == ServerSourceOperatingSystem { systemResolvers = append(systemResolvers, resolver) } diff --git a/resolver/rrcache.go b/resolver/rrcache.go index 79118145..8c901131 100644 --- a/resolver/rrcache.go +++ b/resolver/rrcache.go @@ -29,10 +29,8 @@ type RRCache struct { Extra []dns.RR Expires int64 - // Source Information - Server string - ServerScope int8 - ServerInfo string + // Resolver Information + Resolver *ResolverInfo // Metadata about the request and handling ServedFromCache bool @@ -133,13 +131,11 @@ func (rrCache *RRCache) ExportAllARecords() (ips []net.IP) { // ToNameRecord converts the RRCache to a NameRecord for cleaner persistence. func (rrCache *RRCache) ToNameRecord() *NameRecord { new := &NameRecord{ - Domain: rrCache.Domain, - Question: rrCache.Question.String(), - RCode: rrCache.RCode, - Expires: rrCache.Expires, - Server: rrCache.Server, - ServerScope: rrCache.ServerScope, - ServerInfo: rrCache.ServerInfo, + Domain: rrCache.Domain, + Question: rrCache.Question.String(), + RCode: rrCache.RCode, + Expires: rrCache.Expires, + Resolver: rrCache.Resolver, } // stringify RR entries @@ -204,9 +200,7 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) { rrCache.Extra = parseRR(rrCache.Extra, entry) } - rrCache.Server = nameRecord.Server - rrCache.ServerScope = nameRecord.ServerScope - rrCache.ServerInfo = nameRecord.ServerInfo + rrCache.Resolver = nameRecord.Resolver rrCache.ServedFromCache = true rrCache.Modified = nameRecord.Meta().Modified return rrCache, nil @@ -259,9 +253,7 @@ func (rrCache *RRCache) ShallowCopy() *RRCache { Extra: rrCache.Extra, Expires: rrCache.Expires, - Server: rrCache.Server, - ServerScope: rrCache.ServerScope, - ServerInfo: rrCache.ServerInfo, + Resolver: rrCache.Resolver, ServedFromCache: rrCache.ServedFromCache, RequestingNew: rrCache.RequestingNew, @@ -302,9 +294,9 @@ func (rrCache *RRCache) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra []dns.RR) { // Add cache status and source of data. if rrCache.ServedFromCache { - extra = addExtra(ctx, extra, "served from cache, resolved by "+rrCache.ServerInfo) + extra = addExtra(ctx, extra, "served from cache, resolved by "+rrCache.Resolver.DescriptiveName()) } else { - extra = addExtra(ctx, extra, "freshly resolved by "+rrCache.ServerInfo) + extra = addExtra(ctx, extra, "freshly resolved by "+rrCache.Resolver.DescriptiveName()) } // Add expiry and cache information. diff --git a/resolver/scopes.go b/resolver/scopes.go index a7772186..54819358 100644 --- a/resolver/scopes.go +++ b/resolver/scopes.go @@ -158,13 +158,13 @@ addNextResolver: for _, resolver := range addResolvers { // check for compliance if err := resolver.checkCompliance(ctx, q); err != nil { - log.Tracer(ctx).Tracef("skipping non-compliant resolver %s: %s", resolver.GetName(), err) + log.Tracer(ctx).Tracef("skipping non-compliant resolver %s: %s", resolver.Info.DescriptiveName(), err) continue } // deduplicate for _, selectedResolver := range selected { - if selectedResolver.Server == resolver.Server { + if selectedResolver.Info.ID() == resolver.Info.ID() { continue addNextResolver } } @@ -208,7 +208,7 @@ func (q *Query) checkCompliance() error { func (resolver *Resolver) checkCompliance(_ context.Context, q *Query) error { if noInsecureProtocols(q.SecurityLevel) { - switch resolver.ServerType { + switch resolver.Info.Type { case ServerTypeDNS: return errInsecureProtocol case ServerTypeTCP: @@ -218,20 +218,20 @@ func (resolver *Resolver) checkCompliance(_ context.Context, q *Query) error { case ServerTypeDoH: // compliant case ServerTypeEnv: - // compliant (data is sources from local network only and is highly limited) + // compliant (data is sourced from local network only and is highly limited) default: return errInsecureProtocol } } if noAssignedNameservers(q.SecurityLevel) { - if resolver.Source == ServerSourceOperatingSystem { + if resolver.Info.Source == ServerSourceOperatingSystem { return errAssignedServer } } if noMulticastDNS(q.SecurityLevel) { - if resolver.Source == ServerSourceMDNS { + if resolver.Info.Source == ServerSourceMDNS { return errMulticastDNS } }