diff --git a/updater/fetch.go b/updater/fetch.go index 4638642..b4341ea 100644 --- a/updater/fetch.go +++ b/updater/fetch.go @@ -27,16 +27,9 @@ func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, } } - // create URL - downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath()) - if err != nil { - return fmt.Errorf("error build url (%s + %s): %w", reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath(), err) - } - // check destination dir dirPath := filepath.Dir(rv.storagePath()) - - err = reg.storageDir.EnsureAbsPath(dirPath) + err := reg.storageDir.EnsureAbsPath(dirPath) if err != nil { return fmt.Errorf("could not create updates folder: %s", dirPath) } @@ -49,27 +42,19 @@ func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway // start file download - req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec + resp, downloadURL, err := reg.makeRequest(ctx, client, rv.versionedPath(), tries) if err != nil { - return fmt.Errorf("error creating request (%s): %w", downloadURL, err) - } - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("error fetching url (%s): %w", downloadURL, err) + return err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status) - } - // download and write file n, err := io.Copy(atomicFile, resp.Body) if err != nil { - return fmt.Errorf("failed downloading %s: %w", downloadURL, err) + return fmt.Errorf("failed to download %q: %w", downloadURL, err) } if resp.ContentLength != n { - return fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength) + return fmt.Errorf("failed to finish download of %q: written %d out of %d bytes", downloadURL, n, resp.ContentLength) } // finalize file @@ -100,46 +85,60 @@ func (reg *ResourceRegistry) fetchData(ctx context.Context, client *http.Client, } } - // create URL - downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath) - if err != nil { - return nil, fmt.Errorf("error build url (%s + %s): %w", reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath, err) - } - // start file download - req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec + resp, downloadURL, err := reg.makeRequest(ctx, client, downloadPath, tries) if err != nil { - return nil, fmt.Errorf("error creating request (%s): %w", downloadURL, err) - } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("error fetching url (%s): %w", downloadURL, err) + return nil, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status) - } - // download and write file buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength)) n, err := io.Copy(buf, resp.Body) if err != nil { - return nil, fmt.Errorf("failed downloading %s: %w", downloadURL, err) + return nil, fmt.Errorf("failed to download %q: %w", downloadURL, err) } if resp.ContentLength != n { - return nil, fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength) + return nil, fmt.Errorf("failed to finish download of %q: written %d out of %d bytes", downloadURL, n, resp.ContentLength) } return buf.Bytes(), nil } -func joinURLandPath(baseURL, urlPath string) (string, error) { - u, err := url.Parse(baseURL) +func (reg *ResourceRegistry) makeRequest(ctx context.Context, client *http.Client, downloadPath string, tries int) (resp *http.Response, downloadURL string, err error) { + // parse update URL + updateBaseURL := reg.UpdateURLs[tries%len(reg.UpdateURLs)] + u, err := url.Parse(updateBaseURL) if err != nil { - return "", err + return nil, "", fmt.Errorf("failed to parse update URL %q: %w", updateBaseURL, err) + } + // add download path + u.Path = path.Join(u.Path, downloadPath) + // compile URL + downloadURL = u.String() + + // create request + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, http.NoBody) //nolint:gosec + if err != nil { + return nil, "", fmt.Errorf("failed to create request for %q: %w", downloadURL, err) } - u.Path = path.Join(u.Path, urlPath) - return u.String(), nil + // set user agent + if reg.UserAgent != "" { + req.Header.Set("User-Agent", reg.UserAgent) + } + + // start request + resp, err = client.Do(req) + if err != nil { + return nil, "", fmt.Errorf("failed to make request to %q: %w", downloadURL, err) + } + + // check return code + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, "", fmt.Errorf("failed to fetch %q: %d %s", downloadURL, resp.StatusCode, resp.Status) + } + + return resp, downloadURL, err } diff --git a/updater/registry.go b/updater/registry.go index 91e49a4..3e097cb 100644 --- a/updater/registry.go +++ b/updater/registry.go @@ -24,6 +24,7 @@ type ResourceRegistry struct { resources map[string]*Resource UpdateURLs []string + UserAgent string MandatoryUpdates []string Beta bool