mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Add support for network service
This commit is contained in:
parent
3f8c99517f
commit
12f3c0ea8d
14 changed files with 320 additions and 65 deletions
|
@ -18,7 +18,7 @@ func registerActions() error {
|
|||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "core/restart",
|
||||
Read: api.PermitSelf,
|
||||
Read: api.PermitAdmin,
|
||||
ActionFunc: restart,
|
||||
}); err != nil {
|
||||
return err
|
||||
|
|
|
@ -12,16 +12,23 @@ var (
|
|||
CfgDevModeKey = "core/devMode"
|
||||
defaultDevMode bool
|
||||
|
||||
CfgNetworkServiceKey = "core/networkService"
|
||||
defaultNetworkServiceMode bool
|
||||
|
||||
CfgUseSystemNotificationsKey = "core/useSystemNotifications"
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&defaultDevMode, "devmode", false, "force development mode")
|
||||
flag.BoolVar(&defaultNetworkServiceMode, "network-service", false, "force network service mode")
|
||||
}
|
||||
|
||||
func logFlagOverrides() {
|
||||
if defaultDevMode {
|
||||
log.Warning("core: core/devMode default config is being forced by -devmode flag")
|
||||
log.Warningf("core: %s config is being forced by the -devmode flag", CfgDevModeKey)
|
||||
}
|
||||
if defaultNetworkServiceMode {
|
||||
log.Warningf("core: %s config is being forced by the -network-service flag", CfgNetworkServiceKey)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,6 +50,23 @@ func registerConfig() error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Network Service",
|
||||
Key: CfgNetworkServiceKey,
|
||||
Description: "Use the Portmaster as a network service, where applicable. You will have to take care of lots of network setup yourself in order to run this properly and securely.",
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelExpert,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: defaultNetworkServiceMode,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 513,
|
||||
config.CategoryAnnotation: "Network Service",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Desktop Notifications",
|
||||
Key: CfgUseSystemNotificationsKey,
|
||||
|
|
104
firewall/api.go
104
firewall/api.go
|
@ -6,10 +6,15 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/netenv"
|
||||
|
||||
"github.com/safing/portmaster/updates"
|
||||
|
||||
"github.com/safing/portbase/api"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -33,7 +38,10 @@ Checked process paths:
|
|||
%s
|
||||
|
||||
The authorized root path is %s.
|
||||
You can enable the Development Mode to disable API authentication for development purposes.`
|
||||
You can enable the Development Mode to disable API authentication for development purposes.
|
||||
For production use please create an API key in the settings.`
|
||||
|
||||
deniedMsgMisconfigured = `%wThe authentication system is misconfigured.`
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -80,11 +88,18 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
|
|||
return nil, fmt.Errorf("failed to get remote IP/Port: %s", err)
|
||||
}
|
||||
|
||||
// Check if the request is even local.
|
||||
myIP, err := netenv.IsMyIP(remoteIP)
|
||||
if err == nil && !myIP {
|
||||
// Return to caller that the request was not handled.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Tracer(r.Context()).Tracef("filter: authenticating API request from %s", r.RemoteAddr)
|
||||
|
||||
// It is very important that this works, retry extensively (every 250ms for 5s)
|
||||
var retry bool
|
||||
for tries := 0; tries < 20; tries++ {
|
||||
for tries := 0; tries < 5; tries++ {
|
||||
retry, err = authenticateAPIRequest(
|
||||
r.Context(),
|
||||
&packet.Info{
|
||||
|
@ -102,7 +117,7 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
|
|||
}
|
||||
|
||||
// wait a little
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -116,45 +131,58 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
|
|||
|
||||
func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bool, err error) {
|
||||
var procsChecked []string
|
||||
var originalPid int
|
||||
|
||||
// get process
|
||||
// Get authenticated path.
|
||||
authenticatedPath := updates.RootPath()
|
||||
if authenticatedPath == "" {
|
||||
return false, fmt.Errorf(deniedMsgMisconfigured, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
|
||||
}
|
||||
authenticatedPath += string(filepath.Separator)
|
||||
|
||||
// Get process of request.
|
||||
proc, _, err := process.GetProcessByConnection(ctx, pktInfo)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to get process: %s", err)
|
||||
}
|
||||
originalPid := proc.Pid
|
||||
var previousPid int
|
||||
log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err)
|
||||
originalPid = process.UnidentifiedProcessID
|
||||
} else {
|
||||
originalPid = proc.Pid
|
||||
var previousPid int
|
||||
|
||||
// go up up to two levels, if we don't match
|
||||
for i := 0; i < 5; i++ {
|
||||
// check for eligible PID
|
||||
switch proc.Pid {
|
||||
case process.UnidentifiedProcessID, process.SystemProcessID:
|
||||
break
|
||||
default: // normal process
|
||||
// check if the requesting process is in database root / updates dir
|
||||
if strings.HasPrefix(proc.Path, dataRoot.Path) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// add checked process to list
|
||||
procsChecked = append(procsChecked, proc.Path)
|
||||
|
||||
if i < 4 {
|
||||
// save previous PID
|
||||
previousPid = proc.Pid
|
||||
|
||||
// get parent process
|
||||
proc, err = process.GetOrFindProcess(ctx, proc.ParentPid)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to get process: %s", err)
|
||||
}
|
||||
|
||||
// abort if we are looping
|
||||
if proc.Pid == previousPid {
|
||||
// this also catches -1 pid loops
|
||||
// Go up up to two levels, if we don't match the path.
|
||||
checkLevels := 2
|
||||
for i := 0; i < checkLevels+1; i++ {
|
||||
// Check for eligible path.
|
||||
switch proc.Pid {
|
||||
case process.UnidentifiedProcessID, process.SystemProcessID:
|
||||
break
|
||||
default: // normal process
|
||||
// Check if the requesting process is in database root / updates dir.
|
||||
if strings.HasPrefix(proc.Path, authenticatedPath) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add checked path to list.
|
||||
procsChecked = append(procsChecked, proc.Path)
|
||||
|
||||
// Get the parent process.
|
||||
if i < checkLevels {
|
||||
// save previous PID
|
||||
previousPid = proc.Pid
|
||||
|
||||
// get parent process
|
||||
proc, err = process.GetOrFindProcess(ctx, proc.ParentPid)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Debugf("filter: failed to get parent process of api request: %s", err)
|
||||
break
|
||||
}
|
||||
|
||||
// abort if we are looping
|
||||
if proc.Pid == previousPid {
|
||||
// this also catches -1 pid loops
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -174,7 +202,7 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
|
|||
deniedMsgUnauthorized,
|
||||
api.ErrAPIAccessDeniedMessage,
|
||||
strings.Join(procsChecked, "\n"),
|
||||
dataRoot.Path,
|
||||
authenticatedPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package firewall
|
|||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -88,7 +87,7 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache) *res
|
|||
p := conn.Process().Profile()
|
||||
|
||||
// do not modify own queries
|
||||
if conn.Process().Pid == os.Getpid() {
|
||||
if conn.Process().Pid == ownPID {
|
||||
return rrCache
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/netenv"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -38,6 +40,8 @@ var (
|
|||
|
||||
blockedIPv4 = net.IPv4(0, 0, 0, 17)
|
||||
blockedIPv6 = net.ParseIP("::17")
|
||||
|
||||
ownPID = os.Getpid()
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -177,6 +181,16 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
// Only fast-track local requests.
|
||||
isMe, err := netenv.IsMyIP(meta.Src)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
|
||||
return false
|
||||
case !isMe:
|
||||
return false
|
||||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Debugf("filter: fast-track accepting api connection: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
|
@ -196,6 +210,16 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
// Only fast-track local requests.
|
||||
isMe, err := netenv.IsMyIP(meta.Src)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
|
||||
return false
|
||||
case !isMe:
|
||||
return false
|
||||
}
|
||||
|
||||
// Log and permit.
|
||||
log.Debugf("filter: fast-track accepting local dns: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
|
@ -224,7 +248,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
|
|||
// Redirect rogue dns requests to the Portmaster.
|
||||
if pkt.IsOutbound() &&
|
||||
pkt.Info().DstPort == 53 &&
|
||||
conn.Process().Pid != os.Getpid() &&
|
||||
conn.Process().Pid != ownPID &&
|
||||
nameserverIPMatcherReady.IsSet() &&
|
||||
!nameserverIPMatcher(pkt.Info().Dst) {
|
||||
conn.Verdict = network.VerdictRerouteToNameserver
|
||||
|
|
|
@ -3,7 +3,6 @@ package firewall
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
|
@ -118,8 +117,9 @@ func runDeciders(ctx context.Context, conn *network.Connection, pkt packet.Packe
|
|||
// checkPortmasterConnection allows all connection that originate from
|
||||
// portmaster itself.
|
||||
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
|
||||
// grant self
|
||||
if conn.Process().Pid == os.Getpid() {
|
||||
// Grant own outgoing connections.
|
||||
if conn.Process().Pid == ownPID &&
|
||||
(pkt == nil || pkt.IsOutbound()) {
|
||||
log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
|
||||
conn.Accept("connection by Portmaster", noReasonOptionKey)
|
||||
conn.Internal = true
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/core"
|
||||
)
|
||||
|
||||
// Config Keys
|
||||
|
@ -18,6 +19,8 @@ var (
|
|||
nameserverAddressConfig config.StringOption
|
||||
|
||||
defaultNameserverAddress = "localhost:53"
|
||||
|
||||
networkServiceMode config.BoolOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -65,5 +68,7 @@ func registerConfig() error {
|
|||
}
|
||||
nameserverAddressConfig = config.GetAsString(CfgDefaultNameserverAddressKey, getDefaultNameserverAddress())
|
||||
|
||||
networkServiceMode = config.Concurrent.GetAsBool(core.CfgNetworkServiceKey, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/safing/portmaster/netenv"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/resolver"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -94,25 +93,36 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
|
|||
return reply(nsutil.Localhost())
|
||||
}
|
||||
|
||||
// Authenticate request - only requests from the local host, but with any of its IPs, are allowed.
|
||||
local, err := netenv.IsMyIP(remoteAddr.IP)
|
||||
if err != nil {
|
||||
tracer.Warningf("nameserver: failed to check if request for %s%s is local: %s", q.FQDN, q.QType, err)
|
||||
return nil // Do no reply, drop request immediately.
|
||||
}
|
||||
if !local {
|
||||
tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
|
||||
return nil // Do no reply, drop request immediately.
|
||||
}
|
||||
|
||||
// Validate domain name.
|
||||
if !netutils.IsValidFqdn(q.FQDN) {
|
||||
tracer.Debugf("nameserver: domain name %s is invalid, refusing", q.FQDN)
|
||||
return reply(nsutil.Refused("invalid domain"))
|
||||
}
|
||||
|
||||
// Authenticate request - only requests from the local host, but with any of its IPs, are allowed.
|
||||
local, err := netenv.IsMyIP(remoteAddr.IP)
|
||||
if err != nil {
|
||||
tracer.Warningf("nameserver: failed to check if request for %s%s is local: %s", q.FQDN, q.QType, err)
|
||||
return nil // Do no reply, drop request immediately.
|
||||
}
|
||||
|
||||
// Get connection for this request. This identifies the process behind the request.
|
||||
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port))
|
||||
var conn *network.Connection
|
||||
switch {
|
||||
case local:
|
||||
conn = network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
|
||||
|
||||
case networkServiceMode():
|
||||
conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP)
|
||||
if err != nil {
|
||||
tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err)
|
||||
return nil // Do no reply, drop request immediately.
|
||||
}
|
||||
|
||||
default:
|
||||
tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
|
||||
return nil // Do no reply, drop request immediately.
|
||||
}
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
|
||||
|
|
|
@ -188,7 +188,13 @@ func getProcessContext(ctx context.Context, proc *process.Process) ProcessContex
|
|||
}
|
||||
|
||||
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
|
||||
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection {
|
||||
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
|
||||
// Determine IP version.
|
||||
ipVersion := packet.IPv6
|
||||
if localIP.To4() != nil {
|
||||
ipVersion = packet.IPv4
|
||||
}
|
||||
|
||||
// get Process
|
||||
proc, _, err := process.GetProcessByConnection(
|
||||
ctx,
|
||||
|
@ -222,6 +228,26 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
|
|||
return dnsConn
|
||||
}
|
||||
|
||||
func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, remoteIP net.IP) (*Connection, error) {
|
||||
remoteHost, err := process.GetNetworkHost(ctx, remoteIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timestamp := time.Now().Unix()
|
||||
return &Connection{
|
||||
Scope: fqdn,
|
||||
Entity: &intel.Entity{
|
||||
Domain: fqdn,
|
||||
CNAME: cnames,
|
||||
},
|
||||
process: remoteHost,
|
||||
ProcessContext: getProcessContext(ctx, remoteHost),
|
||||
Started: timestamp,
|
||||
Ended: timestamp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewConnectionFromFirstPacket returns a new connection based on the given packet.
|
||||
func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
||||
// get Process
|
||||
|
|
|
@ -2,11 +2,14 @@ package process
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/safing/portmaster/network/state"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/network/state"
|
||||
"github.com/safing/portmaster/profile"
|
||||
)
|
||||
|
||||
// GetProcessByConnection returns the process that owns the described connection.
|
||||
|
@ -41,3 +44,39 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process
|
|||
|
||||
return process, connInbound, nil
|
||||
}
|
||||
|
||||
func GetNetworkHost(ctx context.Context, remoteIP net.IP) (process *Process, err error) {
|
||||
now := time.Now().Unix()
|
||||
networkHost := &Process{
|
||||
Name: fmt.Sprintf("Network Host %s", remoteIP),
|
||||
UserName: "Unknown",
|
||||
UserID: -255,
|
||||
Pid: -255,
|
||||
ParentPid: -255,
|
||||
Path: fmt.Sprintf("net:%s", remoteIP),
|
||||
FirstSeen: now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
// Get the (linked) local profile.
|
||||
networkHostProfile, err := profile.GetNetworkHostProfile(remoteIP.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Assign profile to process.
|
||||
networkHost.LocalProfileKey = networkHostProfile.Key()
|
||||
networkHost.profile = networkHostProfile.LayeredProfile()
|
||||
|
||||
if networkHostProfile.Name == "" {
|
||||
// Assign name and save.
|
||||
networkHostProfile.Name = networkHost.Name
|
||||
|
||||
err := networkHostProfile.Save()
|
||||
if err != nil {
|
||||
log.Warningf("process: failed to save profile %s: %s", networkHostProfile.ScopedID(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return networkHost, nil
|
||||
}
|
||||
|
|
|
@ -2,11 +2,16 @@ package process
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/profile"
|
||||
)
|
||||
|
||||
var (
|
||||
ownPID = os.Getpid()
|
||||
)
|
||||
|
||||
// GetProfile finds and assigns a profile set to the process.
|
||||
func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) {
|
||||
// Update profile metadata outside of *Process lock.
|
||||
|
@ -31,6 +36,8 @@ func (p *Process) GetProfile(ctx context.Context) (changed bool, err error) {
|
|||
profileID = profile.UnidentifiedProfileID
|
||||
case SystemProcessID:
|
||||
profileID = profile.SystemProfileID
|
||||
case ownPID:
|
||||
profileID = profile.PortmasterProfileID
|
||||
}
|
||||
|
||||
// Get the (linked) local profile.
|
||||
|
@ -56,7 +63,7 @@ func (p *Process) UpdateProfileMetadata() {
|
|||
}
|
||||
|
||||
// Update metadata of profile.
|
||||
metadataUpdated := localProfile.UpdateMetadata(p.Name)
|
||||
metadataUpdated := localProfile.UpdateMetadata(p.Name, p.Path)
|
||||
|
||||
// Mark profile as used.
|
||||
profileChanged := localProfile.MarkUsed()
|
||||
|
|
|
@ -21,6 +21,9 @@ const (
|
|||
|
||||
// SystemProfileID is the profile ID used for the system/kernel.
|
||||
SystemProfileID = "_system"
|
||||
|
||||
// SystemProfileID is the profile ID used for the Portmaster itself.
|
||||
PortmasterProfileID = "_portmaster"
|
||||
)
|
||||
|
||||
var getProfileSingleInflight singleflight.Group
|
||||
|
@ -71,6 +74,9 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
|
|||
case SystemProfileID:
|
||||
profile = New(SourceLocal, SystemProfileID, linkedPath)
|
||||
err = nil
|
||||
case PortmasterProfileID:
|
||||
profile = New(SourceLocal, PortmasterProfileID, linkedPath)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,6 +138,76 @@ func GetProfile(source profileSource, id, linkedPath string) ( //nolint:gocognit
|
|||
return p.(*Profile), nil
|
||||
}
|
||||
|
||||
func GetNetworkHostProfile(remoteIP string) ( //nolint:gocognit
|
||||
profile *Profile,
|
||||
err error,
|
||||
) {
|
||||
scopedID := makeScopedID(SourceNetwork, remoteIP)
|
||||
|
||||
p, err, _ := getProfileSingleInflight.Do(scopedID, func() (interface{}, error) {
|
||||
var previousVersion *Profile
|
||||
|
||||
// Get profile via the scoped ID.
|
||||
// Check if there already is an active and not outdated profile.
|
||||
profile = getActiveProfile(scopedID)
|
||||
if profile != nil {
|
||||
profile.MarkStillActive()
|
||||
|
||||
if profile.outdated.IsSet() {
|
||||
previousVersion = profile
|
||||
} else {
|
||||
return profile, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get from database.
|
||||
profile, err = getProfile(scopedID)
|
||||
switch {
|
||||
case err == nil:
|
||||
// Continue.
|
||||
case errors.Is(err, database.ErrNotFound):
|
||||
// Create new profile.
|
||||
// If there was no profile in the database, create a new one, and return it.
|
||||
profile = New(SourceNetwork, remoteIP, "")
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process profiles coming directly from the database.
|
||||
// As we don't use any caching, these will be new objects.
|
||||
|
||||
// Mark the profile as being saved internally in order to not trigger an
|
||||
// update after saving it to the database.
|
||||
profile.internalSave = true
|
||||
|
||||
// Add a layeredProfile to network profiles.
|
||||
|
||||
// If we are refetching, assign the layered profile from the previous version.
|
||||
if previousVersion != nil {
|
||||
profile.layeredProfile = previousVersion.layeredProfile
|
||||
}
|
||||
|
||||
// Network profiles must have a layered profile, create a new one if it
|
||||
// does not yet exist.
|
||||
if profile.layeredProfile == nil {
|
||||
profile.layeredProfile = NewLayeredProfile(profile)
|
||||
}
|
||||
|
||||
// Add the profile to the currently active profiles.
|
||||
addActiveProfile(profile)
|
||||
|
||||
return profile, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
return nil, errors.New("profile getter returned nil")
|
||||
}
|
||||
|
||||
return p.(*Profile), nil
|
||||
}
|
||||
|
||||
// getProfile fetches the profile for the given scoped ID.
|
||||
func getProfile(scopedID string) (profile *Profile, err error) {
|
||||
// Get profile from the database.
|
||||
|
|
|
@ -32,6 +32,7 @@ type profileSource string
|
|||
const (
|
||||
SourceLocal profileSource = "local" // local, editable
|
||||
SourceSpecial profileSource = "special" // specials (read-only)
|
||||
SourceNetwork profileSource = "network"
|
||||
SourceCommunity profileSource = "community"
|
||||
SourceEnterprise profileSource = "enterprise"
|
||||
)
|
||||
|
@ -386,7 +387,7 @@ func EnsureProfile(r record.Record) (*Profile, error) {
|
|||
// the profile was changed. If there is data that needs to be fetched from the
|
||||
// operating system, it will start an async worker to fetch that data and save
|
||||
// the profile afterwards.
|
||||
func (profile *Profile) UpdateMetadata(processName string) (changed bool) {
|
||||
func (profile *Profile) UpdateMetadata(processName, binaryPath string) (changed bool) {
|
||||
// Check if this is a local profile, else warn and return.
|
||||
if profile.Source != SourceLocal {
|
||||
log.Warningf("tried to update metadata for non-local profile %s", profile.ScopedID())
|
||||
|
@ -397,7 +398,7 @@ func (profile *Profile) UpdateMetadata(processName string) (changed bool) {
|
|||
defer profile.Unlock()
|
||||
|
||||
// Check if this is a special profile.
|
||||
if profile.LinkedPath == "" {
|
||||
if binaryPath == "" {
|
||||
// This is a special profile, just assign the processName, if needed, and
|
||||
// return.
|
||||
if profile.Name != processName {
|
||||
|
@ -407,6 +408,13 @@ func (profile *Profile) UpdateMetadata(processName string) (changed bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
// Update LinkedPath if if differs from the process path.
|
||||
// This will (at the moment) only be the case for the Portmaster profile.
|
||||
if profile.LinkedPath != binaryPath {
|
||||
profile.LinkedPath = binaryPath
|
||||
changed = true
|
||||
}
|
||||
|
||||
var needsUpdateFromSystem bool
|
||||
|
||||
// Check profile name.
|
||||
|
|
|
@ -336,3 +336,12 @@ func stagingActive() bool {
|
|||
_, err := os.Stat(filepath.Join(registry.StorageDir().Path, "staging.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// RootPath returns the root path used for storing updates.
|
||||
func RootPath() string {
|
||||
if !module.Online() {
|
||||
return ""
|
||||
}
|
||||
|
||||
return registry.StorageDir().Path
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue