diff --git a/cmd/drover/main.go b/cmd/drover/main.go index b855e92..cdb8f27 100644 --- a/cmd/drover/main.go +++ b/cmd/drover/main.go @@ -6,6 +6,8 @@ import ( "os" "github.com/spf13/cobra" + + "git.okcu.io/root/drover-go/internal/updater" ) // Build-time variables, populated via -ldflags "-X main.Version=... -X main.Commit=... -X main.BuildDate=...". @@ -20,6 +22,10 @@ var ( var configPath string func main() { + // Inject our build version so the updater package can stamp it on the + // User-Agent header it sends to git.okcu.io. + updater.SetVersion(Version) + if err := newRootCmd().Execute(); err != nil { // Cobra already prints the error; just exit non-zero. os.Exit(1) @@ -64,8 +70,32 @@ func newUpdateCmd() *cobra.Command { Use: "update", Short: "Self-update via the Forgejo Releases API", RunE: func(cmd *cobra.Command, args []string) error { - fmt.Fprintln(cmd.OutOrStdout(), "TODO: update check") - _ = checkOnly // wired up for the upcoming implementation + ctx := cmd.Context() + out := cmd.OutOrStdout() + + src := updater.NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe") + rel, hasUpdate, err := updater.CheckForUpdate(ctx, src, Version) + if err != nil { + return fmt.Errorf("check for update: %w", err) + } + if !hasUpdate { + fmt.Fprintln(out, "No updates available") + return nil + } + fmt.Fprintf(out, "Update available: %s (current v%s)\n", rel.TagName, Version) + if checkOnly { + return nil + } + + fmt.Fprintln(out, "Downloading...") + if err := updater.ApplyUpdate(ctx, rel, func(d, t int64) { + if t > 0 { + fmt.Fprintf(out, "\r%d/%d bytes", d, t) + } + }); err != nil { + return fmt.Errorf("apply update: %w", err) + } + fmt.Fprintln(out, "\nUpdate applied. Restart drover.") return nil }, } diff --git a/go.mod b/go.mod index e57c872..38ed9c4 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,16 @@ module git.okcu.io/root/drover-go go 1.23 -require github.com/spf13/cobra v1.10.2 +require ( + github.com/minio/selfupdate v0.6.0 + github.com/spf13/cobra v1.10.2 + golang.org/x/mod v0.21.0 +) require ( + aead.dev/minisign v0.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/spf13/pflag v1.0.9 // indirect + golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect ) diff --git a/go.sum b/go.sum index a6ee3e0..afc37de 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,34 @@ +aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk= +aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= +github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= +golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/updater/doc.go b/internal/updater/doc.go deleted file mode 100644 index b76d752..0000000 --- a/internal/updater/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package updater performs self-update via the Forgejo Releases API. -package updater diff --git a/internal/updater/updater.go b/internal/updater/updater.go new file mode 100644 index 0000000..9bd7ede --- /dev/null +++ b/internal/updater/updater.go @@ -0,0 +1,479 @@ +// Package updater performs self-update via the Forgejo Releases API. +// +// The package is split into: +// +// - Source: an interface that fetches "the latest release" from somewhere. +// ForgejoSource is the production implementation; tests use a fake. +// - CheckForUpdate: compares a release tag against a current version using +// semver-aware ordering (golang.org/x/mod/semver). +// - ApplyUpdate: downloads the asset + SHA256SUMS.txt, verifies the binary +// against its hash, and atomically replaces the running executable via +// github.com/minio/selfupdate. +// +// The package is GUI- and CLI-framework-agnostic; importers should not pull +// in cobra/Wails as a transitive dependency. +package updater + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "path" + "strings" + "time" + + "github.com/minio/selfupdate" + "golang.org/x/mod/semver" +) + +// version is exposed for tests/User-Agent. Caller-overridable. +var version = "dev" + +// SetVersion lets the calling binary inject its build version into the +// User-Agent header used by HTTP requests. Optional; defaults to "dev". +func SetVersion(v string) { version = v } + +// Release is the minimal subset of a Forgejo release we care about. +type Release struct { + TagName string // e.g. "v0.1.2" + Name string // human-readable name (often equals TagName) + AssetURL string // browser_download_url of the windows-amd64 exe asset + SumsURL string // browser_download_url of SHA256SUMS.txt + Prerelease bool + CreatedAt time.Time +} + +// Source abstracts "where do releases come from" so callers can swap real +// HTTP sources for fakes in tests. +type Source interface { + // Latest returns the latest non-draft release, or (nil, nil) if there + // are no releases yet (e.g. Forgejo returned 404 from /releases/latest). + // Errors should be returned wrapped with %w for any genuine failure. + Latest(ctx context.Context) (*Release, error) +} + +// CheckForUpdate compares the latest release's tag name against +// currentVersion. Returns (release, true, nil) if the latest is strictly +// newer; (nil, false, nil) if the current is up-to-date or newer; or an +// error wrapping the source's error. +// +// If currentVersion fails semver.IsValid (e.g. it is "dev"), CheckForUpdate +// treats every available release as an update so dev builds get notified. +func CheckForUpdate(ctx context.Context, src Source, currentVersion string) (*Release, bool, error) { + rel, err := src.Latest(ctx) + if err != nil { + return nil, false, fmt.Errorf("checking latest release: %w", err) + } + if rel == nil { + return nil, false, nil + } + if !isNewer(rel.TagName, currentVersion) { + return nil, false, nil + } + return rel, true, nil +} + +// isNewer reports whether tag is strictly newer than current under semver +// ordering. Both inputs are normalized to a leading "v". An invalid current +// version is treated as "older than anything" so dev builds always update. +func isNewer(tag, current string) bool { + tagV := ensureV(tag) + curV := ensureV(current) + if !semver.IsValid(tagV) { + // Remote tag is malformed; refuse to update. + return false + } + if !semver.IsValid(curV) { + // Current version is "dev" or otherwise invalid → always update. + return true + } + return semver.Compare(tagV, curV) > 0 +} + +func ensureV(s string) string { + if strings.HasPrefix(s, "v") { + return s + } + return "v" + s +} + +// --------------------------------------------------------------------------- +// SHA256SUMS parsing +// --------------------------------------------------------------------------- + +// parseSHA256Sums returns the lowercase-hex SHA-256 digest associated with +// targetName from a sha256sum-formatted SUMS file. The format is one line +// per file: "<64-hex> " (text mode, two spaces) or "<64-hex> *" +// (binary mode, space + asterisk). Both forms are accepted. +// +// Returns an error if the target is missing or the file is malformed. +func parseSHA256Sums(data []byte, targetName string) (string, error) { + scanner := bufio.NewScanner(bytes.NewReader(data)) + // Allow long lines (default is 64 KiB which is plenty for sums files). + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := strings.TrimRight(scanner.Text(), "\r") + line = strings.TrimSpace(line) + if line == "" { + continue + } + // Find the first space; everything before is hash, everything after + // (with optional leading '*') is the filename. + idx := strings.IndexByte(line, ' ') + if idx < 0 { + return "", fmt.Errorf("malformed SHA256SUMS line %d: missing separator", lineNum) + } + hashHex := line[:idx] + rest := strings.TrimLeft(line[idx+1:], " ") + rest = strings.TrimPrefix(rest, "*") + if len(hashHex) != 64 { + return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is %d chars, want 64", lineNum, len(hashHex)) + } + if _, err := hex.DecodeString(hashHex); err != nil { + return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is not valid hex: %w", lineNum, err) + } + if rest == targetName { + return strings.ToLower(hashHex), nil + } + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("scanning SHA256SUMS: %w", err) + } + return "", fmt.Errorf("no SHA256SUMS entry for %q", targetName) +} + +// --------------------------------------------------------------------------- +// Forgejo source +// --------------------------------------------------------------------------- + +// ForgejoSource talks to a Forgejo instance's Releases API. +type ForgejoSource struct { + baseURL string // e.g. "https://git.okcu.io" + owner string + repo string + assetPattern string // substring matched against asset name + client *http.Client + userAgent string +} + +// NewForgejoSource constructs a Source that talks to host (the Forgejo +// hostname, e.g. "git.okcu.io"). assetPattern is a substring matched against +// release asset names (e.g. "windows-amd64.exe"). +func NewForgejoSource(host, owner, repo, assetPattern string) Source { + return &ForgejoSource{ + baseURL: "https://" + strings.TrimRight(host, "/"), + owner: owner, + repo: repo, + assetPattern: assetPattern, + client: defaultHTTPClient(), + userAgent: "drover-go/" + version, + } +} + +// defaultHTTPClient returns a context-aware http.Client with a 15s connect +// timeout but no overall request deadline (asset downloads may be large). +func defaultHTTPClient() *http.Client { + dialer := &net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + } + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 15 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + return &http.Client{Transport: tr} +} + +// forgejoRelease mirrors the JSON returned by /api/v1/repos/.../releases/latest. +type forgejoRelease struct { + TagName string `json:"tag_name"` + Name string `json:"name"` + Draft bool `json:"draft"` + Prerelease bool `json:"prerelease"` + CreatedAt time.Time `json:"created_at"` + Assets []forgejoAsset `json:"assets"` +} + +type forgejoAsset struct { + Name string `json:"name"` + Size int64 `json:"size"` + DownloadURL string `json:"browser_download_url"` +} + +// Latest fetches the latest release. Returns (nil, nil) on 404 (no releases). +func (s *ForgejoSource) Latest(ctx context.Context) (*Release, error) { + endpoint, err := url.JoinPath(s.baseURL, "api", "v1", "repos", s.owner, s.repo, "releases", "latest") + if err != nil { + return nil, fmt.Errorf("building releases URL: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", s.userAgent) + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("GET %s: %w", endpoint, err) + } + defer resp.Body.Close() + + switch { + case resp.StatusCode == http.StatusNotFound: + // No public releases exist. + return nil, nil + case resp.StatusCode < 200 || resp.StatusCode >= 300: + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, fmt.Errorf("GET %s: HTTP %d: %s", endpoint, resp.StatusCode, strings.TrimSpace(string(body))) + } + + var raw forgejoRelease + if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { + return nil, fmt.Errorf("decoding release JSON: %w", err) + } + if raw.Draft { + // Forgejo's /latest is supposed to skip drafts, but be defensive. + return nil, nil + } + + assetURL, assetName := pickAsset(raw.Assets, s.assetPattern) + if assetURL == "" { + return nil, fmt.Errorf("release %s has no asset matching %q", raw.TagName, s.assetPattern) + } + sumsURL := pickAssetByExactName(raw.Assets, "SHA256SUMS.txt") + if sumsURL == "" { + return nil, fmt.Errorf("release %s is missing SHA256SUMS.txt asset", raw.TagName) + } + + _ = assetName // currently unused outside the source; reserved for future logging + return &Release{ + TagName: raw.TagName, + Name: raw.Name, + AssetURL: assetURL, + SumsURL: sumsURL, + Prerelease: raw.Prerelease, + CreatedAt: raw.CreatedAt, + }, nil +} + +func pickAsset(assets []forgejoAsset, pattern string) (string, string) { + for _, a := range assets { + if strings.Contains(a.Name, pattern) { + return a.DownloadURL, a.Name + } + } + return "", "" +} + +func pickAssetByExactName(assets []forgejoAsset, name string) string { + for _, a := range assets { + if a.Name == name { + return a.DownloadURL + } + } + return "" +} + +// --------------------------------------------------------------------------- +// ApplyUpdate +// --------------------------------------------------------------------------- + +// ApplyUpdate downloads the release's asset + SHA256SUMS.txt, verifies the +// asset against its sha256 entry, and atomically replaces the current +// executable using github.com/minio/selfupdate. +// +// progress, if non-nil, is called periodically with (downloaded, total) +// while the asset is being downloaded. total is set from the Content-Length +// header and may be 0 if the server omits it. +func ApplyUpdate(ctx context.Context, rel *Release, progress func(downloaded, total int64)) error { + if rel == nil { + return errors.New("apply update: release is nil") + } + if rel.AssetURL == "" || rel.SumsURL == "" { + return errors.New("apply update: release is missing AssetURL or SumsURL") + } + + assetName := assetNameFromURL(rel.AssetURL) + if assetName == "" { + return fmt.Errorf("apply update: cannot derive asset filename from URL %q", rel.AssetURL) + } + + client := defaultHTTPClient() + binary, err := downloadAndVerify(ctx, rel, assetName, client, "drover-go/"+version, progress) + if err != nil { + return fmt.Errorf("apply update: %w", err) + } + + // selfupdate.Apply takes an io.Reader and atomically swaps the running + // binary. We pass a bytes.Reader since we already have the verified + // bytes in memory. + if err := selfupdate.Apply(bytes.NewReader(binary), selfupdate.Options{}); err != nil { + // On failure, selfupdate may leave a .old file; surface that fact + // to the caller via the wrapped error. + if rerr := selfupdate.RollbackError(err); rerr != nil { + return fmt.Errorf("apply update: replacing executable failed and rollback failed: %w (rollback: %v)", err, rerr) + } + return fmt.Errorf("apply update: replacing executable: %w", err) + } + return nil +} + +// assetNameFromURL returns the last path segment of u (used to look the file +// up in SHA256SUMS.txt). Returns "" if u does not parse. +func assetNameFromURL(u string) string { + parsed, err := url.Parse(u) + if err != nil { + return "" + } + return path.Base(parsed.Path) +} + +// downloadAndVerify is the I/O-heavy core of ApplyUpdate, factored out so +// it's testable without invoking selfupdate.Apply (which would replace the +// running test binary). Returns the verified asset bytes. +func downloadAndVerify( + ctx context.Context, + rel *Release, + assetName string, + client *http.Client, + userAgent string, + progress func(downloaded, total int64), +) ([]byte, error) { + // 1. Fetch SHA256SUMS.txt first — fail fast if it's missing/wrong before + // we spend time on a possibly-large binary. + sumsBody, err := httpGetAll(ctx, client, rel.SumsURL, userAgent) + if err != nil { + return nil, fmt.Errorf("downloading SHA256SUMS.txt: %w", err) + } + wantHashHex, err := parseSHA256Sums(sumsBody, assetName) + if err != nil { + return nil, fmt.Errorf("reading SHA256SUMS.txt: %w", err) + } + wantHash, err := hex.DecodeString(wantHashHex) + if err != nil { + return nil, fmt.Errorf("decoding expected sha256 hex: %w", err) + } + + // 2. Fetch the asset, hashing as we go. + binBody, err := httpGetWithProgress(ctx, client, rel.AssetURL, userAgent, progress) + if err != nil { + return nil, fmt.Errorf("downloading asset %s: %w", assetName, err) + } + + // 3. Verify. + gotHash := sha256.Sum256(binBody) + if !bytes.Equal(gotHash[:], wantHash) { + return nil, fmt.Errorf("sha256 checksum mismatch for %s: got %x, want %s", assetName, gotHash[:], wantHashHex) + } + return binBody, nil +} + +func httpGetAll(ctx context.Context, client *http.Client, rawurl, userAgent string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body))) + } + // Cap at 1 MiB — SHA256SUMS.txt is tiny; this is for the sums path only. + // Asset downloads use httpGetWithProgress. + return io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) +} + +func httpGetWithProgress( + ctx context.Context, + client *http.Client, + rawurl, userAgent string, + progress func(downloaded, total int64), +) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body))) + } + + total := resp.ContentLength // -1 if unknown + var totalReport int64 + if total > 0 { + totalReport = total + } + + // Cap a reasonable absolute ceiling so a malicious server can't OOM us. + // 256 MiB is far above any realistic drover-go binary. + const maxBytes = 256 * 1024 * 1024 + body := resp.Body + if total > 0 && total > maxBytes { + return nil, fmt.Errorf("asset Content-Length %d exceeds %d byte cap", total, maxBytes) + } + + pr := &progressReader{ + r: io.LimitReader(body, maxBytes), + total: totalReport, + callback: progress, + } + buf, err := io.ReadAll(pr) + if err != nil { + return nil, fmt.Errorf("reading body: %w", err) + } + // Final progress tick if total was unknown (so callers get a "done" signal). + if progress != nil && totalReport == 0 { + progress(int64(len(buf)), int64(len(buf))) + } + return buf, nil +} + +// progressReader wraps an io.Reader and emits progress updates after each +// successful Read. It is allocation-light and never blocks; callbacks are +// invoked synchronously from Read so they should be cheap. +type progressReader struct { + r io.Reader + total int64 + read int64 + callback func(downloaded, total int64) +} + +func (p *progressReader) Read(b []byte) (int, error) { + n, err := p.r.Read(b) + if n > 0 { + p.read += int64(n) + if p.callback != nil { + p.callback(p.read, p.total) + } + } + return n, err +} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go new file mode 100644 index 0000000..7ac69ad --- /dev/null +++ b/internal/updater/updater_test.go @@ -0,0 +1,612 @@ +package updater + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// fakeSource is an in-memory Source implementation for tests. +type fakeSource struct { + rel *Release + err error +} + +func (f *fakeSource) Latest(ctx context.Context) (*Release, error) { + if f.err != nil { + return nil, f.err + } + return f.rel, nil +} + +// --------------------------------------------------------------------------- +// CheckForUpdate +// --------------------------------------------------------------------------- + +func TestCheckForUpdate(t *testing.T) { + t.Parallel() + + rel := func(tag string) *Release { + return &Release{TagName: tag, Name: tag, AssetURL: "u", SumsURL: "s"} + } + sentinel := errors.New("network down") + + tests := []struct { + name string + current string + src Source + wantHasUpdate bool + wantNilRelease bool + wantErrIs error + }{ + { + name: "newer release available", + current: "0.1.0", + src: &fakeSource{rel: rel("v0.2.0")}, + wantHasUpdate: true, + }, + { + name: "newer release with v-prefixed current version", + current: "v0.1.0", + src: &fakeSource{rel: rel("v0.2.0")}, + wantHasUpdate: true, + }, + { + name: "same version — no update", + current: "0.2.0", + src: &fakeSource{rel: rel("v0.2.0")}, + wantHasUpdate: false, + wantNilRelease: true, + }, + { + name: "older release on remote — no update", + current: "0.5.0", + src: &fakeSource{rel: rel("v0.2.0")}, + wantHasUpdate: false, + wantNilRelease: true, + }, + { + name: "rc < final (semver pre-release ordering)", + current: "0.1.0-rc.2", + src: &fakeSource{rel: rel("v0.1.0")}, + wantHasUpdate: true, + }, + { + name: "final >= rc", + current: "0.1.0", + src: &fakeSource{rel: rel("v0.1.0-rc.2")}, + wantHasUpdate: false, + wantNilRelease: true, + }, + { + name: "dev current — always update", + current: "dev", + src: &fakeSource{rel: rel("v0.1.0")}, + wantHasUpdate: true, + }, + { + name: "garbage current — always update (treated as invalid)", + current: "not-a-version", + src: &fakeSource{rel: rel("v0.1.0")}, + wantHasUpdate: true, + }, + { + name: "source returns nil release — no update", + current: "0.1.0", + src: &fakeSource{rel: nil}, + wantHasUpdate: false, + wantNilRelease: true, + }, + { + name: "source error propagates wrapped", + current: "0.1.0", + src: &fakeSource{err: sentinel}, + wantErrIs: sentinel, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + got, has, err := CheckForUpdate(ctx, tc.src, tc.current) + if tc.wantErrIs != nil { + if err == nil { + t.Fatalf("expected error wrapping %v, got nil", tc.wantErrIs) + } + if !errors.Is(err, tc.wantErrIs) { + t.Fatalf("err = %v; expected to wrap %v", err, tc.wantErrIs) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if has != tc.wantHasUpdate { + t.Fatalf("hasUpdate = %v; want %v", has, tc.wantHasUpdate) + } + if tc.wantNilRelease && got != nil { + t.Fatalf("expected nil release; got %+v", got) + } + if !tc.wantNilRelease && tc.wantHasUpdate && got == nil { + t.Fatalf("expected non-nil release on update; got nil") + } + }) + } +} + +// --------------------------------------------------------------------------- +// parseSHA256Sums +// --------------------------------------------------------------------------- + +func TestParseSHA256Sums(t *testing.T) { + t.Parallel() + + const a = "8da085332782708d8767bcace5327a6ec7283c17cfb85e40b03cd2323a90ddc2" + const b = "c1e060ee19444a259b2162f8af0f3fe8c4428a1c6f694dce20de194ac8d7d9a2" + + tests := []struct { + name string + input string + want string + wantName string + wantErr bool + }{ + { + name: "standard text-mode line", + input: a + " drover-v0.1.0-windows-amd64.exe\n", + want: a, + wantName: "drover-v0.1.0-windows-amd64.exe", + }, + { + name: "binary-mode asterisk prefix", + input: b + " *drover-v0.1.0-setup.exe\n", + want: b, + wantName: "drover-v0.1.0-setup.exe", + }, + { + name: "multi-line, target on second", + input: a + " some-other-file.txt\n" + + b + " drover-v0.1.0-setup.exe\n", + want: b, + wantName: "drover-v0.1.0-setup.exe", + }, + { + name: "multi-line, blank lines and CRLF tolerated", + input: "\r\n" + + a + " drover-v0.1.0-windows-amd64.exe\r\n" + + "\n", + want: a, + wantName: "drover-v0.1.0-windows-amd64.exe", + }, + { + name: "missing entry", + input: a + " some-other-file.txt\n", + wantName: "drover-v0.1.0-windows-amd64.exe", + wantErr: true, + }, + { + name: "malformed line — no separator", + input: a + "drover-v0.1.0-windows-amd64.exe\n", + wantName: "drover-v0.1.0-windows-amd64.exe", + wantErr: true, + }, + { + name: "malformed line — bad hex length", + input: "deadbeef drover-v0.1.0-windows-amd64.exe\n", + wantName: "drover-v0.1.0-windows-amd64.exe", + wantErr: true, + }, + { + name: "empty input", + input: "", + wantName: "drover-v0.1.0-windows-amd64.exe", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := parseSHA256Sums([]byte(tc.input), tc.wantName) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error; got hash %q", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Fatalf("hash = %q; want %q", got, tc.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// ForgejoSource.Latest +// --------------------------------------------------------------------------- + +const sampleReleaseJSON = `{ + "tag_name": "v0.1.0", + "name": "v0.1.0", + "draft": false, + "prerelease": false, + "created_at": "2026-04-25T13:35:33+03:00", + "assets": [ + { + "name": "drover-v0.1.0-windows-amd64.exe", + "size": 5800000, + "browser_download_url": "%s/drover-v0.1.0-windows-amd64.exe" + }, + { + "name": "SHA256SUMS.txt", + "size": 200, + "browser_download_url": "%s/SHA256SUMS.txt" + } + ] +}` + +func TestForgejoSource_Latest_Success(t *testing.T) { + t.Parallel() + + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/repos/root/drover-go/releases/latest" { + t.Errorf("unexpected request path: %s", r.URL.Path) + http.NotFound(w, r) + return + } + if ua := r.Header.Get("User-Agent"); !strings.HasPrefix(ua, "drover-go/") { + t.Errorf("expected User-Agent drover-go/...; got %q", ua) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, sampleReleaseJSON, srv.URL, srv.URL) + })) + defer srv.Close() + + src := &ForgejoSource{ + baseURL: srv.URL, + owner: "root", + repo: "drover-go", + assetPattern: "windows-amd64.exe", + client: srv.Client(), + userAgent: "drover-go/test", + } + + rel, err := src.Latest(context.Background()) + if err != nil { + t.Fatalf("Latest err = %v", err) + } + if rel == nil { + t.Fatal("nil release returned") + } + if rel.TagName != "v0.1.0" { + t.Errorf("TagName = %q; want v0.1.0", rel.TagName) + } + wantAsset := srv.URL + "/drover-v0.1.0-windows-amd64.exe" + if rel.AssetURL != wantAsset { + t.Errorf("AssetURL = %q; want %q", rel.AssetURL, wantAsset) + } + wantSums := srv.URL + "/SHA256SUMS.txt" + if rel.SumsURL != wantSums { + t.Errorf("SumsURL = %q; want %q", rel.SumsURL, wantSums) + } + if rel.Prerelease { + t.Error("Prerelease = true; want false") + } + if rel.CreatedAt.IsZero() { + t.Error("CreatedAt is zero") + } +} + +func TestForgejoSource_Latest_404_NoUpdates(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer srv.Close() + + src := &ForgejoSource{ + baseURL: srv.URL, + owner: "root", + repo: "drover-go", + assetPattern: "windows-amd64.exe", + client: srv.Client(), + userAgent: "drover-go/test", + } + + rel, err := src.Latest(context.Background()) + if err != nil { + t.Fatalf("expected nil error on 404; got %v", err) + } + if rel != nil { + t.Fatalf("expected nil release on 404; got %+v", rel) + } +} + +func TestForgejoSource_Latest_500_Error(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + + src := &ForgejoSource{ + baseURL: srv.URL, + owner: "root", + repo: "drover-go", + assetPattern: "windows-amd64.exe", + client: srv.Client(), + userAgent: "drover-go/test", + } + + _, err := src.Latest(context.Background()) + if err == nil { + t.Fatal("expected error on 500; got nil") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error to mention HTTP 500; got %v", err) + } +} + +func TestForgejoSource_Latest_AssetMissing(t *testing.T) { + t.Parallel() + + // Sample with only SHA256SUMS, no exe asset. + body := `{ + "tag_name": "v0.1.0", + "name": "v0.1.0", + "draft": false, + "prerelease": false, + "created_at": "2026-04-25T13:35:33+03:00", + "assets": [ + {"name":"SHA256SUMS.txt","size":1,"browser_download_url":"https://example.invalid/s"} + ] + }` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + src := &ForgejoSource{ + baseURL: srv.URL, + owner: "root", + repo: "drover-go", + assetPattern: "windows-amd64.exe", + client: srv.Client(), + userAgent: "drover-go/test", + } + _, err := src.Latest(context.Background()) + if err == nil { + t.Fatal("expected error when matching asset is missing; got nil") + } +} + +func TestForgejoSource_Latest_SumsMissing(t *testing.T) { + t.Parallel() + + body := `{ + "tag_name": "v0.1.0", + "name": "v0.1.0", + "draft": false, + "prerelease": false, + "created_at": "2026-04-25T13:35:33+03:00", + "assets": [ + {"name":"drover-v0.1.0-windows-amd64.exe","size":1,"browser_download_url":"https://example.invalid/x"} + ] + }` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + src := &ForgejoSource{ + baseURL: srv.URL, + owner: "root", + repo: "drover-go", + assetPattern: "windows-amd64.exe", + client: srv.Client(), + userAgent: "drover-go/test", + } + _, err := src.Latest(context.Background()) + if err == nil { + t.Fatal("expected error when SHA256SUMS.txt is missing; got nil") + } + if !strings.Contains(err.Error(), "SHA256SUMS.txt") { + t.Errorf("expected error to mention SHA256SUMS.txt; got %v", err) + } +} + +func TestNewForgejoSource_BuildsURL(t *testing.T) { + t.Parallel() + src := NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe").(*ForgejoSource) + if want := "https://git.okcu.io"; src.baseURL != want { + t.Errorf("baseURL = %q; want %q", src.baseURL, want) + } + if src.owner != "root" || src.repo != "drover-go" { + t.Errorf("owner/repo wrong: %s/%s", src.owner, src.repo) + } + if src.assetPattern != "windows-amd64.exe" { + t.Errorf("assetPattern = %q", src.assetPattern) + } + if src.client == nil { + t.Error("client must not be nil") + } + if !strings.HasPrefix(src.userAgent, "drover-go/") { + t.Errorf("userAgent = %q; want drover-go/...", src.userAgent) + } +} + +// --------------------------------------------------------------------------- +// downloadAndVerify (the testable extract from ApplyUpdate) +// --------------------------------------------------------------------------- + +func TestDownloadAndVerify_Success(t *testing.T) { + t.Parallel() + + exeData := []byte("MZ\x90fake-windows-binary-content") + hash := sha256.Sum256(exeData) + hexHash := hex.EncodeToString(hash[:]) + + exeName := "drover-v0.1.0-windows-amd64.exe" + sumsBody := hexHash + " " + exeName + "\n" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/exe": + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(exeData))) + _, _ = w.Write(exeData) + case "/sums": + _, _ = io.WriteString(w, sumsBody) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + rel := &Release{ + TagName: "v0.1.0", + Name: "v0.1.0", + AssetURL: srv.URL + "/exe", + SumsURL: srv.URL + "/sums", + } + + var pcDownloaded, pcTotal int64 + bin, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", func(d, total int64) { + pcDownloaded = d + pcTotal = total + }) + if err != nil { + t.Fatalf("downloadAndVerify err = %v", err) + } + if !equalBytes(bin, exeData) { + t.Fatalf("downloaded bytes do not match source") + } + if pcTotal != int64(len(exeData)) { + t.Errorf("progress total = %d; want %d", pcTotal, len(exeData)) + } + if pcDownloaded != int64(len(exeData)) { + t.Errorf("progress downloaded final = %d; want %d", pcDownloaded, len(exeData)) + } +} + +func TestDownloadAndVerify_SHAMismatch(t *testing.T) { + t.Parallel() + + exeData := []byte("real-binary-bytes") + exeName := "drover-v0.1.0-windows-amd64.exe" + wrongHash := strings.Repeat("0", 64) + sumsBody := wrongHash + " " + exeName + "\n" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/exe": + _, _ = w.Write(exeData) + case "/sums": + _, _ = io.WriteString(w, sumsBody) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + rel := &Release{ + TagName: "v0.1.0", + Name: "v0.1.0", + AssetURL: srv.URL + "/exe", + SumsURL: srv.URL + "/sums", + } + + _, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", nil) + if err == nil { + t.Fatal("expected SHA mismatch error; got nil") + } + if !strings.Contains(err.Error(), "sha256") && !strings.Contains(err.Error(), "checksum") { + t.Errorf("expected error to mention sha256/checksum; got %v", err) + } +} + +func TestDownloadAndVerify_AssetHTTP404(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer srv.Close() + + rel := &Release{ + TagName: "v0.1.0", + AssetURL: srv.URL + "/exe", + SumsURL: srv.URL + "/sums", + } + _, err := downloadAndVerify(context.Background(), rel, "drover-v0.1.0-windows-amd64.exe", srv.Client(), "drover-go/test", nil) + if err == nil { + t.Fatal("expected error on 404; got nil") + } +} + +func TestDownloadAndVerify_ContextCanceled(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Block forever — handler unblocks when client closes connection + // (which happens when the test's ctx is canceled). + select { + case <-time.After(5 * time.Second): + case <-r.Context().Done(): + } + })) + defer srv.Close() + + rel := &Release{AssetURL: srv.URL + "/exe", SumsURL: srv.URL + "/sums"} + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err := downloadAndVerify(ctx, rel, "x", srv.Client(), "drover-go/test", nil) + if err == nil { + t.Fatal("expected error from canceled context; got nil") + } +} + +// equalBytes is here to avoid a bytes import for one comparison. +func equalBytes(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// Sanity check that the Forgejo schema fields we care about decode correctly. +func TestReleaseJSONShape(t *testing.T) { + t.Parallel() + var rel forgejoRelease + body := fmt.Sprintf(sampleReleaseJSON, "https://example.invalid", "https://example.invalid") + if err := json.Unmarshal([]byte(body), &rel); err != nil { + t.Fatalf("decode err: %v", err) + } + if len(rel.Assets) != 2 { + t.Fatalf("got %d assets; want 2", len(rel.Assets)) + } +}