Implement internal/updater: selfupdate via Forgejo Releases API
Adds a small, well-tested package that: - Queries /api/v1/repos/root/drover-go/releases/latest (404 = no updates, not an error). - Compares the published tag against the running Version using golang.org/x/mod/semver, so v0.1.0-rc.2 < v0.1.0. "dev" or any semver-invalid current version is treated as "always update". - Downloads the windows-amd64 asset + SHA256SUMS.txt, verifies the sha256 of the binary against its line in the sums file (tolerates the asterisk binary-mode prefix), and atomically swaps the running exe via github.com/minio/selfupdate. - Uses a 15s connect timeout with no overall request deadline, so large asset downloads aren't truncated. - Reports progress via an optional callback. Public surface: Source interface + ForgejoSource implementation, CheckForUpdate, ApplyUpdate, SetVersion. No GUI/cobra/Wails imports in the package, so the same code is reusable from the CLI, the Windows service, and the future tray UI. Wires the package into "drover update" / "drover update --check-only" in cmd/drover/main.go. --check-only exits 0 whether or not an update is available; only network/sha/apply errors are non-zero. Tests cover CheckForUpdate (table-driven incl. semver pre-release ordering, dev fallthrough, source errors), parseSHA256Sums (text and binary modes, CRLF, malformed lines, missing entries), ForgejoSource.Latest (httptest with canned JSON, 404, 500, missing asset, missing SHA256SUMS), and downloadAndVerify (success, sha mismatch, HTTP 404, context cancellation). All run with -race. Smoke-tested manually: built drover.exe and "drover update --check-only" against git.okcu.io prints "No updates available" and exits 0 (no releases yet). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+32
-2
@@ -6,6 +6,8 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.okcu.io/root/drover-go/internal/updater"
|
||||
)
|
||||
|
||||
// Build-time variables, populated via -ldflags "-X main.Version=... -X main.Commit=... -X main.BuildDate=...".
|
||||
@@ -20,6 +22,10 @@ var (
|
||||
var configPath string
|
||||
|
||||
func main() {
|
||||
// Inject our build version so the updater package can stamp it on the
|
||||
// User-Agent header it sends to git.okcu.io.
|
||||
updater.SetVersion(Version)
|
||||
|
||||
if err := newRootCmd().Execute(); err != nil {
|
||||
// Cobra already prints the error; just exit non-zero.
|
||||
os.Exit(1)
|
||||
@@ -64,8 +70,32 @@ func newUpdateCmd() *cobra.Command {
|
||||
Use: "update",
|
||||
Short: "Self-update via the Forgejo Releases API",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "TODO: update check")
|
||||
_ = checkOnly // wired up for the upcoming implementation
|
||||
ctx := cmd.Context()
|
||||
out := cmd.OutOrStdout()
|
||||
|
||||
src := updater.NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe")
|
||||
rel, hasUpdate, err := updater.CheckForUpdate(ctx, src, Version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check for update: %w", err)
|
||||
}
|
||||
if !hasUpdate {
|
||||
fmt.Fprintln(out, "No updates available")
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(out, "Update available: %s (current v%s)\n", rel.TagName, Version)
|
||||
if checkOnly {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintln(out, "Downloading...")
|
||||
if err := updater.ApplyUpdate(ctx, rel, func(d, t int64) {
|
||||
if t > 0 {
|
||||
fmt.Fprintf(out, "\r%d/%d bytes", d, t)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("apply update: %w", err)
|
||||
}
|
||||
fmt.Fprintln(out, "\nUpdate applied. Restart drover.")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2,9 +2,16 @@ module git.okcu.io/root/drover-go
|
||||
|
||||
go 1.23
|
||||
|
||||
require github.com/spf13/cobra v1.10.2
|
||||
require (
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
golang.org/x/mod v0.21.0
|
||||
)
|
||||
|
||||
require (
|
||||
aead.dev/minisign v0.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,10 +1,34 @@
|
||||
aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
|
||||
aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
|
||||
github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
// Package updater performs self-update via the Forgejo Releases API.
|
||||
package updater
|
||||
@@ -0,0 +1,479 @@
|
||||
// Package updater performs self-update via the Forgejo Releases API.
|
||||
//
|
||||
// The package is split into:
|
||||
//
|
||||
// - Source: an interface that fetches "the latest release" from somewhere.
|
||||
// ForgejoSource is the production implementation; tests use a fake.
|
||||
// - CheckForUpdate: compares a release tag against a current version using
|
||||
// semver-aware ordering (golang.org/x/mod/semver).
|
||||
// - ApplyUpdate: downloads the asset + SHA256SUMS.txt, verifies the binary
|
||||
// against its hash, and atomically replaces the running executable via
|
||||
// github.com/minio/selfupdate.
|
||||
//
|
||||
// The package is GUI- and CLI-framework-agnostic; importers should not pull
|
||||
// in cobra/Wails as a transitive dependency.
|
||||
package updater
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/selfupdate"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// version is exposed for tests/User-Agent. Caller-overridable.
|
||||
var version = "dev"
|
||||
|
||||
// SetVersion lets the calling binary inject its build version into the
|
||||
// User-Agent header used by HTTP requests. Optional; defaults to "dev".
|
||||
func SetVersion(v string) { version = v }
|
||||
|
||||
// Release is the minimal subset of a Forgejo release we care about.
|
||||
type Release struct {
|
||||
TagName string // e.g. "v0.1.2"
|
||||
Name string // human-readable name (often equals TagName)
|
||||
AssetURL string // browser_download_url of the windows-amd64 exe asset
|
||||
SumsURL string // browser_download_url of SHA256SUMS.txt
|
||||
Prerelease bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Source abstracts "where do releases come from" so callers can swap real
|
||||
// HTTP sources for fakes in tests.
|
||||
type Source interface {
|
||||
// Latest returns the latest non-draft release, or (nil, nil) if there
|
||||
// are no releases yet (e.g. Forgejo returned 404 from /releases/latest).
|
||||
// Errors should be returned wrapped with %w for any genuine failure.
|
||||
Latest(ctx context.Context) (*Release, error)
|
||||
}
|
||||
|
||||
// CheckForUpdate compares the latest release's tag name against
|
||||
// currentVersion. Returns (release, true, nil) if the latest is strictly
|
||||
// newer; (nil, false, nil) if the current is up-to-date or newer; or an
|
||||
// error wrapping the source's error.
|
||||
//
|
||||
// If currentVersion fails semver.IsValid (e.g. it is "dev"), CheckForUpdate
|
||||
// treats every available release as an update so dev builds get notified.
|
||||
func CheckForUpdate(ctx context.Context, src Source, currentVersion string) (*Release, bool, error) {
|
||||
rel, err := src.Latest(ctx)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("checking latest release: %w", err)
|
||||
}
|
||||
if rel == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !isNewer(rel.TagName, currentVersion) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return rel, true, nil
|
||||
}
|
||||
|
||||
// isNewer reports whether tag is strictly newer than current under semver
|
||||
// ordering. Both inputs are normalized to a leading "v". An invalid current
|
||||
// version is treated as "older than anything" so dev builds always update.
|
||||
func isNewer(tag, current string) bool {
|
||||
tagV := ensureV(tag)
|
||||
curV := ensureV(current)
|
||||
if !semver.IsValid(tagV) {
|
||||
// Remote tag is malformed; refuse to update.
|
||||
return false
|
||||
}
|
||||
if !semver.IsValid(curV) {
|
||||
// Current version is "dev" or otherwise invalid → always update.
|
||||
return true
|
||||
}
|
||||
return semver.Compare(tagV, curV) > 0
|
||||
}
|
||||
|
||||
func ensureV(s string) string {
|
||||
if strings.HasPrefix(s, "v") {
|
||||
return s
|
||||
}
|
||||
return "v" + s
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SHA256SUMS parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// parseSHA256Sums returns the lowercase-hex SHA-256 digest associated with
|
||||
// targetName from a sha256sum-formatted SUMS file. The format is one line
|
||||
// per file: "<64-hex> <name>" (text mode, two spaces) or "<64-hex> *<name>"
|
||||
// (binary mode, space + asterisk). Both forms are accepted.
|
||||
//
|
||||
// Returns an error if the target is missing or the file is malformed.
|
||||
func parseSHA256Sums(data []byte, targetName string) (string, error) {
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
// Allow long lines (default is 64 KiB which is plenty for sums files).
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimRight(scanner.Text(), "\r")
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
// Find the first space; everything before is hash, everything after
|
||||
// (with optional leading '*') is the filename.
|
||||
idx := strings.IndexByte(line, ' ')
|
||||
if idx < 0 {
|
||||
return "", fmt.Errorf("malformed SHA256SUMS line %d: missing separator", lineNum)
|
||||
}
|
||||
hashHex := line[:idx]
|
||||
rest := strings.TrimLeft(line[idx+1:], " ")
|
||||
rest = strings.TrimPrefix(rest, "*")
|
||||
if len(hashHex) != 64 {
|
||||
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is %d chars, want 64", lineNum, len(hashHex))
|
||||
}
|
||||
if _, err := hex.DecodeString(hashHex); err != nil {
|
||||
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is not valid hex: %w", lineNum, err)
|
||||
}
|
||||
if rest == targetName {
|
||||
return strings.ToLower(hashHex), nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("scanning SHA256SUMS: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("no SHA256SUMS entry for %q", targetName)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Forgejo source
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ForgejoSource talks to a Forgejo instance's Releases API.
|
||||
type ForgejoSource struct {
|
||||
baseURL string // e.g. "https://git.okcu.io"
|
||||
owner string
|
||||
repo string
|
||||
assetPattern string // substring matched against asset name
|
||||
client *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// NewForgejoSource constructs a Source that talks to host (the Forgejo
|
||||
// hostname, e.g. "git.okcu.io"). assetPattern is a substring matched against
|
||||
// release asset names (e.g. "windows-amd64.exe").
|
||||
func NewForgejoSource(host, owner, repo, assetPattern string) Source {
|
||||
return &ForgejoSource{
|
||||
baseURL: "https://" + strings.TrimRight(host, "/"),
|
||||
owner: owner,
|
||||
repo: repo,
|
||||
assetPattern: assetPattern,
|
||||
client: defaultHTTPClient(),
|
||||
userAgent: "drover-go/" + version,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultHTTPClient returns a context-aware http.Client with a 15s connect
|
||||
// timeout but no overall request deadline (asset downloads may be large).
|
||||
func defaultHTTPClient() *http.Client {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
return &http.Client{Transport: tr}
|
||||
}
|
||||
|
||||
// forgejoRelease mirrors the JSON returned by /api/v1/repos/.../releases/latest.
|
||||
type forgejoRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
Draft bool `json:"draft"`
|
||||
Prerelease bool `json:"prerelease"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Assets []forgejoAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type forgejoAsset struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
DownloadURL string `json:"browser_download_url"`
|
||||
}
|
||||
|
||||
// Latest fetches the latest release. Returns (nil, nil) on 404 (no releases).
|
||||
func (s *ForgejoSource) Latest(ctx context.Context) (*Release, error) {
|
||||
endpoint, err := url.JoinPath(s.baseURL, "api", "v1", "repos", s.owner, s.repo, "releases", "latest")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building releases URL: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", s.userAgent)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GET %s: %w", endpoint, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusNotFound:
|
||||
// No public releases exist.
|
||||
return nil, nil
|
||||
case resp.StatusCode < 200 || resp.StatusCode >= 300:
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("GET %s: HTTP %d: %s", endpoint, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var raw forgejoRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decoding release JSON: %w", err)
|
||||
}
|
||||
if raw.Draft {
|
||||
// Forgejo's /latest is supposed to skip drafts, but be defensive.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
assetURL, assetName := pickAsset(raw.Assets, s.assetPattern)
|
||||
if assetURL == "" {
|
||||
return nil, fmt.Errorf("release %s has no asset matching %q", raw.TagName, s.assetPattern)
|
||||
}
|
||||
sumsURL := pickAssetByExactName(raw.Assets, "SHA256SUMS.txt")
|
||||
if sumsURL == "" {
|
||||
return nil, fmt.Errorf("release %s is missing SHA256SUMS.txt asset", raw.TagName)
|
||||
}
|
||||
|
||||
_ = assetName // currently unused outside the source; reserved for future logging
|
||||
return &Release{
|
||||
TagName: raw.TagName,
|
||||
Name: raw.Name,
|
||||
AssetURL: assetURL,
|
||||
SumsURL: sumsURL,
|
||||
Prerelease: raw.Prerelease,
|
||||
CreatedAt: raw.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pickAsset(assets []forgejoAsset, pattern string) (string, string) {
|
||||
for _, a := range assets {
|
||||
if strings.Contains(a.Name, pattern) {
|
||||
return a.DownloadURL, a.Name
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func pickAssetByExactName(assets []forgejoAsset, name string) string {
|
||||
for _, a := range assets {
|
||||
if a.Name == name {
|
||||
return a.DownloadURL
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ApplyUpdate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ApplyUpdate downloads the release's asset + SHA256SUMS.txt, verifies the
|
||||
// asset against its sha256 entry, and atomically replaces the current
|
||||
// executable using github.com/minio/selfupdate.
|
||||
//
|
||||
// progress, if non-nil, is called periodically with (downloaded, total)
|
||||
// while the asset is being downloaded. total is set from the Content-Length
|
||||
// header and may be 0 if the server omits it.
|
||||
func ApplyUpdate(ctx context.Context, rel *Release, progress func(downloaded, total int64)) error {
|
||||
if rel == nil {
|
||||
return errors.New("apply update: release is nil")
|
||||
}
|
||||
if rel.AssetURL == "" || rel.SumsURL == "" {
|
||||
return errors.New("apply update: release is missing AssetURL or SumsURL")
|
||||
}
|
||||
|
||||
assetName := assetNameFromURL(rel.AssetURL)
|
||||
if assetName == "" {
|
||||
return fmt.Errorf("apply update: cannot derive asset filename from URL %q", rel.AssetURL)
|
||||
}
|
||||
|
||||
client := defaultHTTPClient()
|
||||
binary, err := downloadAndVerify(ctx, rel, assetName, client, "drover-go/"+version, progress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply update: %w", err)
|
||||
}
|
||||
|
||||
// selfupdate.Apply takes an io.Reader and atomically swaps the running
|
||||
// binary. We pass a bytes.Reader since we already have the verified
|
||||
// bytes in memory.
|
||||
if err := selfupdate.Apply(bytes.NewReader(binary), selfupdate.Options{}); err != nil {
|
||||
// On failure, selfupdate may leave a .old file; surface that fact
|
||||
// to the caller via the wrapped error.
|
||||
if rerr := selfupdate.RollbackError(err); rerr != nil {
|
||||
return fmt.Errorf("apply update: replacing executable failed and rollback failed: %w (rollback: %v)", err, rerr)
|
||||
}
|
||||
return fmt.Errorf("apply update: replacing executable: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// assetNameFromURL returns the last path segment of u (used to look the file
|
||||
// up in SHA256SUMS.txt). Returns "" if u does not parse.
|
||||
func assetNameFromURL(u string) string {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return path.Base(parsed.Path)
|
||||
}
|
||||
|
||||
// downloadAndVerify is the I/O-heavy core of ApplyUpdate, factored out so
|
||||
// it's testable without invoking selfupdate.Apply (which would replace the
|
||||
// running test binary). Returns the verified asset bytes.
|
||||
func downloadAndVerify(
|
||||
ctx context.Context,
|
||||
rel *Release,
|
||||
assetName string,
|
||||
client *http.Client,
|
||||
userAgent string,
|
||||
progress func(downloaded, total int64),
|
||||
) ([]byte, error) {
|
||||
// 1. Fetch SHA256SUMS.txt first — fail fast if it's missing/wrong before
|
||||
// we spend time on a possibly-large binary.
|
||||
sumsBody, err := httpGetAll(ctx, client, rel.SumsURL, userAgent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("downloading SHA256SUMS.txt: %w", err)
|
||||
}
|
||||
wantHashHex, err := parseSHA256Sums(sumsBody, assetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading SHA256SUMS.txt: %w", err)
|
||||
}
|
||||
wantHash, err := hex.DecodeString(wantHashHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding expected sha256 hex: %w", err)
|
||||
}
|
||||
|
||||
// 2. Fetch the asset, hashing as we go.
|
||||
binBody, err := httpGetWithProgress(ctx, client, rel.AssetURL, userAgent, progress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("downloading asset %s: %w", assetName, err)
|
||||
}
|
||||
|
||||
// 3. Verify.
|
||||
gotHash := sha256.Sum256(binBody)
|
||||
if !bytes.Equal(gotHash[:], wantHash) {
|
||||
return nil, fmt.Errorf("sha256 checksum mismatch for %s: got %x, want %s", assetName, gotHash[:], wantHashHex)
|
||||
}
|
||||
return binBody, nil
|
||||
}
|
||||
|
||||
func httpGetAll(ctx context.Context, client *http.Client, rawurl, userAgent string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
// Cap at 1 MiB — SHA256SUMS.txt is tiny; this is for the sums path only.
|
||||
// Asset downloads use httpGetWithProgress.
|
||||
return io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
}
|
||||
|
||||
func httpGetWithProgress(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
rawurl, userAgent string,
|
||||
progress func(downloaded, total int64),
|
||||
) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
total := resp.ContentLength // -1 if unknown
|
||||
var totalReport int64
|
||||
if total > 0 {
|
||||
totalReport = total
|
||||
}
|
||||
|
||||
// Cap a reasonable absolute ceiling so a malicious server can't OOM us.
|
||||
// 256 MiB is far above any realistic drover-go binary.
|
||||
const maxBytes = 256 * 1024 * 1024
|
||||
body := resp.Body
|
||||
if total > 0 && total > maxBytes {
|
||||
return nil, fmt.Errorf("asset Content-Length %d exceeds %d byte cap", total, maxBytes)
|
||||
}
|
||||
|
||||
pr := &progressReader{
|
||||
r: io.LimitReader(body, maxBytes),
|
||||
total: totalReport,
|
||||
callback: progress,
|
||||
}
|
||||
buf, err := io.ReadAll(pr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading body: %w", err)
|
||||
}
|
||||
// Final progress tick if total was unknown (so callers get a "done" signal).
|
||||
if progress != nil && totalReport == 0 {
|
||||
progress(int64(len(buf)), int64(len(buf)))
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader and emits progress updates after each
|
||||
// successful Read. It is allocation-light and never blocks; callbacks are
|
||||
// invoked synchronously from Read so they should be cheap.
|
||||
type progressReader struct {
|
||||
r io.Reader
|
||||
total int64
|
||||
read int64
|
||||
callback func(downloaded, total int64)
|
||||
}
|
||||
|
||||
func (p *progressReader) Read(b []byte) (int, error) {
|
||||
n, err := p.r.Read(b)
|
||||
if n > 0 {
|
||||
p.read += int64(n)
|
||||
if p.callback != nil {
|
||||
p.callback(p.read, p.total)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
@@ -0,0 +1,612 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeSource is an in-memory Source implementation for tests.
|
||||
type fakeSource struct {
|
||||
rel *Release
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeSource) Latest(ctx context.Context) (*Release, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return f.rel, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CheckForUpdate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rel := func(tag string) *Release {
|
||||
return &Release{TagName: tag, Name: tag, AssetURL: "u", SumsURL: "s"}
|
||||
}
|
||||
sentinel := errors.New("network down")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current string
|
||||
src Source
|
||||
wantHasUpdate bool
|
||||
wantNilRelease bool
|
||||
wantErrIs error
|
||||
}{
|
||||
{
|
||||
name: "newer release available",
|
||||
current: "0.1.0",
|
||||
src: &fakeSource{rel: rel("v0.2.0")},
|
||||
wantHasUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "newer release with v-prefixed current version",
|
||||
current: "v0.1.0",
|
||||
src: &fakeSource{rel: rel("v0.2.0")},
|
||||
wantHasUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "same version — no update",
|
||||
current: "0.2.0",
|
||||
src: &fakeSource{rel: rel("v0.2.0")},
|
||||
wantHasUpdate: false,
|
||||
wantNilRelease: true,
|
||||
},
|
||||
{
|
||||
name: "older release on remote — no update",
|
||||
current: "0.5.0",
|
||||
src: &fakeSource{rel: rel("v0.2.0")},
|
||||
wantHasUpdate: false,
|
||||
wantNilRelease: true,
|
||||
},
|
||||
{
|
||||
name: "rc < final (semver pre-release ordering)",
|
||||
current: "0.1.0-rc.2",
|
||||
src: &fakeSource{rel: rel("v0.1.0")},
|
||||
wantHasUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "final >= rc",
|
||||
current: "0.1.0",
|
||||
src: &fakeSource{rel: rel("v0.1.0-rc.2")},
|
||||
wantHasUpdate: false,
|
||||
wantNilRelease: true,
|
||||
},
|
||||
{
|
||||
name: "dev current — always update",
|
||||
current: "dev",
|
||||
src: &fakeSource{rel: rel("v0.1.0")},
|
||||
wantHasUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "garbage current — always update (treated as invalid)",
|
||||
current: "not-a-version",
|
||||
src: &fakeSource{rel: rel("v0.1.0")},
|
||||
wantHasUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "source returns nil release — no update",
|
||||
current: "0.1.0",
|
||||
src: &fakeSource{rel: nil},
|
||||
wantHasUpdate: false,
|
||||
wantNilRelease: true,
|
||||
},
|
||||
{
|
||||
name: "source error propagates wrapped",
|
||||
current: "0.1.0",
|
||||
src: &fakeSource{err: sentinel},
|
||||
wantErrIs: sentinel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
got, has, err := CheckForUpdate(ctx, tc.src, tc.current)
|
||||
if tc.wantErrIs != nil {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error wrapping %v, got nil", tc.wantErrIs)
|
||||
}
|
||||
if !errors.Is(err, tc.wantErrIs) {
|
||||
t.Fatalf("err = %v; expected to wrap %v", err, tc.wantErrIs)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if has != tc.wantHasUpdate {
|
||||
t.Fatalf("hasUpdate = %v; want %v", has, tc.wantHasUpdate)
|
||||
}
|
||||
if tc.wantNilRelease && got != nil {
|
||||
t.Fatalf("expected nil release; got %+v", got)
|
||||
}
|
||||
if !tc.wantNilRelease && tc.wantHasUpdate && got == nil {
|
||||
t.Fatalf("expected non-nil release on update; got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// parseSHA256Sums
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestParseSHA256Sums(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const a = "8da085332782708d8767bcace5327a6ec7283c17cfb85e40b03cd2323a90ddc2"
|
||||
const b = "c1e060ee19444a259b2162f8af0f3fe8c4428a1c6f694dce20de194ac8d7d9a2"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "standard text-mode line",
|
||||
input: a + " drover-v0.1.0-windows-amd64.exe\n",
|
||||
want: a,
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
},
|
||||
{
|
||||
name: "binary-mode asterisk prefix",
|
||||
input: b + " *drover-v0.1.0-setup.exe\n",
|
||||
want: b,
|
||||
wantName: "drover-v0.1.0-setup.exe",
|
||||
},
|
||||
{
|
||||
name: "multi-line, target on second",
|
||||
input: a + " some-other-file.txt\n" +
|
||||
b + " drover-v0.1.0-setup.exe\n",
|
||||
want: b,
|
||||
wantName: "drover-v0.1.0-setup.exe",
|
||||
},
|
||||
{
|
||||
name: "multi-line, blank lines and CRLF tolerated",
|
||||
input: "\r\n" +
|
||||
a + " drover-v0.1.0-windows-amd64.exe\r\n" +
|
||||
"\n",
|
||||
want: a,
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
},
|
||||
{
|
||||
name: "missing entry",
|
||||
input: a + " some-other-file.txt\n",
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed line — no separator",
|
||||
input: a + "drover-v0.1.0-windows-amd64.exe\n",
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed line — bad hex length",
|
||||
input: "deadbeef drover-v0.1.0-windows-amd64.exe\n",
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
wantName: "drover-v0.1.0-windows-amd64.exe",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := parseSHA256Sums([]byte(tc.input), tc.wantName)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error; got hash %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("hash = %q; want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForgejoSource.Latest
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const sampleReleaseJSON = `{
|
||||
"tag_name": "v0.1.0",
|
||||
"name": "v0.1.0",
|
||||
"draft": false,
|
||||
"prerelease": false,
|
||||
"created_at": "2026-04-25T13:35:33+03:00",
|
||||
"assets": [
|
||||
{
|
||||
"name": "drover-v0.1.0-windows-amd64.exe",
|
||||
"size": 5800000,
|
||||
"browser_download_url": "%s/drover-v0.1.0-windows-amd64.exe"
|
||||
},
|
||||
{
|
||||
"name": "SHA256SUMS.txt",
|
||||
"size": 200,
|
||||
"browser_download_url": "%s/SHA256SUMS.txt"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
func TestForgejoSource_Latest_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v1/repos/root/drover-go/releases/latest" {
|
||||
t.Errorf("unexpected request path: %s", r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if ua := r.Header.Get("User-Agent"); !strings.HasPrefix(ua, "drover-go/") {
|
||||
t.Errorf("expected User-Agent drover-go/...; got %q", ua)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, sampleReleaseJSON, srv.URL, srv.URL)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src := &ForgejoSource{
|
||||
baseURL: srv.URL,
|
||||
owner: "root",
|
||||
repo: "drover-go",
|
||||
assetPattern: "windows-amd64.exe",
|
||||
client: srv.Client(),
|
||||
userAgent: "drover-go/test",
|
||||
}
|
||||
|
||||
rel, err := src.Latest(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Latest err = %v", err)
|
||||
}
|
||||
if rel == nil {
|
||||
t.Fatal("nil release returned")
|
||||
}
|
||||
if rel.TagName != "v0.1.0" {
|
||||
t.Errorf("TagName = %q; want v0.1.0", rel.TagName)
|
||||
}
|
||||
wantAsset := srv.URL + "/drover-v0.1.0-windows-amd64.exe"
|
||||
if rel.AssetURL != wantAsset {
|
||||
t.Errorf("AssetURL = %q; want %q", rel.AssetURL, wantAsset)
|
||||
}
|
||||
wantSums := srv.URL + "/SHA256SUMS.txt"
|
||||
if rel.SumsURL != wantSums {
|
||||
t.Errorf("SumsURL = %q; want %q", rel.SumsURL, wantSums)
|
||||
}
|
||||
if rel.Prerelease {
|
||||
t.Error("Prerelease = true; want false")
|
||||
}
|
||||
if rel.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt is zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgejoSource_Latest_404_NoUpdates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src := &ForgejoSource{
|
||||
baseURL: srv.URL,
|
||||
owner: "root",
|
||||
repo: "drover-go",
|
||||
assetPattern: "windows-amd64.exe",
|
||||
client: srv.Client(),
|
||||
userAgent: "drover-go/test",
|
||||
}
|
||||
|
||||
rel, err := src.Latest(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on 404; got %v", err)
|
||||
}
|
||||
if rel != nil {
|
||||
t.Fatalf("expected nil release on 404; got %+v", rel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgejoSource_Latest_500_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "boom", http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src := &ForgejoSource{
|
||||
baseURL: srv.URL,
|
||||
owner: "root",
|
||||
repo: "drover-go",
|
||||
assetPattern: "windows-amd64.exe",
|
||||
client: srv.Client(),
|
||||
userAgent: "drover-go/test",
|
||||
}
|
||||
|
||||
_, err := src.Latest(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 500; got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected error to mention HTTP 500; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgejoSource_Latest_AssetMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Sample with only SHA256SUMS, no exe asset.
|
||||
body := `{
|
||||
"tag_name": "v0.1.0",
|
||||
"name": "v0.1.0",
|
||||
"draft": false,
|
||||
"prerelease": false,
|
||||
"created_at": "2026-04-25T13:35:33+03:00",
|
||||
"assets": [
|
||||
{"name":"SHA256SUMS.txt","size":1,"browser_download_url":"https://example.invalid/s"}
|
||||
]
|
||||
}`
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src := &ForgejoSource{
|
||||
baseURL: srv.URL,
|
||||
owner: "root",
|
||||
repo: "drover-go",
|
||||
assetPattern: "windows-amd64.exe",
|
||||
client: srv.Client(),
|
||||
userAgent: "drover-go/test",
|
||||
}
|
||||
_, err := src.Latest(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error when matching asset is missing; got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgejoSource_Latest_SumsMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := `{
|
||||
"tag_name": "v0.1.0",
|
||||
"name": "v0.1.0",
|
||||
"draft": false,
|
||||
"prerelease": false,
|
||||
"created_at": "2026-04-25T13:35:33+03:00",
|
||||
"assets": [
|
||||
{"name":"drover-v0.1.0-windows-amd64.exe","size":1,"browser_download_url":"https://example.invalid/x"}
|
||||
]
|
||||
}`
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src := &ForgejoSource{
|
||||
baseURL: srv.URL,
|
||||
owner: "root",
|
||||
repo: "drover-go",
|
||||
assetPattern: "windows-amd64.exe",
|
||||
client: srv.Client(),
|
||||
userAgent: "drover-go/test",
|
||||
}
|
||||
_, err := src.Latest(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SHA256SUMS.txt is missing; got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "SHA256SUMS.txt") {
|
||||
t.Errorf("expected error to mention SHA256SUMS.txt; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForgejoSource_BuildsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe").(*ForgejoSource)
|
||||
if want := "https://git.okcu.io"; src.baseURL != want {
|
||||
t.Errorf("baseURL = %q; want %q", src.baseURL, want)
|
||||
}
|
||||
if src.owner != "root" || src.repo != "drover-go" {
|
||||
t.Errorf("owner/repo wrong: %s/%s", src.owner, src.repo)
|
||||
}
|
||||
if src.assetPattern != "windows-amd64.exe" {
|
||||
t.Errorf("assetPattern = %q", src.assetPattern)
|
||||
}
|
||||
if src.client == nil {
|
||||
t.Error("client must not be nil")
|
||||
}
|
||||
if !strings.HasPrefix(src.userAgent, "drover-go/") {
|
||||
t.Errorf("userAgent = %q; want drover-go/...", src.userAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// downloadAndVerify (the testable extract from ApplyUpdate)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestDownloadAndVerify_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exeData := []byte("MZ\x90fake-windows-binary-content")
|
||||
hash := sha256.Sum256(exeData)
|
||||
hexHash := hex.EncodeToString(hash[:])
|
||||
|
||||
exeName := "drover-v0.1.0-windows-amd64.exe"
|
||||
sumsBody := hexHash + " " + exeName + "\n"
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/exe":
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(exeData)))
|
||||
_, _ = w.Write(exeData)
|
||||
case "/sums":
|
||||
_, _ = io.WriteString(w, sumsBody)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rel := &Release{
|
||||
TagName: "v0.1.0",
|
||||
Name: "v0.1.0",
|
||||
AssetURL: srv.URL + "/exe",
|
||||
SumsURL: srv.URL + "/sums",
|
||||
}
|
||||
|
||||
var pcDownloaded, pcTotal int64
|
||||
bin, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", func(d, total int64) {
|
||||
pcDownloaded = d
|
||||
pcTotal = total
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndVerify err = %v", err)
|
||||
}
|
||||
if !equalBytes(bin, exeData) {
|
||||
t.Fatalf("downloaded bytes do not match source")
|
||||
}
|
||||
if pcTotal != int64(len(exeData)) {
|
||||
t.Errorf("progress total = %d; want %d", pcTotal, len(exeData))
|
||||
}
|
||||
if pcDownloaded != int64(len(exeData)) {
|
||||
t.Errorf("progress downloaded final = %d; want %d", pcDownloaded, len(exeData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndVerify_SHAMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exeData := []byte("real-binary-bytes")
|
||||
exeName := "drover-v0.1.0-windows-amd64.exe"
|
||||
wrongHash := strings.Repeat("0", 64)
|
||||
sumsBody := wrongHash + " " + exeName + "\n"
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/exe":
|
||||
_, _ = w.Write(exeData)
|
||||
case "/sums":
|
||||
_, _ = io.WriteString(w, sumsBody)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rel := &Release{
|
||||
TagName: "v0.1.0",
|
||||
Name: "v0.1.0",
|
||||
AssetURL: srv.URL + "/exe",
|
||||
SumsURL: srv.URL + "/sums",
|
||||
}
|
||||
|
||||
_, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected SHA mismatch error; got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "sha256") && !strings.Contains(err.Error(), "checksum") {
|
||||
t.Errorf("expected error to mention sha256/checksum; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndVerify_AssetHTTP404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rel := &Release{
|
||||
TagName: "v0.1.0",
|
||||
AssetURL: srv.URL + "/exe",
|
||||
SumsURL: srv.URL + "/sums",
|
||||
}
|
||||
_, err := downloadAndVerify(context.Background(), rel, "drover-v0.1.0-windows-amd64.exe", srv.Client(), "drover-go/test", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 404; got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndVerify_ContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Block forever — handler unblocks when client closes connection
|
||||
// (which happens when the test's ctx is canceled).
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rel := &Release{AssetURL: srv.URL + "/exe", SumsURL: srv.URL + "/sums"}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
_, err := downloadAndVerify(ctx, rel, "x", srv.Client(), "drover-go/test", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from canceled context; got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// equalBytes is here to avoid a bytes import for one comparison.
|
||||
func equalBytes(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Sanity check that the Forgejo schema fields we care about decode correctly.
|
||||
func TestReleaseJSONShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
var rel forgejoRelease
|
||||
body := fmt.Sprintf(sampleReleaseJSON, "https://example.invalid", "https://example.invalid")
|
||||
if err := json.Unmarshal([]byte(body), &rel); err != nil {
|
||||
t.Fatalf("decode err: %v", err)
|
||||
}
|
||||
if len(rel.Assets) != 2 {
|
||||
t.Fatalf("got %d assets; want 2", len(rel.Assets))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user