diff --git a/updater/fetch.go b/updater/fetch.go index c2239a8..4638642 100644 --- a/updater/fetch.go +++ b/updater/fetch.go @@ -2,6 +2,7 @@ package updater import ( "bytes" + "context" "fmt" "io" "net/http" @@ -16,10 +17,14 @@ import ( "github.com/safing/portbase/log" ) -func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { +func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, rv *ResourceVersion, tries int) error { // backoff when retrying if tries > 0 { - time.Sleep(time.Duration(tries*tries) * time.Second) + select { + case <-ctx.Done(): + return nil // module is shutting down + case <-time.After(time.Duration(tries*tries) * time.Second): + } } // create URL @@ -44,7 +49,11 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway // start file download - resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec + if err != nil { + return fmt.Errorf("error creating request (%s): %w", downloadURL, err) + } + resp, err := client.Do(req) if err != nil { return fmt.Errorf("error fetching url (%s): %w", downloadURL, err) } @@ -81,10 +90,14 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { return nil } -func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, error) { +func (reg *ResourceRegistry) fetchData(ctx context.Context, client *http.Client, downloadPath string, tries int) ([]byte, error) { // backoff when retrying if tries > 0 { - time.Sleep(time.Duration(tries*tries) * time.Second) + select { + case <-ctx.Done(): + return nil, nil // module is shutting down + case <-time.After(time.Duration(tries*tries) * time.Second): + } } // create URL @@ -94,7 +107,11 @@ func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, } // start file download - resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec + if err != nil { + return nil, fmt.Errorf("error creating request (%s): %w", downloadURL, err) + } + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("error fetching url (%s): %w", downloadURL, err) } diff --git a/updater/get.go b/updater/get.go index 9b92124..fd0d133 100644 --- a/updater/get.go +++ b/updater/get.go @@ -1,8 +1,10 @@ package updater import ( + "context" "errors" "fmt" + "net/http" "github.com/safing/portbase/log" ) @@ -43,8 +45,9 @@ func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) { // download file log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath) + client := &http.Client{} for tries := 0; tries < 5; tries++ { - err = reg.fetchFile(file.version, tries) + err = reg.fetchFile(context.TODO(), client, file.version, tries) if err != nil { log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1) } else { diff --git a/updater/storage.go b/updater/storage.go index f444d04..4916501 100644 --- a/updater/storage.go +++ b/updater/storage.go @@ -1,10 +1,12 @@ package updater import ( + "context" "encoding/json" "errors" "fmt" "io/ioutil" + "net/http" "os" "path/filepath" "strings" @@ -78,8 +80,9 @@ func (reg *ResourceRegistry) ScanStorage(root string) error { // LoadIndexes loads the current release indexes from disk // or will fetch a new version if not available and the // registry is marked as online. -func (reg *ResourceRegistry) LoadIndexes() error { +func (reg *ResourceRegistry) LoadIndexes(ctx context.Context) error { var firstErr error + client := &http.Client{} for _, idx := range reg.getIndexes() { err := reg.loadIndexFile(idx) if err == nil { @@ -88,7 +91,7 @@ func (reg *ResourceRegistry) LoadIndexes() error { // try to download the index file if a local disk version // does not exist or we don't have permission to read it. if os.IsNotExist(err) || os.IsPermission(err) { - err = reg.downloadIndex(idx) + err = reg.downloadIndex(ctx, client, idx) } } diff --git a/updater/updating.go b/updater/updating.go index 1650b5f..221e8ee 100644 --- a/updater/updating.go +++ b/updater/updating.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "net/http" "path/filepath" "github.com/safing/portbase/utils" @@ -13,11 +14,12 @@ import ( ) // UpdateIndexes downloads all indexes and returns the first error encountered. -func (reg *ResourceRegistry) UpdateIndexes() error { +func (reg *ResourceRegistry) UpdateIndexes(ctx context.Context) error { var firstErr error + client := &http.Client{} for _, idx := range reg.getIndexes() { - if err := reg.downloadIndex(idx); err != nil { + if err := reg.downloadIndex(ctx, client, idx); err != nil { if firstErr == nil { firstErr = err } @@ -27,13 +29,13 @@ func (reg *ResourceRegistry) UpdateIndexes() error { return firstErr } -func (reg *ResourceRegistry) downloadIndex(idx Index) error { +func (reg *ResourceRegistry) downloadIndex(ctx context.Context, client *http.Client, idx Index) error { var err error var data []byte // download new index for tries := 0; tries < 3; tries++ { - data, err = reg.fetchData(idx.Path, tries) + data, err = reg.fetchData(ctx, client, idx.Path, tries) if err == nil { break } @@ -115,9 +117,10 @@ func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context) error { // download updates log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate)) + client := &http.Client{} for _, rv := range toUpdate { for tries := 0; tries < 3; tries++ { - err = reg.fetchFile(rv, tries) + err = reg.fetchFile(ctx, client, rv, tries) if err == nil { rv.Available = true break