Implement internal/updater: selfupdate via Forgejo Releases API

Adds a small, well-tested package that:
- Queries /api/v1/repos/root/drover-go/releases/latest (404 = no updates,
  not an error).
- Compares the published tag against the running Version using
  golang.org/x/mod/semver, so v0.1.0-rc.2 < v0.1.0. "dev" or any
  semver-invalid current version is treated as "always update".
- Downloads the windows-amd64 asset + SHA256SUMS.txt, verifies the
  sha256 of the binary against its line in the sums file (tolerates
  the asterisk binary-mode prefix), and atomically swaps the running
  exe via github.com/minio/selfupdate.
- Uses a 15s connect timeout with no overall request deadline, so
  large asset downloads aren't truncated.
- Reports progress via an optional callback.

Public surface: Source interface + ForgejoSource implementation,
CheckForUpdate, ApplyUpdate, SetVersion. No GUI/cobra/Wails imports
in the package, so the same code is reusable from the CLI, the
Windows service, and the future tray UI.

Wires the package into "drover update" / "drover update --check-only"
in cmd/drover/main.go. --check-only exits 0 whether or not an update
is available; only network/sha/apply errors are non-zero.

Tests cover CheckForUpdate (table-driven incl. semver pre-release
ordering, dev fallthrough, source errors), parseSHA256Sums (text and
binary modes, CRLF, malformed lines, missing entries),
ForgejoSource.Latest (httptest with canned JSON, 404, 500, missing
asset, missing SHA256SUMS), and downloadAndVerify (success, sha
mismatch, HTTP 404, context cancellation). All run with -race.

