Rewrite network tree saving and cleaning procedures

This commit is contained in:
Daniel 2019-05-22 16:10:05 +02:00
parent 1873999b38
commit fb4fb20d4b
9 changed files with 312 additions and 139 deletions

View file

@ -95,11 +95,11 @@ func stop() error {
func handlePacket(pkt packet.Packet) { func handlePacket(pkt packet.Packet) {
// allow localhost, for now // allow localhost, for now
if pkt.Info().Src.Equal(pkt.Info().Dst) { // if pkt.Info().Src.Equal(pkt.Info().Dst) {
log.Debugf("accepting localhost communication: %s", pkt) // log.Debugf("accepting localhost communication: %s", pkt)
pkt.PermanentAccept() // pkt.PermanentAccept()
return // return
} // }
// allow local dns // allow local dns
if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && pkt.Info().Src.Equal(pkt.Info().Dst) { if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && pkt.Info().Src.Equal(pkt.Info().Dst) {
@ -160,6 +160,9 @@ func handlePacket(pkt packet.Packet) {
// associate packet to link and handle // associate packet to link and handle
link, created := network.GetOrCreateLinkByPacket(pkt) link, created := network.GetOrCreateLinkByPacket(pkt)
defer func() {
go link.SaveIfNeeded()
}()
if created { if created {
link.SetFirewallHandler(initialHandler) link.SetFirewallHandler(initialHandler)
link.HandlePacket(pkt) link.HandlePacket(pkt)
@ -169,7 +172,7 @@ func handlePacket(pkt packet.Packet) {
link.HandlePacket(pkt) link.HandlePacket(pkt)
return return
} }
issueVerdict(pkt, link, 0, true, false) issueVerdict(pkt, link, 0, true)
} }
func initialHandler(pkt packet.Packet, link *network.Link) { func initialHandler(pkt packet.Packet, link *network.Link) {
@ -186,6 +189,9 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
} else { } else {
comm.AddLink(link) comm.AddLink(link)
} }
defer func() {
go comm.SaveIfNeeded()
}()
// approve // approve
link.Accept("internally approved") link.Accept("internally approved")
@ -193,8 +199,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
// finish // finish
link.StopFirewallHandler() link.StopFirewallHandler()
issueVerdict(pkt, link, 0, true, true) issueVerdict(pkt, link, 0, true)
return return
} }
@ -212,11 +217,13 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
log.Tracer(pkt.Ctx()).Errorf("firewall: could not get unknown comm: %s", err) log.Tracer(pkt.Ctx()).Errorf("firewall: could not get unknown comm: %s", err)
link.UpdateVerdict(network.VerdictDrop) link.UpdateVerdict(network.VerdictDrop)
link.StopFirewallHandler() link.StopFirewallHandler()
issueVerdict(pkt, link, 0, true, true) issueVerdict(pkt, link, 0, true)
return return
} }
} }
defer func() {
go comm.SaveIfNeeded()
}()
// add new Link to Communication (and save both) // add new Link to Communication (and save both)
comm.AddLink(link) comm.AddLink(link)
@ -226,11 +233,12 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
if comm.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { if comm.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) {
link.UpdateVerdict(network.VerdictRerouteToNameserver) link.UpdateVerdict(network.VerdictRerouteToNameserver)
link.StopFirewallHandler() link.StopFirewallHandler()
issueVerdict(pkt, link, 0, true, true) issueVerdict(pkt, link, 0, true)
return return
} }
log.Tracer(pkt.Ctx()).Trace("firewall: starting decision process") log.Tracer(pkt.Ctx()).Trace("firewall: starting decision process")
DecideOnCommunication(comm, pkt) DecideOnCommunication(comm, pkt)
DecideOnLink(comm, link, pkt) DecideOnLink(comm, link, pkt)
@ -253,7 +261,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
inspectThenVerdict(pkt, link) inspectThenVerdict(pkt, link)
default: default:
link.StopFirewallHandler() link.StopFirewallHandler()
issueVerdict(pkt, link, 0, true, false) issueVerdict(pkt, link, 0, true)
} }
} }
@ -261,23 +269,23 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
func inspectThenVerdict(pkt packet.Packet, link *network.Link) { func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) pktVerdict, continueInspection := inspection.RunInspectors(pkt, link)
if continueInspection { if continueInspection {
issueVerdict(pkt, link, pktVerdict, false, false) issueVerdict(pkt, link, pktVerdict, false)
return return
} }
// we are done with inspecting // we are done with inspecting
link.StopFirewallHandler() link.StopFirewallHandler()
issueVerdict(pkt, link, 0, true, false) issueVerdict(pkt, link, 0, true)
} }
func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict, allowPermanent, forceSave bool) { func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict, allowPermanent bool) {
link.Lock() link.Lock()
// enable permanent verdict // enable permanent verdict
if allowPermanent && !link.VerdictPermanent { if allowPermanent && !link.VerdictPermanent {
link.VerdictPermanent = permanentVerdicts() link.VerdictPermanent = permanentVerdicts()
if link.VerdictPermanent { if link.VerdictPermanent {
forceSave = true link.SaveWhenFinished()
} }
} }
@ -320,11 +328,6 @@ func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict
link.Unlock() link.Unlock()
log.InfoTracef(pkt.Ctx(), "firewall: %s %s", link.Verdict, link) log.InfoTracef(pkt.Ctx(), "firewall: %s %s", link.Verdict, link)
if forceSave && !link.KeyIsSet() {
// always save if not yet saved
go link.Save()
}
} }
// func tunnelHandler(pkt packet.Packet) { // func tunnelHandler(pkt packet.Packet) {

View file

@ -139,6 +139,9 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
nxDomain(w, query) nxDomain(w, query)
return return
} }
defer func() {
go comm.SaveIfNeeded()
}()
// check for possible DNS tunneling / data transmission // check for possible DNS tunneling / data transmission
// TODO: improve this // TODO: improve this
@ -152,6 +155,9 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// check profile before we even get intel and rr // check profile before we even get intel and rr
firewall.DecideOnCommunicationBeforeIntel(comm, fqdn) firewall.DecideOnCommunicationBeforeIntel(comm, fqdn)
comm.Lock()
comm.SaveWhenFinished()
comm.Unlock()
if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop { if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop {
log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm) log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm)
@ -172,7 +178,6 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
comm.Lock() comm.Lock()
comm.Intel = domainIntel comm.Intel = domainIntel
comm.Unlock() comm.Unlock()
comm.Save()
// check with intel // check with intel
firewall.DecideOnCommunicationAfterIntel(comm, fqdn, rrCache) firewall.DecideOnCommunicationAfterIntel(comm, fqdn, rrCache)

View file

@ -5,35 +5,33 @@ package network
import ( import (
"time" "time"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/process" "github.com/Safing/portmaster/process"
) )
var ( var (
cleanerTickDuration = 10 * time.Second cleanerTickDuration = 10 * time.Second
deadLinksTimeout = 3 * time.Minute deleteLinksAfterEndedThreshold = 5 * time.Minute
thresholdDuration = 3 * time.Minute deleteCommsWithoutLinksThreshhold = 3 * time.Minute
lastEstablishedUpdateThreshold = 30 * time.Second
) )
func cleaner() { func cleaner() {
for { for {
time.Sleep(cleanerTickDuration) time.Sleep(cleanerTickDuration)
cleanLinks() activeComms := cleanLinks()
time.Sleep(2 * time.Second) activeProcs := cleanComms(activeComms)
cleanComms() process.CleanProcessStorage(activeProcs)
time.Sleep(2 * time.Second)
cleanProcesses()
} }
} }
func cleanLinks() { func cleanLinks() (activeComms map[string]struct{}) {
activeComms = make(map[string]struct{})
activeIDs := process.GetActiveConnectionIDs() activeIDs := process.GetActiveConnectionIDs()
now := time.Now().Unix() now := time.Now().Unix()
deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix() deleteOlderThan := time.Now().Add(-deleteLinksAfterEndedThreshold).Unix()
// log.Tracef("network.clean: now=%d", now)
// log.Tracef("network.clean: deleteOlderThan=%d", deleteOlderThan)
linksLock.RLock() linksLock.RLock()
defer linksLock.RUnlock() defer linksLock.RUnlock()
@ -42,18 +40,21 @@ func cleanLinks() {
for key, link := range links { for key, link := range links {
// delete dead links // delete dead links
if link.Ended > 0 { link.Lock()
link.Lock() deleteThis := link.Ended > 0 && link.Ended < deleteOlderThan
deleteThis := link.Ended < deleteOlderThan link.Unlock()
link.Unlock() if deleteThis {
if deleteThis { log.Tracef("network.clean: deleted %s (ended at %d)", link.DatabaseKey(), link.Ended)
// log.Tracef("network.clean: deleted %s", link.DatabaseKey()) go link.Delete()
go link.Delete()
}
continue continue
} }
// not yet deleted, so its still a valid link regarding link count
comm := link.Communication()
comm.Lock()
markActive(activeComms, comm.DatabaseKey())
comm.Unlock()
// check if link is dead // check if link is dead
found = false found = false
for _, activeID := range activeIDs { for _, activeID := range activeIDs {
@ -63,31 +64,53 @@ func cleanLinks() {
} }
} }
// mark end time
if !found { if !found {
// mark end time
link.Lock()
link.Ended = now link.Ended = now
// log.Tracef("network.clean: marked %s as ended.", link.DatabaseKey()) link.Unlock()
go link.Save() log.Tracef("network.clean: marked %s as ended", link.DatabaseKey())
go link.save()
} }
} }
return
} }
func cleanComms() { func cleanComms(activeLinks map[string]struct{}) (activeComms map[string]struct{}) {
activeComms = make(map[string]struct{})
commsLock.RLock() commsLock.RLock()
defer commsLock.RUnlock() defer commsLock.RUnlock()
threshold := time.Now().Add(-thresholdDuration).Unix() threshold := time.Now().Add(-deleteCommsWithoutLinksThreshhold).Unix()
for _, comm := range comms { for _, comm := range comms {
// has links?
_, hasLinks := activeLinks[comm.DatabaseKey()]
// comm created
comm.Lock() comm.Lock()
if comm.FirstLinkEstablished < threshold && comm.LinkCount == 0 { created := comm.Meta().Created
// log.Tracef("network.clean: deleted %s", comm.DatabaseKey())
go comm.Delete()
}
comm.Unlock() comm.Unlock()
if !hasLinks && created < threshold {
log.Tracef("network.clean: deleted %s", comm.DatabaseKey())
go comm.Delete()
} else {
p := comm.Process()
p.Lock()
markActive(activeComms, p.DatabaseKey())
p.Unlock()
}
} }
return
} }
func cleanProcesses() { func markActive(activeMap map[string]struct{}, key string) {
process.CleanProcessStorage(thresholdDuration) _, ok := activeMap[key]
if !ok {
activeMap[key] = struct{}{}
}
} }

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/Safing/portbase/database/record" "github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network/netutils" "github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/network/packet" "github.com/Safing/portmaster/network/packet"
@ -33,9 +34,9 @@ type Communication struct {
FirstLinkEstablished int64 FirstLinkEstablished int64
LastLinkEstablished int64 LastLinkEstablished int64
LinkCount uint
profileUpdateVersion uint32 profileUpdateVersion uint32
saveWhenFinished bool
} }
// Process returns the process that owns the connection. // Process returns the process that owns the connection.
@ -53,6 +54,7 @@ func (comm *Communication) ResetVerdict() {
comm.Verdict = VerdictUndecided comm.Verdict = VerdictUndecided
comm.Reason = "" comm.Reason = ""
comm.saveWhenFinished = true
} }
// GetVerdict returns the current verdict. // GetVerdict returns the current verdict.
@ -97,7 +99,7 @@ func (comm *Communication) UpdateVerdict(newVerdict Verdict) {
if newVerdict > comm.Verdict { if newVerdict > comm.Verdict {
comm.Verdict = newVerdict comm.Verdict = newVerdict
go comm.Save() comm.saveWhenFinished = true
} }
} }
@ -110,6 +112,7 @@ func (comm *Communication) SetReason(reason string) {
comm.Lock() comm.Lock()
defer comm.Unlock() defer comm.Unlock()
comm.Reason = reason comm.Reason = reason
comm.saveWhenFinished = true
} }
// AddReason adds a human readable string as to why a certain verdict was set in regard to this communication. // AddReason adds a human readable string as to why a certain verdict was set in regard to this communication.
@ -174,6 +177,7 @@ func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) {
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
saveWhenFinished: true,
} }
} }
communication.process.AddCommunication() communication.process.AddCommunication()
@ -206,6 +210,7 @@ func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) {
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
saveWhenFinished: true,
} }
} }
communication.process.AddCommunication() communication.process.AddCommunication()
@ -222,6 +227,7 @@ func GetCommunicationByFirstPacket(pkt packet.Packet) (*Communication, error) {
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
saveWhenFinished: true,
} }
} }
communication.process.AddCommunication() communication.process.AddCommunication()
@ -246,12 +252,13 @@ func GetCommunicationByDNSRequest(ctx context.Context, ip net.IP, port uint16, f
communication, ok := GetCommunication(proc.Pid, fqdn) communication, ok := GetCommunication(proc.Pid, fqdn)
if !ok { if !ok {
communication = &Communication{ communication = &Communication{
Domain: fqdn, Domain: fqdn,
process: proc, process: proc,
Inspect: true, Inspect: true,
saveWhenFinished: true,
} }
communication.process.AddCommunication() communication.process.AddCommunication()
communication.Save() communication.saveWhenFinished = true
} }
return communication, nil return communication, nil
} }
@ -268,21 +275,47 @@ func (comm *Communication) makeKey() string {
return fmt.Sprintf("%d/%s", comm.process.Pid, comm.Domain) return fmt.Sprintf("%d/%s", comm.process.Pid, comm.Domain)
} }
// Save saves the connection object in the storage and propagates the change. // SaveWhenFinished marks the Connection for saving after all current actions are finished.
func (comm *Communication) Save() error { func (comm *Communication) SaveWhenFinished() {
comm.Lock() comm.saveWhenFinished = true
defer comm.Unlock() }
// SaveIfNeeded saves the Connection if it is marked for saving when finished.
func (comm *Communication) SaveIfNeeded() {
comm.Lock()
save := comm.saveWhenFinished
if save {
comm.saveWhenFinished = false
}
comm.Unlock()
if save {
comm.save()
}
}
// Save saves the Connection object in the storage and propagates the change.
func (comm *Communication) save() error {
// update comm
comm.Lock()
if comm.process == nil { if comm.process == nil {
comm.Unlock()
return errors.New("cannot save connection without process") return errors.New("cannot save connection without process")
} }
if !comm.KeyIsSet() { if !comm.KeyIsSet() {
comm.SetKey(fmt.Sprintf("network:tree/%d/%s", comm.process.Pid, comm.Domain)) comm.SetKey(fmt.Sprintf("network:tree/%d/%s", comm.process.Pid, comm.Domain))
comm.CreateMeta() comm.UpdateMeta()
}
if comm.Meta().Deleted > 0 {
log.Criticalf("network: revieving dead comm %s", comm)
comm.Meta().Deleted = 0
} }
key := comm.makeKey() key := comm.makeKey()
comm.saveWhenFinished = false
comm.Unlock()
// save comm
commsLock.RLock() commsLock.RLock()
_, ok := comms[key] _, ok := comms[key]
commsLock.RUnlock() commsLock.RUnlock()
@ -299,46 +332,42 @@ func (comm *Communication) Save() error {
// Delete deletes a connection from the storage and propagates the change. // Delete deletes a connection from the storage and propagates the change.
func (comm *Communication) Delete() { func (comm *Communication) Delete() {
commsLock.Lock()
defer commsLock.Unlock()
comm.Lock() comm.Lock()
defer comm.Unlock() defer comm.Unlock()
commsLock.Lock()
delete(comms, comm.makeKey()) delete(comms, comm.makeKey())
commsLock.Unlock()
comm.Meta().Delete() comm.Meta().Delete()
go dbController.PushUpdate(comm) go dbController.PushUpdate(comm)
comm.process.RemoveCommunication()
go comm.process.Save()
} }
// AddLink applies the Communication to the Link and increases sets counter and timestamps. // AddLink applies the Communication to the Link and sets timestamps.
func (comm *Communication) AddLink(link *Link) { func (comm *Communication) AddLink(link *Link) {
// apply comm to link
link.Lock() link.Lock()
link.comm = comm link.comm = comm
link.Verdict = comm.Verdict link.Verdict = comm.Verdict
link.Inspect = comm.Inspect link.Inspect = comm.Inspect
link.saveWhenFinished = true
link.Unlock() link.Unlock()
link.Save()
// update comm LastLinkEstablished
comm.Lock() comm.Lock()
comm.LinkCount++
// check if we should save
if comm.LastLinkEstablished < time.Now().Add(-3*time.Second).Unix() {
comm.saveWhenFinished = true
}
// update LastLinkEstablished
comm.LastLinkEstablished = time.Now().Unix() comm.LastLinkEstablished = time.Now().Unix()
if comm.FirstLinkEstablished == 0 { if comm.FirstLinkEstablished == 0 {
comm.FirstLinkEstablished = comm.LastLinkEstablished comm.FirstLinkEstablished = comm.LastLinkEstablished
} }
comm.Unlock() comm.Unlock()
comm.Save()
}
// RemoveLink lowers the link counter by one.
func (comm *Communication) RemoveLink() {
comm.Lock()
defer comm.Unlock()
if comm.LinkCount > 0 {
comm.LinkCount--
}
} }
// String returns a string representation of Communication. // String returns a string representation of Communication.

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -14,9 +15,9 @@ import (
) )
var ( var (
links = make(map[string]*Link) links = make(map[string]*Link) // key: Link ID
linksLock sync.RWMutex linksLock sync.RWMutex
comms = make(map[string]*Communication) comms = make(map[string]*Communication) // key: PID/Domain
commsLock sync.RWMutex commsLock sync.RWMutex
dbController *database.Controller dbController *database.Controller
@ -45,7 +46,7 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
case 3: case 3:
commsLock.RLock() commsLock.RLock()
defer commsLock.RUnlock() defer commsLock.RUnlock()
conn, ok := comms[splitted[2]] conn, ok := comms[fmt.Sprintf("%s/%s", splitted[1], splitted[2])]
if ok { if ok {
return conn, nil return conn, nil
} }
@ -72,30 +73,38 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato
} }
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
// processes slashes := strings.Count(q.DatabaseKeyPrefix(), "/")
for _, proc := range process.All() {
if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) { if slashes <= 1 {
it.Next <- proc // processes
for _, proc := range process.All() {
if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- proc
}
} }
} }
// comms if slashes <= 2 {
commsLock.RLock() // comms
for _, conn := range comms { commsLock.RLock()
if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { for _, conn := range comms {
it.Next <- conn if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- conn
}
} }
commsLock.RUnlock()
} }
commsLock.RUnlock()
// links if slashes <= 3 {
linksLock.RLock() // links
for _, link := range links { linksLock.RLock()
if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) { for _, link := range links {
it.Next <- link if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- link
}
} }
linksLock.RUnlock()
} }
linksLock.RUnlock()
it.Finish(nil) it.Finish(nil)
} }

View file

@ -42,6 +42,7 @@ type Link struct {
activeInspectors []bool activeInspectors []bool
inspectorData map[uint8]interface{} inspectorData map[uint8]interface{}
saveWhenFinished bool
} }
// Communication returns the Communication the Link is part of // Communication returns the Communication the Link is part of
@ -148,7 +149,7 @@ func (link *Link) UpdateVerdict(newVerdict Verdict) {
if newVerdict > link.Verdict { if newVerdict > link.Verdict {
link.Verdict = newVerdict link.Verdict = newVerdict
go link.Save() link.saveWhenFinished = true
} }
} }
@ -165,6 +166,8 @@ func (link *Link) AddReason(reason string) {
link.Reason += " | " link.Reason += " | "
} }
link.Reason += reason link.Reason += reason
link.saveWhenFinished = true
} }
// packetHandler sequentially handles queued packets // packetHandler sequentially handles queued packets
@ -223,20 +226,42 @@ func (link *Link) ApplyVerdict(pkt packet.Packet) {
} }
} }
// Save saves the link object in the storage and propagates the change. // SaveWhenFinished marks the Link for saving after all current actions are finished.
func (link *Link) Save() error { func (link *Link) SaveWhenFinished() {
link.Lock() link.saveWhenFinished = true
defer link.Unlock() }
// SaveIfNeeded saves the Link if it is marked for saving when finished.
func (link *Link) SaveIfNeeded() {
link.Lock()
save := link.saveWhenFinished
if save {
link.saveWhenFinished = false
}
link.Unlock()
if save {
link.save()
}
}
// Save saves the link object in the storage and propagates the change.
func (link *Link) save() error {
// update link
link.Lock()
if link.comm == nil { if link.comm == nil {
link.Unlock()
return errors.New("cannot save link without comms") return errors.New("cannot save link without comms")
} }
if !link.KeyIsSet() { if !link.KeyIsSet() {
link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.comm.Process().Pid, link.comm.Domain, link.ID)) link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.comm.Process().Pid, link.comm.Domain, link.ID))
link.CreateMeta() link.UpdateMeta()
} }
link.saveWhenFinished = false
link.Unlock()
// save link
linksLock.RLock() linksLock.RLock()
_, ok := links[link.ID] _, ok := links[link.ID]
linksLock.RUnlock() linksLock.RUnlock()
@ -253,17 +278,15 @@ func (link *Link) Save() error {
// Delete deletes a link from the storage and propagates the change. // Delete deletes a link from the storage and propagates the change.
func (link *Link) Delete() { func (link *Link) Delete() {
linksLock.Lock()
defer linksLock.Unlock()
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
linksLock.Lock()
delete(links, link.ID) delete(links, link.ID)
linksLock.Unlock()
link.Meta().Delete() link.Meta().Delete()
go dbController.PushUpdate(link) go dbController.PushUpdate(link)
link.comm.RemoveLink()
go link.comm.Save()
} }
// GetLink fetches a Link from the database from the default namespace for this object // GetLink fetches a Link from the database from the default namespace for this object
@ -294,6 +317,7 @@ func CreateLinkFromPacket(pkt packet.Packet) *Link {
Verdict: VerdictUndecided, Verdict: VerdictUndecided,
Started: time.Now().Unix(), Started: time.Now().Unix(),
RemoteAddress: pkt.FmtRemoteAddress(), RemoteAddress: pkt.FmtRemoteAddress(),
saveWhenFinished: true,
} }
return link return link
} }

View file

@ -52,7 +52,7 @@ func getOrCreateUnknownCommunication(pkt packet.Packet, connClass string) (*Comm
Verdict: VerdictDrop, Verdict: VerdictDrop,
Reason: ReasonUnknownProcess, Reason: ReasonUnknownProcess,
process: process.UnknownProcess, process: process.UnknownProcess,
Inspect: true, Inspect: false,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
} }
if pkt.IsOutbound() { if pkt.IsOutbound() {

View file

@ -5,17 +5,27 @@ import (
"sync" "sync"
"time" "time"
processInfo "github.com/shirou/gopsutil/process"
"github.com/Safing/portbase/database" "github.com/Safing/portbase/database"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/profile" "github.com/Safing/portmaster/profile"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
const (
processDatabaseNamespace = "network:tree"
)
var ( var (
processes = make(map[int]*Process) processes = make(map[int]*Process)
processesLock sync.RWMutex processesLock sync.RWMutex
dbController *database.Controller dbController *database.Controller
dbControllerFlag = abool.NewBool(false) dbControllerFlag = abool.NewBool(false)
deleteProcessesThreshold = 15 * time.Minute
lastEstablishedUpdateThreshold = 30 * time.Second
) )
// GetProcessFromStorage returns a process from the internal storage. // GetProcessFromStorage returns a process from the internal storage.
@ -28,13 +38,13 @@ func GetProcessFromStorage(pid int) (*Process, bool) {
} }
// All returns a copy of all process objects. // All returns a copy of all process objects.
func All() []*Process { func All() map[int]*Process {
processesLock.RLock() processesLock.RLock()
defer processesLock.RUnlock() defer processesLock.RUnlock()
all := make([]*Process, 0, len(processes)) all := make(map[int]*Process)
for _, proc := range processes { for _, proc := range processes {
all = append(all, proc) all[proc.Pid] = proc
} }
return all return all
@ -46,7 +56,7 @@ func (p *Process) Save() {
defer p.Unlock() defer p.Unlock()
if !p.KeyIsSet() { if !p.KeyIsSet() {
p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid)) p.SetKey(fmt.Sprintf("%s/%d", processDatabaseNamespace, p.Pid))
p.CreateMeta() p.CreateMeta()
} }
@ -89,49 +99,90 @@ func (p *Process) Delete() {
} }
// CleanProcessStorage cleans the storage from old processes. // CleanProcessStorage cleans the storage from old processes.
func CleanProcessStorage(thresholdDuration time.Duration) { func CleanProcessStorage(activeComms map[string]struct{}) {
activePIDs, err := getActivePIDs()
if err != nil {
log.Warningf("process: failed to get list of active PIDs: %s", err)
activePIDs = nil
}
processesCopy := All() processesCopy := All()
threshold := time.Now().Add(-thresholdDuration).Unix() threshold := time.Now().Add(-deleteProcessesThreshold).Unix()
delete := false delete := false
// clean primary processes // clean primary processes
for _, p := range processesCopy { for _, p := range processesCopy {
p.Lock() p.Lock()
if !p.Virtual && p.LastCommEstablished < threshold && p.CommCount == 0 { // check if internal
delete = true if p.Pid <= 0 {
p.Unlock()
continue
}
// has comms?
_, hasComms := activeComms[p.DatabaseKey()]
// virtual / active
virtual := p.Virtual
active := false
if activePIDs != nil {
_, active = activePIDs[p.Pid]
} }
p.Unlock() p.Unlock()
if delete { if !virtual && !hasComms && !active && p.LastCommEstablished < threshold {
p.Delete() go p.Delete()
delete = false
} }
} }
// clean virtual/failed processes // clean virtual/failed processes
for _, p := range processesCopy { for _, p := range processesCopy {
p.Lock() p.Lock()
// check if internal
if p.Pid <= 0 {
p.Unlock()
continue
}
switch { switch {
case p.Error != "": case p.Error != "":
if p.Meta().Created < threshold { if p.Meta().Created < threshold {
delete = true delete = true
} }
case p.Virtual: case p.Virtual:
_, parentIsAlive := processes[p.ParentPid] _, parentIsActive := processesCopy[p.ParentPid]
if !parentIsAlive { active := true
if activePIDs != nil {
_, active = activePIDs[p.Pid]
}
if !parentIsActive || !active {
delete = true delete = true
} }
} }
p.Unlock() p.Unlock()
if delete { if delete {
p.Delete() log.Tracef("process.clean: deleted %s", p.DatabaseKey())
go p.Delete()
delete = false delete = false
} }
} }
} }
func getActivePIDs() (map[int]struct{}, error) {
procs, err := processInfo.Processes()
if err != nil {
return nil, err
}
activePIDs := make(map[int]struct{})
for _, p := range procs {
activePIDs[int(p.Pid)] = struct{}{}
}
return activePIDs, nil
}
// SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database. // SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database.
func SetDBController(controller *database.Controller) { func SetDBController(controller *database.Controller) {
dbController = controller dbController = controller

View file

@ -50,7 +50,6 @@ type Process struct {
FirstCommEstablished int64 FirstCommEstablished int64
LastCommEstablished int64 LastCommEstablished int64
CommCount uint
Virtual bool // This process is either merged into another process or is not needed. Virtual bool // This process is either merged into another process or is not needed.
Error string // If this is set, the process is invalid. This is used to cache failing or inexistent processes. Error string // If this is set, the process is invalid. This is used to cache failing or inexistent processes.
@ -80,23 +79,53 @@ func (p *Process) AddCommunication() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
p.CommCount++ // check if we should save
save := false
if p.LastCommEstablished < time.Now().Add(-3*time.Second).Unix() {
save = true
}
// update LastCommEstablished
p.LastCommEstablished = time.Now().Unix() p.LastCommEstablished = time.Now().Unix()
if p.FirstCommEstablished == 0 { if p.FirstCommEstablished == 0 {
p.FirstCommEstablished = p.LastCommEstablished p.FirstCommEstablished = p.LastCommEstablished
} }
}
// RemoveCommunication lowers the connection counter by one. if save {
func (p *Process) RemoveCommunication() { go p.Save()
p.Lock()
defer p.Unlock()
if p.CommCount > 0 {
p.CommCount--
} }
} }
// var db = database.NewInterface(nil)
// CountConnections returns the count of connections of a process
// func (p *Process) CountConnections() int {
// q, err := query.New(fmt.Sprintf("%s/%d/", processDatabaseNamespace, p.Pid)).
// Where(query.Where("Pid", query.Exists, nil)).
// Check()
// if err != nil {
// log.Warningf("process: failed to build query to get connection count of process: %s", err)
// return -1
// }
//
// it, err := db.Query(q)
// if err != nil {
// log.Warningf("process: failed to query db to get connection count of process: %s", err)
// return -1
// }
//
// cnt := 0
// for _ = range it.Next {
// cnt++
// }
// if it.Err() != nil {
// log.Warningf("process: failed to query db to get connection count of process: %s", err)
// return -1
// }
//
// return cnt
// }
// GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID. // GetOrFindPrimaryProcess returns the highest process in the tree that matches the given PID.
func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) {
log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid) log.Tracer(ctx).Tracef("process: getting primary process for PID %d", pid)