package endpoints import ( "fmt" "strconv" "strings" "github.com/safing/portmaster/intel" "github.com/safing/portmaster/network/reference" ) // Endpoint describes an Endpoint Matcher type Endpoint interface { Matches(entity *intel.Entity) (result EPResult, reason string) String() string } // EndpointBase provides general functions for implementing an Endpoint to reduce boilerplate. type EndpointBase struct { //nolint:maligned // TODO Protocol uint8 StartPort uint16 EndPort uint16 Permitted bool } func (ep *EndpointBase) matchesPPP(entity *intel.Entity) (result EPResult) { // only check if protocol is defined if ep.Protocol > 0 { // if protocol is unknown, return Undeterminable if entity.Protocol == 0 { return Undeterminable } // if protocol does not match, return NoMatch if entity.Protocol != ep.Protocol { return NoMatch } } // only check if port is defined if ep.StartPort > 0 { // if port is unknown, return Undeterminable if entity.Port == 0 { return Undeterminable } // if port does not match, return NoMatch if entity.Port < ep.StartPort || entity.Port > ep.EndPort { return NoMatch } } // protocol and port matched or were defined as any if ep.Permitted { return Permitted } return Denied } func (ep *EndpointBase) renderPPP(s string) string { var rendered string if ep.Permitted { rendered = "+ " + s } else { rendered = "- " + s } if ep.Protocol > 0 || ep.StartPort > 0 { if ep.Protocol > 0 { rendered += " " + reference.GetProtocolName(ep.Protocol) } else { rendered += " *" } if ep.StartPort > 0 { if ep.StartPort == ep.EndPort { rendered += "/" + reference.GetPortName(ep.StartPort) } else { rendered += "/" + strconv.Itoa(int(ep.StartPort)) + "-" + strconv.Itoa(int(ep.EndPort)) } } } return rendered } func (ep *EndpointBase) parsePPP(typedEp Endpoint, fields []string) (Endpoint, error) { //nolint:gocognit // TODO switch len(fields) { case 2: // nothing else to do here case 3: // parse protocol and port(s) var ok bool splitted := strings.Split(fields[2], "/") if len(splitted) > 2 { return nil, invalidDefinitionError(fields, "protocol and port must be in format /") } // protocol switch splitted[0] { case "": return nil, invalidDefinitionError(fields, "protocol can't be empty") case "*": // any protocol that supports ports default: n, err := strconv.ParseUint(splitted[0], 10, 8) n8 := uint8(n) if err != nil { // maybe it's a name? n8, ok = reference.GetProtocolNumber(splitted[0]) if !ok { return nil, invalidDefinitionError(fields, "protocol number parsing error") } } ep.Protocol = n8 } // port(s) if len(splitted) > 1 { switch splitted[1] { case "", "*": return nil, invalidDefinitionError(fields, "omit port if should match any") default: portSplitted := strings.Split(splitted[1], "-") if len(portSplitted) > 2 { return nil, invalidDefinitionError(fields, "ports must be in format from-to") } // parse start port n, err := strconv.ParseUint(portSplitted[0], 10, 16) n16 := uint16(n) if err != nil { // maybe it's a name? n16, ok = reference.GetPortNumber(portSplitted[0]) if !ok { return nil, invalidDefinitionError(fields, "port number parsing error") } } ep.StartPort = n16 // parse end port if len(portSplitted) > 1 { n, err = strconv.ParseUint(portSplitted[1], 10, 16) n16 = uint16(n) if err != nil { // maybe it's a name? n16, ok = reference.GetPortNumber(portSplitted[1]) if !ok { return nil, invalidDefinitionError(fields, "port number parsing error") } } } ep.EndPort = n16 } } // check if anything was parsed if ep.Protocol == 0 && ep.StartPort == 0 { return nil, invalidDefinitionError(fields, "omit protocol/port if should match any") } default: return nil, invalidDefinitionError(fields, "there should be only 2 or 3 segments") } switch fields[0] { case "+": ep.Permitted = true case "-": ep.Permitted = false default: return nil, invalidDefinitionError(fields, "invalid permission prefix") } return typedEp, nil } func invalidDefinitionError(fields []string, msg string) error { return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg) } func parseEndpoint(value string) (endpoint Endpoint, err error) { fields := strings.Fields(value) if len(fields) < 2 { return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value) } // any if endpoint, err = parseTypeAny(fields); endpoint != nil || err != nil { return } // ip if endpoint, err = parseTypeIP(fields); endpoint != nil || err != nil { return } // ip range if endpoint, err = parseTypeIPRange(fields); endpoint != nil || err != nil { return } // domain if endpoint, err = parseTypeDomain(fields); endpoint != nil || err != nil { return } // country if endpoint, err = parseTypeCountry(fields); endpoint != nil || err != nil { return } // asn if endpoint, err = parseTypeASN(fields); endpoint != nil || err != nil { return } // lists if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil { return } return nil, fmt.Errorf(`unknown endpoint definition: "%s"`, value) }