diff --git a/core/core.go b/core/core.go index f15a2709..d70569f0 100644 --- a/core/core.go +++ b/core/core.go @@ -18,7 +18,7 @@ var ( ) func init() { - modules.Register("base", nil, registerDatabases, nil, "database", "config", "random") + modules.Register("base", nil, registerDatabases, nil, "database", "config", "rng") module = modules.Register("core", nil, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui") subsystems.Register( diff --git a/firewall/firewall.go b/firewall/firewall.go index 530f7437..923e87f2 100644 --- a/firewall/firewall.go +++ b/firewall/firewall.go @@ -233,6 +233,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { if ps.isMe { // approve conn.Accept("internally approved") + conn.Internal = true // finish conn.StopFirewallHandler() issueVerdict(conn, pkt, 0, true) diff --git a/firewall/master.go b/firewall/master.go index f09ad644..86196e36 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -50,6 +50,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: if conn.Process().Pid == os.Getpid() { log.Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept + conn.Internal = true return } @@ -75,6 +76,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { //nolint: log.Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") + conn.Internal = true return } } diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 7eee0c4d..8f2dbd65 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -199,6 +199,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } }() + // TODO: this has been obsoleted due to special profiles if conn.Process().Profile() == nil { tracer.Infof("nameserver: failed to find process for request %s, returning NXDOMAIN", conn) returnNXDomain(w, query) diff --git a/network/clean.go b/network/clean.go index 2c8cff02..ec51b611 100644 --- a/network/clean.go +++ b/network/clean.go @@ -57,8 +57,7 @@ func cleanConnections() (activePIDs map[int]struct{}) { // Step 2: mark end activePIDs[conn.process.Pid] = struct{}{} conn.Ended = now - // "save" - dbController.PushUpdate(conn) + conn.Save() } case conn.Ended < deleteOlderThan: // Step 3: delete diff --git a/network/connection.go b/network/connection.go index d2ada322..b9bef333 100644 --- a/network/connection.go +++ b/network/connection.go @@ -41,6 +41,7 @@ type Connection struct { //nolint:maligned // TODO: fix alignment VerdictPermanent bool Inspecting bool Encrypted bool // TODO + Internal bool // Portmaster internal connections are marked in order to easily filter these out in the UI pktQueue chan packet.Packet firewallHandler FirewallHandler @@ -58,7 +59,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, ip net.IP, po proc, err := process.GetProcessByEndpoints(ctx, ip, port, dnsAddress, dnsPort, packet.UDP) if err != nil { log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) - proc = process.UnknownProcess + proc = process.GetUnidentifiedProcess(ctx) } timestamp := time.Now().Unix() @@ -80,7 +81,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { proc, inbound, err := process.GetProcessByPacket(pkt) if err != nil { log.Warningf("network: failed to find process of packet %s: %s", pkt, err) - proc = process.UnknownProcess + proc = process.GetUnidentifiedProcess(pkt.Ctx()) } var scope string @@ -229,39 +230,31 @@ func (conn *Connection) SaveWhenFinished() { // Save saves the connection in the storage and propagates the change through the database system. func (conn *Connection) Save() { - if conn.ID == "" { + conn.UpdateMeta() - // dns request - if !conn.KeyIsSet() { + if !conn.KeyIsSet() { + if conn.ID == "" { + // dns request + + // set key conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Scope)) - conn.UpdateMeta() - } - // save to internal state - // check if it already exists - mapKey := strconv.Itoa(conn.process.Pid) + "/" + conn.Scope - dnsConnsLock.Lock() - _, ok := dnsConns[mapKey] - if !ok { + mapKey := strconv.Itoa(conn.process.Pid) + "/" + conn.Scope + + // save + dnsConnsLock.Lock() dnsConns[mapKey] = conn - } - dnsConnsLock.Unlock() + dnsConnsLock.Unlock() + } else { + // network connection - } else { - - // connection - if !conn.KeyIsSet() { + // set key conn.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", conn.process.Pid, conn.Scope, conn.ID)) - conn.UpdateMeta() - } - // save to internal state - // check if it already exists - connsLock.Lock() - _, ok := conns[conn.ID] - if !ok { - conns[conn.ID] = conn - } - connsLock.Unlock() + // save + connsLock.Lock() + conns[conn.ID] = conn + connsLock.Unlock() + } } // notify database controller @@ -270,7 +263,11 @@ func (conn *Connection) Save() { // delete deletes a link from the storage and propagates the change. Nothing is locked - both the conns map and the connection itself require locking func (conn *Connection) delete() { - delete(conns, conn.ID) + if conn.ID == "" { + delete(dnsConns, strconv.Itoa(conn.process.Pid)+"/"+conn.Scope) + } else { + delete(conns, conn.ID) + } conn.Meta().Delete() dbController.PushUpdate(conn) diff --git a/network/database.go b/network/database.go index d7dca398..073dcbc0 100644 --- a/network/database.go +++ b/network/database.go @@ -77,7 +77,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { if slashes <= 1 { // processes for _, proc := range process.All() { - if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) { + if q.Matches(proc) { it.Next <- proc } } @@ -86,9 +86,9 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { if slashes <= 2 { // dns scopes only dnsConnsLock.RLock() - for _, dnsConns := range dnsConns { - if strings.HasPrefix(dnsConns.DatabaseKey(), q.DatabaseKeyPrefix()) { - it.Next <- dnsConns + for _, dnsConn := range dnsConns { + if q.Matches(dnsConn) { + it.Next <- dnsConn } } dnsConnsLock.RUnlock() @@ -98,7 +98,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { // connections connsLock.RLock() for _, conn := range conns { - if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { + if q.Matches(conn) { it.Next <- conn } } diff --git a/network/dns.go b/network/dns.go index 88ca7be3..b33fa99f 100644 --- a/network/dns.go +++ b/network/dns.go @@ -5,6 +5,8 @@ import ( "strconv" "sync" "time" + + "github.com/safing/portmaster/process" ) var ( @@ -16,6 +18,9 @@ var ( // duration after which DNS requests without a following connection are logged openDNSRequestLimit = 3 * time.Second + + // scope prefix + unidentifiedProcessScopePrefix = strconv.Itoa(process.UnidentifiedProcessID) + "/" ) func removeOpenDNSRequest(pid int, fqdn string) { @@ -23,7 +28,13 @@ func removeOpenDNSRequest(pid int, fqdn string) { defer openDNSRequestsLock.Unlock() key := strconv.Itoa(pid) + "/" + fqdn - delete(openDNSRequests, key) + _, ok := openDNSRequests[key] + if ok { + delete(openDNSRequests, key) + } else if pid != process.UnidentifiedProcessID { + // check if there is an open dns request from an unidentified process + delete(openDNSRequests, unidentifiedProcessScopePrefix+fqdn) + } } // SaveOpenDNSRequest saves a dns request connection that was allowed to proceed. diff --git a/process/database.go b/process/database.go index b4ce09b3..1ce5295d 100644 --- a/process/database.go +++ b/process/database.go @@ -53,16 +53,13 @@ func (p *Process) Save() { p.Lock() defer p.Unlock() + p.UpdateMeta() + if !p.KeyIsSet() { + // set key p.SetKey(fmt.Sprintf("%s/%d", processDatabaseNamespace, p.Pid)) - p.CreateMeta() - } - processesLock.RLock() - _, ok := processes[p.Pid] - processesLock.RUnlock() - - if !ok { + // save processesLock.Lock() processes[p.Pid] = p processesLock.Unlock() @@ -113,7 +110,9 @@ func CleanProcessStorage(activePIDs map[int]struct{}) { _, active := activePIDs[p.Pid] switch { - case p.Pid <= 0: + case p.Pid == UnidentifiedProcessID: + // internal + case p.Pid == SystemProcessID: // internal case active: // process in system process table or recently seen on the network diff --git a/process/find.go b/process/find.go index 997d487b..30f93f2d 100644 --- a/process/find.go +++ b/process/find.go @@ -49,7 +49,7 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6: return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) default: - return -1, false, errors.New("unsupported protocol for finding process") + return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") } } @@ -58,7 +58,7 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { if !enableProcessDetection() { log.Tracer(pkt.Ctx()).Tracef("process: process detection disabled") - return UnknownProcess, direction, nil + return GetUnidentifiedProcess(pkt.Ctx()), pkt.Info().Direction, nil } log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet") @@ -107,7 +107,7 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote case protocol == packet.UDP && ipVersion == packet.IPv6: return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, false) default: - return -1, false, errors.New("unsupported protocol for finding process") + return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") } } @@ -116,7 +116,7 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") - return UnknownProcess, nil + return GetUnidentifiedProcess(ctx), nil } log.Tracer(ctx).Tracef("process: getting process and profile by endpoints") diff --git a/process/iphelper/get.go b/process/iphelper/get.go index 99c0f821..6487ea06 100644 --- a/process/iphelper/get.go +++ b/process/iphelper/get.go @@ -9,6 +9,10 @@ import ( "time" ) +const ( + unidentifiedProcessID = -1 +) + var ( tcp4Connections []*ConnectionEntry tcp4Listeners []*ConnectionEntry @@ -55,7 +59,7 @@ func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote } lock.Unlock() if err != nil { - return -1, pktDirection, err + return unidentifiedProcessID, pktDirection, err } // search @@ -67,7 +71,7 @@ func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote time.Sleep(waitTime) } - return -1, pktDirection, nil + return unidentifiedProcessID, pktDirection, nil } // GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection. @@ -91,7 +95,7 @@ func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote } lock.Unlock() if err != nil { - return -1, pktDirection, err + return unidentifiedProcessID, pktDirection, err } // search @@ -103,7 +107,7 @@ func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote time.Sleep(waitTime) } - return -1, pktDirection, nil + return unidentifiedProcessID, pktDirection, nil } // GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection. @@ -127,7 +131,7 @@ func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote } lock.Unlock() if err != nil { - return -1, pktDirection, err + return unidentifiedProcessID, pktDirection, err } // search @@ -139,7 +143,7 @@ func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote time.Sleep(waitTime) } - return -1, pktDirection, nil + return unidentifiedProcessID, pktDirection, nil } // GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection. @@ -163,7 +167,7 @@ func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote } lock.Unlock() if err != nil { - return -1, pktDirection, err + return unidentifiedProcessID, pktDirection, err } // search @@ -175,7 +179,7 @@ func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remote time.Sleep(waitTime) } - return -1, pktDirection, nil + return unidentifiedProcessID, pktDirection, nil } func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { //nolint:unparam // TODO: use direction, it may not be used because results caused problems, investigate. @@ -204,7 +208,7 @@ func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, } } - return -1, pktDirection + return unidentifiedProcessID, pktDirection } func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { @@ -218,7 +222,7 @@ func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localP } } - return -1 + return unidentifiedProcessID } func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) { @@ -231,7 +235,7 @@ func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) } } - return -1 + return unidentifiedProcessID } // GetActiveConnectionIDs returns all currently active connection IDs. diff --git a/process/proc/gather.go b/process/proc/gather.go index 436b0bb2..1413b3c9 100644 --- a/process/proc/gather.go +++ b/process/proc/gather.go @@ -33,7 +33,7 @@ func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid i } } if !ok { - return -1, NoSocket + return unidentifiedProcessID, NoSocket } } @@ -45,7 +45,7 @@ func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid i pid, ok = GetPidOfInode(uid, inode) } if !ok { - return -1, NoProcess + return unidentifiedProcessID, NoProcess } return @@ -64,7 +64,7 @@ func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8 } if !ok { - return -1, NoSocket + return unidentifiedProcessID, NoSocket } } @@ -76,7 +76,7 @@ func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8 pid, ok = GetPidOfInode(uid, inode) } if !ok { - return -1, NoProcess + return unidentifiedProcessID, NoProcess } return diff --git a/process/proc/get.go b/process/proc/get.go index dec27e23..52974b3e 100644 --- a/process/proc/get.go +++ b/process/proc/get.go @@ -7,6 +7,10 @@ import ( "net" ) +const ( + unidentifiedProcessID = -1 +) + // GetTCP4PacketInfo searches the network state tables for a TCP4 connection func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { return search(TCP4, localIP, localPort, pktDirection) @@ -52,11 +56,11 @@ func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) switch status { case NoSocket: - return -1, direction, errors.New("could not find socket") + return unidentifiedProcessID, direction, errors.New("could not find socket") case NoProcess: - return -1, direction, errors.New("could not find PID") + return unidentifiedProcessID, direction, errors.New("could not find PID") default: - return -1, direction, nil + return unidentifiedProcessID, direction, nil } } diff --git a/process/proc/processfinder.go b/process/proc/processfinder.go index 5ee1bb4b..5e6ed7cc 100644 --- a/process/proc/processfinder.go +++ b/process/proc/processfinder.go @@ -77,7 +77,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO } } - return -1, false + return unidentifiedProcessID, false } func findSocketFromPid(pid, inode int) bool { diff --git a/process/proc/sockets.go b/process/proc/sockets.go index c82b078c..bcdd91d4 100644 --- a/process/proc/sockets.go +++ b/process/proc/sockets.go @@ -100,7 +100,7 @@ func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, socketData, err := os.Open(procFile) if err != nil { log.Warningf("process/proc: could not read %s: %s", procFile, err) - return -1, -1, false + return unidentifiedProcessID, unidentifiedProcessID, false } defer socketData.Close() @@ -146,7 +146,7 @@ func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, } - return -1, -1, false + return unidentifiedProcessID, unidentifiedProcessID, false } @@ -187,7 +187,7 @@ func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, return data[0], data[1], true } - return -1, -1, false + return unidentifiedProcessID, unidentifiedProcessID, false } func procDelimiter(c rune) bool { diff --git a/process/process.go b/process/process.go index 2ab29383..8ef1ad73 100644 --- a/process/process.go +++ b/process/process.go @@ -75,11 +75,11 @@ func (p *Process) String() string { func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid) - if pid == -1 { - return UnknownProcess, nil - } - if pid == 0 { - return OSProcess, nil + switch pid { + case UnidentifiedProcessID: + return GetUnidentifiedProcess(ctx), nil + case SystemProcessID: + return GetSystemProcess(ctx), nil } process, err := loadProcess(ctx, pid) @@ -88,8 +88,8 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { } for { - if process.ParentPid == 0 { - return OSProcess, nil + if process.ParentPid <= 0 { + return process, nil } parentProcess, err := loadProcess(ctx, process.ParentPid) if err != nil { @@ -121,11 +121,11 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) { log.Tracer(ctx).Tracef("process: getting process for PID %d", pid) - if pid == -1 { - return UnknownProcess, nil - } - if pid == 0 { - return OSProcess, nil + switch pid { + case UnidentifiedProcessID: + return GetUnidentifiedProcess(ctx), nil + case SystemProcessID: + return GetSystemProcess(ctx), nil } p, err := loadProcess(ctx, pid) @@ -184,11 +184,12 @@ func deduplicateRequest(ctx context.Context, pid int) (finishRequest func()) { } func loadProcess(ctx context.Context, pid int) (*Process, error) { - if pid == -1 { - return UnknownProcess, nil - } - if pid == 0 { - return OSProcess, nil + + switch pid { + case UnidentifiedProcessID: + return GetUnidentifiedProcess(ctx), nil + case SystemProcessID: + return GetSystemProcess(ctx), nil } process, ok := GetProcessFromStorage(pid) diff --git a/process/profile.go b/process/profile.go index 388684fa..0f0ad5c6 100644 --- a/process/profile.go +++ b/process/profile.go @@ -15,6 +15,8 @@ func (p *Process) GetProfile(ctx context.Context) error { // only find profiles if not already done. if p.profile != nil { log.Tracer(ctx).Trace("process: profile already loaded") + // mark profile as used + p.profile.MarkUsed() return nil } log.Tracer(ctx).Trace("process: loading profile") @@ -29,10 +31,8 @@ func (p *Process) GetProfile(ctx context.Context) error { localProfile.Name = p.ExecName } - // mark as used and save - if localProfile.MarkUsed() { - _ = localProfile.Save() - } + // mark profile as used + localProfile.MarkUsed() p.LocalProfileKey = localProfile.Key() p.profile = profile.NewLayeredProfile(localProfile) diff --git a/process/special.go b/process/special.go new file mode 100644 index 00000000..277d337b --- /dev/null +++ b/process/special.go @@ -0,0 +1,84 @@ +package process + +import ( + "context" + "time" + + "github.com/safing/portbase/log" + "github.com/safing/portmaster/profile" +) + +// Special Process IDs +const ( + UnidentifiedProcessID = -1 + SystemProcessID = 0 +) + +var ( + // unidentifiedProcess is used when a process cannot be found. + unidentifiedProcess = &Process{ + UserID: UnidentifiedProcessID, + UserName: "Unknown", + Pid: UnidentifiedProcessID, + ParentPid: UnidentifiedProcessID, + Name: "Unidentified Processes", + } + + // systemProcess is used to represent the Kernel. + systemProcess = &Process{ + UserID: SystemProcessID, + UserName: "Kernel", + Pid: SystemProcessID, + ParentPid: SystemProcessID, + Name: "Operating System", + } +) + +// GetUnidentifiedProcess returns the special process assigned to unidentified processes. +func GetUnidentifiedProcess(ctx context.Context) *Process { + return getSpecialProcess(ctx, UnidentifiedProcessID, unidentifiedProcess, profile.GetUnidentifiedProfile) +} + +// GetSystemProcess returns the special process used for the Kernel. +func GetSystemProcess(ctx context.Context) *Process { + return getSpecialProcess(ctx, SystemProcessID, systemProcess, profile.GetSystemProfile) +} + +func getSpecialProcess(ctx context.Context, pid int, template *Process, getProfile func() *profile.Profile) *Process { + // check storage + p, ok := GetProcessFromStorage(pid) + if ok { + return p + } + + // assign template + p = template + + p.Lock() + defer p.Unlock() + + if p.FirstSeen == 0 { + p.FirstSeen = time.Now().Unix() + } + + // only find profiles if not already done. + if p.profile != nil { + log.Tracer(ctx).Trace("process: special profile already loaded") + // mark profile as used + p.profile.MarkUsed() + return p + } + log.Tracer(ctx).Trace("process: loading special profile") + + // get profile + localProfile := getProfile() + + // mark profile as used + localProfile.MarkUsed() + + p.LocalProfileKey = localProfile.Key() + p.profile = profile.NewLayeredProfile(localProfile) + + go p.Save() + return p +} diff --git a/process/unknown.go b/process/unknown.go deleted file mode 100644 index 4f6cde12..00000000 --- a/process/unknown.go +++ /dev/null @@ -1,26 +0,0 @@ -package process - -var ( - // UnknownProcess is used when a process cannot be found. - UnknownProcess = &Process{ - UserID: -1, - UserName: "Unknown", - Pid: -1, - ParentPid: -1, - Name: "Unknown Processes", - } - - // OSProcess is used to represent the Kernel. - OSProcess = &Process{ - UserID: 0, - UserName: "Kernel", - Pid: 0, - ParentPid: 0, - Name: "Operating System", - } -) - -func init() { - UnknownProcess.Save() - OSProcess.Save() -} diff --git a/profile/active.go b/profile/active.go index ff8f3d3e..ff6a71c8 100644 --- a/profile/active.go +++ b/profile/active.go @@ -1,7 +1,14 @@ package profile import ( + "context" "sync" + "time" +) + +const ( + activeProfileCleanerTickDuration = 10 * time.Minute + activeProfileCleanerThreshold = 1 * time.Hour ) var ( @@ -38,7 +45,34 @@ func markActiveProfileAsOutdated(scopedID string) { profile, ok := activeProfiles[scopedID] if ok { - profile.oudated.Set() + profile.outdated.Set() delete(activeProfiles, scopedID) } } + +func cleanActiveProfiles(ctx context.Context) error { + for { + select { + case <-time.After(activeProfileCleanerTickDuration): + + threshold := time.Now().Add(-activeProfileCleanerThreshold) + + activeProfilesLock.Lock() + for id, profile := range activeProfiles { + // get last used + profile.Lock() + lastUsed := profile.lastUsed + profile.Unlock() + // remove if not used for a while + if lastUsed.Before(threshold) { + profile.outdated.Set() + delete(activeProfiles, id) + } + } + activeProfilesLock.Unlock() + + case <-ctx.Done(): + return nil + } + } +} diff --git a/profile/config-update.go b/profile/config-update.go index 15e62b8e..b4c3f5e4 100644 --- a/profile/config-update.go +++ b/profile/config-update.go @@ -70,8 +70,8 @@ func updateGlobalConfigProfile(ctx context.Context, data interface{}) error { // build global profile for reference profile := &Profile{ - ID: "config", - Source: SourceGlobal, + ID: "global-config", + Source: SourceSpecial, Name: "Global Configuration", Config: make(map[string]interface{}), internalSave: true, diff --git a/profile/module.go b/profile/module.go index 615eddcf..8d962a8a 100644 --- a/profile/module.go +++ b/profile/module.go @@ -42,6 +42,8 @@ func start() error { return err } + module.StartServiceWorker("clean active profiles", 0, cleanActiveProfiles) + err = updateGlobalConfigProfile(module.Ctx, nil) if err != nil { log.Warningf("profile: error during loading global profile from configuration: %s", err) diff --git a/profile/profile-layered.go b/profile/profile-layered.go index f00dbfe7..dd0af165 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -128,7 +128,7 @@ func (lp *LayeredProfile) Update() (revisionCounter uint64) { var changed bool for i, layer := range lp.layers { - if layer.oudated.IsSet() { + if layer.outdated.IsSet() { changed = true // update layer newLayer, err := GetProfile(layer.Source, layer.ID) @@ -175,6 +175,11 @@ func (lp *LayeredProfile) updateCaches() { // TODO: ignore community profiles } +// MarkUsed marks the localProfile as used. +func (lp *LayeredProfile) MarkUsed() { + lp.localProfile.MarkUsed() +} + // SecurityLevel returns the highest security level of all layered profiles. func (lp *LayeredProfile) SecurityLevel() uint8 { return uint8(atomic.LoadUint32(lp.securityLevel)) diff --git a/profile/profile.go b/profile/profile.go index 46e3b447..70f7c162 100644 --- a/profile/profile.go +++ b/profile/profile.go @@ -19,15 +19,15 @@ import ( ) var ( - lastUsedUpdateThreshold = 1 * time.Hour + lastUsedUpdateThreshold = 24 * time.Hour ) // Profile Sources const ( - SourceLocal string = "local" + SourceLocal string = "local" // local, editable + SourceSpecial string = "special" // specials (read-only) SourceCommunity string = "community" SourceEnterprise string = "enterprise" - SourceGlobal string = "global" ) // Default Action IDs @@ -77,7 +77,8 @@ type Profile struct { //nolint:maligned // not worth the effort filterListIDs []string // Lifecycle Management - oudated *abool.AtomicBool + outdated *abool.AtomicBool + lastUsed time.Time // Framework // If a Profile is declared as a Framework (i.e. an Interpreter and the likes), then the real process/actor must be found @@ -94,7 +95,7 @@ type Profile struct { //nolint:maligned // not worth the effort func (profile *Profile) prepConfig() (err error) { // prepare configuration profile.configPerspective, err = config.NewPerspective(profile.Config) - profile.oudated = abool.New() + profile.outdated = abool.New() return } @@ -156,10 +157,11 @@ func (profile *Profile) parseConfig() error { // New returns a new Profile. func New() *Profile { profile := &Profile{ - ID: uuid.NewV4().String(), - Source: SourceLocal, - Created: time.Now().Unix(), - Config: make(map[string]interface{}), + ID: uuid.NewV4().String(), + Source: SourceLocal, + Created: time.Now().Unix(), + Config: make(map[string]interface{}), + internalSave: true, } // create placeholders @@ -190,13 +192,26 @@ func (profile *Profile) Save() error { return profileDB.Put(profile) } -// MarkUsed marks the profile as used, eventually. -func (profile *Profile) MarkUsed() (updated bool) { +// MarkUsed marks the profile as used and saves it when it has changed. +func (profile *Profile) MarkUsed() { + profile.Lock() + // lastUsed + profile.lastUsed = time.Now() + + // ApproxLastUsed + save := false if time.Now().Add(-lastUsedUpdateThreshold).Unix() > profile.ApproxLastUsed { profile.ApproxLastUsed = time.Now().Unix() - return true + save = true + } + profile.Unlock() + + if save { + err := profile.Save() + if err != nil { + log.Warningf("profiles: failed to save profile %s after marking as used: %s", profile.ScopedID(), err) + } } - return false } // String returns a string representation of the Profile. @@ -224,8 +239,6 @@ func (profile *Profile) addEndpointyEntry(cfgKey, newEntry string) { endpointList = append(endpointList, newEntry) profile.Config[cfgKey] = endpointList - // save without full reload - profile.internalSave = true profile.Unlock() err := profile.Save() if err != nil { @@ -233,10 +246,13 @@ func (profile *Profile) addEndpointyEntry(cfgKey, newEntry string) { } // reload manually + profile.Lock() + profile.dataParsed = false err = profile.parseConfig() if err != nil { log.Warningf("profile: failed to parse profile config after adding endpoint: %s", err) } + profile.Unlock() } // GetProfile loads a profile from the database. @@ -249,6 +265,7 @@ func GetProfileByScopedID(scopedID string) (*Profile, error) { // check cache profile := getActiveProfile(scopedID) if profile != nil { + profile.MarkUsed() return profile, nil } @@ -266,7 +283,6 @@ func GetProfileByScopedID(scopedID string) (*Profile, error) { // lock for prepping profile.Lock() - defer profile.Unlock() // prepare config err = profile.prepConfig() @@ -280,7 +296,13 @@ func GetProfileByScopedID(scopedID string) (*Profile, error) { log.Warningf("profiles: profile %s has (partly) invalid configuration: %s", profile.ID, err) } + // mark as internal + profile.internalSave = true + + profile.Unlock() + // mark active + profile.MarkUsed() markProfileActive(profile) return profile, nil diff --git a/profile/special.go b/profile/special.go new file mode 100644 index 00000000..6bce01d7 --- /dev/null +++ b/profile/special.go @@ -0,0 +1,56 @@ +package profile + +import ( + "github.com/safing/portbase/log" +) + +const ( + unidentifiedProfileID = "_unidentified" + systemProfileID = "_system" +) + +// GetUnidentifiedProfile returns the special profile assigned to unidentified processes. +func GetUnidentifiedProfile() *Profile { + // get profile + profile, err := GetProfile(SourceLocal, unidentifiedProfileID) + if err == nil { + return profile + } + + // create if not available (or error) + profile = New() + profile.Name = "Unidentified Processes" + profile.Source = SourceLocal + profile.ID = unidentifiedProfileID + + // save to db + err = profile.Save() + if err != nil { + log.Warningf("profiles: failed to save %s: %s", profile.ScopedID(), err) + } + + return profile +} + +// GetSystemProfile returns the special profile used for the Kernel. +func GetSystemProfile() *Profile { + // get profile + profile, err := GetProfile(SourceLocal, systemProfileID) + if err == nil { + return profile + } + + // create if not available (or error) + profile = New() + profile.Name = "Operating System" + profile.Source = SourceLocal + profile.ID = systemProfileID + + // save to db + err = profile.Save() + if err != nil { + log.Warningf("profiles: failed to save %s: %s", profile.ScopedID(), err) + } + + return profile +}