diff --git a/intel/entity.go b/intel/entity.go index 4ca3765e..15a9d9a4 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -43,6 +43,10 @@ type Entity struct { // Domain is the target domain of the connection. Domain string + // ReverseDomain is the domain the IP address points to. This is only + // resolved and populated when needed. + ReverseDomain string + // CNAME is a list of domain names that have been // resolved for Domain. CNAME []string @@ -150,11 +154,6 @@ func (e *Entity) EnableReverseResolving() { func (e *Entity) reverseResolve(ctx context.Context) { e.reverseResolveOnce.Do(func() { - // check if we should resolve - if !e.reverseResolveEnabled { - return - } - // need IP! if e.IP == nil { return @@ -170,13 +169,20 @@ func (e *Entity) reverseResolve(ctx context.Context) { log.Tracer(ctx).Warningf("intel: failed to resolve IP %s: %s", e.IP, err) return } - e.Domain = domain + e.ReverseDomain = domain }) } // GetDomain returns the domain and whether it is set. -func (e *Entity) GetDomain() (string, bool) { - e.reverseResolve() +func (e *Entity) GetDomain(ctx context.Context, mayUseReverseDomain bool) (string, bool) { + if mayUseReverseDomain && e.reverseResolveEnabled { + e.reverseResolve(ctx) + + if e.ReverseDomain == "" { + return "", false + } + return e.ReverseDomain, true + } if e.Domain == "" { return "", false @@ -268,7 +274,7 @@ func (e *Entity) getDomainLists(ctx context.Context) { return } - domain, ok := e.GetDomain() + domain, ok := e.GetDomain(ctx, false /* mayUseReverseDomain */) if !ok { return } diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index 6c43a7de..491b038d 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -63,19 +63,20 @@ func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, } // Matches checks whether the given entity matches this endpoint definition. -func (ep *EndpointDomain) Matches(entity *intel.Entity) (EPResult, Reason) { - if entity.Domain == "" { +func (ep *EndpointDomain) Matches(ctx context.Context, entity *intel.Entity) (EPResult, Reason) { + domain, ok := entity.GetDomain(ctx, true /* mayUseReverseDomain */) + if !ok { return NoMatch, nil } - result, reason := ep.check(entity, entity.Domain) + result, reason := ep.check(entity, domain) if result != NoMatch { return result, reason } if entity.CNAMECheckEnabled() { - for _, domain := range entity.CNAME { - result, reason = ep.check(entity, domain) + for _, cname := range entity.CNAME { + result, reason = ep.check(entity, cname) if result == Denied { return result, reason }