Add flags to specify custom update server

This commit is contained in:
Daniel 2023-04-20 12:54:59 +02:00
parent 16c756144a
commit 8273894f87
3 changed files with 58 additions and 19 deletions

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@ -24,11 +25,13 @@ import (
var ( var (
dataDir string dataDir string
staging bool
maxRetries int maxRetries int
dataRoot *utils.DirStructure dataRoot *utils.DirStructure
logsRoot *utils.DirStructure logsRoot *utils.DirStructure
updateURLFlag string
userAgentFlag string
// Create registry. // Create registry.
registry = &updater.ResourceRegistry{ registry = &updater.ResourceRegistry{
Name: "updates", Name: "updates",
@ -67,8 +70,8 @@ func init() {
flags := rootCmd.PersistentFlags() flags := rootCmd.PersistentFlags()
{ {
flags.StringVar(&dataDir, "data", "", "Configures the data directory. Alternatively, this can also be set via the environment variable PORTMASTER_DATA.") flags.StringVar(&dataDir, "data", "", "Configures the data directory. Alternatively, this can also be set via the environment variable PORTMASTER_DATA.")
flags.StringVar(&registry.UserAgent, "update-agent", "Start", "Sets the user agent for requests to the update server") flags.StringVar(&updateURLFlag, "update-server", "", "Set an alternative update server (full URL)")
flags.BoolVar(&staging, "staging", false, "Deprecated, configure in settings instead.") flags.StringVar(&userAgentFlag, "update-agent", "", "Set an alternative user agent for requests to the update server")
flags.IntVar(&maxRetries, "max-retries", 5, "Maximum number of retries when starting a Portmaster component") flags.IntVar(&maxRetries, "max-retries", 5, "Maximum number of retries when starting a Portmaster component")
flags.BoolVar(&stdinSignals, "input-signals", false, "Emulate signals using stdin.") flags.BoolVar(&stdinSignals, "input-signals", false, "Emulate signals using stdin.")
_ = rootCmd.MarkPersistentFlagDirname("data") _ = rootCmd.MarkPersistentFlagDirname("data")
@ -137,6 +140,25 @@ func initCobra() {
} }
func configureRegistry(mustLoadIndex bool) error { func configureRegistry(mustLoadIndex bool) error {
// Check if update server URL supplied via flag is a valid URL.
if updateURLFlag != "" {
u, err := url.Parse(updateURLFlag)
if err != nil {
return fmt.Errorf("supplied update server URL is invalid: %w", err)
}
if u.Scheme != "https" {
return errors.New("supplied update server URL must use HTTPS")
}
}
// Override values from flags.
if userAgentFlag != "" {
registry.UserAgent = userAgentFlag
}
if updateURLFlag != "" {
registry.UpdateURLs = []string{updateURLFlag}
}
// If dataDir is not set, check the environment variable. // If dataDir is not set, check the environment variable.
if dataDir == "" { if dataDir == "" {
dataDir = os.Getenv("PORTMASTER_DATA") dataDir = os.Getenv("PORTMASTER_DATA")

View file

@ -134,14 +134,12 @@ func logProgress(state *updater.RegistryState) {
len(downloadDetails.Resources), len(downloadDetails.Resources),
downloadDetails.Resources[downloadDetails.FinishedUpTo], downloadDetails.Resources[downloadDetails.FinishedUpTo],
) )
} else { } else if state.Updates.LastDownloadAt == nil {
if state.Updates.LastDownloadAt == nil {
log.Println("finalizing downloads") log.Println("finalizing downloads")
} }
} }
} }
} }
}
func purge() error { func purge() error {
portlog.SetLogLevel(portlog.TraceLevel) portlog.SetLogLevel(portlog.TraceLevel)

View file

@ -2,8 +2,10 @@ package updates
import ( import (
"context" "context"
"errors"
"flag" "flag"
"fmt" "fmt"
"net/url"
"runtime" "runtime"
"time" "time"
@ -43,7 +45,9 @@ const (
var ( var (
module *modules.Module module *modules.Module
registry *updater.ResourceRegistry registry *updater.ResourceRegistry
userAgentFromFlag string userAgentFromFlag string
updateServerFromFlag string
updateTask *modules.Task updateTask *modules.Task
updateASAP bool updateASAP bool
@ -59,6 +63,11 @@ var (
// fetching resources from the update server. // fetching resources from the update server.
UserAgent = fmt.Sprintf("Portmaster (%s %s)", runtime.GOOS, runtime.GOARCH) UserAgent = fmt.Sprintf("Portmaster (%s %s)", runtime.GOOS, runtime.GOARCH)
// DefaultUpdateURLs defines the default base URLs of the update server.
DefaultUpdateURLs = []string{
"https://updates.safing.io",
}
// DisableSoftwareAutoUpdate specifies whether software updates should be disabled. // DisableSoftwareAutoUpdate specifies whether software updates should be disabled.
// This is used on Android, as it will never require binary updates. // This is used on Android, as it will never require binary updates.
DisableSoftwareAutoUpdate = false DisableSoftwareAutoUpdate = false
@ -75,10 +84,8 @@ func init() {
module.RegisterEvent(VersionUpdateEvent, true) module.RegisterEvent(VersionUpdateEvent, true)
module.RegisterEvent(ResourceUpdateEvent, true) module.RegisterEvent(ResourceUpdateEvent, true)
flag.StringVar(&userAgentFromFlag, "update-agent", "", "set the user agent for requests to the update server") flag.StringVar(&updateServerFromFlag, "update-server", "", "set an alternative update server (full URL)")
flag.StringVar(&userAgentFromFlag, "update-agent", "", "set an alternative user agent for requests to the update server")
var dummy bool
flag.BoolVar(&dummy, "staging", false, "deprecated, configure in settings instead")
} }
func prep() error { func prep() error {
@ -86,6 +93,17 @@ func prep() error {
return err return err
} }
// Check if update server URL supplied via flag is a valid URL.
if updateServerFromFlag != "" {
u, err := url.Parse(updateServerFromFlag)
if err != nil {
return fmt.Errorf("supplied update server URL is invalid: %w", err)
}
if u.Scheme != "https" {
return errors.New("supplied update server URL must use HTTPS")
}
}
return registerAPIEndpoints() return registerAPIEndpoints()
} }
@ -105,9 +123,7 @@ func start() error {
// create registry // create registry
registry = &updater.ResourceRegistry{ registry = &updater.ResourceRegistry{
Name: ModuleName, Name: ModuleName,
UpdateURLs: []string{ UpdateURLs: DefaultUpdateURLs,
"https://updates.safing.io",
},
UserAgent: UserAgent, UserAgent: UserAgent,
MandatoryUpdates: helper.MandatoryUpdates(), MandatoryUpdates: helper.MandatoryUpdates(),
AutoUnpack: helper.AutoUnpackUpdates(), AutoUnpack: helper.AutoUnpackUpdates(),
@ -115,10 +131,13 @@ func start() error {
DevMode: devMode(), DevMode: devMode(),
Online: true, Online: true,
} }
// Override values from flags.
if userAgentFromFlag != "" { if userAgentFromFlag != "" {
// override with flag value
registry.UserAgent = userAgentFromFlag registry.UserAgent = userAgentFromFlag
} }
if updateServerFromFlag != "" {
registry.UpdateURLs = []string{updateServerFromFlag}
}
// pre-init state // pre-init state
updateStateExport, err := LoadStateExport() updateStateExport, err := LoadStateExport()