Fix endpoint scope

This commit is contained in:
Daniel 2020-05-20 14:53:14 +02:00
parent f1765a7abb
commit c48f8e5782
4 changed files with 19 additions and 19 deletions

View file

@ -29,10 +29,6 @@ type EndpointScope struct {
scopes uint8
}
// Localhost
// LAN
// Internet
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.IP == nil {
@ -64,16 +60,14 @@ func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
// Scopes returns the string representation of all scopes.
func (ep *EndpointScope) Scopes() string {
if ep.scopes == 3 || ep.scopes > 4 {
// single scope
switch ep.scopes {
case scopeLocalhost:
return scopeLocalhostName
case scopeLAN:
return scopeLANName
case scopeInternet:
return scopeInternetName
}
// single scope
switch ep.scopes {
case scopeLocalhost:
return scopeLocalhostName
case scopeLAN:
return scopeLANName
case scopeInternet:
return scopeInternetName
}
// multiple scopes
@ -99,11 +93,11 @@ func parseTypeScope(fields []string) (Endpoint, error) {
for _, val := range strings.Split(strings.ToLower(fields[1]), ",") {
switch val {
case scopeLocalhostMatcher:
ep.scopes &= scopeLocalhost
ep.scopes ^= scopeLocalhost
case scopeLANMatcher:
ep.scopes &= scopeLAN
ep.scopes ^= scopeLAN
case scopeInternetMatcher:
ep.scopes &= scopeInternet
ep.scopes ^= scopeInternet
default:
return nil, nil
}

View file

@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error {
return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg)
}
func parseEndpoint(value string) (endpoint Endpoint, err error) {
func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit
fields := strings.Fields(value)
if len(fields) < 2 {
return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value)

View file

@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) {
testParsing(t, "+ AS1234")
testParsing(t, "+ AS12345")
// network scope
testParsing(t, "+ Localhost")
testParsing(t, "+ LAN")
testParsing(t, "+ Internet")
testParsing(t, "+ Localhost,LAN,Internet")
// protocol and ports
testParsing(t, "+ * TCP/1-1024")
testParsing(t, "+ * */DNS")

View file

@ -358,7 +358,7 @@ func TestEndpointMatching(t *testing.T) {
// Lists
ep, err = parseEndpoint("+ L:A,B,C")
_, err = parseEndpoint("+ L:A,B,C")
if err != nil {
t.Fatal(err)
}