Add CNAME blocking support

This commit is contained in:
Patrick Pacher 2020-04-17 15:55:52 +02:00
parent 1f90c05654
commit bffe4a9eaf
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
9 changed files with 445 additions and 136 deletions

View file

@ -32,14 +32,32 @@ type Entity struct {
asnListLoaded bool
reverseResolveEnabled bool
resolveSubDomainLists bool
checkCNAMEs bool
// Protocol is the protcol number used by the connection.
Protocol uint8
Port uint16
Domain string
IP net.IP
Country string
ASN uint
// Port is the destination port of the connection
Port uint16
// Domain is the target domain of the connection.
Domain string
// CNAME is a list of domain names that have been
// resolved for Domain.
CNAME []string
// IP is the IP address of the connection. If domain is
// set, IP has been resolved by following all CNAMEs.
IP net.IP
// Country holds the country the IP address (ASN) is
// located in.
Country string
// ASN holds the autonomous system number of the IP.
ASN uint
location *geoip.Location
Lists []string
@ -79,6 +97,7 @@ func (e *Entity) ResetLists() {
e.countryListLoaded = false
e.asnListLoaded = false
e.resolveSubDomainLists = false
e.checkCNAMEs = false
e.loadDomainListOnce = sync.Once{}
e.loadIPListOnce = sync.Once{}
e.loadCoutryListOnce = sync.Once{}
@ -94,6 +113,21 @@ func (e *Entity) ResolveSubDomainLists(enabled bool) {
e.resolveSubDomainLists = enabled
}
// EnableCNAMECheck enalbes or disables list lookups for
// entity CNAMEs.
func (e *Entity) EnableCNAMECheck(enabled bool) {
if e.domainListLoaded {
log.Warningf("intel/filterlists: tried to change CNAME resolving for %s but lists are already fetched", e.Domain)
}
e.checkCNAMEs = enabled
}
// CNAMECheckEnabled returns true if the entities CNAMEs should
// also be checked.
func (e *Entity) CNAMECheckEnabled() bool {
return e.checkCNAMEs
}
// Domain and IP
// EnableReverseResolving enables reverse resolving the domain from the IP on demand.
@ -220,10 +254,23 @@ func (e *Entity) getDomainLists() {
}
e.loadDomainListOnce.Do(func() {
var domains = []string{domain}
var domainsToInspect = []string{domain}
if e.checkCNAMEs {
log.Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)
domainsToInspect = append(domainsToInspect, e.CNAME...)
}
var domains []string
if e.resolveSubDomainLists {
domains = splitDomain(domain)
log.Tracef("intel: subdomain list resolving is enabled, checking %v", domains)
for _, domain := range domainsToInspect {
subdomains := splitDomain(domain)
domains = append(domains, subdomains...)
log.Tracef("intel: subdomain list resolving is enabled: %s => %v", domains, subdomains)
}
} else {
domains = domainsToInspect
}
for _, d := range domains {

View file

@ -2,9 +2,11 @@ package nameserver
import (
"context"
"fmt"
"net"
"strings"
"github.com/safing/portbase/database"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/log"
@ -14,6 +16,7 @@ import (
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/profile/endpoints"
"github.com/safing/portmaster/resolver"
"github.com/miekg/dns"
@ -87,9 +90,11 @@ func stop() error {
return nil
}
func returnNXDomain(w dns.ResponseWriter, query *dns.Msg) {
func returnNXDomain(w dns.ResponseWriter, query *dns.Msg, reason string) {
m := new(dns.Msg)
m.SetRcode(query, dns.RcodeNameError)
rr, _ := dns.NewRR("portmaster.block.reason. 0 IN TXT " + fmt.Sprintf("%q", reason))
m.Extra = []dns.RR{rr}
_ = w.WriteMsg(m)
}
@ -126,7 +131,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
if question.Qclass != dns.ClassINET {
// we only serve IN records, return nxdomain
log.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass)
returnNXDomain(w, query)
returnNXDomain(w, query, "wrong type")
return nil
}
@ -166,7 +171,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// check if valid domain name
if !netutils.IsValidFqdn(q.FQDN) {
log.Debugf("nameserver: domain name %s is invalid, returning nxdomain", q.FQDN)
returnNXDomain(w, query)
returnNXDomain(w, query, "invalid domain")
return nil
}
@ -177,7 +182,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// get connection
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, remoteAddr.IP, uint16(remoteAddr.Port))
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
// once we decided on the connection we might need to save it to the database
// so we defer that check right now.
@ -202,7 +207,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// TODO: this has been obsoleted due to special profiles
if conn.Process().Profile() == nil {
tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn)
returnNXDomain(w, query)
returnNXDomain(w, query, "unknown process")
// NOTE(ppacher): saving unknown process connection might end up in a lot of
// processes. Consider disabling that via config.
conn.Failed("Unknown process")
@ -218,7 +223,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms)
if lms < 10 {
tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms)
returnNXDomain(w, query)
returnNXDomain(w, query, "lms")
conn.Block("Possible data tunnel")
return nil
}
@ -229,7 +234,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
switch conn.Verdict {
case network.VerdictBlock:
tracer.Infof("nameserver: %s blocked, returning nxdomain", conn)
returnNXDomain(w, query)
returnNXDomain(w, query, conn.Reason)
return nil
case network.VerdictDrop, network.VerdictFailed:
tracer.Infof("nameserver: %s dropped, not replying", conn)
@ -241,7 +246,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
if err != nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err)
returnNXDomain(w, query)
returnNXDomain(w, query, conn.Reason)
conn.Failed("failed to resolve: " + err.Error())
return nil
}
@ -251,41 +256,92 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// TODO: FilterDNSResponse also sets a connection verdict
if rrCache == nil {
tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn)
returnNXDomain(w, query)
returnNXDomain(w, query, conn.Reason)
conn.Block("DNS response filtered")
return nil
}
// save IP addresses to IPInfo
cnames := make(map[string]string)
ips := make(map[string]struct{})
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.CNAME:
cnames[v.Hdr.Name] = v.Target
case *dns.A:
ipInfo, err := resolver.GetIPInfo(v.A.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.A.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
}
ips[v.A.String()] = struct{}{}
case *dns.AAAA:
ipInfo, err := resolver.GetIPInfo(v.AAAA.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.AAAA.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
ips[v.AAAA.String()] = struct{}{}
}
}
for ip := range ips {
record := resolver.ResolvedDomain{
Domain: q.FQDN,
}
// resolve all CNAMEs in the correct order.
var domain = q.FQDN
for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}
record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}
// update the entity to include the cnames
conn.Entity.CNAME = record.CNAMEs
// get the existing IP info or create a new one
var save bool
info, err := resolver.GetIPInfo(ip)
if err != nil {
if err != database.ErrNotFound {
log.Errorf("nameserver: failed to search for IP info record: %s", err)
}
info = &resolver.IPInfo{
IP: ip,
}
save = true
}
// and the new resolved domain record and save
if new := info.AddDomain(record); new {
save = true
}
if save {
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
}
// if we have CNAMEs and the profile is configured to filter them
// we need to re-check the lists and endpoints here
if conn.Process().Profile().FilterCNAMEs() {
conn.Entity.ResetLists()
conn.Entity.EnableCNAMECheck(true)
result, reason := conn.Process().Profile().MatchEndpoint(conn.Entity)
if result == endpoints.Denied {
conn.Block("endpoint in blocklist: " + reason)
returnNXDomain(w, query, conn.Reason)
return nil
}
if result == endpoints.NoMatch {
result, reason = conn.Process().Profile().MatchFilterLists(conn.Entity)
if result == endpoints.Denied {
conn.Block("endpoint in filterlists: " + reason)
returnNXDomain(w, query, conn.Reason)
return nil
}
}
}

