Add support for network service

This commit is contained in:
Daniel 2021-01-19 15:43:22 +01:00
parent 3f8c99517f
commit 12f3c0ea8d
14 changed files with 320 additions and 65 deletions

View file

@ -18,7 +18,7 @@ func registerActions() error {
if err := api.RegisterEndpoint(api.Endpoint{ if err := api.RegisterEndpoint(api.Endpoint{
Path: "core/restart", Path: "core/restart",
Read: api.PermitSelf, Read: api.PermitAdmin,
ActionFunc: restart, ActionFunc: restart,
}); err != nil { }); err != nil {
return err return err

View file

@ -12,16 +12,23 @@ var (
CfgDevModeKey = "core/devMode" CfgDevModeKey = "core/devMode"
defaultDevMode bool defaultDevMode bool
CfgNetworkServiceKey = "core/networkService"
defaultNetworkServiceMode bool
CfgUseSystemNotificationsKey = "core/useSystemNotifications" CfgUseSystemNotificationsKey = "core/useSystemNotifications"
) )
func init() { func init() {
flag.BoolVar(&defaultDevMode, "devmode", false, "force development mode") flag.BoolVar(&defaultDevMode, "devmode", false, "force development mode")
flag.BoolVar(&defaultNetworkServiceMode, "network-service", false, "force network service mode")
} }
func logFlagOverrides() { func logFlagOverrides() {
if defaultDevMode { if defaultDevMode {
log.Warning("core: core/devMode default config is being forced by -devmode flag") log.Warningf("core: %s config is being forced by the -devmode flag", CfgDevModeKey)
}
if defaultNetworkServiceMode {
log.Warningf("core: %s config is being forced by the -network-service flag", CfgNetworkServiceKey)
} }
} }
@ -43,6 +50,23 @@ func registerConfig() error {
return err return err
} }
err = config.Register(&config.Option{
Name: "Network Service",
Key: CfgNetworkServiceKey,
Description: "Use the Portmaster as a network service, where applicable. You will have to take care of lots of network setup yourself in order to run this properly and securely.",
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelExpert,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: defaultNetworkServiceMode,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: 513,
config.CategoryAnnotation: "Network Service",
},
})
if err != nil {
return err
}
err = config.Register(&config.Option{ err = config.Register(&config.Option{
Name: "Desktop Notifications", Name: "Desktop Notifications",
Key: CfgUseSystemNotificationsKey, Key: CfgUseSystemNotificationsKey,

View file

@ -6,10 +6,15 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/updates"
"github.com/safing/portbase/api" "github.com/safing/portbase/api"
"github.com/safing/portbase/dataroot" "github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -33,7 +38,10 @@ Checked process paths:
%s %s
The authorized root path is %s. The authorized root path is %s.
You can enable the Development Mode to disable API authentication for development purposes.` You can enable the Development Mode to disable API authentication for development purposes.
For production use please create an API key in the settings.`
deniedMsgMisconfigured = `%wThe authentication system is misconfigured.`
) )
var ( var (
@ -80,11 +88,18 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
return nil, fmt.Errorf("failed to get remote IP/Port: %s", err) return nil, fmt.Errorf("failed to get remote IP/Port: %s", err)
} }
// Check if the request is even local.
myIP, err := netenv.IsMyIP(remoteIP)
if err == nil && !myIP {
// Return to caller that the request was not handled.
return nil, nil
}
log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr) log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr)
// It is very important that this works, retry extensively (every 250ms for 5s) // It is very important that this works, retry extensively (every 250ms for 5s)
var retry bool var retry bool
for tries := 0; tries < 20; tries++ { for tries := 0; tries < 5; tries++ {
retry, err = authenticateAPIRequest( retry, err = authenticateAPIRequest(
r.Context(), r.Context(),
&packet.Info{ &packet.Info{
@ -102,7 +117,7 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
} }
// wait a little // wait a little
time.Sleep(250 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -116,39 +131,51 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bool, err error) { func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bool, err error) {
var procsChecked []string var procsChecked []string
var originalPid int
// get process // Get authenticated path.
authenticatedPath := updates.RootPath()
if authenticatedPath == "" {
return false, fmt.Errorf(deniedMsgMisconfigured, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
}
authenticatedPath += string(filepath.Separator)
// Get process of request.
proc, _, err := process.GetProcessByConnection(ctx, pktInfo) proc, _, err := process.GetProcessByConnection(ctx, pktInfo)
if err != nil { if err != nil {
return true, fmt.Errorf("failed to get process: %s", err) log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err)
} originalPid = process.UnidentifiedProcessID
originalPid := proc.Pid } else {
originalPid = proc.Pid
var previousPid int var previousPid int
// go up up to two levels, if we don't match // Go up up to two levels, if we don't match the path.
for i := 0; i < 5; i++ { checkLevels := 2
// check for eligible PID for i := 0; i < checkLevels+1; i++ {
// Check for eligible path.
switch proc.Pid { switch proc.Pid {
case process.UnidentifiedProcessID, process.SystemProcessID: case process.UnidentifiedProcessID, process.SystemProcessID:
break break
default: // normal process default: // normal process
// check if the requesting process is in database root / updates dir // Check if the requesting process is in database root / updates dir.
if strings.HasPrefix(proc.Path, dataRoot.Path) { if strings.HasPrefix(proc.Path, authenticatedPath) {
return false, nil return false, nil
} }
} }
// add checked process to list // Add checked path to list.
procsChecked = append(procsChecked, proc.Path) procsChecked = append(procsChecked, proc.Path)
if i < 4 { // Get the parent process.
if i < checkLevels {
// save previous PID // save previous PID
previousPid = proc.Pid previousPid = proc.Pid
// get parent process // get parent process
proc, err = process.GetOrFindProcess(ctx, proc.ParentPid) proc, err = process.GetOrFindProcess(ctx, proc.ParentPid)
if err != nil { if err != nil {
return true, fmt.Errorf("failed to get process: %s", err) log.Tracer(ctx).Debugf("filter: failed to get parent process of api request: %s", err)
break
} }
// abort if we are looping // abort if we are looping
@ -158,6 +185,7 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
} }
} }
} }
}
switch originalPid { switch originalPid {
case process.UnidentifiedProcessID: case process.UnidentifiedProcessID:
@ -174,7 +202,7 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
deniedMsgUnauthorized, deniedMsgUnauthorized,
api.ErrAPIAccessDeniedMessage, api.ErrAPIAccessDeniedMessage,
strings.Join(procsChecked, "\n"), strings.Join(procsChecked, "\n"),
dataRoot.Path, authenticatedPath,
) )
} }
} }