Smoke-tested manually: built drover.exe and "drover update --check-only"
against git.okcu.io prints "No updates available" and exits 0 (no
releases yet).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:20:24 +03:00
parent 25df64213c
commit 1ad8de32f2
6 changed files with 1155 additions and 5 deletions
+32 -2
View File
@@ -6,6 +6,8 @@ import (
"os"
"github.com/spf13/cobra"
"git.okcu.io/root/drover-go/internal/updater"
)
// Build-time variables, populated via -ldflags "-X main.Version=... -X main.Commit=... -X main.BuildDate=...".
@@ -20,6 +22,10 @@ var (
var configPath string
func main() {
// Inject our build version so the updater package can stamp it on the
// User-Agent header it sends to git.okcu.io.
updater.SetVersion(Version)
if err := newRootCmd().Execute(); err != nil {
// Cobra already prints the error; just exit non-zero.
os.Exit(1)
@@ -64,8 +70,32 @@ func newUpdateCmd() *cobra.Command {
Use: "update",
Short: "Self-update via the Forgejo Releases API",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Fprintln(cmd.OutOrStdout(), "TODO: update check")
_ = checkOnly // wired up for the upcoming implementation
ctx := cmd.Context()
out := cmd.OutOrStdout()
src := updater.NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe")
rel, hasUpdate, err := updater.CheckForUpdate(ctx, src, Version)
if err != nil {
return fmt.Errorf("check for update: %w", err)
}
if !hasUpdate {
fmt.Fprintln(out, "No updates available")
return nil
}
fmt.Fprintf(out, "Update available: %s (current v%s)\n", rel.TagName, Version)
if checkOnly {
return nil
}
fmt.Fprintln(out, "Downloading...")
if err := updater.ApplyUpdate(ctx, rel, func(d, t int64) {
if t > 0 {
fmt.Fprintf(out, "\r%d/%d bytes", d, t)
}
}); err != nil {
return fmt.Errorf("apply update: %w", err)
}
fmt.Fprintln(out, "\nUpdate applied. Restart drover.")
return nil
},
}
+8 -1
View File
@@ -2,9 +2,16 @@ module git.okcu.io/root/drover-go
go 1.23
require github.com/spf13/cobra v1.10.2
require (
github.com/minio/selfupdate v0.6.0
github.com/spf13/cobra v1.10.2
golang.org/x/mod v0.21.0
)
require (
aead.dev/minisign v0.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.9 // indirect
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
)
+24
View File
@@ -1,10 +1,34 @@
aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-2
View File
@@ -1,2 +0,0 @@
// Package updater performs self-update via the Forgejo Releases API.
package updater
+479
View File
@@ -0,0 +1,479 @@
// Package updater performs self-update via the Forgejo Releases API.
//
// The package is split into:
//
// - Source: an interface that fetches "the latest release" from somewhere.
// ForgejoSource is the production implementation; tests use a fake.
// - CheckForUpdate: compares a release tag against a current version using
// semver-aware ordering (golang.org/x/mod/semver).
// - ApplyUpdate: downloads the asset + SHA256SUMS.txt, verifies the binary
// against its hash, and atomically replaces the running executable via
// github.com/minio/selfupdate.
//
// The package is GUI- and CLI-framework-agnostic; importers should not pull
// in cobra/Wails as a transitive dependency.
package updater
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/minio/selfupdate"
"golang.org/x/mod/semver"
)
// version is exposed for tests/User-Agent. Caller-overridable.
var version = "dev"
// SetVersion lets the calling binary inject its build version into the
// User-Agent header used by HTTP requests. Optional; defaults to "dev".
func SetVersion(v string) { version = v }
// Release is the minimal subset of a Forgejo release we care about.
type Release struct {
TagName string // e.g. "v0.1.2"
Name string // human-readable name (often equals TagName)
AssetURL string // browser_download_url of the windows-amd64 exe asset
SumsURL string // browser_download_url of SHA256SUMS.txt
Prerelease bool
CreatedAt time.Time
}
// Source abstracts "where do releases come from" so callers can swap real
// HTTP sources for fakes in tests.
type Source interface {
// Latest returns the latest non-draft release, or (nil, nil) if there
// are no releases yet (e.g. Forgejo returned 404 from /releases/latest).
// Errors should be returned wrapped with %w for any genuine failure.
Latest(ctx context.Context) (*Release, error)
}
// CheckForUpdate compares the latest release's tag name against
// currentVersion. Returns (release, true, nil) if the latest is strictly
// newer; (nil, false, nil) if the current is up-to-date or newer; or an
// error wrapping the source's error.
//
// If currentVersion fails semver.IsValid (e.g. it is "dev"), CheckForUpdate
// treats every available release as an update so dev builds get notified.
func CheckForUpdate(ctx context.Context, src Source, currentVersion string) (*Release, bool, error) {
rel, err := src.Latest(ctx)
if err != nil {
return nil, false, fmt.Errorf("checking latest release: %w", err)
}
if rel == nil {
return nil, false, nil
}
if !isNewer(rel.TagName, currentVersion) {
return nil, false, nil
}
return rel, true, nil
}
// isNewer reports whether tag is strictly newer than current under semver
// ordering. Both inputs are normalized to a leading "v". An invalid current
// version is treated as "older than anything" so dev builds always update.
func isNewer(tag, current string) bool {
tagV := ensureV(tag)
curV := ensureV(current)
if !semver.IsValid(tagV) {
// Remote tag is malformed; refuse to update.
return false
}
if !semver.IsValid(curV) {
// Current version is "dev" or otherwise invalid → always update.
return true
}
return semver.Compare(tagV, curV) > 0
}
func ensureV(s string) string {
if strings.HasPrefix(s, "v") {
return s
}
return "v" + s
}
// ---------------------------------------------------------------------------
// SHA256SUMS parsing
// ---------------------------------------------------------------------------
// parseSHA256Sums returns the lowercase-hex SHA-256 digest associated with
// targetName from a sha256sum-formatted SUMS file. The format is one line
// per file: "<64-hex> <name>" (text mode, two spaces) or "<64-hex> *<name>"
// (binary mode, space + asterisk). Both forms are accepted.
//
// Returns an error if the target is missing or the file is malformed.
func parseSHA256Sums(data []byte, targetName string) (string, error) {
scanner := bufio.NewScanner(bytes.NewReader(data))
// Allow long lines (default is 64 KiB which is plenty for sums files).
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimRight(scanner.Text(), "\r")
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Find the first space; everything before is hash, everything after
// (with optional leading '*') is the filename.
idx := strings.IndexByte(line, ' ')
if idx < 0 {
return "", fmt.Errorf("malformed SHA256SUMS line %d: missing separator", lineNum)
}
hashHex := line[:idx]
rest := strings.TrimLeft(line[idx+1:], " ")
rest = strings.TrimPrefix(rest, "*")
if len(hashHex) != 64 {
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is %d chars, want 64", lineNum, len(hashHex))
}
if _, err := hex.DecodeString(hashHex); err != nil {
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is not valid hex: %w", lineNum, err)
}
if rest == targetName {
return strings.ToLower(hashHex), nil
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scanning SHA256SUMS: %w", err)
}
return "", fmt.Errorf("no SHA256SUMS entry for %q", targetName)
}
// ---------------------------------------------------------------------------
// Forgejo source
// ---------------------------------------------------------------------------
// ForgejoSource talks to a Forgejo instance's Releases API.
type ForgejoSource struct {
baseURL string // e.g. "https://git.okcu.io"
owner string
repo string
assetPattern string // substring matched against asset name
client *http.Client
userAgent string
}
// NewForgejoSource constructs a Source that talks to host (the Forgejo
// hostname, e.g. "git.okcu.io"). assetPattern is a substring matched against
// release asset names (e.g. "windows-amd64.exe").
func NewForgejoSource(host, owner, repo, assetPattern string) Source {
return &ForgejoSource{
baseURL: "https://" + strings.TrimRight(host, "/"),
owner: owner,
repo: repo,
assetPattern: assetPattern,
client: defaultHTTPClient(),
userAgent: "drover-go/" + version,
}
}
// defaultHTTPClient returns a context-aware http.Client with a 15s connect
// timeout but no overall request deadline (asset downloads may be large).
func defaultHTTPClient() *http.Client {
dialer := &net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{Transport: tr}
}
// forgejoRelease mirrors the JSON returned by /api/v1/repos/.../releases/latest.
type forgejoRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Draft bool `json:"draft"`
Prerelease bool `json:"prerelease"`
CreatedAt time.Time `json:"created_at"`
Assets []forgejoAsset `json:"assets"`
}
type forgejoAsset struct {
Name string `json:"name"`
Size int64 `json:"size"`
DownloadURL string `json:"browser_download_url"`
}
// Latest fetches the latest release. Returns (nil, nil) on 404 (no releases).
func (s *ForgejoSource) Latest(ctx context.Context) (*Release, error) {
endpoint, err := url.JoinPath(s.baseURL, "api", "v1", "repos", s.owner, s.repo, "releases", "latest")
if err != nil {
return nil, fmt.Errorf("building releases URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", s.userAgent)
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("GET %s: %w", endpoint, err)
}
defer resp.Body.Close()
switch {
case resp.StatusCode == http.StatusNotFound:
// No public releases exist.
return nil, nil
case resp.StatusCode < 200 || resp.StatusCode >= 300:
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", endpoint, resp.StatusCode, strings.TrimSpace(string(body)))
}
var raw forgejoRelease
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return nil, fmt.Errorf("decoding release JSON: %w", err)
}
if raw.Draft {
// Forgejo's /latest is supposed to skip drafts, but be defensive.
return nil, nil
}
assetURL, assetName := pickAsset(raw.Assets, s.assetPattern)
if assetURL == "" {
return nil, fmt.Errorf("release %s has no asset matching %q", raw.TagName, s.assetPattern)
}
sumsURL := pickAssetByExactName(raw.Assets, "SHA256SUMS.txt")
if sumsURL == "" {
return nil, fmt.Errorf("release %s is missing SHA256SUMS.txt asset", raw.TagName)
}
_ = assetName // currently unused outside the source; reserved for future logging
return &Release{
TagName: raw.TagName,
Name: raw.Name,
AssetURL: assetURL,
SumsURL: sumsURL,
Prerelease: raw.Prerelease,
CreatedAt: raw.CreatedAt,
}, nil
}
func pickAsset(assets []forgejoAsset, pattern string) (string, string) {
for _, a := range assets {
if strings.Contains(a.Name, pattern) {
return a.DownloadURL, a.Name
}
}
return "", ""
}
func pickAssetByExactName(assets []forgejoAsset, name string) string {
for _, a := range assets {
if a.Name == name {
return a.DownloadURL
}
}
return ""
}
// ---------------------------------------------------------------------------
// ApplyUpdate
// ---------------------------------------------------------------------------
// ApplyUpdate downloads the release's asset + SHA256SUMS.txt, verifies the
// asset against its sha256 entry, and atomically replaces the current
// executable using github.com/minio/selfupdate.
//
// progress, if non-nil, is called periodically with (downloaded, total)
// while the asset is being downloaded. total is set from the Content-Length
// header and may be 0 if the server omits it.
func ApplyUpdate(ctx context.Context, rel *Release, progress func(downloaded, total int64)) error {
if rel == nil {
return errors.New("apply update: release is nil")
}
if rel.AssetURL == "" || rel.SumsURL == "" {
return errors.New("apply update: release is missing AssetURL or SumsURL")
}
assetName := assetNameFromURL(rel.AssetURL)
if assetName == "" {
return fmt.Errorf("apply update: cannot derive asset filename from URL %q", rel.AssetURL)
}
client := defaultHTTPClient()
binary, err := downloadAndVerify(ctx, rel, assetName, client, "drover-go/"+version, progress)
if err != nil {
return fmt.Errorf("apply update: %w", err)
}
// selfupdate.Apply takes an io.Reader and atomically swaps the running
// binary. We pass a bytes.Reader since we already have the verified
// bytes in memory.
if err := selfupdate.Apply(bytes.NewReader(binary), selfupdate.Options{}); err != nil {
// On failure, selfupdate may leave a .old file; surface that fact
// to the caller via the wrapped error.
if rerr := selfupdate.RollbackError(err); rerr != nil {
return fmt.Errorf("apply update: replacing executable failed and rollback failed: %w (rollback: %v)", err, rerr)
}
return fmt.Errorf("apply update: replacing executable: %w", err)
}
return nil
}
// assetNameFromURL returns the last path segment of u (used to look the file
// up in SHA256SUMS.txt). Returns "" if u does not parse.
func assetNameFromURL(u string) string {
parsed, err := url.Parse(u)
if err != nil {
return ""
}
return path.Base(parsed.Path)
}
// downloadAndVerify is the I/O-heavy core of ApplyUpdate, factored out so
// it's testable without invoking selfupdate.Apply (which would replace the
// running test binary). Returns the verified asset bytes.
func downloadAndVerify(
ctx context.Context,
rel *Release,
assetName string,
client *http.Client,
userAgent string,
progress func(downloaded, total int64),
) ([]byte, error) {
// 1. Fetch SHA256SUMS.txt first — fail fast if it's missing/wrong before
// we spend time on a possibly-large binary.
sumsBody, err := httpGetAll(ctx, client, rel.SumsURL, userAgent)
if err != nil {
return nil, fmt.Errorf("downloading SHA256SUMS.txt: %w", err)
}
wantHashHex, err := parseSHA256Sums(sumsBody, assetName)
if err != nil {
return nil, fmt.Errorf("reading SHA256SUMS.txt: %w", err)
}
wantHash, err := hex.DecodeString(wantHashHex)
if err != nil {
return nil, fmt.Errorf("decoding expected sha256 hex: %w", err)
}
// 2. Fetch the asset, hashing as we go.
binBody, err := httpGetWithProgress(ctx, client, rel.AssetURL, userAgent, progress)
if err != nil {
return nil, fmt.Errorf("downloading asset %s: %w", assetName, err)
}
// 3. Verify.
gotHash := sha256.Sum256(binBody)
if !bytes.Equal(gotHash[:], wantHash) {
return nil, fmt.Errorf("sha256 checksum mismatch for %s: got %x, want %s", assetName, gotHash[:], wantHashHex)
}
return binBody, nil
}
func httpGetAll(ctx context.Context, client *http.Client, rawurl, userAgent string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
}
// Cap at 1 MiB — SHA256SUMS.txt is tiny; this is for the sums path only.
// Asset downloads use httpGetWithProgress.
return io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
}
func httpGetWithProgress(
ctx context.Context,
client *http.Client,
rawurl, userAgent string,
progress func(downloaded, total int64),
) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
}
total := resp.ContentLength // -1 if unknown
var totalReport int64
if total > 0 {
totalReport = total
}
// Cap a reasonable absolute ceiling so a malicious server can't OOM us.
// 256 MiB is far above any realistic drover-go binary.
const maxBytes = 256 * 1024 * 1024
body := resp.Body
if total > 0 && total > maxBytes {
return nil, fmt.Errorf("asset Content-Length %d exceeds %d byte cap", total, maxBytes)
}
pr := &progressReader{
r: io.LimitReader(body, maxBytes),
total: totalReport,
callback: progress,
}
buf, err := io.ReadAll(pr)
if err != nil {
return nil, fmt.Errorf("reading body: %w", err)
}
// Final progress tick if total was unknown (so callers get a "done" signal).
if progress != nil && totalReport == 0 {
progress(int64(len(buf)), int64(len(buf)))
}
return buf, nil
}
// progressReader wraps an io.Reader and emits progress updates after each
// successful Read. It is allocation-light and never blocks; callbacks are
// invoked synchronously from Read so they should be cheap.
type progressReader struct {
r io.Reader
total int64
read int64
callback func(downloaded, total int64)
}
func (p *progressReader) Read(b []byte) (int, error) {
n, err := p.r.Read(b)
if n > 0 {
p.read += int64(n)
if p.callback != nil {
p.callback(p.read, p.total)
}
}
return n, err
}
+612
View File
@@ -0,0 +1,612 @@
package updater
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// fakeSource is an in-memory Source implementation for tests.
type fakeSource struct {
rel *Release
err error
}
func (f *fakeSource) Latest(ctx context.Context) (*Release, error) {
if f.err != nil {
return nil, f.err
}
return f.rel, nil
}
// ---------------------------------------------------------------------------
// CheckForUpdate
// ---------------------------------------------------------------------------
func TestCheckForUpdate(t *testing.T) {
t.Parallel()
rel := func(tag string) *Release {
return &Release{TagName: tag, Name: tag, AssetURL: "u", SumsURL: "s"}
}
sentinel := errors.New("network down")
tests := []struct {
name string
current string
src Source
wantHasUpdate bool
wantNilRelease bool
wantErrIs error
}{
{
name: "newer release available",
current: "0.1.0",
src: &fakeSource{rel: rel("v0.2.0")},
wantHasUpdate: true,
},
{
name: "newer release with v-prefixed current version",
current: "v0.1.0",
src: &fakeSource{rel: rel("v0.2.0")},
wantHasUpdate: true,
},
{
name: "same version — no update",
current: "0.2.0",
src: &fakeSource{rel: rel("v0.2.0")},
wantHasUpdate: false,
wantNilRelease: true,
},
{
name: "older release on remote — no update",
current: "0.5.0",
src: &fakeSource{rel: rel("v0.2.0")},
wantHasUpdate: false,
wantNilRelease: true,
},
{
name: "rc < final (semver pre-release ordering)",
current: "0.1.0-rc.2",
src: &fakeSource{rel: rel("v0.1.0")},
wantHasUpdate: true,
},
{
name: "final >= rc",
current: "0.1.0",
src: &fakeSource{rel: rel("v0.1.0-rc.2")},
wantHasUpdate: false,
wantNilRelease: true,
},
{
name: "dev current — always update",
current: "dev",
src: &fakeSource{rel: rel("v0.1.0")},
wantHasUpdate: true,
},
{
name: "garbage current — always update (treated as invalid)",
current: "not-a-version",
src: &fakeSource{rel: rel("v0.1.0")},
wantHasUpdate: true,
},
{
name: "source returns nil release — no update",
current: "0.1.0",
src: &fakeSource{rel: nil},
wantHasUpdate: false,
wantNilRelease: true,
},
{
name: "source error propagates wrapped",
current: "0.1.0",
src: &fakeSource{err: sentinel},
wantErrIs: sentinel,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
got, has, err := CheckForUpdate(ctx, tc.src, tc.current)
if tc.wantErrIs != nil {
if err == nil {
t.Fatalf("expected error wrapping %v, got nil", tc.wantErrIs)
}
if !errors.Is(err, tc.wantErrIs) {
t.Fatalf("err = %v; expected to wrap %v", err, tc.wantErrIs)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if has != tc.wantHasUpdate {
t.Fatalf("hasUpdate = %v; want %v", has, tc.wantHasUpdate)
}
if tc.wantNilRelease && got != nil {
t.Fatalf("expected nil release; got %+v", got)
}
if !tc.wantNilRelease && tc.wantHasUpdate && got == nil {
t.Fatalf("expected non-nil release on update; got nil")
}
})
}
}
// ---------------------------------------------------------------------------
// parseSHA256Sums
// ---------------------------------------------------------------------------
func TestParseSHA256Sums(t *testing.T) {
t.Parallel()
const a = "8da085332782708d8767bcace5327a6ec7283c17cfb85e40b03cd2323a90ddc2"
const b = "c1e060ee19444a259b2162f8af0f3fe8c4428a1c6f694dce20de194ac8d7d9a2"
tests := []struct {
name string
input string
want string
wantName string
wantErr bool
}{
{
name: "standard text-mode line",
input: a + " drover-v0.1.0-windows-amd64.exe\n",
want: a,
wantName: "drover-v0.1.0-windows-amd64.exe",
},
{
name: "binary-mode asterisk prefix",
input: b + " *drover-v0.1.0-setup.exe\n",
want: b,
wantName: "drover-v0.1.0-setup.exe",
},
{
name: "multi-line, target on second",
input: a + " some-other-file.txt\n" +
b + " drover-v0.1.0-setup.exe\n",
want: b,
wantName: "drover-v0.1.0-setup.exe",
},
{
name: "multi-line, blank lines and CRLF tolerated",
input: "\r\n" +
a + " drover-v0.1.0-windows-amd64.exe\r\n" +
"\n",
want: a,
wantName: "drover-v0.1.0-windows-amd64.exe",
},
{
name: "missing entry",
input: a + " some-other-file.txt\n",
wantName: "drover-v0.1.0-windows-amd64.exe",
wantErr: true,
},
{
name: "malformed line — no separator",
input: a + "drover-v0.1.0-windows-amd64.exe\n",
wantName: "drover-v0.1.0-windows-amd64.exe",
wantErr: true,
},
{
name: "malformed line — bad hex length",
input: "deadbeef drover-v0.1.0-windows-amd64.exe\n",
wantName: "drover-v0.1.0-windows-amd64.exe",
wantErr: true,
},
{
name: "empty input",
input: "",
wantName: "drover-v0.1.0-windows-amd64.exe",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := parseSHA256Sums([]byte(tc.input), tc.wantName)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error; got hash %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Fatalf("hash = %q; want %q", got, tc.want)
}
})
}
}
// ---------------------------------------------------------------------------
// ForgejoSource.Latest
// ---------------------------------------------------------------------------
const sampleReleaseJSON = `{
"tag_name": "v0.1.0",
"name": "v0.1.0",
"draft": false,
"prerelease": false,
"created_at": "2026-04-25T13:35:33+03:00",
"assets": [
{
"name": "drover-v0.1.0-windows-amd64.exe",
"size": 5800000,
"browser_download_url": "%s/drover-v0.1.0-windows-amd64.exe"
},
{
"name": "SHA256SUMS.txt",
"size": 200,
"browser_download_url": "%s/SHA256SUMS.txt"
}
]
}`
func TestForgejoSource_Latest_Success(t *testing.T) {
t.Parallel()
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/repos/root/drover-go/releases/latest" {
t.Errorf("unexpected request path: %s", r.URL.Path)
http.NotFound(w, r)
return
}
if ua := r.Header.Get("User-Agent"); !strings.HasPrefix(ua, "drover-go/") {
t.Errorf("expected User-Agent drover-go/...; got %q", ua)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, sampleReleaseJSON, srv.URL, srv.URL)
}))
defer srv.Close()
src := &ForgejoSource{
baseURL: srv.URL,
owner: "root",
repo: "drover-go",
assetPattern: "windows-amd64.exe",
client: srv.Client(),
userAgent: "drover-go/test",
}
rel, err := src.Latest(context.Background())
if err != nil {
t.Fatalf("Latest err = %v", err)
}
if rel == nil {
t.Fatal("nil release returned")
}
if rel.TagName != "v0.1.0" {
t.Errorf("TagName = %q; want v0.1.0", rel.TagName)
}
wantAsset := srv.URL + "/drover-v0.1.0-windows-amd64.exe"
if rel.AssetURL != wantAsset {
t.Errorf("AssetURL = %q; want %q", rel.AssetURL, wantAsset)
}
wantSums := srv.URL + "/SHA256SUMS.txt"
if rel.SumsURL != wantSums {
t.Errorf("SumsURL = %q; want %q", rel.SumsURL, wantSums)
}
if rel.Prerelease {
t.Error("Prerelease = true; want false")
}
if rel.CreatedAt.IsZero() {
t.Error("CreatedAt is zero")
}
}
func TestForgejoSource_Latest_404_NoUpdates(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer srv.Close()
src := &ForgejoSource{
baseURL: srv.URL,
owner: "root",
repo: "drover-go",
assetPattern: "windows-amd64.exe",
client: srv.Client(),
userAgent: "drover-go/test",
}
rel, err := src.Latest(context.Background())
if err != nil {
t.Fatalf("expected nil error on 404; got %v", err)
}
if rel != nil {
t.Fatalf("expected nil release on 404; got %+v", rel)
}
}
func TestForgejoSource_Latest_500_Error(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "boom", http.StatusInternalServerError)
}))
defer srv.Close()
src := &ForgejoSource{
baseURL: srv.URL,
owner: "root",
repo: "drover-go",
assetPattern: "windows-amd64.exe",
client: srv.Client(),
userAgent: "drover-go/test",
}
_, err := src.Latest(context.Background())
if err == nil {
t.Fatal("expected error on 500; got nil")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("expected error to mention HTTP 500; got %v", err)
}
}
func TestForgejoSource_Latest_AssetMissing(t *testing.T) {
t.Parallel()
// Sample with only SHA256SUMS, no exe asset.
body := `{
"tag_name": "v0.1.0",
"name": "v0.1.0",
"draft": false,
"prerelease": false,
"created_at": "2026-04-25T13:35:33+03:00",
"assets": [
{"name":"SHA256SUMS.txt","size":1,"browser_download_url":"https://example.invalid/s"}
]
}`
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, body)
}))
defer srv.Close()
src := &ForgejoSource{
baseURL: srv.URL,
owner: "root",
repo: "drover-go",
assetPattern: "windows-amd64.exe",
client: srv.Client(),
userAgent: "drover-go/test",
}
_, err := src.Latest(context.Background())
if err == nil {
t.Fatal("expected error when matching asset is missing; got nil")
}
}
func TestForgejoSource_Latest_SumsMissing(t *testing.T) {
t.Parallel()
body := `{
"tag_name": "v0.1.0",
"name": "v0.1.0",
"draft": false,
"prerelease": false,
"created_at": "2026-04-25T13:35:33+03:00",
"assets": [
{"name":"drover-v0.1.0-windows-amd64.exe","size":1,"browser_download_url":"https://example.invalid/x"}
]
}`
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, body)
}))
defer srv.Close()
src := &ForgejoSource{
baseURL: srv.URL,
owner: "root",
repo: "drover-go",
assetPattern: "windows-amd64.exe",
client: srv.Client(),
userAgent: "drover-go/test",
}
_, err := src.Latest(context.Background())
if err == nil {
t.Fatal("expected error when SHA256SUMS.txt is missing; got nil")
}
if !strings.Contains(err.Error(), "SHA256SUMS.txt") {
t.Errorf("expected error to mention SHA256SUMS.txt; got %v", err)
}
}
func TestNewForgejoSource_BuildsURL(t *testing.T) {
t.Parallel()
src := NewForgejoSource("git.okcu.io", "root", "drover-go", "windows-amd64.exe").(*ForgejoSource)
if want := "https://git.okcu.io"; src.baseURL != want {
t.Errorf("baseURL = %q; want %q", src.baseURL, want)
}
if src.owner != "root" || src.repo != "drover-go" {
t.Errorf("owner/repo wrong: %s/%s", src.owner, src.repo)
}
if src.assetPattern != "windows-amd64.exe" {
t.Errorf("assetPattern = %q", src.assetPattern)
}
if src.client == nil {
t.Error("client must not be nil")
}
if !strings.HasPrefix(src.userAgent, "drover-go/") {
t.Errorf("userAgent = %q; want drover-go/...", src.userAgent)
}
}
// ---------------------------------------------------------------------------
// downloadAndVerify (the testable extract from ApplyUpdate)
// ---------------------------------------------------------------------------
func TestDownloadAndVerify_Success(t *testing.T) {
t.Parallel()
exeData := []byte("MZ\x90fake-windows-binary-content")
hash := sha256.Sum256(exeData)
hexHash := hex.EncodeToString(hash[:])
exeName := "drover-v0.1.0-windows-amd64.exe"
sumsBody := hexHash + " " + exeName + "\n"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/exe":
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(exeData)))
_, _ = w.Write(exeData)
case "/sums":
_, _ = io.WriteString(w, sumsBody)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
rel := &Release{
TagName: "v0.1.0",
Name: "v0.1.0",
AssetURL: srv.URL + "/exe",
SumsURL: srv.URL + "/sums",
}
var pcDownloaded, pcTotal int64
bin, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", func(d, total int64) {
pcDownloaded = d
pcTotal = total
})
if err != nil {
t.Fatalf("downloadAndVerify err = %v", err)
}
if !equalBytes(bin, exeData) {
t.Fatalf("downloaded bytes do not match source")
}
if pcTotal != int64(len(exeData)) {
t.Errorf("progress total = %d; want %d", pcTotal, len(exeData))
}
if pcDownloaded != int64(len(exeData)) {
t.Errorf("progress downloaded final = %d; want %d", pcDownloaded, len(exeData))
}
}
func TestDownloadAndVerify_SHAMismatch(t *testing.T) {
t.Parallel()
exeData := []byte("real-binary-bytes")
exeName := "drover-v0.1.0-windows-amd64.exe"
wrongHash := strings.Repeat("0", 64)
sumsBody := wrongHash + " " + exeName + "\n"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/exe":
_, _ = w.Write(exeData)
case "/sums":
_, _ = io.WriteString(w, sumsBody)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
rel := &Release{
TagName: "v0.1.0",
Name: "v0.1.0",
AssetURL: srv.URL + "/exe",
SumsURL: srv.URL + "/sums",
}
_, err := downloadAndVerify(context.Background(), rel, exeName, srv.Client(), "drover-go/test", nil)
if err == nil {
t.Fatal("expected SHA mismatch error; got nil")
}
if !strings.Contains(err.Error(), "sha256") && !strings.Contains(err.Error(), "checksum") {
t.Errorf("expected error to mention sha256/checksum; got %v", err)
}
}
func TestDownloadAndVerify_AssetHTTP404(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer srv.Close()
rel := &Release{
TagName: "v0.1.0",
AssetURL: srv.URL + "/exe",
SumsURL: srv.URL + "/sums",
}
_, err := downloadAndVerify(context.Background(), rel, "drover-v0.1.0-windows-amd64.exe", srv.Client(), "drover-go/test", nil)
if err == nil {
t.Fatal("expected error on 404; got nil")
}
}
func TestDownloadAndVerify_ContextCanceled(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Block forever — handler unblocks when client closes connection
// (which happens when the test's ctx is canceled).
select {
case <-time.After(5 * time.Second):
case <-r.Context().Done():
}
}))
defer srv.Close()
rel := &Release{AssetURL: srv.URL + "/exe", SumsURL: srv.URL + "/sums"}
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
_, err := downloadAndVerify(ctx, rel, "x", srv.Client(), "drover-go/test", nil)
if err == nil {
t.Fatal("expected error from canceled context; got nil")
}
}
// equalBytes is here to avoid a bytes import for one comparison.
func equalBytes(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// Sanity check that the Forgejo schema fields we care about decode correctly.
func TestReleaseJSONShape(t *testing.T) {
t.Parallel()
var rel forgejoRelease
body := fmt.Sprintf(sampleReleaseJSON, "https://example.invalid", "https://example.invalid")
if err := json.Unmarshal([]byte(body), &rel); err != nil {
t.Fatalf("decode err: %v", err)
}
if len(rel.Assets) != 2 {
t.Fatalf("got %d assets; want 2", len(rel.Assets))
}
}