Use context for network connections in updater

This commit is contained in:
Daniel 2020-07-21 15:04:05 +02:00
parent 91f759d148
commit 02bb6d1d9b
4 changed files with 40 additions and 14 deletions

View file

@ -2,6 +2,7 @@ package updater
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -16,10 +17,14 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, rv *ResourceVersion, tries int) error {
// backoff when retrying // backoff when retrying
if tries > 0 { if tries > 0 {
time.Sleep(time.Duration(tries*tries) * time.Second) select {
case <-ctx.Done():
return nil // module is shutting down
case <-time.After(time.Duration(tries*tries) * time.Second):
}
} }
// create URL // create URL
@ -44,7 +49,11 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error {
defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway
// start file download // start file download
resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec
if err != nil {
return fmt.Errorf("error creating request (%s): %w", downloadURL, err)
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("error fetching url (%s): %w", downloadURL, err) return fmt.Errorf("error fetching url (%s): %w", downloadURL, err)
} }
@ -81,10 +90,14 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error {
return nil return nil
} }
func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, error) { func (reg *ResourceRegistry) fetchData(ctx context.Context, client *http.Client, downloadPath string, tries int) ([]byte, error) {
// backoff when retrying // backoff when retrying
if tries > 0 { if tries > 0 {
time.Sleep(time.Duration(tries*tries) * time.Second) select {
case <-ctx.Done():
return nil, nil // module is shutting down
case <-time.After(time.Duration(tries*tries) * time.Second):
}
} }
// create URL // create URL
@ -94,7 +107,11 @@ func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte,
} }
// start file download // start file download
resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec
if err != nil {
return nil, fmt.Errorf("error creating request (%s): %w", downloadURL, err)
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching url (%s): %w", downloadURL, err) return nil, fmt.Errorf("error fetching url (%s): %w", downloadURL, err)
} }

View file

@ -1,8 +1,10 @@
package updater package updater
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
@ -43,8 +45,9 @@ func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) {
// download file // download file
log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath) log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath)
client := &http.Client{}
for tries := 0; tries < 5; tries++ { for tries := 0; tries < 5; tries++ {
err = reg.fetchFile(file.version, tries) err = reg.fetchFile(context.TODO(), client, file.version, tries)
if err != nil { if err != nil {
log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1) log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1)
} else { } else {

View file

@ -1,10 +1,12 @@
package updater package updater
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -78,8 +80,9 @@ func (reg *ResourceRegistry) ScanStorage(root string) error {
// LoadIndexes loads the current release indexes from disk // LoadIndexes loads the current release indexes from disk
// or will fetch a new version if not available and the // or will fetch a new version if not available and the
// registry is marked as online. // registry is marked as online.
func (reg *ResourceRegistry) LoadIndexes() error { func (reg *ResourceRegistry) LoadIndexes(ctx context.Context) error {
var firstErr error var firstErr error
client := &http.Client{}
for _, idx := range reg.getIndexes() { for _, idx := range reg.getIndexes() {
err := reg.loadIndexFile(idx) err := reg.loadIndexFile(idx)
if err == nil { if err == nil {
@ -88,7 +91,7 @@ func (reg *ResourceRegistry) LoadIndexes() error {
// try to download the index file if a local disk version // try to download the index file if a local disk version
// does not exist or we don't have permission to read it. // does not exist or we don't have permission to read it.
if os.IsNotExist(err) || os.IsPermission(err) { if os.IsNotExist(err) || os.IsPermission(err) {
err = reg.downloadIndex(idx) err = reg.downloadIndex(ctx, client, idx)
} }
} }

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"path/filepath" "path/filepath"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
@ -13,11 +14,12 @@ import (
) )
// UpdateIndexes downloads all indexes and returns the first error encountered. // UpdateIndexes downloads all indexes and returns the first error encountered.
func (reg *ResourceRegistry) UpdateIndexes() error { func (reg *ResourceRegistry) UpdateIndexes(ctx context.Context) error {
var firstErr error var firstErr error
client := &http.Client{}
for _, idx := range reg.getIndexes() { for _, idx := range reg.getIndexes() {
if err := reg.downloadIndex(idx); err != nil { if err := reg.downloadIndex(ctx, client, idx); err != nil {
if firstErr == nil { if firstErr == nil {
firstErr = err firstErr = err
} }
@ -27,13 +29,13 @@ func (reg *ResourceRegistry) UpdateIndexes() error {
return firstErr return firstErr
} }
func (reg *ResourceRegistry) downloadIndex(idx Index) error { func (reg *ResourceRegistry) downloadIndex(ctx context.Context, client *http.Client, idx Index) error {
var err error var err error
var data []byte var data []byte
// download new index // download new index
for tries := 0; tries < 3; tries++ { for tries := 0; tries < 3; tries++ {
data, err = reg.fetchData(idx.Path, tries) data, err = reg.fetchData(ctx, client, idx.Path, tries)
if err == nil { if err == nil {
break break
} }
@ -115,9 +117,10 @@ func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context) error {
// download updates // download updates
log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate)) log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate))
client := &http.Client{}
for _, rv := range toUpdate { for _, rv := range toUpdate {
for tries := 0; tries < 3; tries++ { for tries := 0; tries < 3; tries++ {
err = reg.fetchFile(rv, tries) err = reg.fetchFile(ctx, client, rv, tries)
if err == nil { if err == nil {
rv.Available = true rv.Available = true
break break