View file

@ -3,7 +3,6 @@ package firewall
import ( import (
"context" "context"
"net" "net"
"os"
"strings" "strings"
"time" "time"
@ -88,7 +87,7 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res
p := conn.Process().Profile() p := conn.Process().Profile()
// do not modify own queries // do not modify own queries
if conn.Process().Pid == os.Getpid() { if conn.Process().Pid == ownPID {
return rrCache return rrCache
} }

View file

@ -8,6 +8,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/safing/portmaster/netenv"
"github.com/tevino/abool" "github.com/tevino/abool"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -38,6 +40,8 @@ var (
blockedIPv4 = net.IPv4(0, 0, 0, 17) blockedIPv4 = net.IPv4(0, 0, 0, 17)
blockedIPv6 = net.ParseIP("::17") blockedIPv6 = net.ParseIP("::17")
ownPID = os.Getpid()
) )
func init() { func init() {
@ -177,6 +181,16 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
return false return false
} }
// Only fast-track local requests.
isMe, err := netenv.IsMyIP(meta.Src)
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
case !isMe:
return false
}
// Log and permit. // Log and permit.
log.Debugf("filter: fast-track accepting api connection: %s", pkt) log.Debugf("filter: fast-track accepting api connection: %s", pkt)
_ = pkt.PermanentAccept() _ = pkt.PermanentAccept()
@ -196,6 +210,16 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
return false return false
} }
// Only fast-track local requests.
isMe, err := netenv.IsMyIP(meta.Src)
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
case !isMe:
return false
}
// Log and permit. // Log and permit.
log.Debugf("filter: fast-track accepting local dns: %s", pkt) log.Debugf("filter: fast-track accepting local dns: %s", pkt)
_ = pkt.PermanentAccept() _ = pkt.PermanentAccept()
@ -224,7 +248,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
// Redirect rogue dns requests to the Portmaster. // Redirect rogue dns requests to the Portmaster.
if pkt.IsOutbound() && if pkt.IsOutbound() &&
pkt.Info().DstPort == 53 && pkt.Info().DstPort == 53 &&
conn.Process().Pid != os.Getpid() && conn.Process().Pid != ownPID &&
nameserverIPMatcherReady.IsSet() && nameserverIPMatcherReady.IsSet() &&
!nameserverIPMatcher(pkt.Info().Dst) { !nameserverIPMatcher(pkt.Info().Dst) {
conn.Verdict = network.VerdictRerouteToNameserver conn.Verdict = network.VerdictRerouteToNameserver

View file

@ -3,7 +3,6 @@ package firewall
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -118,8 +117,9 @@ func runDeciders(ctx context.Context, conn *network.Connection, pkt packet.Packe
// checkPortmasterConnection allows all connection that originate from // checkPortmasterConnection allows all connection that originate from
// portmaster itself. // portmaster itself.
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
// grant self // Grant own outgoing connections.
if conn.Process().Pid == os.Getpid() { if conn.Process().Pid == ownPID &&
(pkt == nil || pkt.IsOutbound()) {
log.Tracer(ctx).Infof("filter: granting own connection %s", conn) log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
conn.Accept("connection by Portmaster", noReasonOptionKey) conn.Accept("connection by Portmaster", noReasonOptionKey)
conn.Internal = true conn.Internal = true

View file

@ -6,6 +6,7 @@ import (
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/core"
) )
// Config Keys // Config Keys
@ -18,6 +19,8 @@ var (
nameserverAddressConfig config.StringOption nameserverAddressConfig config.StringOption
defaultNameserverAddress = "localhost:53" defaultNameserverAddress = "localhost:53"
networkServiceMode config.BoolOption
) )
func init() { func init() {
@ -65,5 +68,7 @@ func registerConfig() error {
} }
nameserverAddressConfig = config.GetAsString(CfgDefaultNameserverAddressKey, getDefaultNameserverAddress()) nameserverAddressConfig = config.GetAsString(CfgDefaultNameserverAddressKey, getDefaultNameserverAddress())
networkServiceMode = config.Concurrent.GetAsBool(core.CfgNetworkServiceKey, false)
return nil return nil
} }

View file

@ -12,7 +12,6 @@ import (
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/resolver" "github.com/safing/portmaster/resolver"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -94,25 +93,36 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
return reply(nsutil.Localhost()) return reply(nsutil.Localhost())
} }
// Authenticate request - only requests from the local host, but with any of its IPs, are allowed.
local, err := netenv.IsMyIP(remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to check if request for %s%s is local: %s", q.FQDN, q.QType, err)
return nil // Do no reply, drop request immediately.
}
if !local {
tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
return nil // Do no reply, drop request immediately.
}
// Validate domain name. // Validate domain name.
if !netutils.IsValidFqdn(q.FQDN) { if !netutils.IsValidFqdn(q.FQDN) {
tracer.Debugf("nameserver: domain name %s is invalid, refusing", q.FQDN) tracer.Debugf("nameserver: domain name %s is invalid, refusing", q.FQDN)
return reply(nsutil.Refused("invalid domain")) return reply(nsutil.Refused("invalid domain"))
} }
// Authenticate request - only requests from the local host, but with any of its IPs, are allowed.
local, err := netenv.IsMyIP(remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to check if request for %s%s is local: %s", q.FQDN, q.QType, err)
return nil // Do no reply, drop request immediately.
}
// Get connection for this request. This identifies the process behind the request. // Get connection for this request. This identifies the process behind the request.
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port)) var conn *network.Connection
switch {
case local:
conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
case networkServiceMode():
conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err)
return nil // Do no reply, drop request immediately.
}
default:
tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
return nil // Do no reply, drop request immediately.
}
conn.Lock() conn.Lock()
defer conn.Unlock() defer conn.Unlock()

View file

@ -188,7 +188,13 @@ func getProcessContext(ctx context.Context, proc *process.Process) ProcessContex
} }
// NewConnectionFromDNSRequest returns a new connection based on the given dns request. // NewConnectionFromDNSRequest returns a new connection based on the given dns request.
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection { func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
// Determine IP version.
ipVersion := packet.IPv6
if localIP.To4() != nil {
ipVersion = packet.IPv4
}
// get Process // get Process
proc, _, err := process.GetProcessByConnection( proc, _, err := process.GetProcessByConnection(
ctx, ctx,
@ -222,6 +228,26 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
return dnsConn return dnsConn
} }
func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, remoteIP net.IP) (*Connection, error) {
remoteHost, err := process.GetNetworkHost(ctx, remoteIP)
if err != nil {
return nil, err
}
timestamp := time.Now().Unix()
return &Connection{
Scope: fqdn,
Entity: &intel.Entity{
Domain: fqdn,
CNAME: cnames,
},
process: remoteHost,
ProcessContext: getProcessContext(ctx, remoteHost),
Started: timestamp,
Ended: timestamp,
}, nil
}
// NewConnectionFromFirstPacket returns a new connection based on the given packet. // NewConnectionFromFirstPacket returns a new connection based on the given packet.
func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
// get Process // get Process