View file

@ -5,6 +5,7 @@ import (
"net"
"strings"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/netenv"
@ -164,35 +165,60 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
}
// save IP addresses to IPInfo
cnames := make(map[string]string)
ips := make(map[string]struct{})
for _, rr := range append(rrCache.Answer, rrCache.Extra...) {
switch v := rr.(type) {
case *dns.CNAME:
cnames[v.Hdr.Name] = v.Target
case *dns.A:
ipInfo, err := resolver.GetIPInfo(v.A.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.A.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
}
ips[v.A.String()] = struct{}{}
case *dns.AAAA:
ipInfo, err := resolver.GetIPInfo(v.AAAA.String())
if err != nil {
ipInfo = &resolver.IPInfo{
IP: v.AAAA.String(),
Domains: []string{q.FQDN},
}
_ = ipInfo.Save()
} else {
added := ipInfo.AddDomain(q.FQDN)
if added {
_ = ipInfo.Save()
}
ips[v.AAAA.String()] = struct{}{}
}
}
for ip := range ips {
record := resolver.ResolvedDomain{
Domain: q.FQDN,
}
// resolve all CNAMEs in the correct order.
var domain = q.FQDN
for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}
record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}
// get the existing IP info or create a new one
var save bool
info, err := resolver.GetIPInfo(ip)
if err != nil {
if err != database.ErrNotFound {
log.Errorf("nameserver: failed to search for IP info record: %s", err)
}
info = &resolver.IPInfo{
IP: ip,
}
save = true
}
// and the new resolved domain record and save
if new := info.AddDomain(record); new {
save = true
}
if save {
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
}

