mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Merge branch 'develop' into feature/ui-revamp
This commit is contained in:
commit
2ccf8c635a
33 changed files with 383 additions and 514 deletions
18
Gopkg.lock
generated
18
Gopkg.lock
generated
|
@ -49,6 +49,17 @@
|
|||
revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73"
|
||||
version = "v1.1.1"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:c8098f53cd182561cfb128c9a5ba70e41ad2364b763f33f05c6bd54003ae6495"
|
||||
name = "github.com/florianl/go-nfqueue"
|
||||
packages = [
|
||||
".",
|
||||
"internal/unix",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "a2f196e98ab0ffdcb8b5252e7cbba98e45dea204"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:b6581f9180e0f2d5549280d71819ab951db9d511478c87daca95669589d505c0"
|
||||
name = "github.com/go-ole/go-ole"
|
||||
|
@ -140,12 +151,12 @@
|
|||
version = "v1.1.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:9d781ead5ca35ef02cdf0dc516b239cb387fe73207b0dd01760f7d4a825f4cd3"
|
||||
digest = "1:508f444b8e00a569a40899aaf5740348b44c305d36f36d4f002b277677deef95"
|
||||
name = "github.com/miekg/dns"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "da812eed45cba1ce4c978e746039483064b8f92d"
|
||||
revision = "10e0aeedbee54849adab780611454192a9980443"
|
||||
version = "v1.1.33"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:3282ac9a9ddf5c2c0eda96693364d34fe0f8d10a0748259082a5c9fbd3e1f7e4"
|
||||
|
@ -368,7 +379,6 @@
|
|||
"github.com/google/renameio",
|
||||
"github.com/hashicorp/go-multierror",
|
||||
"github.com/hashicorp/go-version",
|
||||
"github.com/mdlayher/netlink",
|
||||
"github.com/miekg/dns",
|
||||
"github.com/oschwald/maxminddb-golang",
|
||||
"github.com/shirou/gopsutil/process",
|
||||
|
|
|
@ -26,10 +26,6 @@
|
|||
|
||||
ignored = ["github.com/safing/portbase/*", "github.com/safing/spn/*"]
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/miekg/dns"
|
||||
branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/florianl/go-nfqueue"
|
||||
branch = "master" # switch back once we migrate to go.mod
|
||||
|
|
|
@ -121,9 +121,9 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res
|
|||
err,
|
||||
)
|
||||
}
|
||||
} else if rrCache.TTL > time.Now().Add(10*time.Second).Unix() {
|
||||
} else if rrCache.Expires > time.Now().Add(10*time.Second).Unix() {
|
||||
// Set a low TTL of 10 seconds if TTL is higher than that.
|
||||
rrCache.TTL = time.Now().Add(10 * time.Second).Unix()
|
||||
rrCache.Expires = time.Now().Add(10 * time.Second).Unix()
|
||||
err := rrCache.Save()
|
||||
if err != nil {
|
||||
log.Debugf(
|
||||
|
@ -170,31 +170,28 @@ func DecideOnResolvedDNS(
|
|||
|
||||
updateIPsAndCNAMEs(q, rrCache, conn)
|
||||
|
||||
if mayBlockCNAMEs(conn) {
|
||||
if mayBlockCNAMEs(ctx, conn) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Gate17 integration
|
||||
// tunnelInfo, err := AssignTunnelIP(fqdn)
|
||||
|
||||
return updatedRR
|
||||
}
|
||||
|
||||
func mayBlockCNAMEs(conn *network.Connection) bool {
|
||||
func mayBlockCNAMEs(ctx context.Context, conn *network.Connection) bool {
|
||||
// if we have CNAMEs and the profile is configured to filter them
|
||||
// we need to re-check the lists and endpoints here
|
||||
if conn.Process().Profile().FilterCNAMEs() {
|
||||
conn.Entity.ResetLists()
|
||||
conn.Entity.EnableCNAMECheck(true)
|
||||
conn.Entity.EnableCNAMECheck(ctx, true)
|
||||
|
||||
result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity)
|
||||
result, reason := conn.Process().Profile().MatchEndpoint(ctx, conn.Entity)
|
||||
if result == endpoints.Denied {
|
||||
conn.BlockWithContext(reason.String(), reason.Context())
|
||||
return true
|
||||
}
|
||||
|
||||
if result == endpoints.NoMatch {
|
||||
result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity)
|
||||
result, reason = conn.Process().Profile().MatchFilterLists(ctx, conn.Entity)
|
||||
if result == endpoints.Denied {
|
||||
conn.BlockWithContext(reason.String(), reason.Context())
|
||||
return true
|
||||
|
@ -205,10 +202,19 @@ func mayBlockCNAMEs(conn *network.Connection) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// updateIPsAndCNAMEs saves all the IP->Name mappings to the cache database and
|
||||
// updates the CNAMEs in the Connection's Entity.
|
||||
func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *network.Connection) {
|
||||
// save IP addresses to IPInfo
|
||||
// Get profileID for scoping IPInfo.
|
||||
var profileID string
|
||||
proc := conn.Process()
|
||||
if proc != nil {
|
||||
profileID = proc.LocalProfileKey
|
||||
}
|
||||
|
||||
// Collect IPs and CNAMEs.
|
||||
cnames := make(map[string]string)
|
||||
ips := make(map[string]struct{})
|
||||
ips := make([]net.IP, 0, len(rrCache.Answer))
|
||||
|
||||
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
|
||||
switch v := rr.(type) {
|
||||
|
@ -216,19 +222,27 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
|
|||
cnames[v.Hdr.Name] = v.Target
|
||||
|
||||
case *dns.A:
|
||||
ips[v.A.String()] = struct{}{}
|
||||
ips = append(ips, v.A)
|
||||
|
||||
case *dns.AAAA:
|
||||
ips[v.AAAA.String()] = struct{}{}
|
||||
ips = append(ips, v.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
for ip := range ips {
|
||||
record := resolver.ResolvedDomain{
|
||||
Domain: q.FQDN,
|
||||
// Package IPs and CNAMEs into IPInfo structs.
|
||||
for _, ip := range ips {
|
||||
// Never save domain attributions for localhost IPs.
|
||||
if netutils.ClassifyIP(ip) == netutils.HostLocal {
|
||||
continue
|
||||
}
|
||||
|
||||
// resolve all CNAMEs in the correct order.
|
||||
// Create new record for this IP.
|
||||
record := resolver.ResolvedDomain{
|
||||
Domain: q.FQDN,
|
||||
Expires: rrCache.Expires,
|
||||
}
|
||||
|
||||
// Resolve all CNAMEs in the correct order and add the to the record.
|
||||
var domain = q.FQDN
|
||||
for {
|
||||
nextDomain, isCNAME := cnames[domain]
|
||||
|
@ -240,31 +254,30 @@ func updateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
|
|||
domain = nextDomain
|
||||
}
|
||||
|
||||
// update the entity to include the cnames
|
||||
// Update the entity to include the CNAMEs of the query response.
|
||||
conn.Entity.CNAME = record.CNAMEs
|
||||
|
||||
// get the existing IP info or create a new one
|
||||
var save bool
|
||||
info, err := resolver.GetIPInfo(ip)
|
||||
// Check if there is an existing record for this DNS response.
|
||||
// Else create a new one.
|
||||
ipString := ip.String()
|
||||
info, err := resolver.GetIPInfo(profileID, ipString)
|
||||
if err != nil {
|
||||
if err != database.ErrNotFound {
|
||||
log.Errorf("nameserver: failed to search for IP info record: %s", err)
|
||||
}
|
||||
|
||||
info = &resolver.IPInfo{
|
||||
IP: ip,
|
||||
IP: ipString,
|
||||
ProfileID: profileID,
|
||||
}
|
||||
save = true
|
||||
}
|
||||
|
||||
// and the new resolved domain record and save
|
||||
if new := info.AddDomain(record); new {
|
||||
save = true
|
||||
}
|
||||
if save {
|
||||
if err := info.Save(); err != nil {
|
||||
log.Errorf("nameserver: failed to save IP info record: %s", err)
|
||||
}
|
||||
// Add the new record to the resolved domains for this IP and scope.
|
||||
info.AddDomain(record)
|
||||
|
||||
// Save if the record is new or has been updated.
|
||||
if err := info.Save(); err != nil {
|
||||
log.Errorf("nameserver: failed to save IP info record: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,13 +87,13 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
switch meta.Protocol {
|
||||
case packet.ICMP:
|
||||
// Always permit ICMP.
|
||||
log.Debugf("accepting ICMP: %s", pkt)
|
||||
log.Debugf("filter: fast-track accepting ICMP: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
return true
|
||||
|
||||
case packet.ICMPv6:
|
||||
// Always permit ICMPv6.
|
||||
log.Debugf("accepting ICMPv6: %s", pkt)
|
||||
log.Debugf("filter: fast-track accepting ICMPv6: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
return true
|
||||
|
||||
|
@ -116,7 +116,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Debugf("accepting DHCP: %s", pkt)
|
||||
log.Debugf("filter: fast-track accepting DHCP: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
return true
|
||||
|
||||
|
@ -141,7 +141,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
// Only allow to own IPs.
|
||||
dstIsMe, err := netenv.IsMyIP(meta.Dst)
|
||||
if err != nil {
|
||||
log.Warningf("filter: failed to check if IP is local: %s", err)
|
||||
log.Warningf("filter: failed to check if IP %s is local: %s", meta.Dst, err)
|
||||
}
|
||||
if !dstIsMe {
|
||||
return false
|
||||
|
@ -150,9 +150,9 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
// Log and permit.
|
||||
switch meta.DstPort {
|
||||
case 53:
|
||||
log.Debugf("accepting local dns: %s", pkt)
|
||||
log.Debugf("filter: fast-track accepting local dns: %s", pkt)
|
||||
case apiPort:
|
||||
log.Debugf("accepting api connection: %s", pkt)
|
||||
log.Debugf("filter: fast-track accepting api connection: %s", pkt)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
}
|
||||
|
||||
func initialHandler(conn *network.Connection, pkt packet.Packet) {
|
||||
log.Tracer(pkt.Ctx()).Trace("filter: [initial handler]")
|
||||
log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler")
|
||||
|
||||
// check for internal firewall bypass
|
||||
ps := getPortStatusAndMarkUsed(pkt.Info().LocalPort())
|
||||
|
|
|
@ -4,6 +4,7 @@ package windowskext
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
@ -45,6 +46,11 @@ func Handler(packets chan packet.Packet) {
|
|||
|
||||
packetInfo, err := RecvVerdictRequest()
|
||||
if err != nil {
|
||||
// Check if we are done with processing.
|
||||
if errors.Is(err, ErrKextNotReady) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warningf("failed to get packet from windows kext: %s", err)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -134,7 +134,7 @@ func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Pa
|
|||
return false
|
||||
}
|
||||
|
||||
func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkEndpointLists(ctx context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
var result endpoints.EPResult
|
||||
var reason endpoints.Reason
|
||||
|
||||
|
@ -143,9 +143,9 @@ func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Pa
|
|||
|
||||
// check endpoints list
|
||||
if conn.Inbound {
|
||||
result, reason = p.MatchServiceEndpoint(conn.Entity)
|
||||
result, reason = p.MatchServiceEndpoint(ctx, conn.Entity)
|
||||
} else {
|
||||
result, reason = p.MatchEndpoint(conn.Entity)
|
||||
result, reason = p.MatchEndpoint(ctx, conn.Entity)
|
||||
}
|
||||
switch result {
|
||||
case endpoints.Denied:
|
||||
|
@ -271,7 +271,7 @@ func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet.
|
|||
// apply privacy filter lists
|
||||
p := conn.Process().Profile()
|
||||
|
||||
result, reason := p.MatchFilterLists(conn.Entity)
|
||||
result, reason := p.MatchFilterLists(ctx, conn.Entity)
|
||||
switch result {
|
||||
case endpoints.Denied:
|
||||
conn.DenyWithContext(reason.String(), reason.Context())
|
||||
|
|
119
intel/entity.go
119
intel/entity.go
|
@ -37,12 +37,19 @@ type Entity struct {
|
|||
// Protocol is the protcol number used by the connection.
|
||||
Protocol uint8
|
||||
|
||||
// Port is the destination port of the connection
|
||||
// Port is the remote port of the connection
|
||||
Port uint16
|
||||
|
||||
// dstPort is the destination port of the connection
|
||||
dstPort uint16
|
||||
|
||||
// Domain is the target domain of the connection.
|
||||
Domain string
|
||||
|
||||
// ReverseDomain is the domain the IP address points to. This is only
|
||||
// resolved and populated when needed.
|
||||
ReverseDomain string
|
||||
|
||||
// CNAME is a list of domain names that have been
|
||||
// resolved for Domain.
|
||||
CNAME []string
|
||||
|
@ -88,10 +95,20 @@ func (e *Entity) Init() *Entity {
|
|||
return e
|
||||
}
|
||||
|
||||
// SetDstPort sets the destination port.
|
||||
func (e *Entity) SetDstPort(dstPort uint16) {
|
||||
e.dstPort = dstPort
|
||||
}
|
||||
|
||||
// DstPort returns the destination port.
|
||||
func (e *Entity) DstPort() uint16 {
|
||||
return e.dstPort
|
||||
}
|
||||
|
||||
// FetchData fetches additional information, meant to be called before persisting an entity record.
|
||||
func (e *Entity) FetchData() {
|
||||
e.getLocation()
|
||||
e.getLists()
|
||||
func (e *Entity) FetchData(ctx context.Context) {
|
||||
e.getLocation(ctx)
|
||||
e.getLists(ctx)
|
||||
}
|
||||
|
||||
// ResetLists resets the current list data and forces
|
||||
|
@ -119,18 +136,18 @@ func (e *Entity) ResetLists() {
|
|||
|
||||
// ResolveSubDomainLists enables or disables list lookups for
|
||||
// sub-domains.
|
||||
func (e *Entity) ResolveSubDomainLists(enabled bool) {
|
||||
func (e *Entity) ResolveSubDomainLists(ctx context.Context, enabled bool) {
|
||||
if e.domainListLoaded {
|
||||
log.Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain)
|
||||
log.Tracer(ctx).Warningf("intel/filterlists: tried to change sub-domain resolving for %s but lists are already fetched", e.Domain)
|
||||
}
|
||||
e.resolveSubDomainLists = enabled
|
||||
}
|
||||
|
||||
// EnableCNAMECheck enalbes or disables list lookups for
|
||||
// entity CNAMEs.
|
||||
func (e *Entity) EnableCNAMECheck(enabled bool) {
|
||||
func (e *Entity) EnableCNAMECheck(ctx context.Context, enabled bool) {
|
||||
if e.domainListLoaded {
|
||||
log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain)
|
||||
log.Tracer(ctx).Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain)
|
||||
}
|
||||
e.checkCNAMEs = enabled
|
||||
}
|
||||
|
@ -148,13 +165,8 @@ func (e *Entity) EnableReverseResolving() {
|
|||
e.reverseResolveEnabled = true
|
||||
}
|
||||
|
||||
func (e *Entity) reverseResolve() {
|
||||
func (e *Entity) reverseResolve(ctx context.Context) {
|
||||
e.reverseResolveOnce.Do(func() {
|
||||
// check if we should resolve
|
||||
if !e.reverseResolveEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
// need IP!
|
||||
if e.IP == nil {
|
||||
return
|
||||
|
@ -165,18 +177,25 @@ func (e *Entity) reverseResolve() {
|
|||
return
|
||||
}
|
||||
// TODO: security level
|
||||
domain, err := reverseResolver(context.TODO(), e.IP.String(), status.SecurityLevelNormal)
|
||||
domain, err := reverseResolver(ctx, e.IP.String(), status.SecurityLevelNormal)
|
||||
if err != nil {
|
||||
log.Warningf("intel: failed to resolve IP %s: %s", e.IP, err)
|
||||
log.Tracer(ctx).Warningf("intel: failed to resolve IP %s: %s", e.IP, err)
|
||||
return
|
||||
}
|
||||
e.Domain = domain
|
||||
e.ReverseDomain = domain
|
||||
})
|
||||
}
|
||||
|
||||
// GetDomain returns the domain and whether it is set.
|
||||
func (e *Entity) GetDomain() (string, bool) {
|
||||
e.reverseResolve()
|
||||
func (e *Entity) GetDomain(ctx context.Context, mayUseReverseDomain bool) (string, bool) {
|
||||
if mayUseReverseDomain && e.reverseResolveEnabled {
|
||||
e.reverseResolve(ctx)
|
||||
|
||||
if e.ReverseDomain == "" {
|
||||
return "", false
|
||||
}
|
||||
return e.ReverseDomain, true
|
||||
}
|
||||
|
||||
if e.Domain == "" {
|
||||
return "", false
|
||||
|
@ -194,7 +213,7 @@ func (e *Entity) GetIP() (net.IP, bool) {
|
|||
|
||||
// Location
|
||||
|
||||
func (e *Entity) getLocation() {
|
||||
func (e *Entity) getLocation(ctx context.Context) {
|
||||
e.fetchLocationOnce.Do(func() {
|
||||
// need IP!
|
||||
if e.IP == nil {
|
||||
|
@ -204,7 +223,7 @@ func (e *Entity) getLocation() {
|
|||
// get location data
|
||||
loc, err := geoip.GetLocation(e.IP)
|
||||
if err != nil {
|
||||
log.Warningf("intel: failed to get location data for %s: %s", e.IP, err)
|
||||
log.Tracer(ctx).Warningf("intel: failed to get location data for %s: %s", e.IP, err)
|
||||
return
|
||||
}
|
||||
e.location = loc
|
||||
|
@ -214,8 +233,8 @@ func (e *Entity) getLocation() {
|
|||
}
|
||||
|
||||
// GetLocation returns the raw location data and whether it is set.
|
||||
func (e *Entity) GetLocation() (*geoip.Location, bool) {
|
||||
e.getLocation()
|
||||
func (e *Entity) GetLocation(ctx context.Context) (*geoip.Location, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.location == nil {
|
||||
return nil, false
|
||||
|
@ -224,8 +243,8 @@ func (e *Entity) GetLocation() (*geoip.Location, bool) {
|
|||
}
|
||||
|
||||
// GetCountry returns the two letter ISO country code and whether it is set.
|
||||
func (e *Entity) GetCountry() (string, bool) {
|
||||
e.getLocation()
|
||||
func (e *Entity) GetCountry(ctx context.Context) (string, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.Country == "" {
|
||||
return "", false
|
||||
|
@ -234,8 +253,8 @@ func (e *Entity) GetCountry() (string, bool) {
|
|||
}
|
||||
|
||||
// GetASN returns the AS number and whether it is set.
|
||||
func (e *Entity) GetASN() (uint, bool) {
|
||||
e.getLocation()
|
||||
func (e *Entity) GetASN(ctx context.Context) (uint, bool) {
|
||||
e.getLocation(ctx)
|
||||
|
||||
if e.ASN == 0 {
|
||||
return 0, false
|
||||
|
@ -244,11 +263,11 @@ func (e *Entity) GetASN() (uint, bool) {
|
|||
}
|
||||
|
||||
// Lists
|
||||
func (e *Entity) getLists() {
|
||||
e.getDomainLists()
|
||||
e.getASNLists()
|
||||
e.getIPLists()
|
||||
e.getCountryLists()
|
||||
func (e *Entity) getLists(ctx context.Context) {
|
||||
e.getDomainLists(ctx)
|
||||
e.getASNLists(ctx)
|
||||
e.getIPLists(ctx)
|
||||
e.getCountryLists(ctx)
|
||||
}
|
||||
|
||||
func (e *Entity) mergeList(key string, list []string) {
|
||||
|
@ -263,12 +282,12 @@ func (e *Entity) mergeList(key string, list []string) {
|
|||
e.ListOccurences[key] = mergeStringList(e.ListOccurences[key], list)
|
||||
}
|
||||
|
||||
func (e *Entity) getDomainLists() {
|
||||
func (e *Entity) getDomainLists(ctx context.Context) {
|
||||
if e.domainListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
domain, ok := e.GetDomain()
|
||||
domain, ok := e.GetDomain(ctx, false /* mayUseReverseDomain */)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -277,7 +296,7 @@ func (e *Entity) getDomainLists() {
|
|||
var domainsToInspect = []string{domain}
|
||||
|
||||
if e.checkCNAMEs {
|
||||
log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)
|
||||
log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)
|
||||
domainsToInspect = append(domainsToInspect, e.CNAME...)
|
||||
}
|
||||
|
||||
|
@ -294,10 +313,10 @@ func (e *Entity) getDomainLists() {
|
|||
domains = makeDistinct(domains)
|
||||
|
||||
for _, d := range domains {
|
||||
log.Tracef("intel: loading domain list for %s", d)
|
||||
log.Tracer(ctx).Tracef("intel: loading domain list for %s", d)
|
||||
list, err := filterlists.LookupDomain(d)
|
||||
if err != nil {
|
||||
log.Errorf("intel: failed to get domain blocklists for %s: %s", d, err)
|
||||
log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err)
|
||||
e.loadDomainListOnce = sync.Once{}
|
||||
return
|
||||
}
|
||||
|
@ -334,22 +353,22 @@ func splitDomain(domain string) []string {
|
|||
return domains
|
||||
}
|
||||
|
||||
func (e *Entity) getASNLists() {
|
||||
func (e *Entity) getASNLists(ctx context.Context) {
|
||||
if e.asnListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
asn, ok := e.GetASN()
|
||||
asn, ok := e.GetASN(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("intel: loading ASN list for %d", asn)
|
||||
log.Tracer(ctx).Tracef("intel: loading ASN list for %d", asn)
|
||||
e.loadAsnListOnce.Do(func() {
|
||||
asnStr := fmt.Sprintf("%d", asn)
|
||||
list, err := filterlists.LookupASNString(asnStr)
|
||||
if err != nil {
|
||||
log.Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err)
|
||||
log.Tracer(ctx).Errorf("intel: failed to get ASN blocklist for %d: %s", asn, err)
|
||||
e.loadAsnListOnce = sync.Once{}
|
||||
return
|
||||
}
|
||||
|
@ -359,21 +378,21 @@ func (e *Entity) getASNLists() {
|
|||
})
|
||||
}
|
||||
|
||||
func (e *Entity) getCountryLists() {
|
||||
func (e *Entity) getCountryLists(ctx context.Context) {
|
||||
if e.countryListLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
country, ok := e.GetCountry()
|
||||
country, ok := e.GetCountry(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("intel: loading country list for %s", country)
|
||||
log.Tracer(ctx).Tracef("intel: loading country list for %s", country)
|
||||
e.loadCoutryListOnce.Do(func() {
|
||||
list, err := filterlists.LookupCountry(country)
|
||||
if err != nil {
|
||||
log.Errorf("intel: failed to load country blocklist for %s: %s", country, err)
|
||||
log.Tracer(ctx).Errorf("intel: failed to load country blocklist for %s: %s", country, err)
|
||||
e.loadCoutryListOnce = sync.Once{}
|
||||
return
|
||||
}
|
||||
|
@ -383,7 +402,7 @@ func (e *Entity) getCountryLists() {
|
|||
})
|
||||
}
|
||||
|
||||
func (e *Entity) getIPLists() {
|
||||
func (e *Entity) getIPLists(ctx context.Context) {
|
||||
if e.ipListLoaded {
|
||||
return
|
||||
}
|
||||
|
@ -402,12 +421,12 @@ func (e *Entity) getIPLists() {
|
|||
return
|
||||
}
|
||||
|
||||
log.Tracef("intel: loading IP list for %s", ip)
|
||||
log.Tracer(ctx).Tracef("intel: loading IP list for %s", ip)
|
||||
e.loadIPListOnce.Do(func() {
|
||||
list, err := filterlists.LookupIP(ip)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err)
|
||||
log.Tracer(ctx).Errorf("intel: failed to get IP blocklist for %s: %s", ip.String(), err)
|
||||
e.loadIPListOnce = sync.Once{}
|
||||
return
|
||||
}
|
||||
|
@ -418,8 +437,8 @@ func (e *Entity) getIPLists() {
|
|||
|
||||
// LoadLists searches all filterlists for all occurrences of
|
||||
// this entity.
|
||||
func (e *Entity) LoadLists() bool {
|
||||
e.getLists()
|
||||
func (e *Entity) LoadLists(ctx context.Context) bool {
|
||||
e.getLists(ctx)
|
||||
|
||||
return e.ListOccurences != nil
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
func getReader(ip net.IP) *maxminddb.Reader {
|
||||
|
@ -49,7 +48,5 @@ func GetLocation(ip net.IP) (record *Location, err error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
log.Tracef("geoip: record: %+v", record)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
|
|
@ -76,12 +76,15 @@ func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
|
|||
return handleRequest(ctx, w, query)
|
||||
})
|
||||
if err != nil {
|
||||
log.Warningf("intel: failed to handle dns request: %s", err)
|
||||
log.Warningf("nameserver: failed to handle dns request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO
|
||||
// Only process first question, that's how everyone does it.
|
||||
if len(request.Question) == 0 {
|
||||
return errors.New("missing question")
|
||||
}
|
||||
originalQuestion := request.Question[0]
|
||||
|
||||
// Check if we are handling a non-standard query name.
|
||||
|
@ -116,7 +119,12 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
|
|||
|
||||
// Setup quick reply function.
|
||||
reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error {
|
||||
return sendResponse(ctx, w, request, responder, rrProviders...)
|
||||
err := sendResponse(ctx, w, request, responder, rrProviders...)
|
||||
// Log error here instead of returning it in order to keep the context.
|
||||
if err != nil {
|
||||
tracer.Errorf("nameserver: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return with server failure if offline.
|
||||
|
|
|
@ -10,6 +10,11 @@ import (
|
|||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNilRR is returned when a parsed RR is nil.
|
||||
ErrNilRR = errors.New("is nil")
|
||||
)
|
||||
|
||||
// Responder defines the interface that any block/deny reason interface
|
||||
// may implement to support sending custom DNS responses for a given reason.
|
||||
// That is, if a reason context implements the Responder interface the
|
||||
|
@ -39,8 +44,9 @@ func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
|
|||
return rf(ctx, request)
|
||||
}
|
||||
|
||||
// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for
|
||||
// each A or AAAA question respectively.
|
||||
// ZeroIP is a ResponderFunc than replies with either 0.0.0.0 or :: for each A
|
||||
// or AAAA question respectively. If there is no A or AAAA question, it
|
||||
// defaults to replying with NXDomain.
|
||||
func ZeroIP(msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg)
|
||||
|
@ -52,15 +58,16 @@ func ZeroIP(msgs ...string) ResponderFunc {
|
|||
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
rr, err = dns.NewRR(question.Name + " 0 IN A 0.0.0.0")
|
||||
rr, err = dns.NewRR(question.Name + " 0 IN A 0.0.0.0")
|
||||
case dns.TypeAAAA:
|
||||
rr, err = dns.NewRR(question.Name + " 0 IN AAAA ::")
|
||||
rr, err = dns.NewRR(question.Name + " 0 IN AAAA ::")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Tracer(ctx).Errorf("nameserver: failed to create zero-ip response for %s: %s", question.Name, err)
|
||||
hasErr = true
|
||||
} else {
|
||||
case rr != nil:
|
||||
reply.Answer = append(reply.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
@ -81,6 +88,7 @@ func ZeroIP(msgs ...string) ResponderFunc {
|
|||
}
|
||||
|
||||
// Localhost is a ResponderFunc than replies with localhost IP addresses.
|
||||
// If there is no A or AAAA question, it defaults to replying with NXDomain.
|
||||
func Localhost(msgs ...string) ResponderFunc {
|
||||
return func(ctx context.Context, request *dns.Msg) *dns.Msg {
|
||||
reply := new(dns.Msg)
|
||||
|
@ -97,10 +105,11 @@ func Localhost(msgs ...string) ResponderFunc {
|
|||
rr, err = dns.NewRR("localhost. 0 IN AAAA ::1")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Tracer(ctx).Errorf("nameserver: failed to create localhost response for %s: %s", question.Name, err)
|
||||
hasErr = true
|
||||
} else {
|
||||
case rr != nil:
|
||||
reply.Answer = append(reply.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
@ -159,7 +168,7 @@ func MakeMessageRecord(level log.Severity, msg string) (dns.RR, error) { //nolin
|
|||
return nil, err
|
||||
}
|
||||
if rr == nil {
|
||||
return nil, errors.New("record is nil")
|
||||
return nil, ErrNilRR
|
||||
}
|
||||
return rr, nil
|
||||
}
|
||||
|
|
|
@ -1,236 +0,0 @@
|
|||
package only
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portmaster/netenv"
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
"github.com/safing/portmaster/resolver"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
dnsServer *dns.Server
|
||||
mtDNSRequest = "dns request"
|
||||
|
||||
listenAddress = "127.0.0.1:53"
|
||||
ipv4Localhost = net.IPv4(127, 0, 0, 1)
|
||||
localhostRRs []dns.RR
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("nameserver", initLocalhostRRs, start, stop, "core", "resolver", "network", "netenv")
|
||||
}
|
||||
|
||||
func initLocalhostRRs() error {
|
||||
localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
localhostIPv6, err := dns.NewRR("localhost. 17 IN AAAA ::1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
localhostRRs = []dns.RR{localhostIPv4, localhostIPv6}
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"}
|
||||
dns.HandleFunc(".", handleRequestAsMicroTask)
|
||||
|
||||
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
|
||||
err := dnsServer.ListenAndServe()
|
||||
if err != nil {
|
||||
// check if we are shutting down
|
||||
if module.IsStopping() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
if dnsServer != nil {
|
||||
return dnsServer.Shutdown()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(query, dns.RcodeNameError)
|
||||
_ = w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(query, dns.RcodeServerFailure)
|
||||
_ = w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) {
|
||||
err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error {
|
||||
return handleRequest(ctx, w, query)
|
||||
})
|
||||
if err != nil {
|
||||
log.Warningf("nameserver: failed to handle dns request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) error {
|
||||
// return with server failure if offline
|
||||
if netenv.GetOnlineStatus() == netenv.StatusOffline {
|
||||
returnServerFailure(w, query)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only process first question, that's how everyone does it.
|
||||
question := query.Question[0]
|
||||
q := &resolver.Query{
|
||||
FQDN: question.Name,
|
||||
QType: dns.Type(question.Qtype),
|
||||
}
|
||||
|
||||
// check class
|
||||
if question.Qclass != dns.ClassINET {
|
||||
// we only serve IN records, return nxdomain
|
||||
returnNXDomain(w, query)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle request for localhost
|
||||
if strings.HasSuffix(q.FQDN, "localhost.") {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(query)
|
||||
m.Answer = localhostRRs
|
||||
_ = w.WriteMsg(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// get addresses
|
||||
remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType)
|
||||
return nil
|
||||
}
|
||||
if !remoteAddr.IP.Equal(ipv4Localhost) {
|
||||
// if request is not coming from 127.0.0.1, check if it's really local
|
||||
|
||||
localAddr, ok := w.RemoteAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", q.FQDN, q.QType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ignore external request
|
||||
if !remoteAddr.IP.Equal(localAddr.IP) {
|
||||
log.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// check if valid domain name
|
||||
if !netutils.IsValidFqdn(q.FQDN) {
|
||||
log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN)
|
||||
returnNXDomain(w, query)
|
||||
return nil
|
||||
}
|
||||
|
||||
// start tracer
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)
|
||||
|
||||
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
|
||||
|
||||
// get intel and RRs
|
||||
rrCache, err := resolver.Resolve(ctx, q)
|
||||
if err != nil {
|
||||
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
|
||||
tracer.Warningf("nameserver: request for %s%s: %s", q.FQDN, q.QType, err)
|
||||
returnNXDomain(w, query)
|
||||
return nil
|
||||
}
|
||||
|
||||
// save IP addresses to IPInfo
|
||||
cnames := make(map[string]string)
|
||||
ips := make(map[string]struct{})
|
||||
|
||||
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
|
||||
switch v := rr.(type) {
|
||||
case *dns.CNAME:
|
||||
cnames[v.Hdr.Name] = v.Target
|
||||
|
||||
case *dns.A:
|
||||
ips[v.A.String()] = struct{}{}
|
||||
|
||||
case *dns.AAAA:
|
||||
ips[v.AAAA.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for ip := range ips {
|
||||
record := resolver.ResolvedDomain{
|
||||
Domain: q.FQDN,
|
||||
}
|
||||
|
||||
// resolve all CNAMEs in the correct order.
|
||||
var domain = q.FQDN
|
||||
for {
|
||||
nextDomain, isCNAME := cnames[domain]
|
||||
if !isCNAME {
|
||||
break
|
||||
}
|
||||
|
||||
record.CNAMEs = append(record.CNAMEs, nextDomain)
|
||||
domain = nextDomain
|
||||
}
|
||||
|
||||
// get the existing IP info or create a new one
|
||||
var save bool
|
||||
info, err := resolver.GetIPInfo(ip)
|
||||
if err != nil {
|
||||
if err != database.ErrNotFound {
|
||||
log.Errorf("nameserver: failed to search for IP info record: %s", err)
|
||||
}
|
||||
|
||||
info = &resolver.IPInfo{
|
||||
IP: ip,
|
||||
}
|
||||
save = true
|
||||
}
|
||||
|
||||
// and the new resolved domain record and save
|
||||
if new := info.AddDomain(record); new {
|
||||
save = true
|
||||
}
|
||||
if save {
|
||||
if err := info.Save(); err != nil {
|
||||
log.Errorf("nameserver: failed to save IP info record: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reply to query
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(query)
|
||||
m.Answer = rrCache.Answer
|
||||
m.Ns = rrCache.Ns
|
||||
m.Extra = rrCache.Extra
|
||||
_ = w.WriteMsg(m)
|
||||
tracer.Debugf("nameserver: returning response %s%s", q.FQDN, q.QType)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -36,19 +36,19 @@ func sendResponse(
|
|||
}
|
||||
|
||||
// Write reply.
|
||||
if err := writeDNSResponse(w, reply); err != nil {
|
||||
return fmt.Errorf("nameserver: failed to send response: %w", err)
|
||||
if err := writeDNSResponse(ctx, w, reply); err != nil {
|
||||
return fmt.Errorf("failed to send response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) {
|
||||
func writeDNSResponse(ctx context.Context, w dns.ResponseWriter, m *dns.Msg) (err error) {
|
||||
defer func() {
|
||||
// recover from panic
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = fmt.Errorf("panic: %s", panicErr)
|
||||
log.Warningf("nameserver: panic caused by this msg: %#v", m)
|
||||
log.Tracer(ctx).Debugf("nameserver: panic caused by this msg: %#v", m)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -56,10 +56,11 @@ func writeDNSResponse(w dns.ResponseWriter, m *dns.Msg) (err error) {
|
|||
if err != nil {
|
||||
// If we receive an error we might have exceeded the message size with all
|
||||
// our extra information records. Retry again without the extra section.
|
||||
log.Tracer(ctx).Tracef("nameserver: retrying to write dns message without extra section, error was: %s", err)
|
||||
m.Extra = nil
|
||||
noExtraErr := w.WriteMsg(m)
|
||||
if noExtraErr == nil {
|
||||
log.Warningf("nameserver: failed to write dns message with extra section: %s", err)
|
||||
return fmt.Errorf("failed to write dns message without extra section: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
// Windows specific constants for the WSAIoctl interface.
|
||||
//nolint:golint,stylecheck
|
||||
const (
|
||||
SIO_RCVALL = syscall.IOC_IN | syscall.IOC_VENDOR | 1
|
||||
|
||||
|
|
|
@ -172,7 +172,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
|
|||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err)
|
||||
log.Tracer(ctx).Debugf("network: failed to find process of dns request for %s: %s", fqdn, err)
|
||||
proc = process.GetUnidentifiedProcess(ctx)
|
||||
}
|
||||
|
||||
|
@ -196,7 +196,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
|||
// get Process
|
||||
proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info())
|
||||
if err != nil {
|
||||
log.Debugf("network: failed to find process of packet %s: %s", pkt, err)
|
||||
log.Tracer(pkt.Ctx()).Debugf("network: failed to find process of packet %s: %s", pkt, err)
|
||||
proc = process.GetUnidentifiedProcess(pkt.Ctx())
|
||||
}
|
||||
|
||||
|
@ -224,6 +224,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
|||
Protocol: uint8(pkt.Info().Protocol),
|
||||
Port: pkt.Info().SrcPort,
|
||||
}
|
||||
entity.SetDstPort(pkt.Info().DstPort)
|
||||
|
||||
} else {
|
||||
|
||||
|
@ -233,11 +234,12 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
|||
Protocol: uint8(pkt.Info().Protocol),
|
||||
Port: pkt.Info().DstPort,
|
||||
}
|
||||
entity.SetDstPort(entity.Port)
|
||||
|
||||
// check if we can find a domain for that IP
|
||||
ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String())
|
||||
ipinfo, err := resolver.GetIPInfo(proc.LocalProfileKey, pkt.Info().Dst.String())
|
||||
if err == nil {
|
||||
lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain()
|
||||
lastResolvedDomain := ipinfo.MostRecentDomain()
|
||||
if lastResolvedDomain != nil {
|
||||
scope = lastResolvedDomain.Domain
|
||||
entity.Domain = lastResolvedDomain.Domain
|
||||
|
@ -299,9 +301,7 @@ func GetConnection(id string) (*Connection, bool) {
|
|||
|
||||
// AcceptWithContext accepts the connection.
|
||||
func (conn *Connection) AcceptWithContext(reason string, ctx interface{}) {
|
||||
if conn.SetVerdict(VerdictAccept, reason, ctx) {
|
||||
log.Infof("filter: granting connection %s, %s", conn, conn.Reason)
|
||||
} else {
|
||||
if !conn.SetVerdict(VerdictAccept, reason, ctx) {
|
||||
log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict)
|
||||
}
|
||||
}
|
||||
|
@ -313,9 +313,7 @@ func (conn *Connection) Accept(reason string) {
|
|||
|
||||
// BlockWithContext blocks the connection.
|
||||
func (conn *Connection) BlockWithContext(reason string, ctx interface{}) {
|
||||
if conn.SetVerdict(VerdictBlock, reason, ctx) {
|
||||
log.Infof("filter: blocking connection %s, %s", conn, conn.Reason)
|
||||
} else {
|
||||
if !conn.SetVerdict(VerdictBlock, reason, ctx) {
|
||||
log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict)
|
||||
}
|
||||
}
|
||||
|
@ -327,9 +325,7 @@ func (conn *Connection) Block(reason string) {
|
|||
|
||||
// DropWithContext drops the connection.
|
||||
func (conn *Connection) DropWithContext(reason string, ctx interface{}) {
|
||||
if conn.SetVerdict(VerdictDrop, reason, ctx) {
|
||||
log.Infof("filter: dropping connection %s, %s", conn, conn.Reason)
|
||||
} else {
|
||||
if !conn.SetVerdict(VerdictDrop, reason, ctx) {
|
||||
log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict)
|
||||
}
|
||||
}
|
||||
|
@ -355,9 +351,7 @@ func (conn *Connection) Deny(reason string) {
|
|||
|
||||
// FailedWithContext marks the connection with VerdictFailed and stores the reason.
|
||||
func (conn *Connection) FailedWithContext(reason string, ctx interface{}) {
|
||||
if conn.SetVerdict(VerdictFailed, reason, ctx) {
|
||||
log.Infof("filter: dropping connection %s because of an internal error: %s", conn, reason)
|
||||
} else {
|
||||
if !conn.SetVerdict(VerdictFailed, reason, ctx) {
|
||||
log.Warningf("filter: tried to drop %s due to error but current verdict is %s", conn, conn.Verdict)
|
||||
}
|
||||
}
|
||||
|
@ -495,6 +489,9 @@ func (conn *Connection) packetHandler() {
|
|||
} else {
|
||||
defaultFirewallHandler(conn, pkt)
|
||||
}
|
||||
// log verdict
|
||||
log.Tracer(pkt.Ctx()).Infof("filter: connection %s %s: %s", conn, conn.Verdict.Verb(), conn.Reason)
|
||||
|
||||
// save does not touch any changing data
|
||||
// must not be locked, will deadlock with cleaner functions
|
||||
if conn.saveWhenFinished {
|
||||
|
|
|
@ -22,8 +22,8 @@ const (
|
|||
|
||||
var (
|
||||
profileDB = database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package endpoints
|
||||
|
||||
import "github.com/safing/portmaster/intel"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/safing/portmaster/intel"
|
||||
)
|
||||
|
||||
// EndpointAny matches anything.
|
||||
type EndpointAny struct {
|
||||
|
@ -8,7 +12,7 @@ type EndpointAny struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointAny) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
func (ep *EndpointAny) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
return ep.match(ep, entity, "*", "matches")
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
@ -20,8 +21,8 @@ type EndpointASN struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointASN) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
asn, ok := entity.GetASN()
|
||||
func (ep *EndpointASN) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
asn, ok := entity.GetASN(ctx)
|
||||
if !ok {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
|
@ -19,8 +20,8 @@ type EndpointCountry struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointCountry) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
country, ok := entity.GetCountry()
|
||||
func (ep *EndpointCountry) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
country, ok := entity.GetCountry(ctx)
|
||||
if !ok {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
|
@ -62,19 +63,20 @@ func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult,
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointDomain) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
if entity.Domain == "" {
|
||||
func (ep *EndpointDomain) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
domain, ok := entity.GetDomain(ctx, true /* mayUseReverseDomain */)
|
||||
if !ok {
|
||||
return NoMatch, nil
|
||||
}
|
||||
|
||||
result, reason := ep.check(entity, entity.Domain)
|
||||
result, reason := ep.check(entity, domain)
|
||||
if result != NoMatch {
|
||||
return result, reason
|
||||
}
|
||||
|
||||
if entity.CNAMECheckEnabled() {
|
||||
for _, domain := range entity.CNAME {
|
||||
result, reason = ep.check(entity, domain)
|
||||
for _, cname := range entity.CNAME {
|
||||
result, reason = ep.check(entity, cname)
|
||||
if result == Denied {
|
||||
return result, reason
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/safing/portmaster/intel"
|
||||
|
@ -14,7 +15,7 @@ type EndpointIP struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointIP) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
func (ep *EndpointIP) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
if entity.IP == nil {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/safing/portmaster/intel"
|
||||
|
@ -14,7 +15,7 @@ type EndpointIPRange struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointIPRange) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
func (ep *EndpointIPRange) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
if entity.IP == nil {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/intel"
|
||||
|
@ -15,8 +16,8 @@ type EndpointLists struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointLists) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
if !entity.LoadLists() {
|
||||
func (ep *EndpointLists) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
if !entity.LoadLists(ctx) {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
|
@ -30,7 +31,7 @@ type EndpointScope struct {
|
|||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
func (ep *EndpointScope) Matches(_ context.Context, entity *intel.Entity) (EPResult, Reason) {
|
||||
if entity.IP == nil {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -11,7 +12,7 @@ import (
|
|||
|
||||
// Endpoint describes an Endpoint Matcher
|
||||
type Endpoint interface {
|
||||
Matches(entity *intel.Entity) (EPResult, Reason)
|
||||
Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason)
|
||||
String() string
|
||||
}
|
||||
|
||||
|
@ -69,11 +70,11 @@ func (ep *EndpointBase) matchesPPP(entity *intel.Entity) (result EPResult) {
|
|||
// only check if port is defined
|
||||
if ep.StartPort > 0 {
|
||||
// if port is unknown, return Undeterminable
|
||||
if entity.Port == 0 {
|
||||
if entity.DstPort() == 0 {
|
||||
return Undeterminable
|
||||
}
|
||||
// if port does not match, return NoMatch
|
||||
if entity.Port < ep.StartPort || entity.Port > ep.EndPort {
|
||||
if entity.DstPort() < ep.StartPort || entity.DstPort() > ep.EndPort {
|
||||
return NoMatch
|
||||
}
|
||||
}
|
||||
|
@ -219,10 +220,6 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocog
|
|||
if endpoint, err = parseTypeIPRange(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
// domain
|
||||
if endpoint, err = parseTypeDomain(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
// country
|
||||
if endpoint, err = parseTypeCountry(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
|
@ -239,6 +236,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocog
|
|||
if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
// domain
|
||||
if endpoint, err = parseTypeDomain(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`unknown endpoint definition: "%s"`, value)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
|
@ -63,10 +64,10 @@ func (e Endpoints) IsSet() bool {
|
|||
}
|
||||
|
||||
// Match checks whether the given entity matches any of the endpoint definitions in the list.
|
||||
func (e Endpoints) Match(entity *intel.Entity) (result EPResult, reason Reason) {
|
||||
func (e Endpoints) Match(ctx context.Context, entity *intel.Entity) (result EPResult, reason Reason) {
|
||||
for _, entry := range e {
|
||||
if entry != nil {
|
||||
if result, reason = entry.Matches(entity); result != NoMatch {
|
||||
if result, reason = entry.Matches(ctx, entity); result != NoMatch {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
@ -16,7 +17,9 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func testEndpointMatch(t *testing.T, ep Endpoint, entity *intel.Entity, expectedResult EPResult) {
|
||||
result, _ := ep.Matches(entity)
|
||||
entity.SetDstPort(entity.Port)
|
||||
|
||||
result, _ := ep.Matches(context.TODO(), entity)
|
||||
if result != expectedResult {
|
||||
t.Errorf(
|
||||
"line %d: unexpected result for endpoint %s and entity %+v: result=%s, expected=%s",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package profile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -221,10 +222,10 @@ func (lp *LayeredProfile) DefaultAction() uint8 {
|
|||
}
|
||||
|
||||
// MatchEndpoint checks if the given endpoint matches an entry in any of the profiles.
|
||||
func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
func (lp *LayeredProfile) MatchEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
for _, layer := range lp.layers {
|
||||
if layer.endpoints.IsSet() {
|
||||
result, reason := layer.endpoints.Match(entity)
|
||||
result, reason := layer.endpoints.Match(ctx, entity)
|
||||
if endpoints.IsDecision(result) {
|
||||
return result, reason
|
||||
}
|
||||
|
@ -233,16 +234,16 @@ func (lp *LayeredProfile) MatchEndpoint(entity *intel.Entity) (endpoints.EPResul
|
|||
|
||||
cfgLock.RLock()
|
||||
defer cfgLock.RUnlock()
|
||||
return cfgEndpoints.Match(entity)
|
||||
return cfgEndpoints.Match(ctx, entity)
|
||||
}
|
||||
|
||||
// MatchServiceEndpoint checks if the given endpoint of an inbound connection matches an entry in any of the profiles.
|
||||
func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
func (lp *LayeredProfile) MatchServiceEndpoint(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
entity.EnableReverseResolving()
|
||||
|
||||
for _, layer := range lp.layers {
|
||||
if layer.serviceEndpoints.IsSet() {
|
||||
result, reason := layer.serviceEndpoints.Match(entity)
|
||||
result, reason := layer.serviceEndpoints.Match(ctx, entity)
|
||||
if endpoints.IsDecision(result) {
|
||||
return result, reason
|
||||
}
|
||||
|
@ -251,19 +252,19 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (endpoints.
|
|||
|
||||
cfgLock.RLock()
|
||||
defer cfgLock.RUnlock()
|
||||
return cfgServiceEndpoints.Match(entity)
|
||||
return cfgServiceEndpoints.Match(ctx, entity)
|
||||
}
|
||||
|
||||
// MatchFilterLists matches the entity against the set of filter
|
||||
// lists.
|
||||
func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
entity.ResolveSubDomainLists(lp.FilterSubDomains())
|
||||
entity.EnableCNAMECheck(lp.FilterCNAMEs())
|
||||
func (lp *LayeredProfile) MatchFilterLists(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
|
||||
entity.ResolveSubDomainLists(ctx, lp.FilterSubDomains())
|
||||
entity.EnableCNAMECheck(ctx, lp.FilterCNAMEs())
|
||||
|
||||
for _, layer := range lp.layers {
|
||||
// search for the first layer that has filterListIDs set
|
||||
if len(layer.filterListIDs) > 0 {
|
||||
entity.LoadLists()
|
||||
entity.LoadLists(ctx)
|
||||
|
||||
if entity.MatchLists(layer.filterListIDs) {
|
||||
return endpoints.Denied, entity.ListBlockReason()
|
||||
|
@ -276,7 +277,7 @@ func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPRe
|
|||
cfgLock.RLock()
|
||||
defer cfgLock.RUnlock()
|
||||
if len(cfgFilterLists) > 0 {
|
||||
entity.LoadLists()
|
||||
entity.LoadLists(ctx)
|
||||
|
||||
if entity.MatchLists(cfgFilterLists) {
|
||||
return endpoints.Denied, entity.ListBlockReason()
|
||||
|
|
|
@ -7,12 +7,26 @@ import (
|
|||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// IPInfoProfileScopeGlobal is the profile scope used for unscoped IPInfo entries.
|
||||
IPInfoProfileScopeGlobal = "global"
|
||||
)
|
||||
|
||||
var (
|
||||
ipInfoDatabase = database.NewInterface(&database.Options{
|
||||
AlwaysSetRelativateExpiry: 86400, // 24 hours
|
||||
Local: true,
|
||||
Internal: true,
|
||||
|
||||
// Cache entries because new/updated entries will often be queries soon
|
||||
// after inserted.
|
||||
CacheSize: 256,
|
||||
|
||||
// We only use the cache database here, so we can delay and batch all our
|
||||
// writes. Also, no one else accesses these records, so we are fine using
|
||||
// this.
|
||||
DelayCachedWrites: "cache",
|
||||
})
|
||||
)
|
||||
|
||||
|
@ -25,6 +39,11 @@ type ResolvedDomain struct {
|
|||
// CNAMEs is a list of CNAMEs that have been resolved for
|
||||
// Domain.
|
||||
CNAMEs []string
|
||||
|
||||
// Expires holds the timestamp when this entry expires.
|
||||
// This does not mean that the entry may not be used anymore afterwards,
|
||||
// but that this is used to calcuate the TTL of the database record.
|
||||
Expires int64
|
||||
}
|
||||
|
||||
// String returns a string representation of ResolvedDomain including
|
||||
|
@ -54,29 +73,16 @@ func (rds ResolvedDomains) String() string {
|
|||
return strings.Join(domains, " or ")
|
||||
}
|
||||
|
||||
// MostRecentDomain returns the most recent domain.
|
||||
func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain {
|
||||
if len(rds) == 0 {
|
||||
return nil
|
||||
}
|
||||
// TODO(ppacher): we could also do that by using ResolvedAt()
|
||||
mostRecent := rds[len(rds)-1]
|
||||
return &mostRecent
|
||||
}
|
||||
|
||||
// IPInfo represents various information about an IP.
|
||||
type IPInfo struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
// IP holds the acutal IP address.
|
||||
// IP holds the actual IP address.
|
||||
IP string
|
||||
|
||||
// Domains holds a list of domains that have been
|
||||
// resolved to IP. This field is deprecated and should
|
||||
// be removed.
|
||||
// DEPRECATED: remove with alpha.
|
||||
Domains []string `json:"Domains,omitempty"`
|
||||
// ProfileID is used to scope this entry to a process group.
|
||||
ProfileID string
|
||||
|
||||
// ResolvedDomain is a slice of domains that
|
||||
// have been requested by various applications
|
||||
|
@ -84,35 +90,43 @@ type IPInfo struct {
|
|||
ResolvedDomains ResolvedDomains
|
||||
}
|
||||
|
||||
// AddDomain adds a new resolved domain to ipi.
|
||||
func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool {
|
||||
for idx, d := range ipi.ResolvedDomains {
|
||||
if d.Domain == resolved.Domain {
|
||||
if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) {
|
||||
return false
|
||||
}
|
||||
// AddDomain adds a new resolved domain to IPInfo.
|
||||
func (info *IPInfo) AddDomain(resolved ResolvedDomain) {
|
||||
info.Lock()
|
||||
defer info.Unlock()
|
||||
|
||||
// we have a different CNAME chain now, remove the previous
|
||||
// entry and add it at the end.
|
||||
ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...)
|
||||
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
|
||||
return true
|
||||
// Delete old for the same domain.
|
||||
for idx, d := range info.ResolvedDomains {
|
||||
if d.Domain == resolved.Domain {
|
||||
info.ResolvedDomains = append(info.ResolvedDomains[:idx], info.ResolvedDomains[idx+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
|
||||
return true
|
||||
// Add new entry to the end.
|
||||
info.ResolvedDomains = append(info.ResolvedDomains, resolved)
|
||||
}
|
||||
|
||||
func makeIPInfoKey(ip string) string {
|
||||
return fmt.Sprintf("cache:intel/ipInfo/%s", ip)
|
||||
// MostRecentDomain returns the most recent domain.
|
||||
func (info *IPInfo) MostRecentDomain() *ResolvedDomain {
|
||||
info.Lock()
|
||||
defer info.Unlock()
|
||||
|
||||
if len(info.ResolvedDomains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mostRecent := info.ResolvedDomains[len(info.ResolvedDomains)-1]
|
||||
return &mostRecent
|
||||
}
|
||||
|
||||
func makeIPInfoKey(profileID, ip string) string {
|
||||
return fmt.Sprintf("cache:intel/ipInfo/%s/%s", profileID, ip)
|
||||
}
|
||||
|
||||
// GetIPInfo gets an IPInfo record from the database.
|
||||
func GetIPInfo(ip string) (*IPInfo, error) {
|
||||
key := makeIPInfoKey(ip)
|
||||
|
||||
r, err := ipInfoDatabase.Get(key)
|
||||
func GetIPInfo(profileID, ip string) (*IPInfo, error) {
|
||||
r, err := ipInfoDatabase.Get(makeIPInfoKey(profileID, ip))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -126,18 +140,6 @@ func GetIPInfo(ip string) (*IPInfo, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Legacy support,
|
||||
// DEPRECATED: remove with alpha
|
||||
if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 {
|
||||
for _, d := range new.Domains {
|
||||
new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{
|
||||
Domain: d,
|
||||
// rest is empty...
|
||||
})
|
||||
}
|
||||
new.Domains = nil // clean up so we remove it from the database
|
||||
}
|
||||
|
||||
return new, nil
|
||||
}
|
||||
|
||||
|
@ -150,27 +152,38 @@ func GetIPInfo(ip string) (*IPInfo, error) {
|
|||
}
|
||||
|
||||
// Save saves the IPInfo record to the database.
|
||||
func (ipi *IPInfo) Save() error {
|
||||
ipi.Lock()
|
||||
if !ipi.KeyIsSet() {
|
||||
ipi.SetKey(makeIPInfoKey(ipi.IP))
|
||||
}
|
||||
ipi.Unlock()
|
||||
func (info *IPInfo) Save() error {
|
||||
info.Lock()
|
||||
|
||||
// Legacy support
|
||||
// Ensure we don't write new Domain fields into the
|
||||
// database.
|
||||
// DEPRECATED: remove with alpha
|
||||
if len(ipi.Domains) > 0 {
|
||||
ipi.Domains = nil
|
||||
// Set database key if not yet set already.
|
||||
if !info.KeyIsSet() {
|
||||
// Default to global scope if scope is unset.
|
||||
if info.ProfileID == "" {
|
||||
info.ProfileID = IPInfoProfileScopeGlobal
|
||||
}
|
||||
info.SetKey(makeIPInfoKey(info.ProfileID, info.IP))
|
||||
}
|
||||
|
||||
return ipInfoDatabase.Put(ipi)
|
||||
// Calculate and set cache expiry.
|
||||
var expires int64 = 86400 // Minimum TTL of one day.
|
||||
for _, rd := range info.ResolvedDomains {
|
||||
if rd.Expires > expires {
|
||||
expires = rd.Expires
|
||||
}
|
||||
}
|
||||
info.UpdateMeta()
|
||||
expires += 3600 // Add one hour to expiry as a buffer.
|
||||
info.Meta().SetAbsoluteExpiry(expires)
|
||||
|
||||
info.Unlock()
|
||||
|
||||
return ipInfoDatabase.Put(info)
|
||||
}
|
||||
|
||||
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
|
||||
func (ipi *IPInfo) String() string {
|
||||
ipi.Lock()
|
||||
defer ipi.Unlock()
|
||||
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.ResolvedDomains.String())
|
||||
func (info *IPInfo) String() string {
|
||||
info.Lock()
|
||||
defer info.Unlock()
|
||||
|
||||
return fmt.Sprintf("<IPInfo[%s] %s: %s>", info.Key(), info.IP, info.ResolvedDomains.String())
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestIPInfo(t *testing.T) {
|
|||
CNAMEs: []string{"example.com"},
|
||||
}
|
||||
|
||||
ipi := &IPInfo{
|
||||
info := &IPInfo{
|
||||
IP: "1.2.3.4",
|
||||
ResolvedDomains: ResolvedDomains{
|
||||
example,
|
||||
|
@ -27,22 +27,18 @@ func TestIPInfo(t *testing.T) {
|
|||
Domain: "sub2.example.com",
|
||||
CNAMEs: []string{"sub1.example.com", "example.com"},
|
||||
}
|
||||
added := ipi.AddDomain(sub2Example)
|
||||
|
||||
assert.True(t, added)
|
||||
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
|
||||
info.AddDomain(sub2Example)
|
||||
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains)
|
||||
|
||||
// try again, should do nothing now
|
||||
added = ipi.AddDomain(sub2Example)
|
||||
assert.False(t, added)
|
||||
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
|
||||
info.AddDomain(sub2Example)
|
||||
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, info.ResolvedDomains)
|
||||
|
||||
subOverWrite := ResolvedDomain{
|
||||
Domain: "sub1.example.com",
|
||||
CNAMEs: []string{}, // now without CNAMEs
|
||||
}
|
||||
|
||||
added = ipi.AddDomain(subOverWrite)
|
||||
assert.True(t, added)
|
||||
assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, ipi.ResolvedDomains)
|
||||
info.AddDomain(subOverWrite)
|
||||
assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, info.ResolvedDomains)
|
||||
}
|
||||
|
|
|
@ -93,6 +93,9 @@ func start() error {
|
|||
listenToMDNS,
|
||||
)
|
||||
|
||||
module.StartServiceWorker("name record delayed cache writer", 0, recordDatabase.DelayedCacheWriter)
|
||||
module.StartServiceWorker("ip info delayed cache writer", 0, ipInfoDatabase.DelayedCacheWriter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -12,10 +12,24 @@ import (
|
|||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// databaseOvertime defines how much longer than the TTL name records are
|
||||
// cached in the database.
|
||||
databaseOvertime = 86400 * 14 // two weeks
|
||||
)
|
||||
|
||||
var (
|
||||
recordDatabase = database.NewInterface(&database.Options{
|
||||
AlwaysSetRelativateExpiry: 2592000, // 30 days
|
||||
CacheSize: 256,
|
||||
Local: true,
|
||||
Internal: true,
|
||||
|
||||
// Cache entries because application often resolve domains multiple times.
|
||||
CacheSize: 256,
|
||||
|
||||
// We only use the cache database here, so we can delay and batch all our
|
||||
// writes. Also, no one else accesses these records, so we are fine using
|
||||
// this.
|
||||
DelayCachedWrites: "cache",
|
||||
})
|
||||
|
||||
nameRecordsKeyPrefix = "cache:intel/nameRecord/"
|
||||
|
@ -32,7 +46,7 @@ type NameRecord struct {
|
|||
Answer []string
|
||||
Ns []string
|
||||
Extra []string
|
||||
TTL int64
|
||||
Expires int64
|
||||
|
||||
Server string
|
||||
ServerScope int8
|
||||
|
@ -84,6 +98,9 @@ func (rec *NameRecord) Save() error {
|
|||
}
|
||||
|
||||
rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question))
|
||||
rec.UpdateMeta()
|
||||
rec.Meta().SetAbsoluteExpiry(rec.Expires + databaseOvertime)
|
||||
|
||||
return recordDatabase.PutNew(rec)
|
||||
}
|
||||
|
||||
|
|
|
@ -220,19 +220,19 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
log.Tracer(ctx).Tracef(
|
||||
"resolver: cache for %s will expire in %s, refreshing async now",
|
||||
q.ID(),
|
||||
time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second),
|
||||
time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second),
|
||||
)
|
||||
|
||||
// resolve async
|
||||
module.StartWorker("resolve async", func(ctx context.Context) error {
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
module.StartWorker("resolve async", func(asyncCtx context.Context) error {
|
||||
tracingCtx, tracer := log.AddTracer(asyncCtx)
|
||||
defer tracer.Submit()
|
||||
tracer.Debugf("resolver: resolving %s async", q.ID())
|
||||
_, err := resolveAndCache(ctx, q, nil)
|
||||
tracer.Tracef("resolver: resolving %s async", q.ID())
|
||||
_, err := resolveAndCache(tracingCtx, q, nil)
|
||||
if err != nil {
|
||||
tracer.Warningf("resolver: async query for %s failed: %s", q.ID(), err)
|
||||
} else {
|
||||
tracer.Debugf("resolver: async query for %s succeeded", q.ID())
|
||||
tracer.Infof("resolver: async query for %s succeeded", q.ID())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
@ -242,7 +242,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
|
||||
log.Tracer(ctx).Tracef(
|
||||
"resolver: using cached RR (expires in %s)",
|
||||
time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second),
|
||||
time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second),
|
||||
)
|
||||
return rrCache
|
||||
}
|
||||
|
|
|
@ -25,10 +25,10 @@ type RRCache struct {
|
|||
RCode int
|
||||
|
||||
// Response Content
|
||||
Answer []dns.RR
|
||||
Ns []dns.RR
|
||||
Extra []dns.RR
|
||||
TTL int64
|
||||
Answer []dns.RR
|
||||
Ns []dns.RR
|
||||
Extra []dns.RR
|
||||
Expires int64
|
||||
|
||||
// Source Information
|
||||
Server string
|
||||
|
@ -54,12 +54,12 @@ func (rrCache *RRCache) ID() string {
|
|||
|
||||
// Expired returns whether the record has expired.
|
||||
func (rrCache *RRCache) Expired() bool {
|
||||
return rrCache.TTL <= time.Now().Unix()
|
||||
return rrCache.Expires <= time.Now().Unix()
|
||||
}
|
||||
|
||||
// ExpiresSoon returns whether the record will expire soon and should already be refreshed.
|
||||
func (rrCache *RRCache) ExpiresSoon() bool {
|
||||
return rrCache.TTL <= time.Now().Unix()+refreshTTL
|
||||
return rrCache.Expires <= time.Now().Unix()+refreshTTL
|
||||
}
|
||||
|
||||
// Clean sets all TTLs to 17 and sets cache expiry with specified minimum.
|
||||
|
@ -99,7 +99,7 @@ func (rrCache *RRCache) Clean(minExpires uint32) {
|
|||
}
|
||||
|
||||
// log.Tracef("lowest TTL is %d", lowestTTL)
|
||||
rrCache.TTL = time.Now().Unix() + int64(lowestTTL)
|
||||
rrCache.Expires = time.Now().Unix() + int64(lowestTTL)
|
||||
}
|
||||
|
||||
// ExportAllARecords return of a list of all A and AAAA IP addresses.
|
||||
|
@ -131,7 +131,7 @@ func (rrCache *RRCache) ToNameRecord() *NameRecord {
|
|||
Domain: rrCache.Domain,
|
||||
Question: rrCache.Question.String(),
|
||||
RCode: rrCache.RCode,
|
||||
TTL: rrCache.TTL,
|
||||
Expires: rrCache.Expires,
|
||||
Server: rrCache.Server,
|
||||
ServerScope: rrCache.ServerScope,
|
||||
ServerInfo: rrCache.ServerInfo,
|
||||
|
@ -188,7 +188,7 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) {
|
|||
}
|
||||
|
||||
rrCache.RCode = nameRecord.RCode
|
||||
rrCache.TTL = nameRecord.TTL
|
||||
rrCache.Expires = nameRecord.Expires
|
||||
for _, entry := range nameRecord.Answer {
|
||||
rrCache.Answer = parseRR(rrCache.Answer, entry)
|
||||
}
|
||||
|
@ -249,10 +249,10 @@ func (rrCache *RRCache) ShallowCopy() *RRCache {
|
|||
Question: rrCache.Question,
|
||||
RCode: rrCache.RCode,
|
||||
|
||||
Answer: rrCache.Answer,
|
||||
Ns: rrCache.Ns,
|
||||
Extra: rrCache.Extra,
|
||||
TTL: rrCache.TTL,
|
||||
Answer: rrCache.Answer,
|
||||
Ns: rrCache.Ns,
|
||||
Extra: rrCache.Extra,
|
||||
Expires: rrCache.Expires,
|
||||
|
||||
Server: rrCache.Server,
|
||||
ServerScope: rrCache.ServerScope,
|
||||
|
@ -310,9 +310,9 @@ func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra
|
|||
|
||||
// Add expiry and cache information.
|
||||
if rrCache.Expired() {
|
||||
extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.TTL, 0)).Round(time.Second)))
|
||||
extra = addExtra(ctx, extra, fmt.Sprintf("record expired since %s", time.Since(time.Unix(rrCache.Expires, 0)).Round(time.Second)))
|
||||
} else {
|
||||
extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second)))
|
||||
extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.Expires, 0)).Round(time.Second)))
|
||||
}
|
||||
if rrCache.RequestingNew {
|
||||
extra = addExtra(ctx, extra, "async request to refresh the cache has been started")
|
||||
|
|
Loading…
Add table
Reference in a new issue