View file

@ -2,11 +2,14 @@ package process
import ( import (
"context" "context"
"fmt"
"github.com/safing/portmaster/network/state" "net"
"time"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/state"
"github.com/safing/portmaster/profile"
) )
// GetProcessByConnection returns the process that owns the described connection. // GetProcessByConnection returns the process that owns the described connection.
@ -41,3 +44,39 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process
return process, connInbound, nil return process, connInbound, nil
} }
func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err error) {
now := time.Now().Unix()
networkHost := &Process{
Name: fmt.Sprintf("Network Host %s", remoteIP),
UserName: "Unknown",
UserID: -255,
Pid: -255,
ParentPid: -255,
Path: fmt.Sprintf("net:%s", remoteIP),
FirstSeen: now,
LastSeen: now,
}
// Get the (linked) local profile.
networkHostProfile, err := profile.GetNetworkHostProfile(remoteIP.String())
if err != nil {
return nil, err
}
// Assign profile to process.
networkHost.LocalProfileKey = networkHostProfile.Key()
networkHost.profile = networkHostProfile.LayeredProfile()
if networkHostProfile.Name == "" {
// Assign name and save.
networkHostProfile.Name = networkHost.Name
err := networkHostProfile.Save()
if err != nil {
log.Warningf("process: failed to save profile %s: %s", networkHostProfile.ScopedID(), err)
}
}
return networkHost, nil
}

View file

@ -2,11 +2,16 @@ package process
import ( import (
"context" "context"
"os"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile"
) )
var (
ownPID = os.Getpid()
)
// GetProfile finds and assigns a profile set to the process. // GetProfile finds and assigns a profile set to the process.
func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) { func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) {
// Update profile metadata outside of *Process lock. // Update profile metadata outside of *Process lock.
@ -31,6 +36,8 @@ func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) {
profileID = profile.UnidentifiedProfileID profileID = profile.UnidentifiedProfileID
case SystemProcessID: case SystemProcessID:
profileID = profile.SystemProfileID profileID = profile.SystemProfileID
case ownPID:
profileID = profile.PortmasterProfileID
} }
// Get the (linked) local profile. // Get the (linked) local profile.
@ -56,7 +63,7 @@ func (p *Process) UpdateProfileMetadata() {
} }
// Update metadata of profile. // Update metadata of profile.
metadataUpdated := localProfile.UpdateMetadata(p.Name) metadataUpdated := localProfile.UpdateMetadata(p.Name, p.Path)
// Mark profile as used. // Mark profile as used.
profileChanged := localProfile.MarkUsed() profileChanged := localProfile.MarkUsed()

View file