View file

@ -54,9 +54,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
}
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, port uint16) *Connection {
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
// get Process
proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP)
proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP)
if err != nil {
log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err)
proc = process.GetUnidentifiedProcess(ctx)
@ -67,7 +67,8 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, po
Scope: fqdn,
Entity: (&intel.Entity{
Domain: fqdn,
}).Init(),
CNAME: cnames,
}),
process: proc,
Started: timestamp,
Ended: timestamp,
@ -104,7 +105,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
IP: pkt.Info().Src,
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().SrcPort,
}).Init()
})
} else {
@ -113,18 +114,21 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
IP: pkt.Info().Dst,
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().DstPort,
}).Init()
})
// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String())
if err == nil {
lastResolvedDomain := ipinfo.ResolvedDomains.MostRecentDomain()
if lastResolvedDomain != nil {
scope = lastResolvedDomain.Domain
entity.Domain = lastResolvedDomain.Domain
entity.CNAME = lastResolvedDomain.CNAMEs
removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain)
}
}
// outbound to domain
scope = ipinfo.Domains[0]
entity.Domain = scope
removeOpenDNSRequest(proc.Pid, scope)
} else {
if scope == "" {
// outbound direct (possibly P2P) connection
switch netutils.ClassifyIP(pkt.Info().Dst) {

View file

@ -30,6 +30,9 @@ var (
CfgOptionFilterSubDomainsKey = "filter/includeSubdomains"
cfgOptionFilterSubDomains config.IntOption // security level option
CfgOptionFilterCNAMEKey = "filter/includeCNAMEs"
cfgOptionFilterCNAME config.IntOption // security level option
CfgOptionBlockScopeLocalKey = "filter/blockLocal"
cfgOptionBlockScopeLocal config.IntOption // security level option
@ -180,6 +183,23 @@ Examples:
cfgOptionFilterLists = config.Concurrent.GetAsStringArray(CfgOptionFilterListKey, []string{})
cfgStringArrayOptions[CfgOptionFilterListKey] = cfgOptionFilterLists
// Include CNAMEs
err = config.Register(&config.Option{
Name: "Filter CNAMEs",
Key: CfgOptionFilterCNAMEKey,
Description: "Also filter requests where a CNAME would be blocked",
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: status.SecurityLevelsAll,
ValidationRegex: "^(7|6|4)$",
})
if err != nil {
return err
}
cfgOptionFilterCNAME = config.Concurrent.GetAsInt(CfgOptionFilterCNAMEKey, int64(status.SecurityLevelsAll))
cfgIntOptions[CfgOptionFilterCNAMEKey] = cfgOptionFilterCNAME
// Include subdomains
err = config.Register(&config.Option{
Name: "Filter SubDomains",
Key: CfgOptionFilterSubDomainsKey,

View file

@ -31,35 +31,77 @@ type EndpointDomain struct {
Reason string
}
func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, string) {
switch ep.MatchType {
case domainMatchTypeExact:
if domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeZone:
if domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
}
if strings.HasSuffix(domain, ep.DomainZone) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeSuffix:
if strings.HasSuffix(domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypePrefix:
if strings.HasPrefix(domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeContains:
if strings.Contains(domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
}
}
return NoMatch, ""
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointDomain) Matches(entity *intel.Entity) (result EPResult, reason string) {
if entity.Domain == "" {
return NoMatch, ""
}
switch ep.MatchType {
case domainMatchTypeExact:
if entity.Domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeZone:
if entity.Domain == ep.Domain {
return ep.matchesPPP(entity), ep.Reason
}
if strings.HasSuffix(entity.Domain, ep.DomainZone) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeSuffix:
if strings.HasSuffix(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypePrefix:
if strings.HasPrefix(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeContains:
if strings.Contains(entity.Domain, ep.Domain) {
return ep.matchesPPP(entity), ep.Reason
result, reason = ep.check(entity, entity.Domain)
if result != NoMatch {
return
}
if entity.CNAMECheckEnabled() {
for _, domain := range entity.CNAME {
switch ep.MatchType {
case domainMatchTypeExact:
if domain == ep.Domain {
result, reason = ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeZone:
if domain == ep.Domain {
result, reason = ep.matchesPPP(entity), ep.Reason
}
if strings.HasSuffix(domain, ep.DomainZone) {
result, reason = ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeSuffix:
if strings.HasSuffix(domain, ep.Domain) {
result, reason = ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypePrefix:
if strings.HasPrefix(domain, ep.Domain) {
result, reason = ep.matchesPPP(entity), ep.Reason
}
case domainMatchTypeContains:
if strings.Contains(domain, ep.Domain) {
result, reason = ep.matchesPPP(entity), ep.Reason
}
}
if result == Denied {
return result, reason
}
}
}

View file

@ -43,6 +43,7 @@ type LayeredProfile struct {
RemoveOutOfScopeDNS config.BoolOption
RemoveBlockedDNS config.BoolOption
FilterSubDomains config.BoolOption
FilterCNAMEs config.BoolOption
PreventBypassing config.BoolOption
}
@ -99,6 +100,10 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile {
CfgOptionFilterSubDomainsKey,
cfgOptionFilterSubDomains,
)
new.FilterCNAMEs = new.wrapSecurityLevelOption(
CfgOptionFilterCNAMEKey,
cfgOptionFilterCNAME,
)
new.PreventBypassing = new.wrapSecurityLevelOption(
CfgOptionPreventBypassingKey,
cfgOptionPreventBypassing,
@ -236,6 +241,7 @@ func (lp *LayeredProfile) MatchServiceEndpoint(entity *intel.Entity) (result end
// lists.
func (lp *LayeredProfile) MatchFilterLists(entity *intel.Entity) (endpoints.EPResult, string) {
entity.ResolveSubDomainLists(lp.FilterSubDomains())
entity.EnableCNAMECheck(lp.FilterCNAMEs())
lookupMap, hasLists := entity.GetListsMap()
if !hasLists {

View file

@ -16,13 +16,92 @@ var (
})
)
// ResolvedDomain holds a Domain name and a list of
// CNAMES that have been resolved.
type ResolvedDomain struct {
// Domain is the domain as requested by the application.
Domain string
// CNAMEs is a list of CNAMEs that have been resolved for
// Domain.
CNAMEs []string
}
// String returns a string representation of ResolvedDomain including
// the CNAME chain. It implements fmt.Stringer
func (resolved *ResolvedDomain) String() string {
ret := resolved.Domain
cnames := ""
if len(resolved.CNAMEs) > 0 {
cnames = " (-> " + strings.Join(resolved.CNAMEs, "->") + ")"
}
return ret + cnames
}
// ResolvedDomains is a helper type for operating on a slice
// of ResolvedDomain
type ResolvedDomains []ResolvedDomain
// String returns a string representation of all domains joined
// to a single string.
func (rds ResolvedDomains) String() string {
var domains []string
for _, n := range rds {
domains = append(domains, n.String())
}
return strings.Join(domains, " or ")
}
// MostRecentDomain returns the most recent domain.
func (rds ResolvedDomains) MostRecentDomain() *ResolvedDomain {
if len(rds) == 0 {
return nil
}
// TODO(ppacher): we could also do that by using ResolvedAt()
mostRecent := rds[len(rds)-1]
return &mostRecent
}
// IPInfo represents various information about an IP.
type IPInfo struct {
record.Base
sync.Mutex
IP string
Domains []string
// IP holds the acutal IP address.
IP string
// Domains holds a list of domains that have been
// resolved to IP. This field is deprecated and should
// be removed.
// DEPRECATED: remove with alpha.
Domains []string `json:"Domains,omitempty"`
// ResolvedDomain is a slice of domains that
// have been requested by various applications
// and have been resolved to IP.
ResolvedDomains ResolvedDomains
}
// AddDomain adds a new resolved domain to ipi.
func (ipi *IPInfo) AddDomain(resolved ResolvedDomain) bool {
for idx, d := range ipi.ResolvedDomains {
if d.Domain == resolved.Domain {
if utils.StringSliceEqual(d.CNAMEs, resolved.CNAMEs) {
return false
}
// we have a different CNAME chain now, remove the previous
// entry and add it at the end.
ipi.ResolvedDomains = append(ipi.ResolvedDomains[:idx], ipi.ResolvedDomains[idx+1:]...)
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
return true
}
}
ipi.ResolvedDomains = append(ipi.ResolvedDomains, resolved)
return true
}
func makeIPInfoKey(ip string) string {
@ -46,6 +125,19 @@ func GetIPInfo(ip string) (*IPInfo, error) {
if err != nil {
return nil, err
}
// Legacy support,
// DEPRECATED: remove with alpha
if len(new.Domains) > 0 && len(new.ResolvedDomains) == 0 {
for _, d := range new.Domains {
new.ResolvedDomains = append(new.ResolvedDomains, ResolvedDomain{
Domain: d,
// rest is empty...
})
}
new.Domains = nil // clean up so we remove it from the database
}
return new, nil
}
@ -57,17 +149,6 @@ func GetIPInfo(ip string) (*IPInfo, error) {
return new, nil
}
// AddDomain adds a domain to the list and reports back if it was added, or was already present.
func (ipi *IPInfo) AddDomain(domain string) (added bool) {
ipi.Lock()
defer ipi.Unlock()
if !utils.StringInSlice(ipi.Domains, domain) {
ipi.Domains = append([]string{domain}, ipi.Domains...)
return true
}
return false
}
// Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error {
ipi.Lock()
@ -75,17 +156,21 @@ func (ipi *IPInfo) Save() error {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
return ipInfoDatabase.Put(ipi)
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) FmtDomains() string {
return strings.Join(ipi.Domains, " or ")
// Legacy support
// Ensure we don't write new Domain fields into the
// database.
// DEPRECATED: remove with alpha
if len(ipi.Domains) > 0 {
ipi.Domains = nil
}
return ipInfoDatabase.Put(ipi)
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string {
ipi.Lock()
defer ipi.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.FmtDomains())
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.ResolvedDomains.String())
}

View file

@ -1,25 +1,48 @@
package resolver
import "testing"
import (
"testing"
func testDomains(t *testing.T, ipi *IPInfo, expectedDomains string) {
if ipi.FmtDomains() != expectedDomains {
t.Errorf("unexpected domains '%s', expected '%s'", ipi.FmtDomains(), expectedDomains)
}
}
"github.com/stretchr/testify/assert"
)
func TestIPInfo(t *testing.T) {
ipi := &IPInfo{
IP: "1.2.3.4",
Domains: []string{"example.com.", "sub.example.com."},
example := ResolvedDomain{
Domain: "example.com.",
}
subExample := ResolvedDomain{
Domain: "sub1.example.com",
CNAMEs: []string{"example.com"},
}
testDomains(t, ipi, "example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("sub.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi := &IPInfo{
IP: "1.2.3.4",
ResolvedDomains: ResolvedDomains{
example,
subExample,
},
}
sub2Example := ResolvedDomain{
Domain: "sub2.example.com",
CNAMEs: []string{"sub1.example.com", "example.com"},
}
added := ipi.AddDomain(sub2Example)
assert.True(t, added)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
// try again, should do nothing now
added = ipi.AddDomain(sub2Example)
assert.False(t, added)
assert.Equal(t, ResolvedDomains{example, subExample, sub2Example}, ipi.ResolvedDomains)
subOverWrite := ResolvedDomain{
Domain: "sub1.example.com",
CNAMEs: []string{}, // now without CNAMEs
}
added = ipi.AddDomain(subOverWrite)
assert.True(t, added)
assert.Equal(t, ResolvedDomains{example, sub2Example, subOverWrite}, ipi.ResolvedDomains)
}