diff --git a/firewall/api.go b/firewall/api.go index 4b63d890..b73729c8 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -33,11 +33,11 @@ func startAPIAuth() { var err error _, apiPort, err = parseHostPort(apiListenAddress()) if err != nil { - log.Warningf("firewall: failed to parse API address for improved api auth mechanism: %s", err) + log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err) return } apiPortSet = true - log.Tracef("firewall: api port set to %d", apiPort) + log.Tracef("filter: api port set to %d", apiPort) } func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err error) { @@ -83,7 +83,7 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er } } - log.Debugf("firewall: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dataRoot.Path) + log.Debugf("filter: denying api access to %s - also checked %s (trusted root is %s)", procsChecked[0], strings.Join(procsChecked[1:], " "), dataRoot.Path) return false, nil } diff --git a/firewall/firewall.go b/firewall/firewall.go index 0b32a9a0..ac3e85d8 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -2,7 +2,7 @@ package firewall import ( "context" - "fmt" + "net" "os" "sync/atomic" "time" @@ -34,7 +34,7 @@ var ( // localNet4 *net.IPNet - // localhost4 = net.IPv4(127, 0, 0, 1) + localhost4 = net.IPv4(127, 0, 0, 1) // localhost6 = net.IPv6loopback // tunnelNet4 *net.IPNet @@ -61,6 +61,8 @@ func init() { DefaultValue: true, }, ) + + network.SetDefaultFirewallHandler(defaultHandler) } func prep() (err error) { @@ -78,16 +80,16 @@ func prep() (err error) { // // Yes, this would normally be 127.0.0.0/8 // // TODO: figure out any side effects // if err != nil { - // return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err) + // return fmt.Errorf("filter: failed to parse cidr 127.0.0.0/24: %s", err) // } // _, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16") // if err != nil { - // return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err) + // return fmt.Errorf("filter: failed to parse cidr 127.17.0.0/16: %s", err) // } // _, tunnelNet6, err = net.ParseCIDR("fd17::/64") // if err != nil { - // return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err) + // return fmt.Errorf("filter: failed to parse cidr fd17::/64: %s", err) // } var pA uint64 @@ -135,7 +137,10 @@ func handlePacket(pkt packet.Packet) { // } // allow local dns - if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && pkt.Info().Src.Equal(pkt.Info().Dst) { + if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && + (pkt.Info().Src.Equal(pkt.Info().Dst) || // Windows redirects back to same interface + pkt.Info().Src.Equal(localhost4) || // Linux sometimes does 127.0.0.1->127.0.0.53 + pkt.Info().Dst.Equal(localhost4)) { log.Debugf("accepting local dns: %s", pkt) _ = pkt.PermanentAccept() return @@ -180,11 +185,11 @@ func handlePacket(pkt packet.Packet) { // TODO: Howto handle NetBios? } - // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetLinkID()) + // log.Debugf("filter: pkt %s has ID %s", pkt, pkt.GetLinkID()) // use this to time how long it takes process packet // timed := time.Now() - // defer log.Tracef("firewall: took %s to process packet %s", time.Now().Sub(timed).String(), pkt) + // defer log.Tracef("filter: took %s to process packet %s", time.Now().Sub(timed).String(), pkt) // check if packet is destined for tunnel // switch pkt.IPVersion() { @@ -201,156 +206,110 @@ func handlePacket(pkt packet.Packet) { traceCtx, tracer := log.AddTracer(context.Background()) if tracer != nil { pkt.SetCtx(traceCtx) - tracer.Tracef("firewall: handling packet: %s", pkt) + tracer.Tracef("filter: handling packet: %s", pkt) } // associate packet to link and handle - link, created := network.GetOrCreateLinkByPacket(pkt) - if created { - link.SetFirewallHandler(initialHandler) + conn, ok := network.GetConnection(pkt.GetConnectionID()) + if ok { + tracer.Tracef("filter: assigned to connection %s", conn.ID) + } else { + conn = network.NewConnectionFromFirstPacket(pkt) + tracer.Tracef("filter: created new connection %s", conn.ID) + conn.SetFirewallHandler(initialHandler) } - link.HandlePacket(pkt) + // handle packet + conn.HandlePacket(pkt) } -func initialHandler(pkt packet.Packet, link *network.Link) { - defer func() { - go link.SaveIfNeeded() - }() - - log.Tracer(pkt.Ctx()).Trace("firewall: [initial handler]") +func initialHandler(conn *network.Connection, pkt packet.Packet) { + log.Tracer(pkt.Ctx()).Trace("filter: [initial handler]") // check for internal firewall bypass ps := getPortStatusAndMarkUsed(pkt.Info().LocalPort()) if ps.isMe { - // connect to comms - comm, err := network.GetOwnComm(pkt) - if err != nil { - // log.Warningf("firewall: could not get own comm: %s", err) - log.Tracer(pkt.Ctx()).Warningf("firewall: could not get own comm: %s", err) - } else { - comm.AddLink(link) - } - // approve - link.Accept("internally approved") - log.Tracer(pkt.Ctx()).Tracef("firewall: internally approved link (via local port %d)", pkt.Info().LocalPort()) - + conn.Accept("internally approved") // finish - link.StopFirewallHandler() - issueVerdict(pkt, link, 0, true) + conn.StopFirewallHandler() + issueVerdict(conn, pkt, 0, true) return } - // get Communication - comm, err := network.GetCommunicationByFirstPacket(pkt) - if err != nil { - log.Tracer(pkt.Ctx()).Warningf("firewall: could not get process, denying link: %s", err) - - // get "unknown" comm - link.Deny(fmt.Sprintf("could not get process: %s", err)) - comm, err = network.GetUnknownCommunication(pkt) - - if err != nil { - // all failed - log.Tracer(pkt.Ctx()).Errorf("firewall: could not get unknown comm: %s", err) - link.UpdateVerdict(network.VerdictDrop) - link.StopFirewallHandler() - issueVerdict(pkt, link, 0, true) - return - } - } - - // add new Link to Communication (and save both) - comm.AddLink(link) - log.Tracer(pkt.Ctx()).Tracef("firewall: link attached to %s", comm) - // reroute dns requests to nameserver - if comm.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { - link.UpdateVerdict(network.VerdictRerouteToNameserver) - link.StopFirewallHandler() - issueVerdict(pkt, link, 0, true) + if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { + conn.Verdict = network.VerdictRerouteToNameserver + conn.StopFirewallHandler() + issueVerdict(conn, pkt, 0, true) return } - log.Tracer(pkt.Ctx()).Trace("firewall: starting decision process") - - // TODO: filter lists may have IPs in the future! - DecideOnCommunication(comm) - DecideOnLink(comm, link, pkt) - - // TODO: link this to real status - // gate17Active := mode.Client() + log.Tracer(pkt.Ctx()).Trace("filter: starting decision process") + DecideOnConnection(conn, pkt) + conn.Inspecting = false // TODO: enable inspecting again switch { - // case gate17Active && link.Inspect: - // // tunnel link, but also inspect (after reroute) - // link.Tunneled = true - // link.SetFirewallHandler(inspectThenVerdict) - // verdict(pkt, link.GetVerdict()) - // case gate17Active: - // // tunnel link, don't inspect - // link.Tunneled = true - // link.StopFirewallHandler() - // permanentVerdict(pkt, network.VerdictAccept) - case link.Inspect: - log.Tracer(pkt.Ctx()).Trace("firewall: start inspecting") - link.SetFirewallHandler(inspectThenVerdict) - inspectThenVerdict(pkt, link) + case conn.Inspecting: + log.Tracer(pkt.Ctx()).Trace("filter: start inspecting") + conn.SetFirewallHandler(inspectThenVerdict) + inspectThenVerdict(conn, pkt) default: - link.StopFirewallHandler() - issueVerdict(pkt, link, 0, true) + conn.StopFirewallHandler() + issueVerdict(conn, pkt, 0, true) } } -func inspectThenVerdict(pkt packet.Packet, link *network.Link) { - pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) +func defaultHandler(conn *network.Connection, pkt packet.Packet) { + issueVerdict(conn, pkt, 0, true) +} + +func inspectThenVerdict(conn *network.Connection, pkt packet.Packet) { + pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt) if continueInspection { - issueVerdict(pkt, link, pktVerdict, false) + issueVerdict(conn, pkt, pktVerdict, false) return } // we are done with inspecting - link.StopFirewallHandler() - issueVerdict(pkt, link, 0, true) + conn.StopFirewallHandler() + issueVerdict(conn, pkt, 0, true) } -func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict, allowPermanent bool) { - link.Lock() - +func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.Verdict, allowPermanent bool) { // enable permanent verdict - if allowPermanent && !link.VerdictPermanent { - link.VerdictPermanent = permanentVerdicts() - if link.VerdictPermanent { - link.SaveWhenFinished() + if allowPermanent && !conn.VerdictPermanent { + conn.VerdictPermanent = permanentVerdicts() + if conn.VerdictPermanent { + conn.SaveWhenFinished() } } - // do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link - if verdict < link.Verdict { - verdict = link.Verdict + // do not allow to circumvent decision: e.g. to ACCEPT packets from a DROP-ed connection + if verdict < conn.Verdict { + verdict = conn.Verdict } var err error switch verdict { case network.VerdictAccept: atomic.AddUint64(packetsAccepted, 1) - if link.VerdictPermanent { + if conn.VerdictPermanent { err = pkt.PermanentAccept() } else { err = pkt.Accept() } case network.VerdictBlock: atomic.AddUint64(packetsBlocked, 1) - if link.VerdictPermanent { + if conn.VerdictPermanent { err = pkt.PermanentBlock() } else { err = pkt.Block() } case network.VerdictDrop: atomic.AddUint64(packetsDropped, 1) - if link.VerdictPermanent { + if conn.VerdictPermanent { err = pkt.PermanentDrop() } else { err = pkt.Drop() @@ -364,13 +323,9 @@ func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict err = pkt.Drop() } - link.Unlock() - if err != nil { - log.Warningf("firewall: failed to apply verdict to pkt %s: %s", pkt, err) + log.Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err) } - - log.Tracer(pkt.Ctx()).Infof("firewall: %s %s", link.Verdict, link) } // func tunnelHandler(pkt packet.Packet) { @@ -381,7 +336,7 @@ func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict // } // // entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords()) -// log.Tracef("firewall: rerouting %s to tunnel entry point", pkt) +// log.Tracef("filter: rerouting %s to tunnel entry point", pkt) // pkt.RerouteToTunnel() // return // } @@ -403,7 +358,7 @@ func statLogger() { case <-module.Stopping(): return case <-time.After(10 * time.Second): - log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped)) + log.Tracef("filter: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped)) atomic.StoreUint64(packetsAccepted, 0) atomic.StoreUint64(packetsBlocked, 0) atomic.StoreUint64(packetsDropped, 0) diff --git a/firewall/inspection/inspection.go b/firewall/inspection/inspection.go index 70ef4e06..750eb4ba 100644 --- a/firewall/inspection/inspection.go +++ b/firewall/inspection/inspection.go @@ -12,12 +12,12 @@ const ( DO_NOTHING uint8 = iota BLOCK_PACKET DROP_PACKET - BLOCK_LINK - DROP_LINK + BLOCK_CONN + DROP_CONN STOP_INSPECTING ) -type inspectorFn func(packet.Packet, *network.Link) uint8 +type inspectorFn func(*network.Connection, packet.Packet) uint8 var ( inspectors []inspectorFn @@ -38,20 +38,20 @@ func RegisterInspector(name string, inspector inspectorFn, inspectVerdict networ } // RunInspectors runs all the applicable inspectors on the given packet. -func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool) { +func RunInspectors(conn *network.Connection, pkt packet.Packet) (network.Verdict, bool) { // inspectorsLock.Lock() // defer inspectorsLock.Unlock() - activeInspectors := link.GetActiveInspectors() + activeInspectors := conn.GetActiveInspectors() if activeInspectors == nil { activeInspectors = make([]bool, len(inspectors)) - link.SetActiveInspectors(activeInspectors) + conn.SetActiveInspectors(activeInspectors) } - inspectorData := link.GetInspectorData() + inspectorData := conn.GetInspectorData() if inspectorData == nil { inspectorData = make(map[uint8]interface{}) - link.SetInspectorData(inspectorData) + conn.SetInspectorData(inspectorData) } continueInspection := false @@ -62,12 +62,12 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool if skip { continue } - if link.Verdict > inspectVerdicts[key] { + if conn.Verdict > inspectVerdicts[key] { activeInspectors[key] = true continue } - action := inspectors[key](pkt, link) // Actually run inspector + action := inspectors[key](conn, pkt) // Actually run inspector switch action { case DO_NOTHING: if verdict < network.VerdictAccept { @@ -82,16 +82,14 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool case DROP_PACKET: verdict = network.VerdictDrop continueInspection = true - case BLOCK_LINK: - link.UpdateVerdict(network.VerdictBlock) + case BLOCK_CONN: + conn.SetVerdict(network.VerdictBlock) + verdict = conn.Verdict activeInspectors[key] = true - if verdict < network.VerdictBlock { - verdict = network.VerdictBlock - } - case DROP_LINK: - link.UpdateVerdict(network.VerdictDrop) + case DROP_CONN: + conn.SetVerdict(network.VerdictDrop) + verdict = conn.Verdict activeInspectors[key] = true - verdict = network.VerdictDrop case STOP_INSPECTING: activeInspectors[key] = true } diff --git a/firewall/interception/nfqueue/packet.go b/firewall/interception/nfqueue/packet.go index 35ffa091..98b214c3 100644 --- a/firewall/interception/nfqueue/packet.go +++ b/firewall/interception/nfqueue/packet.go @@ -51,7 +51,7 @@ func (pkt *Packet) setVerdict(v uint32) (err error) { }() pkt.verdict <- v close(pkt.verdict) - // log.Tracef("firewall: packet %s verdict %d", pkt, v) + // log.Tracef("filter: packet %s verdict %d", pkt, v) return err } diff --git a/firewall/master.go b/firewall/master.go index 6ab79820..ba33d2db 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -31,91 +31,170 @@ import ( // 4. DecideOnLink // is called when when the first packet of a link arrives only if communication has verdict UNDECIDED or CANTSAY -// DecideOnCommunicationBeforeDNS makes a decision about a communication before the dns query is resolved and intel is gathered. -func DecideOnCommunicationBeforeDNS(comm *network.Communication) { +// DecideOnConnection makes a decision about a connection. +func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation - if comm.UpdateAndCheck() { - log.Infof("firewall: re-evaluating verdict on %s", comm) - comm.ResetVerdict() - } - - // check if need to run - if comm.GetVerdict() != network.VerdictUndecided { - return + if conn.UpdateAndCheck() { + log.Infof("filter: re-evaluating verdict on %s", conn) + conn.Verdict = network.VerdictUndecided } // grant self - if comm.Process().Pid == os.Getpid() { - log.Infof("firewall: granting own communication %s", comm) - comm.Accept("") + if conn.Process().Pid == os.Getpid() { + log.Infof("filter: granting own connection %s", conn) + conn.Verdict = network.VerdictAccept return } + // check if process is communicating with itself + if pkt != nil { + if conn.Process().Pid >= 0 && pkt.Info().Src.Equal(pkt.Info().Dst) { + // get PID + otherPid, _, err := process.GetPidByEndpoints( + pkt.Info().RemoteIP(), + pkt.Info().RemotePort(), + pkt.Info().LocalIP(), + pkt.Info().LocalPort(), + pkt.Info().Protocol, + ) + if err == nil { + + // get primary process + otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid) + if err == nil { + + if otherProcess.Pid == conn.Process().Pid { + conn.Accept("connection to self") + return + } + } + } + } + } + // get profile - p := comm.Process().Profile() - - // check for any network access - if p.BlockScopeInternet() && p.BlockScopeLAN() { - log.Infof("firewall: denying communication %s, accessing Internet or LAN not permitted", comm) - comm.Deny("accessing Internet or LAN not permitted") + p := conn.Process().Profile() + if p == nil { + conn.Block("no profile") return } - // continueing with access to either Internet or LAN - // check endpoint list - // FIXME: comm.Entity.Lock() - result, reason := p.MatchEndpoint(comm.Entity) - // FIXME: comm.Entity.Unlock() + // check conn type + switch conn.Scope { + case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + if p.BlockInbound() { + if conn.Scope == network.IncomingHost { + conn.Block("inbound connections blocked") + } else { + conn.Deny("inbound connections blocked") + } + return + } + case network.PeerLAN, network.PeerInternet, network.PeerInvalid: + // Important: PeerHost is and should be missing! + if p.BlockP2P() { + conn.Block("direct connections (P2P) blocked") + return + } + } + + // check scopes + if conn.Entity.IP != nil { + classification := netutils.ClassifyIP(conn.Entity.IP) + + switch classification { + case netutils.Global, netutils.GlobalMulticast: + if p.BlockScopeInternet() { + conn.Deny("Internet access blocked") // Block Outbound / Drop Inbound + return + } + case netutils.SiteLocal, netutils.LinkLocal, netutils.LocalMulticast: + if p.BlockScopeLAN() { + conn.Block("LAN access blocked") // Block Outbound / Drop Inbound + return + } + case netutils.HostLocal: + if p.BlockScopeLocal() { + conn.Block("Localhost access blocked") // Block Outbound / Drop Inbound + return + } + default: // netutils.Invalid + conn.Deny("invalid IP") // Block Outbound / Drop Inbound + return + } + } else if conn.Entity.Domain != "" { + // DNS Query + // DNS is expected to resolve to LAN or Internet addresses + // TODO: handle domains mapped to localhost + if p.BlockScopeInternet() && p.BlockScopeLAN() { + conn.Block("Internet and LAN access blocked") + return + } + } + + // check endpoints list + var result endpoints.EPResult + var reason string + if conn.Inbound { + result, reason = p.MatchServiceEndpoint(conn.Entity) + } else { + result, reason = p.MatchEndpoint(conn.Entity) + } switch result { - case endpoints.Undeterminable: - comm.UpdateVerdict(network.VerdictUndeterminable) - return case endpoints.Denied: - log.Infof("firewall: denying communication %s, domain is blacklisted: %s", comm, reason) - comm.Deny(fmt.Sprintf("domain is blacklisted: %s", reason)) + conn.Deny("endpoint is blacklisted: " + reason) // Block Outbound / Drop Inbound return case endpoints.Permitted: - log.Infof("firewall: permitting communication %s, domain is whitelisted: %s", comm, reason) - comm.Accept(fmt.Sprintf("domain is whitelisted: %s", reason)) + conn.Accept("endpoint is whitelisted: " + reason) + return + } + // continuing with result == NoMatch + + // implicit default=block for inbound + if conn.Inbound { + conn.Drop("endpoint is not whitelisted (incoming is always default=block)") return } - // continueing with result == NoMatch // check default action if p.DefaultAction() == profile.DefaultActionPermit { - log.Infof("firewall: permitting communication %s, domain is not blacklisted (default=permit)", comm) - comm.Accept("domain is not blacklisted (default=permit)") + conn.Accept("endpoint is not blacklisted (default=permit)") return } // check relation if !p.DisableAutoPermit() { - if checkRelation(comm) { + related, reason := checkRelation(conn) + if related { + conn.Accept(reason) return } } // prompt if p.DefaultAction() == profile.DefaultActionAsk { - prompt(comm, nil, nil) + prompt(conn, pkt) return } // DefaultAction == DefaultActionBlock - log.Infof("firewall: denying communication %s, domain is not whitelisted (default=block)", comm) - comm.Deny("domain is not whitelisted (default=block)") + conn.Deny("endpoint is not whitelisted (default=block)") return } // FilterDNSResponse filters a dns response according to the application profile and settings. -func FilterDNSResponse(comm *network.Communication, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { //nolint:gocognit // TODO +func FilterDNSResponse(conn *network.Connection, q *resolver.Query, rrCache *resolver.RRCache) *resolver.RRCache { //nolint:gocognit // TODO // do not modify own queries - this should not happen anyway - if comm.Process().Pid == os.Getpid() { + if conn.Process().Pid == os.Getpid() { return rrCache } // get profile - p := comm.Process().Profile() + p := conn.Process().Profile() + if p == nil { + conn.Block("no profile") + return nil + } // check if DNS response filtering is completely turned off if !p.RemoveOutOfScopeDNS() && !p.RemoveBlockedDNS() { @@ -201,14 +280,13 @@ func FilterDNSResponse(comm *network.Communication, q *resolver.Query, rrCache * if addressesRemoved > 0 { rrCache.Filtered = true if addressesOk == 0 { - comm.Deny("no addresses returned for this domain are permitted") - log.Infof("firewall: fully dns responses for communication %s", comm) + conn.Block("no addresses returned for this domain are permitted") return nil } } if rrCache.Filtered { - log.Infof("firewall: filtered DNS replies for %s: %s", comm, strings.Join(rrCache.FilteredEntries, ", ")) + log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", ")) } // TODO: Gate17 integration @@ -217,231 +295,22 @@ func FilterDNSResponse(comm *network.Communication, q *resolver.Query, rrCache * return rrCache } -// DecideOnCommunication makes a decision about a communication with its first packet. -func DecideOnCommunication(comm *network.Communication) { - // update profiles and check if communication needs reevaluation - if comm.UpdateAndCheck() { - log.Infof("firewall: re-evaluating verdict on %s", comm) - comm.ResetVerdict() - - // if communicating with a domain entity, re-evaluate with BeforeDNS - if strings.HasSuffix(comm.Scope, ".") { - DecideOnCommunicationBeforeDNS(comm) - } - } - - // check if need to run - if comm.GetVerdict() != network.VerdictUndecided { - return - } - - // grant self - if comm.Process().Pid == os.Getpid() { - log.Infof("firewall: granting own communication %s", comm) - comm.Accept("") - return - } - - // get profile - p := comm.Process().Profile() - - // check comm type - switch comm.Scope { - case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: - if p.BlockInbound() { - log.Infof("firewall: denying communication %s, not a service", comm) - if comm.Scope == network.IncomingHost { - comm.Block("not a service") - } else { - comm.Deny("not a service") - } - return - } - case network.PeerLAN, network.PeerInternet, network.PeerInvalid: - // Important: PeerHost is and should be missing! - if p.BlockP2P() { - log.Infof("firewall: denying communication %s, peer to peer comms (to an IP) not allowed", comm) - comm.Deny("peer to peer comms (to an IP) not allowed") - return - } - } - - // check network scope - switch comm.Scope { - case network.IncomingHost: - if p.BlockScopeLocal() { - log.Infof("firewall: denying communication %s, serving localhost not allowed", comm) - comm.Block("serving localhost not allowed") - return - } - case network.IncomingLAN: - if p.BlockScopeLAN() { - log.Infof("firewall: denying communication %s, serving LAN not allowed", comm) - comm.Deny("serving LAN not allowed") - return - } - case network.IncomingInternet: - if p.BlockScopeInternet() { - log.Infof("firewall: denying communication %s, serving Internet not allowed", comm) - comm.Deny("serving Internet not allowed") - return - } - case network.IncomingInvalid: - log.Infof("firewall: denying communication %s, invalid IP address", comm) - comm.Drop("invalid IP address") - return - case network.PeerHost: - if p.BlockScopeLocal() { - log.Infof("firewall: denying communication %s, accessing localhost not allowed", comm) - comm.Block("accessing localhost not allowed") - return - } - case network.PeerLAN: - if p.BlockScopeLAN() { - log.Infof("firewall: denying communication %s, accessing the LAN not allowed", comm) - comm.Deny("accessing the LAN not allowed") - return - } - case network.PeerInternet: - if p.BlockScopeInternet() { - log.Infof("firewall: denying communication %s, accessing the Internet not allowed", comm) - comm.Deny("accessing the Internet not allowed") - return - } - case network.PeerInvalid: - log.Infof("firewall: denying communication %s, invalid IP address", comm) - comm.Deny("invalid IP address") - return - } - - log.Infof("firewall: undeterminable verdict for communication %s", comm) - comm.UpdateVerdict(network.VerdictUndeterminable) -} - -// DecideOnLink makes a decision about a link with the first packet. -func DecideOnLink(comm *network.Communication, link *network.Link, pkt packet.Packet) { - - // grant self - if comm.Process().Pid == os.Getpid() { - log.Infof("firewall: granting own link %s", comm) - link.Accept("") - return - } - - // check if process is communicating with itself - if comm.Process().Pid >= 0 && pkt.Info().Src.Equal(pkt.Info().Dst) { - // get PID - otherPid, _, err := process.GetPidByEndpoints( - pkt.Info().RemoteIP(), - pkt.Info().RemotePort(), - pkt.Info().LocalIP(), - pkt.Info().LocalPort(), - pkt.Info().Protocol, - ) - if err == nil { - - // get primary process - otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid) - if err == nil { - - if otherProcess.Pid == comm.Process().Pid { - log.Infof("firewall: permitting connection to self %s", comm) - link.AddReason("connection to self") - - link.Lock() - link.Verdict = network.VerdictAccept - link.SaveWhenFinished() - link.Unlock() - return - } - - } - } - } - - // check if we aleady have a verdict - switch comm.GetVerdict() { - case network.VerdictUndecided, network.VerdictUndeterminable: - // continue - default: - link.UpdateVerdict(comm.GetVerdict()) - return - } - - // get profile - p := comm.Process().Profile() - - // check endpoints list - var result endpoints.EPResult - var reason string - // FIXME: link.Entity.Lock() - if comm.Direction { - result, reason = p.MatchServiceEndpoint(link.Entity) - } else { - result, reason = p.MatchEndpoint(link.Entity) - } - // FIXME: link.Entity.Unlock() - switch result { - case endpoints.Denied: - log.Infof("firewall: denying link %s, endpoint is blacklisted: %s", link, reason) - link.Deny(fmt.Sprintf("endpoint is blacklisted: %s", reason)) - return - case endpoints.Permitted: - log.Infof("firewall: permitting link %s, endpoint is whitelisted: %s", link, reason) - link.Accept(fmt.Sprintf("endpoint is whitelisted: %s", reason)) - return - } - // continueing with result == NoMatch - - // implicit default=block for incoming - if comm.Direction { - log.Infof("firewall: denying link %s: endpoint is not whitelisted (incoming is always default=block)", link) - link.Deny("endpoint is not whitelisted (incoming is always default=block)") - return - } - - // check default action - if p.DefaultAction() == profile.DefaultActionPermit { - log.Infof("firewall: permitting link %s: endpoint is not blacklisted (default=permit)", link) - link.Accept("endpoint is not blacklisted (default=permit)") - return - } - - // check relation - if !p.DisableAutoPermit() { - if checkRelation(comm) { - return - } - } - - // prompt - if p.DefaultAction() == profile.DefaultActionAsk { - prompt(comm, link, pkt) - return - } - - // DefaultAction == DefaultActionBlock - log.Infof("firewall: denying link %s: endpoint is not whitelisted (default=block)", link) - link.Deny("endpoint is not whitelisted (default=block)") - return -} - // checkRelation tries to find a relation between a process and a communication. This is for better out of the box experience and is _not_ meant to thwart intentional malware. -func checkRelation(comm *network.Communication) (related bool) { - if comm.Entity.Domain != "" { - return false +func checkRelation(conn *network.Connection) (related bool, reason string) { + if conn.Entity.Domain != "" { + return false, "" } // don't check for unknown processes - if comm.Process().Pid < 0 { - return false + if conn.Process().Pid < 0 { + return false, "" } - pathElements := strings.Split(comm.Process().Path, string(filepath.Separator)) + pathElements := strings.Split(conn.Process().Path, string(filepath.Separator)) // only look at the last two path segments if len(pathElements) > 2 { pathElements = pathElements[len(pathElements)-2:] } - domainElements := strings.Split(comm.Entity.Domain, ".") + domainElements := strings.Split(conn.Entity.Domain, ".") var domainElement string var processElement string @@ -455,21 +324,20 @@ matchLoop: break matchLoop } } - if levenshtein.Match(domainElement, comm.Process().Name, nil) > 0.5 { + if levenshtein.Match(domainElement, conn.Process().Name, nil) > 0.5 { related = true - processElement = comm.Process().Name + processElement = conn.Process().Name break matchLoop } - if levenshtein.Match(domainElement, comm.Process().ExecName, nil) > 0.5 { + if levenshtein.Match(domainElement, conn.Process().ExecName, nil) > 0.5 { related = true - processElement = comm.Process().ExecName + processElement = conn.Process().ExecName break matchLoop } } if related { - log.Infof("firewall: permitting communication %s, match to domain was found: %s is related to %s", comm, domainElement, processElement) - comm.Accept(fmt.Sprintf("domain is related to process: %s is related to %s", domainElement, processElement)) + reason = fmt.Sprintf("domain is related to process: %s is related to %s", domainElement, processElement) } - return related + return } diff --git a/firewall/ports.go b/firewall/ports.go index 60a1ceef..50b39d31 100644 --- a/firewall/ports.go +++ b/firewall/ports.go @@ -50,7 +50,7 @@ func GetPermittedPort() uint16 { // generate port between 10000 and 65535 rN, err := rng.Number(55535) if err != nil { - log.Warningf("firewall: failed to generate random port: %s", err) + log.Warningf("filter: failed to generate random port: %s", err) return 0 } port := uint16(rN + 10000) diff --git a/firewall/prompt.go b/firewall/prompt.go index 9e2cf17f..9805fb82 100644 --- a/firewall/prompt.go +++ b/firewall/prompt.go @@ -29,27 +29,17 @@ var ( mtSaveProfile = "save profile" ) -//nolint:gocognit // FIXME -func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) { +func prompt(conn *network.Connection, pkt packet.Packet) { //nolint:gocognit // TODO nTTL := time.Duration(promptTimeout()) * time.Second // first check if there is an existing notification for this. // build notification ID var nID string switch { - case comm.Direction, comm.Entity.Domain == "": // connection to/from IP - if pkt == nil { - log.Error("firewall: could not prompt for incoming/direct connection: missing pkt") - if link != nil { - link.Deny("internal error") - } else { - comm.Deny("internal error") - } - return - } - nID = fmt.Sprintf("firewall-prompt-%d-%s-%s", comm.Process().Pid, comm.Scope, pkt.Info().RemoteIP()) + case conn.Inbound, conn.Entity.Domain == "": // connection to/from IP + nID = fmt.Sprintf("firewall-prompt-%d-%s-%s", conn.Process().Pid, conn.Scope, pkt.Info().RemoteIP()) default: // connection to domain - nID = fmt.Sprintf("firewall-prompt-%d-%s", comm.Process().Pid, comm.Scope) + nID = fmt.Sprintf("firewall-prompt-%d-%s", conn.Process().Pid, conn.Scope) } n := notifications.Get(nID) saveResponse := true @@ -69,8 +59,8 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) // add message and actions switch { - case comm.Direction: // incoming - n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (%d/%d)", comm.Process(), link.Entity.IP.String(), link.Entity.Protocol, link.Entity.Port) + case conn.Inbound: + n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (%d/%d)", conn.Process(), conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port) n.AvailableActions = []*notifications.Action{ { ID: permitServingIP, @@ -81,8 +71,8 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) Text: "Deny", }, } - case comm.Entity.Domain == "": // direct connection - n.Message = fmt.Sprintf("Application %s wants to connect to %s (%d/%d)", comm.Process(), link.Entity.IP.String(), link.Entity.Protocol, link.Entity.Port) + case conn.Entity.Domain == "": // direct connection + n.Message = fmt.Sprintf("Application %s wants to connect to %s (%d/%d)", conn.Process(), conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port) n.AvailableActions = []*notifications.Action{ { ID: permitIP, @@ -94,10 +84,10 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) }, } default: // connection to domain - if link != nil { - n.Message = fmt.Sprintf("Application %s wants to connect to %s (%s %d/%d)", comm.Process(), comm.Entity.Domain, link.Entity.IP.String(), link.Entity.Protocol, link.Entity.Port) + if pkt != nil { + n.Message = fmt.Sprintf("Application %s wants to connect to %s (%s %d/%d)", conn.Process(), conn.Entity.Domain, conn.Entity.IP.String(), conn.Entity.Protocol, conn.Entity.Port) } else { - n.Message = fmt.Sprintf("Application %s wants to connect to %s", comm.Process(), comm.Entity.Domain) + n.Message = fmt.Sprintf("Application %s wants to connect to %s", conn.Process(), conn.Entity.Domain) } n.AvailableActions = []*notifications.Action{ { @@ -123,17 +113,9 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) case promptResponse := <-n.Response(): switch promptResponse { case permitDomainAll, permitDomainDistinct, permitIP, permitServingIP: - if link != nil { - link.Accept("permitted by user") - } else { - comm.Accept("permitted by user") - } + conn.Accept("permitted by user") default: // deny - if link != nil { - link.Accept("denied by user") - } else { - comm.Accept("denied by user") - } + conn.Deny("denied by user") } // end here if we won't save the response to the profile @@ -142,42 +124,43 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) } // get profile - p := comm.Process().Profile() + p := conn.Process().Profile() var ep endpoints.Endpoint switch promptResponse { case permitDomainAll: ep = &endpoints.EndpointDomain{ EndpointBase: endpoints.EndpointBase{Permitted: true}, - Domain: "." + comm.Entity.Domain, + Domain: "." + conn.Entity.Domain, } case permitDomainDistinct: ep = &endpoints.EndpointDomain{ EndpointBase: endpoints.EndpointBase{Permitted: true}, - Domain: comm.Entity.Domain, + Domain: conn.Entity.Domain, } case denyDomainAll: ep = &endpoints.EndpointDomain{ EndpointBase: endpoints.EndpointBase{Permitted: false}, - Domain: "." + comm.Entity.Domain, + Domain: "." + conn.Entity.Domain, } case denyDomainDistinct: ep = &endpoints.EndpointDomain{ EndpointBase: endpoints.EndpointBase{Permitted: false}, - Domain: comm.Entity.Domain, + Domain: conn.Entity.Domain, } case permitIP, permitServingIP: ep = &endpoints.EndpointIP{ EndpointBase: endpoints.EndpointBase{Permitted: true}, - IP: comm.Entity.IP, + IP: conn.Entity.IP, } case denyIP, denyServingIP: ep = &endpoints.EndpointIP{ EndpointBase: endpoints.EndpointBase{Permitted: false}, - IP: comm.Entity.IP, + IP: conn.Entity.IP, } default: log.Warningf("filter: unknown prompt response: %s", promptResponse) + return } switch promptResponse { @@ -188,10 +171,6 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet) } case <-n.Expired(): - if link != nil { - link.Deny("no response to prompt") - } else { - comm.Deny("no response to prompt") - } + conn.Deny("no response to prompt") } } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 6c7a7601..c2ee2ef3 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -174,38 +174,32 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), q.FQDN) - if err != nil { - tracer.Errorf("nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err) - returnNXDomain(w, query) - return nil - } - defer func() { - go comm.SaveIfNeeded() - }() + conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, remoteAddr.IP, uint16(remoteAddr.Port)) // save security level to query - q.SecurityLevel = comm.Process().Profile().SecurityLevel() + q.SecurityLevel = conn.Process().Profile().SecurityLevel() // check for possible DNS tunneling / data transmission // TODO: improve this lms := dga.LmsScoreOfDomain(q.FQDN) // log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms) if lms < 10 { - tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", comm.Process(), q.FQDN, lms) + tracer.Warningf("nameserver: possible data tunnel by %s: %s has lms score of %f, returning nxdomain", conn.Process(), q.FQDN, lms) returnNXDomain(w, query) return nil } // check profile before we even get intel and rr - firewall.DecideOnCommunicationBeforeDNS(comm) - comm.Lock() - comm.SaveWhenFinished() - comm.Unlock() - - if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop { - tracer.Infof("nameserver: %s denied before intel, returning nxdomain", comm) + firewall.DecideOnConnection(conn, nil) + switch conn.Verdict { + case network.VerdictBlock: + tracer.Infof("nameserver: %s blocked, returning nxdomain", conn) returnNXDomain(w, query) + // FIXME: save denied dns connection + return nil + case network.VerdictDrop: + tracer.Infof("nameserver: %s dropped, not replying", conn) + // FIXME: save denied dns connection return nil } @@ -213,16 +207,18 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er rrCache, err := resolver.Resolve(ctx, q) if err != nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains - tracer.Warningf("nameserver: %s requested %s%s: %s", comm.Process(), q.FQDN, q.QType, err) + tracer.Warningf("nameserver: %s requested %s%s: %s", conn.Process(), q.FQDN, q.QType, err) returnNXDomain(w, query) return nil } // filter DNS response - rrCache = firewall.FilterDNSResponse(comm, q, rrCache) + rrCache = firewall.FilterDNSResponse(conn, q, rrCache) + // TODO: FilterDNSResponse also sets a connection verdict if rrCache == nil { - tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", comm) + tracer.Infof("nameserver: %s implicitly denied by filtering the dns response, returning nxdomain", conn) returnNXDomain(w, query) + // FIXME: save denied dns connection return nil } @@ -267,7 +263,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er m.Ns = rrCache.Ns m.Extra = rrCache.Extra _ = w.WriteMsg(m) - tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, comm.Process()) + tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) return nil } diff --git a/network/clean.go b/network/clean.go index c4f575b7..25fdfb94 100644 --- a/network/clean.go +++ b/network/clean.go @@ -9,113 +9,73 @@ import ( ) var ( - cleanerTickDuration = 10 * time.Second - deleteLinksAfterEndedThreshold = 5 * time.Minute - deleteCommsWithoutLinksThreshhold = 3 * time.Minute - - mtSaveLink = "save network link" + cleanerTickDuration = 5 * time.Second + deleteConnsAfterEndedThreshold = 5 * time.Minute ) -func cleaner() { +func connectionCleaner(ctx context.Context) error { + ticker := time.NewTicker(cleanerTickDuration) + for { - time.Sleep(cleanerTickDuration) - - activeComms := cleanLinks() - activeProcs := cleanComms(activeComms) - process.CleanProcessStorage(activeProcs) + select { + case <-ctx.Done(): + ticker.Stop() + return nil + case <-ticker.C: + activePIDs := cleanConnections() + process.CleanProcessStorage(activePIDs) + } } } -func cleanLinks() (activeComms map[string]struct{}) { - activeComms = make(map[string]struct{}) - activeIDs := process.GetActiveConnectionIDs() +func cleanConnections() (activePIDs map[int]struct{}) { + activePIDs = make(map[int]struct{}) - now := time.Now().Unix() - deleteOlderThan := time.Now().Add(-deleteLinksAfterEndedThreshold).Unix() - - linksLock.RLock() - defer linksLock.RUnlock() - - var found bool - for key, link := range links { - - // delete dead links - link.lock.Lock() - deleteThis := link.Ended > 0 && link.Ended < deleteOlderThan - link.lock.Unlock() - if deleteThis { - log.Tracef("network.clean: deleted %s (ended at %d)", link.DatabaseKey(), link.Ended) - go link.Delete() - continue + name := "clean connections" // TODO: change to new fn + module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { + activeIDs := make(map[string]struct{}) + for _, cID := range process.GetActiveConnectionIDs() { + activeIDs[cID] = struct{}{} } - // not yet deleted, so its still a valid link regarding link count - comm := link.Communication() - comm.lock.Lock() - markActive(activeComms, comm.DatabaseKey()) - comm.lock.Unlock() + now := time.Now().Unix() + deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix() - // check if link is dead - found = false - for _, activeID := range activeIDs { - if key == activeID { - found = true - break + connsLock.Lock() + defer connsLock.Unlock() + + for key, conn := range conns { + // get conn.Ended + conn.Lock() + ended := conn.Ended + conn.Unlock() + + // delete inactive connections + switch { + case ended == 0: + // Step 1: check if still active + _, ok := activeIDs[key] + if ok { + activePIDs[conn.process.Pid] = struct{}{} + } else { + // Step 2: mark end + activePIDs[conn.process.Pid] = struct{}{} + conn.Lock() + conn.Ended = now + conn.Unlock() + // "save" + dbController.PushUpdate(conn) + } + case ended < deleteOlderThan: + // Step 3: delete + log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0)) + conn.delete() } + } - if !found { - // mark end time - link.lock.Lock() - link.Ended = now - link.lock.Unlock() - log.Tracef("network.clean: marked %s as ended", link.DatabaseKey()) - // save - linkToSave := link - module.StartMicroTask(&mtSaveLink, func(ctx context.Context) error { - linkToSave.saveAndLog() - return nil - }) - } + return nil + }) - } - - return activeComms -} - -func cleanComms(activeLinks map[string]struct{}) (activeComms map[string]struct{}) { - activeComms = make(map[string]struct{}) - - commsLock.RLock() - defer commsLock.RUnlock() - - threshold := time.Now().Add(-deleteCommsWithoutLinksThreshhold).Unix() - for _, comm := range comms { - // has links? - _, hasLinks := activeLinks[comm.DatabaseKey()] - - // comm created - comm.lock.Lock() - created := comm.Meta().Created - comm.lock.Unlock() - - if !hasLinks && created < threshold { - log.Tracef("network.clean: deleted %s", comm.DatabaseKey()) - go comm.Delete() - } else { - p := comm.Process() - p.Lock() - markActive(activeComms, p.DatabaseKey()) - p.Unlock() - } - - } - return -} - -func markActive(activeMap map[string]struct{}, key string) { - _, ok := activeMap[key] - if !ok { - activeMap[key] = struct{}{} - } + return activePIDs } diff --git a/network/communication.go b/network/communication.go deleted file mode 100644 index 266e032d..00000000 --- a/network/communication.go +++ /dev/null @@ -1,417 +0,0 @@ -package network - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - "github.com/safing/portmaster/resolver" - - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portmaster/intel" - "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/network/packet" - "github.com/safing/portmaster/process" -) - -// Communication describes a logical connection between a process and a domain. -//nolint:maligned // TODO: fix alignment -type Communication struct { - record.Base - lock sync.Mutex - - Scope string - Entity *intel.Entity - Direction bool - - Verdict Verdict - Reason string - ReasonID string // format source[:id[:id]] - Inspect bool - process *process.Process - profileRevisionCounter uint64 - - FirstLinkEstablished int64 - LastLinkEstablished int64 - - saveWhenFinished bool -} - -// Lock locks the communication and the communication's Entity. -func (comm *Communication) Lock() { - comm.lock.Lock() - comm.Entity.Lock() -} - -// Lock unlocks the communication and the communication's Entity. -func (comm *Communication) Unlock() { - comm.Entity.Unlock() - comm.lock.Unlock() -} - -// Process returns the process that owns the connection. -func (comm *Communication) Process() *process.Process { - comm.lock.Lock() - defer comm.lock.Unlock() - - return comm.process -} - -// ResetVerdict resets the verdict to VerdictUndecided. -func (comm *Communication) ResetVerdict() { - comm.lock.Lock() - defer comm.lock.Unlock() - - comm.Verdict = VerdictUndecided - comm.Reason = "" - comm.saveWhenFinished = true -} - -// GetVerdict returns the current verdict. -func (comm *Communication) GetVerdict() Verdict { - comm.lock.Lock() - defer comm.lock.Unlock() - - return comm.Verdict -} - -// Accept accepts the communication and adds the given reason. -func (comm *Communication) Accept(reason string) { - comm.AddReason(reason) - comm.UpdateVerdict(VerdictAccept) -} - -// Deny blocks or drops the communication depending on the connection direction and adds the given reason. -func (comm *Communication) Deny(reason string) { - if comm.Direction { - comm.Drop(reason) - } else { - comm.Block(reason) - } -} - -// Block blocks the communication and adds the given reason. -func (comm *Communication) Block(reason string) { - comm.AddReason(reason) - comm.UpdateVerdict(VerdictBlock) -} - -// Drop drops the communication and adds the given reason. -func (comm *Communication) Drop(reason string) { - comm.AddReason(reason) - comm.UpdateVerdict(VerdictDrop) -} - -// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts. -func (comm *Communication) UpdateVerdict(newVerdict Verdict) { - comm.lock.Lock() - defer comm.lock.Unlock() - - if newVerdict > comm.Verdict { - comm.Verdict = newVerdict - comm.saveWhenFinished = true - } -} - -// SetReason sets/replaces a human readable string as to why a certain verdict was set in regard to this communication. -func (comm *Communication) SetReason(reason string) { - if reason == "" { - return - } - - comm.lock.Lock() - defer comm.lock.Unlock() - comm.Reason = reason - comm.saveWhenFinished = true -} - -// AddReason adds a human readable string as to why a certain verdict was set in regard to this communication. -func (comm *Communication) AddReason(reason string) { - if reason == "" { - return - } - - comm.lock.Lock() - defer comm.lock.Unlock() - - if comm.Reason != "" { - comm.Reason += " | " - } - comm.Reason += reason -} - -// UpdateAndCheck updates profiles and checks whether a reevaluation is needed. -func (comm *Communication) UpdateAndCheck() (needsReevaluation bool) { - revCnt := comm.Process().Profile().Update() - - comm.lock.Lock() - defer comm.lock.Unlock() - if comm.profileRevisionCounter != revCnt { - comm.profileRevisionCounter = revCnt - needsReevaluation = true - } - - return -} - -// GetCommunicationByFirstPacket returns the matching communication from the internal storage. -func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) { - // get Process - proc, direction, err := process.GetProcessByPacket(pkt) - if err != nil { - return nil, err - } - var scope string - - // Incoming - if direction { - switch netutils.ClassifyIP(pkt.Info().Src) { - case netutils.HostLocal: - scope = IncomingHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = IncomingLAN - case netutils.Global, netutils.GlobalMulticast: - scope = IncomingInternet - case netutils.Invalid: - scope = IncomingInvalid - } - - communication, ok := GetCommunication(proc.Pid, scope) - if !ok { - communication = &Communication{ - Scope: scope, - Entity: (&intel.Entity{}).Init(), - Direction: Inbound, - process: proc, - Inspect: true, - FirstLinkEstablished: time.Now().Unix(), - saveWhenFinished: true, - } - } - communication.process.AddCommunication() - return communication, nil - } - - // get domain - ipinfo, err := resolver.GetIPInfo(pkt.FmtRemoteIP()) - - // PeerToPeer - if err != nil { - // if no domain could be found, it must be a direct connection (ie. no DNS) - - switch netutils.ClassifyIP(pkt.Info().Dst) { - case netutils.HostLocal: - scope = PeerHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = PeerLAN - case netutils.Global, netutils.GlobalMulticast: - scope = PeerInternet - case netutils.Invalid: - scope = PeerInvalid - } - - communication, ok := GetCommunication(proc.Pid, scope) - if !ok { - communication = &Communication{ - Scope: scope, - Entity: (&intel.Entity{}).Init(), - Direction: Outbound, - process: proc, - Inspect: true, - FirstLinkEstablished: time.Now().Unix(), - saveWhenFinished: true, - } - } - communication.process.AddCommunication() - return communication, nil - } - - // To Domain - // FIXME: how to handle multiple possible domains? - communication, ok := GetCommunication(proc.Pid, ipinfo.Domains[0]) - if !ok { - communication = &Communication{ - Scope: ipinfo.Domains[0], - Entity: (&intel.Entity{ - Domain: ipinfo.Domains[0], - }).Init(), - Direction: Outbound, - process: proc, - Inspect: true, - FirstLinkEstablished: time.Now().Unix(), - saveWhenFinished: true, - } - } - communication.process.AddCommunication() - return communication, nil -} - -// var localhost = net.IPv4(127, 0, 0, 1) - -var ( - dnsAddress = net.IPv4(127, 0, 0, 1) - dnsPort uint16 = 53 -) - -// GetCommunicationByDNSRequest returns the matching communication from the internal storage. -func GetCommunicationByDNSRequest(ctx context.Context, ip net.IP, port uint16, fqdn string) (*Communication, error) { - // get Process - proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP) - if err != nil { - return nil, err - } - - communication, ok := GetCommunication(proc.Pid, fqdn) - if !ok { - communication = &Communication{ - Scope: fqdn, - Entity: (&intel.Entity{ - Domain: fqdn, - }).Init(), - process: proc, - Inspect: true, - saveWhenFinished: true, - } - communication.process.AddCommunication() - communication.saveWhenFinished = true - } - return communication, nil -} - -// GetCommunication fetches a connection object from the internal storage. -func GetCommunication(pid int, domain string) (comm *Communication, ok bool) { - commsLock.RLock() - defer commsLock.RUnlock() - comm, ok = comms[fmt.Sprintf("%d/%s", pid, domain)] - return -} - -func (comm *Communication) makeKey() string { - return fmt.Sprintf("%d/%s", comm.process.Pid, comm.Scope) -} - -// SaveWhenFinished marks the Connection for saving after all current actions are finished. -func (comm *Communication) SaveWhenFinished() { - comm.saveWhenFinished = true -} - -// SaveIfNeeded saves the Connection if it is marked for saving when finished. -func (comm *Communication) SaveIfNeeded() { - comm.lock.Lock() - save := comm.saveWhenFinished - if save { - comm.saveWhenFinished = false - } - comm.lock.Unlock() - - if save { - err := comm.save() - if err != nil { - log.Warningf("network: failed to save comm %s: %s", comm, err) - } - } -} - -// Save saves the Connection object in the storage and propagates the change. -func (comm *Communication) save() error { - // update comm - comm.lock.Lock() - if comm.process == nil { - comm.lock.Unlock() - return errors.New("cannot save connection without process") - } - - if !comm.KeyIsSet() { - comm.SetKey(fmt.Sprintf("network:tree/%d/%s", comm.process.Pid, comm.Scope)) - comm.UpdateMeta() - } - if comm.Meta().Deleted > 0 { - log.Criticalf("network: revieving dead comm %s", comm) - comm.Meta().Deleted = 0 - } - key := comm.makeKey() - comm.saveWhenFinished = false - comm.lock.Unlock() - - // save comm - commsLock.RLock() - _, ok := comms[key] - commsLock.RUnlock() - - if !ok { - commsLock.Lock() - comms[key] = comm - commsLock.Unlock() - } - - go dbController.PushUpdate(comm) - return nil -} - -// Delete deletes a connection from the storage and propagates the change. -func (comm *Communication) Delete() { - commsLock.Lock() - defer commsLock.Unlock() - comm.lock.Lock() - defer comm.lock.Unlock() - - delete(comms, comm.makeKey()) - - comm.Meta().Delete() - go dbController.PushUpdate(comm) -} - -// AddLink applies the Communication to the Link and sets timestamps. -func (comm *Communication) AddLink(link *Link) { - comm.lock.Lock() - defer comm.lock.Unlock() - - // apply comm to link - link.lock.Lock() - link.comm = comm - link.Verdict = comm.Verdict - link.Inspect = comm.Inspect - // FIXME: use new copy methods - link.Entity.Domain = comm.Entity.Domain - link.saveWhenFinished = true - link.lock.Unlock() - - // check if we should save - if comm.LastLinkEstablished < time.Now().Add(-3*time.Second).Unix() { - comm.saveWhenFinished = true - } - - // update LastLinkEstablished - comm.LastLinkEstablished = time.Now().Unix() - if comm.FirstLinkEstablished == 0 { - comm.FirstLinkEstablished = comm.LastLinkEstablished - } -} - -// String returns a string representation of Communication. -func (comm *Communication) String() string { - comm.Lock() - defer comm.Unlock() - - switch comm.Scope { - case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: - if comm.process == nil { - return "? <- *" - } - return fmt.Sprintf("%s <- *", comm.process.String()) - case PeerHost, PeerLAN, PeerInternet, PeerInvalid: - if comm.process == nil { - return "? -> *" - } - return fmt.Sprintf("%s -> *", comm.process.String()) - default: - if comm.process == nil { - return fmt.Sprintf("? -> %s", comm.Scope) - } - return fmt.Sprintf("%s -> %s", comm.process.String(), comm.Scope) - } -} diff --git a/network/connection.go b/network/connection.go new file mode 100644 index 00000000..9876193d --- /dev/null +++ b/network/connection.go @@ -0,0 +1,379 @@ +package network + +import ( + "context" + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/intel" + "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/process" + "github.com/safing/portmaster/resolver" +) + +// FirewallHandler defines the function signature for a firewall handle function +type FirewallHandler func(conn *Connection, pkt packet.Packet) + +// Connection describes a distinct physical network connection identified by the IP/Port pair. +type Connection struct { //nolint:maligned // TODO: fix alignment + record.Base + sync.Mutex + + ID string + Scope string + Inbound bool + Entity *intel.Entity // needs locking, instance is never shared + process *process.Process + + Verdict Verdict + Reason string + ReasonID string // format source[:id[:id]] // TODO + + Started int64 + Ended int64 + Tunneled bool + VerdictPermanent bool + Inspecting bool + Encrypted bool // TODO + + pktQueue chan packet.Packet + firewallHandler FirewallHandler + + activeInspectors []bool + inspectorData map[uint8]interface{} + + saveWhenFinished bool + profileRevisionCounter uint64 +} + +// NewConnectionFromDNSRequest +func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, port uint16) *Connection { + // get Process + proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP) + if err != nil { + log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) + proc = process.UnknownProcess + } + + timestamp := time.Now().Unix() + dnsConn := &Connection{ + Scope: fqdn, + Entity: (&intel.Entity{ + Domain: fqdn, + }).Init(), + process: proc, + Started: timestamp, + Ended: timestamp, + } + saveOpenDNSRequest(dnsConn) + return dnsConn +} + +func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { + // get Process + proc, inbound, err := process.GetProcessByPacket(pkt) + if err != nil { + log.Warningf("network: failed to find process of packet %s: %s", pkt, err) + proc = process.UnknownProcess + } + + var scope string + var entity *intel.Entity + + if inbound { + + // inbound connection + switch netutils.ClassifyIP(pkt.Info().Src) { + case netutils.HostLocal: + scope = IncomingHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + scope = IncomingLAN + case netutils.Global, netutils.GlobalMulticast: + scope = IncomingInternet + default: // netutils.Invalid + scope = IncomingInvalid + } + entity = (&intel.Entity{ + IP: pkt.Info().Src, + Protocol: uint8(pkt.Info().Protocol), + Port: pkt.Info().SrcPort, + }).Init() + + } else { + + // outbound connection + entity = (&intel.Entity{ + IP: pkt.Info().Dst, + Protocol: uint8(pkt.Info().Protocol), + Port: pkt.Info().DstPort, + }).Init() + + // check if we can find a domain for that IP + ipinfo, err := resolver.GetIPInfo(pkt.Info().Dst.String()) + if err == nil { + + // outbound to domain + scope = ipinfo.Domains[0] + entity.Domain = scope + removeOpenDNSRequest(proc.Pid, scope) + + } else { + + // outbound direct (possibly P2P) connection + switch netutils.ClassifyIP(pkt.Info().Dst) { + case netutils.HostLocal: + scope = PeerHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + scope = PeerLAN + case netutils.Global, netutils.GlobalMulticast: + scope = PeerInternet + default: // netutils.Invalid + scope = PeerInvalid + } + + } + } + + timestamp := time.Now().Unix() + return &Connection{ + ID: pkt.GetConnectionID(), + Scope: scope, + Entity: entity, + process: proc, + Started: timestamp, + } +} + +// GetConnection fetches a Connection from the database. +func GetConnection(id string) (*Connection, bool) { + connsLock.RLock() + defer connsLock.RUnlock() + + conn, ok := conns[id] + return conn, ok +} + +// Accept accepts the connection. +func (conn *Connection) Accept(reason string) { + if conn.SetVerdict(VerdictAccept) { + conn.Reason = reason + log.Infof("filter: granting connection %s, %s", conn, conn.Reason) + } else { + log.Warningf("filter: tried to accept %s, but current verdict is %s", conn, conn.Verdict) + } +} + +// Block blocks the connection. +func (conn *Connection) Block(reason string) { + if conn.SetVerdict(VerdictBlock) { + conn.Reason = reason + log.Infof("filter: blocking connection %s, %s", conn, conn.Reason) + } else { + log.Warningf("filter: tried to block %s, but current verdict is %s", conn, conn.Verdict) + } +} + +// Drop drops the connection. +func (conn *Connection) Drop(reason string) { + if conn.SetVerdict(VerdictDrop) { + conn.Reason = reason + log.Infof("filter: dropping connection %s, %s", conn, conn.Reason) + } else { + log.Warningf("filter: tried to drop %s, but current verdict is %s", conn, conn.Verdict) + } +} + +// Deny blocks or drops the link depending on the connection direction. +func (conn *Connection) Deny(reason string) { + if conn.Inbound { + conn.Drop(reason) + } else { + conn.Block(reason) + } +} + +// SetVerdict sets a new verdict for the connection, making sure it does not interfere with previous verdicts. +func (conn *Connection) SetVerdict(newVerdict Verdict) (ok bool) { + if newVerdict >= conn.Verdict { + conn.Verdict = newVerdict + return true + } + return false +} + +// Process returns the connection's process. +func (conn *Connection) Process() *process.Process { + return conn.process +} + +// SaveWhenFinished marks the connection for saving it after the firewall handler. +func (conn *Connection) SaveWhenFinished() { + conn.saveWhenFinished = true +} + +// save saves the link object in the storage and propagates the change. +func (conn *Connection) save() { + if conn.ID == "" { + + // dns request + if !conn.KeyIsSet() { + conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Scope)) + conn.UpdateMeta() + } + // save to internal state + // check if it already exists + mapKey := strconv.Itoa(conn.process.Pid) + "/" + conn.Scope + dnsConnsLock.RLock() + _, ok := dnsConns[mapKey] + dnsConnsLock.RUnlock() + if !ok { + dnsConnsLock.Lock() + dnsConns[mapKey] = conn + dnsConnsLock.Unlock() + } + + } else { + + // connection + if !conn.KeyIsSet() { + conn.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", conn.process.Pid, conn.Scope, conn.ID)) + conn.UpdateMeta() + } + // save to internal state + // check if it already exists + connsLock.RLock() + _, ok := conns[conn.ID] + connsLock.RUnlock() + if !ok { + connsLock.Lock() + conns[conn.ID] = conn + connsLock.Unlock() + } + + } + + // notify database controller + dbController.PushUpdate(conn) +} + +// delete deletes a link from the storage and propagates the change. Nothing is locked - both the conns map and the connection itself require locking +func (conn *Connection) delete() { + delete(conns, conn.ID) + + conn.Meta().Delete() + dbController.PushUpdate(conn) +} + +// UpdateAndCheck updates profiles and checks whether a reevaluation is needed. +func (conn *Connection) UpdateAndCheck() (needsReevaluation bool) { + p := conn.process.Profile() + if p == nil { + return false + } + revCnt := p.Update() + + if conn.profileRevisionCounter != revCnt { + conn.profileRevisionCounter = revCnt + needsReevaluation = true + } + return +} + +// SetFirewallHandler sets the firewall handler for this link, and starts a worker to handle the packets. +func (conn *Connection) SetFirewallHandler(handler FirewallHandler) { + if conn.firewallHandler == nil { + conn.pktQueue = make(chan packet.Packet, 1000) + + // start handling + module.StartWorker("packet handler", func(ctx context.Context) error { + conn.packetHandler() + return nil + }) + } + conn.firewallHandler = handler +} + +// StopFirewallHandler unsets the firewall handler and stops the handler worker. +func (conn *Connection) StopFirewallHandler() { + conn.firewallHandler = nil + conn.pktQueue <- nil +} + +// HandlePacket queues packet of Link for handling +func (conn *Connection) HandlePacket(pkt packet.Packet) { + conn.Lock() + defer conn.Unlock() + + // execute handler or verdict + if conn.firewallHandler != nil { + conn.pktQueue <- pkt + // TODO: drop if overflowing? + } else { + defaultFirewallHandler(conn, pkt) + } +} + +// packetHandler sequentially handles queued packets +func (conn *Connection) packetHandler() { + for { + pkt := <-conn.pktQueue + if pkt == nil { + return + } + // get handler + conn.Lock() + // execute handler or verdict + if conn.firewallHandler != nil { + conn.firewallHandler(conn, pkt) + } else { + defaultFirewallHandler(conn, pkt) + } + conn.Unlock() + // save does not touch any changing data + // must not be locked, will deadlock with cleaner functions + if conn.saveWhenFinished { + conn.saveWhenFinished = false + conn.save() + } + // submit trace logs + log.Tracer(pkt.Ctx()).Submit() + } +} + +// GetActiveInspectors returns the list of active inspectors. +func (conn *Connection) GetActiveInspectors() []bool { + return conn.activeInspectors +} + +// SetActiveInspectors sets the list of active inspectors. +func (conn *Connection) SetActiveInspectors(new []bool) { + conn.activeInspectors = new +} + +// GetInspectorData returns the list of inspector data. +func (conn *Connection) GetInspectorData() map[uint8]interface{} { + return conn.inspectorData +} + +// SetInspectorData set the list of inspector data. +func (conn *Connection) SetInspectorData(new map[uint8]interface{}) { + conn.inspectorData = new +} + +// String returns a string representation of conn. +func (conn *Connection) String() string { + switch conn.Scope { + case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: + return fmt.Sprintf("%s <- %s", conn.process, conn.Entity.IP) + case PeerHost, PeerLAN, PeerInternet, PeerInvalid: + return fmt.Sprintf("%s -> %s", conn.process, conn.Entity.IP) + default: + return fmt.Sprintf("%s to %s (%s)", conn.process, conn.Entity.Domain, conn.Entity.IP) + } +} diff --git a/network/database.go b/network/database.go index 9c31200b..d7dca398 100644 --- a/network/database.go +++ b/network/database.go @@ -1,7 +1,6 @@ package network import ( - "fmt" "strconv" "strings" "sync" @@ -15,10 +14,10 @@ import ( ) var ( - links = make(map[string]*Link) // key: Link ID - linksLock sync.RWMutex - comms = make(map[string]*Communication) // key: PID/Domain - commsLock sync.RWMutex + dnsConns = make(map[string]*Connection) // key: /Scope + dnsConnsLock sync.RWMutex + conns = make(map[string]*Connection) // key: Connection ID + connsLock sync.RWMutex dbController *database.Controller ) @@ -44,18 +43,18 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { } } case 3: - commsLock.RLock() - defer commsLock.RUnlock() - conn, ok := comms[fmt.Sprintf("%s/%s", splitted[1], splitted[2])] + dnsConnsLock.RLock() + defer dnsConnsLock.RUnlock() + conn, ok := dnsConns[splitted[1]+"/"+splitted[2]] if ok { return conn, nil } case 4: - linksLock.RLock() - defer linksLock.RUnlock() - link, ok := links[splitted[3]] + connsLock.RLock() + defer connsLock.RUnlock() + conn, ok := conns[splitted[3]] if ok { - return link, nil + return conn, nil } } } @@ -85,25 +84,25 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { } if slashes <= 2 { - // comms - commsLock.RLock() - for _, conn := range comms { + // dns scopes only + dnsConnsLock.RLock() + for _, dnsConns := range dnsConns { + if strings.HasPrefix(dnsConns.DatabaseKey(), q.DatabaseKeyPrefix()) { + it.Next <- dnsConns + } + } + dnsConnsLock.RUnlock() + } + + if slashes <= 3 { + // connections + connsLock.RLock() + for _, conn := range conns { if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { it.Next <- conn } } - commsLock.RUnlock() - } - - if slashes <= 3 { - // links - linksLock.RLock() - for _, link := range links { - if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) { - it.Next <- link - } - } - linksLock.RUnlock() + connsLock.RUnlock() } it.Finish(nil) diff --git a/network/dns.go b/network/dns.go new file mode 100644 index 00000000..2b0f2301 --- /dev/null +++ b/network/dns.go @@ -0,0 +1,73 @@ +package network + +import ( + "context" + "strconv" + "sync" + "time" +) + +var ( + openDNSRequests = make(map[string]*Connection) // key: /fqdn + openDNSRequestsLock sync.Mutex + + // write open dns requests every + writeOpenDNSRequestsTickDuration = 5 * time.Second + + // duration after which DNS requests without a following connection are logged + openDNSRequestLimit = 3 * time.Second +) + +func removeOpenDNSRequest(pid int, fqdn string) { + openDNSRequestsLock.Lock() + defer openDNSRequestsLock.Unlock() + + key := strconv.Itoa(pid) + "/" + fqdn + delete(openDNSRequests, key) +} + +func saveOpenDNSRequest(conn *Connection) { + openDNSRequestsLock.Lock() + defer openDNSRequestsLock.Unlock() + + key := strconv.Itoa(conn.process.Pid) + "/" + conn.Scope + + existingConn, ok := openDNSRequests[key] + if ok { + existingConn.Lock() + defer existingConn.Unlock() + + existingConn.Ended = conn.Started + } else { + openDNSRequests[key] = conn + } +} + +func openDNSRequestWriter(ctx context.Context) error { + ticker := time.NewTicker(writeOpenDNSRequestsTickDuration) + + for { + select { + case <-ctx.Done(): + ticker.Stop() + return nil + case <-ticker.C: + writeOpenDNSRequestsToDB() + } + } +} + +func writeOpenDNSRequestsToDB() { + openDNSRequestsLock.Lock() + defer openDNSRequestsLock.Unlock() + + threshold := time.Now().Add(-openDNSRequestLimit).Unix() + for id, conn := range openDNSRequests { + conn.Lock() + if conn.Ended < threshold { + conn.save() + delete(openDNSRequests, id) + } + conn.Unlock() + } +} diff --git a/network/link.go b/network/link.go deleted file mode 100644 index 94784def..00000000 --- a/network/link.go +++ /dev/null @@ -1,428 +0,0 @@ -package network - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/safing/portmaster/intel" - - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portmaster/network/packet" -) - -// FirewallHandler defines the function signature for a firewall handle function -type FirewallHandler func(pkt packet.Packet, link *Link) - -// Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection. -type Link struct { //nolint:maligned // TODO: fix alignment - record.Base - lock sync.Mutex - - ID string - Entity *intel.Entity - Direction bool - - Verdict Verdict - Reason string - ReasonID string // format source[:id[:id]] - Tunneled bool - VerdictPermanent bool - Inspect bool - Started int64 - Ended int64 - - pktQueue chan packet.Packet - firewallHandler FirewallHandler - comm *Communication - - activeInspectors []bool - inspectorData map[uint8]interface{} - saveWhenFinished bool -} - -// Lock locks the link and the link's Entity. -func (link *Link) Lock() { - link.lock.Lock() - link.Entity.Lock() -} - -// Lock unlocks the link and the link's Entity. -func (link *Link) Unlock() { - link.Entity.Unlock() - link.lock.Unlock() -} - -// Communication returns the Communication the Link is part of -func (link *Link) Communication() *Communication { - link.lock.Lock() - defer link.lock.Unlock() - - return link.comm -} - -// GetVerdict returns the current verdict. -func (link *Link) GetVerdict() Verdict { - link.lock.Lock() - defer link.lock.Unlock() - - return link.Verdict -} - -// FirewallHandlerIsSet returns whether a firewall handler is set or not -func (link *Link) FirewallHandlerIsSet() bool { - link.lock.Lock() - defer link.lock.Unlock() - - return link.firewallHandler != nil -} - -// SetFirewallHandler sets the firewall handler for this link -func (link *Link) SetFirewallHandler(handler FirewallHandler) { - link.lock.Lock() - defer link.lock.Unlock() - - if link.firewallHandler == nil { - link.pktQueue = make(chan packet.Packet, 1000) - - // start handling - module.StartWorker("packet handler", func(ctx context.Context) error { - link.packetHandler() - return nil - }) - } - link.firewallHandler = handler -} - -// StopFirewallHandler unsets the firewall handler -func (link *Link) StopFirewallHandler() { - link.lock.Lock() - link.firewallHandler = nil - link.lock.Unlock() - link.pktQueue <- nil -} - -// HandlePacket queues packet of Link for handling -func (link *Link) HandlePacket(pkt packet.Packet) { - // get handler - link.lock.Lock() - handler := link.firewallHandler - link.lock.Unlock() - - // send to queue - if handler != nil { - link.pktQueue <- pkt - return - } - - // no handler! - log.Warningf("network: link %s does not have a firewallHandler, dropping packet", link) - err := pkt.Drop() - if err != nil { - log.Warningf("network: failed to drop packet %s: %s", pkt, err) - } -} - -// Accept accepts the link and adds the given reason. -func (link *Link) Accept(reason string) { - link.AddReason(reason) - link.UpdateVerdict(VerdictAccept) -} - -// Deny blocks or drops the link depending on the connection direction and adds the given reason. -func (link *Link) Deny(reason string) { - if link.Direction { - link.Drop(reason) - } else { - link.Block(reason) - } -} - -// Block blocks the link and adds the given reason. -func (link *Link) Block(reason string) { - link.AddReason(reason) - link.UpdateVerdict(VerdictBlock) -} - -// Drop drops the link and adds the given reason. -func (link *Link) Drop(reason string) { - link.AddReason(reason) - link.UpdateVerdict(VerdictDrop) -} - -// RerouteToNameserver reroutes the link to the portmaster nameserver. -func (link *Link) RerouteToNameserver() { - link.UpdateVerdict(VerdictRerouteToNameserver) -} - -// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection. -func (link *Link) RerouteToTunnel(reason string) { - link.AddReason(reason) - link.UpdateVerdict(VerdictRerouteToTunnel) -} - -// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts -func (link *Link) UpdateVerdict(newVerdict Verdict) { - link.lock.Lock() - defer link.lock.Unlock() - - if newVerdict > link.Verdict { - link.Verdict = newVerdict - link.saveWhenFinished = true - } -} - -// AddReason adds a human readable string as to why a certain verdict was set in regard to this link -func (link *Link) AddReason(reason string) { - if reason == "" { - return - } - - link.lock.Lock() - defer link.lock.Unlock() - - if link.Reason != "" { - link.Reason += " | " - } - link.Reason += reason - - link.saveWhenFinished = true -} - -// packetHandler sequentially handles queued packets -func (link *Link) packetHandler() { - for { - pkt := <-link.pktQueue - if pkt == nil { - return - } - // get handler - link.lock.Lock() - handler := link.firewallHandler - link.lock.Unlock() - // execute handler or verdict - if handler != nil { - handler(pkt, link) - } else { - link.ApplyVerdict(pkt) - } - // submit trace logs - log.Tracer(pkt.Ctx()).Submit() - } -} - -// ApplyVerdict appies the link verdict to a packet. -func (link *Link) ApplyVerdict(pkt packet.Packet) { - link.lock.Lock() - defer link.lock.Unlock() - - var err error - - if link.VerdictPermanent { - switch link.Verdict { - case VerdictAccept: - err = pkt.PermanentAccept() - case VerdictBlock: - err = pkt.PermanentBlock() - case VerdictDrop: - err = pkt.PermanentDrop() - case VerdictRerouteToNameserver: - err = pkt.RerouteToNameserver() - case VerdictRerouteToTunnel: - err = pkt.RerouteToTunnel() - default: - err = pkt.Drop() - } - } else { - switch link.Verdict { - case VerdictAccept: - err = pkt.Accept() - case VerdictBlock: - err = pkt.Block() - case VerdictDrop: - err = pkt.Drop() - case VerdictRerouteToNameserver: - err = pkt.RerouteToNameserver() - case VerdictRerouteToTunnel: - err = pkt.RerouteToTunnel() - default: - err = pkt.Drop() - } - } - - if err != nil { - log.Warningf("network: failed to apply link verdict to packet %s: %s", pkt, err) - } -} - -// SaveWhenFinished marks the Link for saving after all current actions are finished. -func (link *Link) SaveWhenFinished() { - // FIXME: check if we should lock here - link.saveWhenFinished = true -} - -// SaveIfNeeded saves the Link if it is marked for saving when finished. -func (link *Link) SaveIfNeeded() { - link.lock.Lock() - save := link.saveWhenFinished - if save { - link.saveWhenFinished = false - } - link.lock.Unlock() - - if save { - link.saveAndLog() - } -} - -// saveAndLog saves the link object in the storage and propagates the change. It does not return an error, but logs it. -func (link *Link) saveAndLog() { - err := link.save() - if err != nil { - log.Warningf("network: failed to save link %s: %s", link, err) - } -} - -// save saves the link object in the storage and propagates the change. -func (link *Link) save() error { - // update link - link.lock.Lock() - if link.comm == nil { - link.lock.Unlock() - return errors.New("cannot save link without comms") - } - - if !link.KeyIsSet() { - link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.comm.Process().Pid, link.comm.Scope, link.ID)) - link.UpdateMeta() - } - link.saveWhenFinished = false - link.lock.Unlock() - - // save link - linksLock.RLock() - _, ok := links[link.ID] - linksLock.RUnlock() - - if !ok { - linksLock.Lock() - links[link.ID] = link - linksLock.Unlock() - } - - go dbController.PushUpdate(link) - return nil -} - -// Delete deletes a link from the storage and propagates the change. -func (link *Link) Delete() { - linksLock.Lock() - defer linksLock.Unlock() - link.lock.Lock() - defer link.lock.Unlock() - - delete(links, link.ID) - - link.Meta().Delete() - go dbController.PushUpdate(link) -} - -// GetLink fetches a Link from the database from the default namespace for this object -func GetLink(id string) (*Link, bool) { - linksLock.RLock() - defer linksLock.RUnlock() - - link, ok := links[id] - return link, ok -} - -// GetOrCreateLinkByPacket returns the associated Link for a packet and a bool expressing if the Link was newly created -func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) { - link, ok := GetLink(pkt.GetLinkID()) - if ok { - log.Tracer(pkt.Ctx()).Tracef("network: assigned to link %s", link.ID) - return link, false - } - link = CreateLinkFromPacket(pkt) - log.Tracer(pkt.Ctx()).Tracef("network: created new link %s", link.ID) - return link, true -} - -// CreateLinkFromPacket creates a new Link based on Packet. -func CreateLinkFromPacket(pkt packet.Packet) *Link { - link := &Link{ - ID: pkt.GetLinkID(), - Entity: (&intel.Entity{ - IP: pkt.Info().RemoteIP(), - Protocol: uint8(pkt.Info().Protocol), - Port: pkt.Info().RemotePort(), - }).Init(), - Direction: pkt.IsInbound(), - Verdict: VerdictUndecided, - Started: time.Now().Unix(), - saveWhenFinished: true, - } - return link -} - -// GetActiveInspectors returns the list of active inspectors. -func (link *Link) GetActiveInspectors() []bool { - link.lock.Lock() - defer link.lock.Unlock() - - return link.activeInspectors -} - -// SetActiveInspectors sets the list of active inspectors. -func (link *Link) SetActiveInspectors(new []bool) { - link.lock.Lock() - defer link.lock.Unlock() - - link.activeInspectors = new -} - -// GetInspectorData returns the list of inspector data. -func (link *Link) GetInspectorData() map[uint8]interface{} { - link.lock.Lock() - defer link.lock.Unlock() - - return link.inspectorData -} - -// SetInspectorData set the list of inspector data. -func (link *Link) SetInspectorData(new map[uint8]interface{}) { - link.lock.Lock() - defer link.lock.Unlock() - - link.inspectorData = new -} - -// String returns a string representation of Link. -func (link *Link) String() string { - link.lock.Lock() - defer link.lock.Unlock() - - if link.comm == nil { - return fmt.Sprintf("? <-> %s", link.Entity.IP.String()) - } - switch link.comm.Scope { - case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: - if link.comm.process == nil { - return fmt.Sprintf("? <- %s", link.Entity.IP.String()) - } - return fmt.Sprintf("%s <- %s", link.comm.process.String(), link.Entity.IP.String()) - case PeerHost, PeerLAN, PeerInternet, PeerInvalid: - if link.comm.process == nil { - return fmt.Sprintf("? -> %s", link.Entity.IP.String()) - } - return fmt.Sprintf("%s -> %s", link.comm.process.String(), link.Entity.IP.String()) - default: - if link.comm.process == nil { - return fmt.Sprintf("? -> %s (%s)", link.comm.Scope, link.Entity.IP.String()) - } - return fmt.Sprintf("%s to %s (%s)", link.comm.process.String(), link.comm.Scope, link.Entity.IP.String()) - } -} diff --git a/network/module.go b/network/module.go index ed905d91..8b2c8309 100644 --- a/network/module.go +++ b/network/module.go @@ -1,24 +1,39 @@ package network import ( + "net" + "github.com/safing/portbase/modules" ) var ( module *modules.Module + + dnsAddress = net.IPv4(127, 0, 0, 1) + dnsPort uint16 = 53 + + defaultFirewallHandler FirewallHandler ) func init() { module = modules.Register("network", nil, start, nil, "core", "processes") } +// SetDefaultFirewallHandler sets the default firewall handler. +func SetDefaultFirewallHandler(handler FirewallHandler) { + if defaultFirewallHandler == nil { + defaultFirewallHandler = handler + } +} + func start() error { err := registerAsDatabase() if err != nil { return err } - go cleaner() + module.StartServiceWorker("clean connections", 0, connectionCleaner) + module.StartServiceWorker("write open dns requests", 0, openDNSRequestWriter) return nil } diff --git a/network/packet/packet.go b/network/packet/packet.go index 69bf590c..942dd215 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -10,7 +10,7 @@ import ( type Base struct { ctx context.Context info Info - linkID string + connID string Payload []byte } @@ -70,26 +70,26 @@ func (pkt *Base) GetPayload() ([]byte, error) { return pkt.Payload, ErrFailedToLoadPayload } -// GetLinkID returns the link ID for this packet. -func (pkt *Base) GetLinkID() string { - if pkt.linkID == "" { - pkt.createLinkID() +// GetConnectionID returns the link ID for this packet. +func (pkt *Base) GetConnectionID() string { + if pkt.connID == "" { + pkt.createConnectionID() } - return pkt.linkID + return pkt.connID } -func (pkt *Base) createLinkID() { +func (pkt *Base) createConnectionID() { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { if pkt.info.Direction { - pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) + pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } else { - pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) + pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } } else { if pkt.info.Direction { - pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) + pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } else { - pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) + pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) } } } @@ -215,7 +215,7 @@ type Packet interface { SetOutbound() HasPorts() bool GetPayload() ([]byte, error) - GetLinkID() string + GetConnectionID() string // MATCHING MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool diff --git a/network/self.go b/network/self.go deleted file mode 100644 index b57bd981..00000000 --- a/network/self.go +++ /dev/null @@ -1,79 +0,0 @@ -package network - -import ( - "fmt" - "os" - "time" - - "github.com/safing/portmaster/intel" - "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/network/packet" - "github.com/safing/portmaster/process" -) - -// GetOwnComm returns the communication for the given packet, that originates from the Portmaster itself. -func GetOwnComm(pkt packet.Packet) (*Communication, error) { - var scope string - - // Incoming - if pkt.IsInbound() { - switch netutils.ClassifyIP(pkt.Info().RemoteIP()) { - case netutils.HostLocal: - scope = IncomingHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = IncomingLAN - case netutils.Global, netutils.GlobalMulticast: - scope = IncomingInternet - case netutils.Invalid: - scope = IncomingInvalid - } - - communication, ok := GetCommunication(os.Getpid(), scope) - if !ok { - proc, err := process.GetOrFindProcess(pkt.Ctx(), os.Getpid()) - if err != nil { - return nil, fmt.Errorf("could not get own process") - } - communication = &Communication{ - Scope: scope, - Entity: (&intel.Entity{}).Init(), - Direction: Inbound, - process: proc, - Inspect: true, - FirstLinkEstablished: time.Now().Unix(), - } - } - communication.process.AddCommunication() - return communication, nil - } - - // PeerToPeer - switch netutils.ClassifyIP(pkt.Info().RemoteIP()) { - case netutils.HostLocal: - scope = PeerHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - scope = PeerLAN - case netutils.Global, netutils.GlobalMulticast: - scope = PeerInternet - case netutils.Invalid: - scope = PeerInvalid - } - - communication, ok := GetCommunication(os.Getpid(), scope) - if !ok { - proc, err := process.GetOrFindProcess(pkt.Ctx(), os.Getpid()) - if err != nil { - return nil, fmt.Errorf("could not get own process") - } - communication = &Communication{ - Scope: scope, - Entity: (&intel.Entity{}).Init(), - Direction: Outbound, - process: proc, - Inspect: true, - FirstLinkEstablished: time.Now().Unix(), - } - } - communication.process.AddCommunication() - return communication, nil -} diff --git a/network/unknown.go b/network/unknown.go deleted file mode 100644 index 7277b40d..00000000 --- a/network/unknown.go +++ /dev/null @@ -1,66 +0,0 @@ -package network - -import ( - "time" - - "github.com/safing/portmaster/intel" - "github.com/safing/portmaster/network/netutils" - "github.com/safing/portmaster/network/packet" - "github.com/safing/portmaster/process" -) - -// Static reasons -const ( - ReasonUnknownProcess = "unknown connection owner: process could not be found" -) - -// GetUnknownCommunication returns the connection to a packet of unknown owner. -func GetUnknownCommunication(pkt packet.Packet) (*Communication, error) { - if pkt.IsInbound() { - switch netutils.ClassifyIP(pkt.Info().Src) { - case netutils.HostLocal: - return getOrCreateUnknownCommunication(pkt, IncomingHost) - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - return getOrCreateUnknownCommunication(pkt, IncomingLAN) - case netutils.Global, netutils.GlobalMulticast: - return getOrCreateUnknownCommunication(pkt, IncomingInternet) - case netutils.Invalid: - return getOrCreateUnknownCommunication(pkt, IncomingInvalid) - } - } - - switch netutils.ClassifyIP(pkt.Info().Dst) { - case netutils.HostLocal: - return getOrCreateUnknownCommunication(pkt, PeerHost) - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - return getOrCreateUnknownCommunication(pkt, PeerLAN) - case netutils.Global, netutils.GlobalMulticast: - return getOrCreateUnknownCommunication(pkt, PeerInternet) - case netutils.Invalid: - return getOrCreateUnknownCommunication(pkt, PeerInvalid) - } - - // this should never happen - return getOrCreateUnknownCommunication(pkt, PeerInvalid) -} - -func getOrCreateUnknownCommunication(pkt packet.Packet, connScope string) (*Communication, error) { - connection, ok := GetCommunication(process.UnknownProcess.Pid, connScope) - if !ok { - connection = &Communication{ - Scope: connScope, - Entity: (&intel.Entity{}).Init(), - Direction: pkt.IsInbound(), - Verdict: VerdictDrop, - Reason: ReasonUnknownProcess, - process: process.UnknownProcess, - Inspect: false, - FirstLinkEstablished: time.Now().Unix(), - } - if pkt.IsOutbound() { - connection.Verdict = VerdictBlock - } - } - connection.process.AddCommunication() - return connection, nil -} diff --git a/process/database.go b/process/database.go index e7240af2..b4ce09b3 100644 --- a/process/database.go +++ b/process/database.go @@ -23,7 +23,7 @@ var ( dbController *database.Controller dbControllerFlag = abool.NewBool(false) - deleteProcessesThreshold = 15 * time.Minute + deleteProcessesThreshold = 7 * time.Minute ) // GetProcessFromStorage returns a process from the internal storage. @@ -68,7 +68,7 @@ func (p *Process) Save() { processesLock.Unlock() } - if dbControllerFlag.IsSet() && p.Error == "" { + if dbControllerFlag.IsSet() { go dbController.PushUpdate(p) } } @@ -93,90 +93,49 @@ func (p *Process) Delete() { } // CleanProcessStorage cleans the storage from old processes. -func CleanProcessStorage(activeComms map[string]struct{}) { - activePIDs, err := getActivePIDs() +func CleanProcessStorage(activePIDs map[int]struct{}) { + // add system table of processes + procs, err := processInfo.Processes() if err != nil { log.Warningf("process: failed to get list of active PIDs: %s", err) - activePIDs = nil + } else { + for _, p := range procs { + activePIDs[int(p.Pid)] = struct{}{} + } } - processesCopy := All() + processesCopy := All() threshold := time.Now().Add(-deleteProcessesThreshold).Unix() - delete := false // clean primary processes for _, p := range processesCopy { p.Lock() - // check if internal - if p.Pid <= 0 { - p.Unlock() - continue - } - - // has comms? - _, hasComms := activeComms[p.DatabaseKey()] - - // virtual / active - virtual := p.Virtual - active := false - if activePIDs != nil { - _, active = activePIDs[p.Pid] - } - p.Unlock() - - if !virtual && !hasComms && !active && p.LastCommEstablished < threshold { - go p.Delete() - } - } - - // clean virtual/failed processes - for _, p := range processesCopy { - p.Lock() - // check if internal - if p.Pid <= 0 { - p.Unlock() - continue - } + _, active := activePIDs[p.Pid] switch { - case p.Error != "": - if p.Meta().Created < threshold { - delete = true - } - case p.Virtual: - _, parentIsActive := processesCopy[p.ParentPid] - active := true - if activePIDs != nil { - _, active = activePIDs[p.Pid] - } - if !parentIsActive || !active { - delete = true + case p.Pid <= 0: + // internal + case active: + // process in system process table or recently seen on the network + default: + // delete now or soon + switch { + case p.LastSeen == 0: + // add last + p.LastSeen = time.Now().Unix() + case p.LastSeen > threshold: + // within keep period + default: + // delete now + log.Tracef("process.clean: deleted %s", p.DatabaseKey()) + go p.Delete() } } + p.Unlock() - - if delete { - log.Tracef("process.clean: deleted %s", p.DatabaseKey()) - go p.Delete() - delete = false - } } } -func getActivePIDs() (map[int]struct{}, error) { - procs, err := processInfo.Processes() - if err != nil { - return nil, err - } - - activePIDs := make(map[int]struct{}) - for _, p := range procs { - activePIDs[int(p.Pid)] = struct{}{} - } - - return activePIDs, nil -} - // SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database. func SetDBController(controller *database.Controller) { dbController = controller diff --git a/process/process.go b/process/process.go index 608c39ad..2ab29383 100644 --- a/process/process.go +++ b/process/process.go @@ -46,11 +46,10 @@ type Process struct { Icon string // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache. - FirstCommEstablished int64 - LastCommEstablished int64 + FirstSeen int64 + LastSeen int64 - Virtual bool // This process is either merged into another process or is not needed. - Error string // If this is set, the process is invalid. This is used to cache failing or inexistent processes. + Virtual bool // This process is either merged into another process or is not needed. } // Profile returns the assigned layered profile. @@ -63,67 +62,15 @@ func (p *Process) Profile() *profile.LayeredProfile { // Strings returns a string representation of process. func (p *Process) String() string { - p.Lock() - defer p.Unlock() - if p == nil { return "?" } - return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid) -} -// AddCommunication increases the connection counter and the last connection timestamp. -func (p *Process) AddCommunication() { p.Lock() defer p.Unlock() - - // check if we should save - save := false - if p.LastCommEstablished == 0 || p.LastCommEstablished < time.Now().Add(-3*time.Second).Unix() { - save = true - } - - // update LastCommEstablished - p.LastCommEstablished = time.Now().Unix() - if p.FirstCommEstablished == 0 { - p.FirstCommEstablished = p.LastCommEstablished - } - - if save { - go p.Save() - } + return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid) } -// var db = database.NewInterface(nil) - -// CountConnections returns the count of connections of a process -// func (p *Process) CountConnections() int { -// q, err := query.New(fmt.Sprintf("%s/%d/", processDatabaseNamespace, p.Pid)). -// Where(query.Where("Pid", query.Exists, nil)). -// Check() -// if err != nil { -// log.Warningf("process: failed to build query to get connection count of process: %s", err) -// return -1 -// } -// -// it, err := db.Query(q) -// if err != nil { -// log.Warningf("process: failed to query db to get connection count of process: %s", err) -// return -1 -// } -// -// cnt := 0 -// for _ = range it.Next { -// cnt++ -// } -// if it.Err() != nil { -// log.Warningf("process: failed to query db to get connection count of process: %s", err) -// return -1 -// } -// -// return cnt -// } - // GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid) @@ -139,9 +86,6 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { if err != nil { return nil, err } - if process.Error != "" { - return nil, fmt.Errorf("%s [cached error]", process.Error) - } for { if process.ParentPid == 0 { @@ -152,10 +96,6 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err) return process, nil } - if parentProcess.Error != "" { - log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s [cached error]", process.Pid, process.ParentPid, parentProcess.Error) - return process, nil - } // if parent process path does not match, we have reached the top of the tree of matching processes if process.Path != parentProcess.Path { @@ -192,9 +132,6 @@ func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { if err != nil { return nil, err } - if p.Error != "" { - return nil, fmt.Errorf("%s [cached error]", p.Error) - } // mark for use, save to storage p.Lock() @@ -275,8 +212,9 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // create new process new := &Process{ - Pid: pid, - Virtual: true, // caller must decide to actually use the process - we need to save now. + Pid: pid, + Virtual: true, // caller must decide to actually use the process - we need to save now. + FirstSeen: time.Now().Unix(), } switch { @@ -302,7 +240,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { var uids []int32 uids, err = pInfo.Uids() if err != nil { - return failedToLoad(new, fmt.Errorf("failed to get UID for p%d: %s", pid, err)) + return nil, fmt.Errorf("failed to get UID for p%d: %s", pid, err) } new.UserID = int(uids[0]) } @@ -310,7 +248,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Username new.UserName, err = pInfo.Username() if err != nil { - return failedToLoad(new, fmt.Errorf("process: failed to get Username for p%d: %s", pid, err)) + return nil, fmt.Errorf("process: failed to get Username for p%d: %s", pid, err) } // TODO: User Home @@ -319,14 +257,14 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // PPID ppid, err := pInfo.Ppid() if err != nil { - return failedToLoad(new, fmt.Errorf("failed to get PPID for p%d: %s", pid, err)) + return nil, fmt.Errorf("failed to get PPID for p%d: %s", pid, err) } new.ParentPid = int(ppid) // Path new.Path, err = pInfo.Exe() if err != nil { - return failedToLoad(new, fmt.Errorf("failed to get Path for p%d: %s", pid, err)) + return nil, fmt.Errorf("failed to get Path for p%d: %s", pid, err) } // Executable Name _, new.ExecName = filepath.Split(new.Path) @@ -341,13 +279,13 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { // Command line arguments new.CmdLine, err = pInfo.Cmdline() if err != nil { - return failedToLoad(new, fmt.Errorf("failed to get Cmdline for p%d: %s", pid, err)) + return nil, fmt.Errorf("failed to get Cmdline for p%d: %s", pid, err) } // Name new.Name, err = pInfo.Name() if err != nil { - return failedToLoad(new, fmt.Errorf("failed to get Name for p%d: %s", pid, err)) + return nil, fmt.Errorf("failed to get Name for p%d: %s", pid, err) } if new.Name == "" { new.Name = new.ExecName @@ -436,9 +374,3 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { new.Save() return new, nil } - -func failedToLoad(p *Process, err error) (*Process, error) { - p.Error = err.Error() - p.Save() - return nil, err -} diff --git a/profile/profile-layered.go b/profile/profile-layered.go index c306d7fe..11c9b004 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -174,14 +174,12 @@ func (lp *LayeredProfile) SecurityLevel() uint8 { func (lp *LayeredProfile) DefaultAction() uint8 { for _, layer := range lp.layers { if layer.defaultAction > 0 { - log.Tracef("profile: default action by layer = %d", layer.defaultAction) return layer.defaultAction } } cfgLock.RLock() defer cfgLock.RUnlock() - log.Tracef("profile: default action from global = %d", cfgDefaultAction) return cfgDefaultAction }