@ -21,6 +21,9 @@ const (
// SystemProfileID is the profile ID used for the system/kernel. // SystemProfileID is the profile ID used for the system/kernel.
SystemProfileID = "_system" SystemProfileID = "_system"
// SystemProfileID is the profile ID used for the Portmaster itself.
PortmasterProfileID = "_portmaster"
) )
var getProfileSingleInflight singleflight.Group var getProfileSingleInflight singleflight.Group
@ -71,6 +74,9 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
case SystemProfileID: case SystemProfileID:
profile = New(SourceLocal, SystemProfileID, linkedPath) profile = New(SourceLocal, SystemProfileID, linkedPath)
err = nil err = nil
case PortmasterProfileID:
profile = New(SourceLocal, PortmasterProfileID, linkedPath)
err = nil
} }
} }
@ -132,6 +138,76 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
return p.(*Profile), nil return p.(*Profile), nil
} }
func GetNetworkHostProfile(remoteIP string) ( //nolint:gocognit
profile *Profile,
err error,
) {
scopedID := makeScopedID(SourceNetwork, remoteIP)
p, err, _ := getProfileSingleInflight.Do(scopedID, func() (interface{}, error) {
var previousVersion *Profile
// Get profile via the scoped ID.
// Check if there already is an active and not outdated profile.
profile = getActiveProfile(scopedID)
if profile != nil {
profile.MarkStillActive()
if profile.outdated.IsSet() {
previousVersion = profile
} else {
return profile, nil
}
}
// Get from database.
profile, err = getProfile(scopedID)
switch {
case err == nil:
// Continue.
case errors.Is(err, database.ErrNotFound):
// Create new profile.
// If there was no profile in the database, create a new one, and return it.
profile = New(SourceNetwork, remoteIP, "")
default:
return nil, err
}
// Process profiles coming directly from the database.
// As we don't use any caching, these will be new objects.
// Mark the profile as being saved internally in order to not trigger an
// update after saving it to the database.
profile.internalSave = true
// Add a layeredProfile to network profiles.
// If we are refetching, assign the layered profile from the previous version.
if previousVersion != nil {
profile.layeredProfile = previousVersion.layeredProfile
}
// Network profiles must have a layered profile, create a new one if it
// does not yet exist.
if profile.layeredProfile == nil {
profile.layeredProfile = NewLayeredProfile(profile)
}
// Add the profile to the currently active profiles.
addActiveProfile(profile)
return profile, nil
})
if err != nil {
return nil, err
}
if p == nil {
return nil, errors.New("profile getter returned nil")
}
return p.(*Profile), nil
}
// getProfile fetches the profile for the given scoped ID. // getProfile fetches the profile for the given scoped ID.
func getProfile(scopedID string) (profile *Profile, err error) { func getProfile(scopedID string) (profile *Profile, err error) {
// Get profile from the database. // Get profile from the database.

View file

@ -32,6 +32,7 @@ type profileSource string
const ( const (
SourceLocal profileSource = "local" // local, editable SourceLocal profileSource = "local" // local, editable
SourceSpecial profileSource = "special" // specials (read-only) SourceSpecial profileSource = "special" // specials (read-only)
SourceNetwork profileSource = "network"
SourceCommunity profileSource = "community" SourceCommunity profileSource = "community"
SourceEnterprise profileSource = "enterprise" SourceEnterprise profileSource = "enterprise"
) )
@ -386,7 +387,7 @@ func EnsureProfile(r record.Record) (*Profile, error) {
// the profile was changed. If there is data that needs to be fetched from the // the profile was changed. If there is data that needs to be fetched from the
// operating system, it will start an async worker to fetch that data and save // operating system, it will start an async worker to fetch that data and save
// the profile afterwards. // the profile afterwards.
func (profile *Profile) UpdateMetadata(processName string) (changed bool) { func (profile *Profile) UpdateMetadata(processName, binaryPath string) (changed bool) {
// Check if this is a local profile, else warn and return. // Check if this is a local profile, else warn and return.
if profile.Source != SourceLocal { if profile.Source != SourceLocal {
log.Warningf("tried to update metadata for non-local profile %s", profile.ScopedID()) log.Warningf("tried to update metadata for non-local profile %s", profile.ScopedID())
@ -397,7 +398,7 @@ func (profile *Profile) UpdateMetadata(processName string) (changed bool) {
defer profile.Unlock() defer profile.Unlock()
// Check if this is a special profile. // Check if this is a special profile.
if profile.LinkedPath == "" { if binaryPath == "" {
// This is a special profile, just assign the processName, if needed, and // This is a special profile, just assign the processName, if needed, and
// return. // return.
if profile.Name != processName { if profile.Name != processName {
@ -407,6 +408,13 @@ func (profile *Profile) UpdateMetadata(processName string) (changed bool) {
return false return false
} }
// Update LinkedPath if if differs from the process path.
// This will (at the moment) only be the case for the Portmaster profile.
if profile.LinkedPath != binaryPath {
profile.LinkedPath = binaryPath
changed = true
}
var needsUpdateFromSystem bool var needsUpdateFromSystem bool
// Check profile name. // Check profile name.

View file

@ -336,3 +336,12 @@ func stagingActive() bool {
_, err := os.Stat(filepath.Join(registry.StorageDir().Path, "staging.json")) _, err := os.Stat(filepath.Join(registry.StorageDir().Path, "staging.json"))
return err == nil return err == nil
} }
// RootPath returns the root path used for storing updates.
func RootPath() string {
if !module.Online() {
return ""
}
return registry.StorageDir().Path
}