Fix endpoint scope

This commit is contained in:
Daniel 2020-05-20 14:53:14 +02:00
parent f1765a7abb
commit c48f8e5782
4 changed files with 19 additions and 19 deletions

View file

@ -29,10 +29,6 @@ type EndpointScope struct {
scopes uint8 scopes uint8
} }
// Localhost
// LAN
// Internet
// Matches checks whether the given entity matches this endpoint definition. // Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.IP == nil { if entity.IP == nil {
@ -64,16 +60,14 @@ func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
// Scopes returns the string representation of all scopes. // Scopes returns the string representation of all scopes.
func (ep *EndpointScope) Scopes() string { func (ep *EndpointScope) Scopes() string {
if ep.scopes == 3 || ep.scopes > 4 { // single scope
// single scope switch ep.scopes {
switch ep.scopes { case scopeLocalhost:
case scopeLocalhost: return scopeLocalhostName
return scopeLocalhostName case scopeLAN:
case scopeLAN: return scopeLANName
return scopeLANName case scopeInternet:
case scopeInternet: return scopeInternetName
return scopeInternetName
}
} }
// multiple scopes // multiple scopes
@ -99,11 +93,11 @@ func parseTypeScope(fields []string) (Endpoint, error) {
for _, val := range strings.Split(strings.ToLower(fields[1]), ",") { for _, val := range strings.Split(strings.ToLower(fields[1]), ",") {
switch val { switch val {
case scopeLocalhostMatcher: case scopeLocalhostMatcher:
ep.scopes &= scopeLocalhost ep.scopes ^= scopeLocalhost
case scopeLANMatcher: case scopeLANMatcher:
ep.scopes &= scopeLAN ep.scopes ^= scopeLAN
case scopeInternetMatcher: case scopeInternetMatcher:
ep.scopes &= scopeInternet ep.scopes ^= scopeInternet
default: default:
return nil, nil return nil, nil
} }

View file

@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error {
return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg) return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg)
} }
func parseEndpoint(value string) (endpoint Endpoint, err error) { func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit
fields := strings.Fields(value) fields := strings.Fields(value)
if len(fields) < 2 { if len(fields) < 2 {
return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value) return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value)

View file

@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) {
testParsing(t, "+ AS1234") testParsing(t, "+ AS1234")
testParsing(t, "+ AS12345") testParsing(t, "+ AS12345")
// network scope
testParsing(t, "+ Localhost")
testParsing(t, "+ LAN")
testParsing(t, "+ Internet")
testParsing(t, "+ Localhost,LAN,Internet")
// protocol and ports // protocol and ports
testParsing(t, "+ * TCP/1-1024") testParsing(t, "+ * TCP/1-1024")
testParsing(t, "+ * */DNS") testParsing(t, "+ * */DNS")

View file

@ -358,7 +358,7 @@ func TestEndpointMatching(t *testing.T) {
// Lists // Lists
ep, err = parseEndpoint("+ L:A,B,C") _, err = parseEndpoint("+ L:A,B,C")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }