Simplify profile reloading

Also, increase prompt decision timeout.
This commit is contained in:
Daniel 2021-01-25 17:04:59 +01:00
parent cad957bae0
commit 9cf214fdff
12 changed files with 48 additions and 125 deletions

View file

@ -97,7 +97,7 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr) log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr)
// It is very important that this works, retry extensively (every 250ms for 5s) // It is important that this works, retry 5 times: every 500ms for 2.5s.
var retry bool var retry bool
for tries := 0; tries < 5; tries++ { for tries := 0; tries < 5; tries++ {
retry, err = authenticateAPIRequest( retry, err = authenticateAPIRequest(

View file

@ -6,7 +6,7 @@ import (
"github.com/safing/portmaster/core" "github.com/safing/portmaster/core"
) )
// Configuration Keys // Configuration Keys.
var ( var (
CfgOptionEnableFilterKey = "filter/enable" CfgOptionEnableFilterKey = "filter/enable"

View file

@ -32,6 +32,8 @@ const (
var ( var (
promptNotificationCreation sync.Mutex promptNotificationCreation sync.Mutex
decisionTimeout int64 = 10 // in seconds
) )
type promptData struct { type promptData struct {
@ -45,10 +47,16 @@ type promptProfile struct {
LinkedPath string LinkedPath string
} }
func prompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) { //nolint:gocognit // TODO func prompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
// Create notification. // Create notification.
n := createPrompt(ctx, conn, pkt) n := createPrompt(ctx, conn, pkt)
// Get decision timeout and make sure it does not exceed the ask timeout.
timeout := decisionTimeout
if timeout > askTimeout() {
timeout = askTimeout()
}
// wait for response/timeout // wait for response/timeout
select { select {
case promptResponse := <-n.Response(): case promptResponse := <-n.Response():
@ -59,7 +67,7 @@ func prompt(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey) conn.Deny("blocked via prompt", profile.CfgOptionEndpointsKey)
} }
case <-time.After(1 * time.Second): case <-time.After(time.Duration(timeout) * time.Second):
log.Tracer(ctx).Debugf("filter: continuing prompting async") log.Tracer(ctx).Debugf("filter: continuing prompting async")
conn.Deny("prompting in progress, please respond to prompt", profile.CfgOptionDefaultActionKey) conn.Deny("prompting in progress, please respond to prompt", profile.CfgOptionDefaultActionKey)

View file

@ -4,10 +4,11 @@ import (
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
) )
// Configuration Keys // Configuration Keys.
var ( var (
CfgOptionEnableProcessDetectionKey = "core/enableProcessDetection" CfgOptionEnableProcessDetectionKey = "core/enableProcessDetection"
enableProcessDetection config.BoolOption
enableProcessDetection config.BoolOption
) )
func registerConfiguration() error { func registerConfiguration() error {

View file

@ -45,21 +45,21 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process
return process, connInbound, nil return process, connInbound, nil
} }
func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err error) { func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err error) { //nolint:interfacer
now := time.Now().Unix() now := time.Now().Unix()
networkHost := &Process{ networkHost := &Process{
Name: fmt.Sprintf("Network Host %s", remoteIP), Name: fmt.Sprintf("Network Host %s", remoteIP),
UserName: "Unknown", UserName: "Unknown",
UserID: -255, UserID: NetworkHostProcessID,
Pid: -255, Pid: NetworkHostProcessID,
ParentPid: -255, ParentPid: NetworkHostProcessID,
Path: fmt.Sprintf("net:%s", remoteIP), Path: fmt.Sprintf("net:%s", remoteIP),
FirstSeen: now, FirstSeen: now,
LastSeen: now, LastSeen: now,
} }
// Get the (linked) local profile. // Get the (linked) local profile.
networkHostProfile, err := profile.GetNetworkHostProfile(remoteIP.String()) networkHostProfile, err := profile.GetProfile(profile.SourceNetwork, remoteIP.String(), "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -25,7 +25,7 @@ const (
var getProcessSingleInflight singleflight.Group var getProcessSingleInflight singleflight.Group
// A Process represents a process running on the operating system // A Process represents a process running on the operating system.
type Process struct { type Process struct {
record.Base record.Base
sync.Mutex sync.Mutex

View file

@ -9,9 +9,14 @@ import (
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
// UnidentifiedProcessID is the PID used for anything that could not be const (
// attributed to a PID for any reason. // UnidentifiedProcessID is the PID used for anything that could not be
const UnidentifiedProcessID = -1 // attributed to a PID for any reason.
UnidentifiedProcessID = -1
// NetworkHostProcessID is the PID used for requests served to the network.
NetworkHostProcessID = -255
)
var ( var (
// unidentifiedProcess is used when a process cannot be found. // unidentifiedProcess is used when a process cannot be found.

View file

@ -58,6 +58,13 @@ func addActiveProfile(profile *Profile) {
activeProfilesLock.Lock() activeProfilesLock.Lock()
defer activeProfilesLock.Unlock() defer activeProfilesLock.Unlock()
// Mark any previous profile as outdated.
previous, ok := activeProfiles[profile.ScopedID()]
if ok {
previous.outdated.Set()
}
// Mark new profile active and add to active profiles.
profile.MarkStillActive() profile.MarkStillActive()
activeProfiles[profile.ScopedID()] = profile activeProfiles[profile.ScopedID()] = profile
} }

View file

@ -8,7 +8,7 @@ import (
"github.com/safing/portmaster/status" "github.com/safing/portmaster/status"
) )
// Configuration Keys // Configuration Keys.
var ( var (
cfgStringOptions = make(map[string]config.StringOption) cfgStringOptions = make(map[string]config.StringOption)
cfgStringArrayOptions = make(map[string]config.StringArrayOption) cfgStringArrayOptions = make(map[string]config.StringArrayOption)

View file

@ -47,7 +47,6 @@ func startProfileUpdateChecker() error {
} }
module.StartServiceWorker("update active profiles", 0, func(ctx context.Context) (err error) { module.StartServiceWorker("update active profiles", 0, func(ctx context.Context) (err error) {
feedSelect:
for { for {
select { select {
case r := <-profilesSub.Feed: case r := <-profilesSub.Feed:
@ -56,14 +55,6 @@ func startProfileUpdateChecker() error {
return errors.New("subscription canceled") return errors.New("subscription canceled")
} }
// check if internal save
if !r.IsWrapped() {
profile, ok := r.(*Profile)
if ok && profile.internalSave {
continue feedSelect
}
}
// mark as outdated // mark as outdated
markActiveProfileAsOutdated(strings.TrimPrefix(r.Key(), profilesDBPath)) markActiveProfileAsOutdated(strings.TrimPrefix(r.Key(), profilesDBPath))
case <-ctx.Done(): case <-ctx.Done():

View file

@ -28,7 +28,7 @@ const (
var getProfileSingleInflight singleflight.Group var getProfileSingleInflight singleflight.Group
// GetProfile fetches a profile. This function ensure that the profile loaded // GetProfile fetches a profile. This function ensures that the loaded profile
// is shared among all callers. You must always supply both the scopedID and // is shared among all callers. You must always supply both the scopedID and
// linkedPath parameters whenever available. // linkedPath parameters whenever available.
func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
@ -105,12 +105,8 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
// Process profiles coming directly from the database. // Process profiles coming directly from the database.
// As we don't use any caching, these will be new objects. // As we don't use any caching, these will be new objects.
// Mark the profile as being saved internally in order to not trigger an // Add a layeredProfile to local and network profiles.
// update after saving it to the database. if profile.Source == SourceLocal || profile.Source == SourceNetwork {
profile.internalSave = true
// Add a layeredProfile to local profiles.
if profile.Source == SourceLocal {
// If we are refetching, assign the layered profile from the previous version. // If we are refetching, assign the layered profile from the previous version.
if previousVersion != nil { if previousVersion != nil {
profile.layeredProfile = previousVersion.layeredProfile profile.layeredProfile = previousVersion.layeredProfile
@ -138,76 +134,6 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
return p.(*Profile), nil return p.(*Profile), nil
} }
func GetNetworkHostProfile(remoteIP string) ( //nolint:gocognit
profile *Profile,
err error,
) {
scopedID := makeScopedID(SourceNetwork, remoteIP)
p, err, _ := getProfileSingleInflight.Do(scopedID, func() (interface{}, error) {
var previousVersion *Profile
// Get profile via the scoped ID.
// Check if there already is an active and not outdated profile.
profile = getActiveProfile(scopedID)
if profile != nil {
profile.MarkStillActive()
if profile.outdated.IsSet() {
previousVersion = profile
} else {
return profile, nil
}
}
// Get from database.
profile, err = getProfile(scopedID)
switch {
case err == nil:
// Continue.
case errors.Is(err, database.ErrNotFound):
// Create new profile.
// If there was no profile in the database, create a new one, and return it.
profile = New(SourceNetwork, remoteIP, "")
default:
return nil, err
}
// Process profiles coming directly from the database.
// As we don't use any caching, these will be new objects.
// Mark the profile as being saved internally in order to not trigger an
// update after saving it to the database.
profile.internalSave = true
// Add a layeredProfile to network profiles.
// If we are refetching, assign the layered profile from the previous version.
if previousVersion != nil {
profile.layeredProfile = previousVersion.layeredProfile
}
// Network profiles must have a layered profile, create a new one if it
// does not yet exist.
if profile.layeredProfile == nil {
profile.layeredProfile = NewLayeredProfile(profile)
}
// Add the profile to the currently active profiles.
addActiveProfile(profile)
return profile, nil
})
if err != nil {
return nil, err
}
if p == nil {
return nil, errors.New("profile getter returned nil")
}
return p.(*Profile), nil
}
// getProfile fetches the profile for the given scoped ID. // getProfile fetches the profile for the given scoped ID.
func getProfile(scopedID string) (profile *Profile, err error) { func getProfile(scopedID string) (profile *Profile, err error) {
// Get profile from the database. // Get profile from the database.
@ -266,13 +192,13 @@ func prepProfile(r record.Record) (*Profile, error) {
// prepare config // prepare config
err = profile.prepConfig() err = profile.prepConfig()
if err != nil { if err != nil {
log.Warningf("profiles: profile %s has (partly) invalid configuration: %s", profile.ID, err) log.Errorf("profiles: profile %s has (partly) invalid configuration: %s", profile.ID, err)
} }
// parse config // parse config
err = profile.parseConfig() err = profile.parseConfig()
if err != nil { if err != nil {
log.Warningf("profiles: profile %s has (partly) invalid configuration: %s", profile.ID, err) log.Errorf("profiles: profile %s has (partly) invalid configuration: %s", profile.ID, err)
} }
// return parsed profile // return parsed profile

View file

@ -127,8 +127,6 @@ type Profile struct { //nolint:maligned // not worth the effort
// Lifecycle Management // Lifecycle Management
outdated *abool.AtomicBool outdated *abool.AtomicBool
lastActive *int64 lastActive *int64
internalSave bool
} }
func (profile *Profile) prepConfig() (err error) { func (profile *Profile) prepConfig() (err error) {
@ -197,12 +195,11 @@ func (profile *Profile) parseConfig() error {
// New returns a new Profile. // New returns a new Profile.
func New(source profileSource, id string, linkedPath string) *Profile { func New(source profileSource, id string, linkedPath string) *Profile {
profile := &Profile{ profile := &Profile{
ID: id, ID: id,
Source: source, Source: source,
LinkedPath: linkedPath, LinkedPath: linkedPath,
Created: time.Now().Unix(), Created: time.Now().Unix(),
Config: make(map[string]interface{}), Config: make(map[string]interface{}),
internalSave: true,
} }
// Generate random ID if none is given. // Generate random ID if none is given.
@ -301,18 +298,6 @@ func (profile *Profile) addEndpointyEntry(cfgKey, newEntry string) {
} }
}() }()
// When finished increase the revision counter of the layered profile.
defer func() {
if !changed || profile.layeredProfile == nil {
return
}
profile.layeredProfile.Lock()
defer profile.layeredProfile.Unlock()
profile.layeredProfile.RevisionCounter++
}()
// Lock the profile for editing. // Lock the profile for editing.
profile.Lock() profile.Lock()
defer profile.Unlock() defer profile.Unlock()
@ -350,7 +335,7 @@ func (profile *Profile) addEndpointyEntry(cfgKey, newEntry string) {
profile.dataParsed = false profile.dataParsed = false
err := profile.parseConfig() err := profile.parseConfig()
if err != nil { if err != nil {
log.Warningf("profile: failed to parse %s config after adding endpoint: %s", profile, err) log.Errorf("profile: failed to parse %s config after adding endpoint: %s", profile, err)
} }
} }