Implement internal/updater: selfupdate via Forgejo Releases API

Adds a small, well-tested package that:
- Queries /api/v1/repos/root/drover-go/releases/latest (404 = no updates,
  not an error).
- Compares the published tag against the running Version using
  golang.org/x/mod/semver, so v0.1.0-rc.2 < v0.1.0. "dev" or any
  semver-invalid current version is treated as "always update".
- Downloads the windows-amd64 asset + SHA256SUMS.txt, verifies the
  sha256 of the binary against its line in the sums file (tolerates
  the asterisk binary-mode prefix), and atomically swaps the running
  exe via github.com/minio/selfupdate.
- Uses a 15s connect timeout with no overall request deadline, so
  large asset downloads aren't truncated.
- Reports progress via an optional callback.

Public surface: Source interface + ForgejoSource implementation,
CheckForUpdate, ApplyUpdate, SetVersion. No GUI/cobra/Wails imports
in the package, so the same code is reusable from the CLI, the
Windows service, and the future tray UI.

Wires the package into "drover update" / "drover update --check-only"
in cmd/drover/main.go. --check-only exits 0 whether or not an update
is available; only network/sha/apply errors are non-zero.

Tests cover CheckForUpdate (table-driven incl. semver pre-release
ordering, dev fallthrough, source errors), parseSHA256Sums (text and
binary modes, CRLF, malformed lines, missing entries),
ForgejoSource.Latest (httptest with canned JSON, 404, 500, missing
asset, missing SHA256SUMS), and downloadAndVerify (success, sha
mismatch, HTTP 404, context cancellation). All run with -race.

Smoke-tested manually: built drover.exe and "drover update --check-only"
against git.okcu.io prints "No updates available" and exits 0 (no
releases yet).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:20:24 +03:00
parent 25df64213c
commit 1ad8de32f2
6 changed files with 1155 additions and 5 deletions
+479
View File
@@ -0,0 +1,479 @@
// Package updater performs self-update via the Forgejo Releases API.
//
// The package is split into:
//
// - Source: an interface that fetches "the latest release" from somewhere.
// ForgejoSource is the production implementation; tests use a fake.
// - CheckForUpdate: compares a release tag against a current version using
// semver-aware ordering (golang.org/x/mod/semver).
// - ApplyUpdate: downloads the asset + SHA256SUMS.txt, verifies the binary
// against its hash, and atomically replaces the running executable via
// github.com/minio/selfupdate.
//
// The package is GUI- and CLI-framework-agnostic; importers should not pull
// in cobra/Wails as a transitive dependency.
package updater
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/minio/selfupdate"
"golang.org/x/mod/semver"
)
// version is exposed for tests/User-Agent. Caller-overridable.
var version = "dev"
// SetVersion lets the calling binary inject its build version into the
// User-Agent header used by HTTP requests. Optional; defaults to "dev".
func SetVersion(v string) { version = v }
// Release is the minimal subset of a Forgejo release we care about.
type Release struct {
TagName string // e.g. "v0.1.2"
Name string // human-readable name (often equals TagName)
AssetURL string // browser_download_url of the windows-amd64 exe asset
SumsURL string // browser_download_url of SHA256SUMS.txt
Prerelease bool
CreatedAt time.Time
}
// Source abstracts "where do releases come from" so callers can swap real
// HTTP sources for fakes in tests.
type Source interface {
// Latest returns the latest non-draft release, or (nil, nil) if there
// are no releases yet (e.g. Forgejo returned 404 from /releases/latest).
// Errors should be returned wrapped with %w for any genuine failure.
Latest(ctx context.Context) (*Release, error)
}
// CheckForUpdate compares the latest release's tag name against
// currentVersion. Returns (release, true, nil) if the latest is strictly
// newer; (nil, false, nil) if the current is up-to-date or newer; or an
// error wrapping the source's error.
//
// If currentVersion fails semver.IsValid (e.g. it is "dev"), CheckForUpdate
// treats every available release as an update so dev builds get notified.
func CheckForUpdate(ctx context.Context, src Source, currentVersion string) (*Release, bool, error) {
rel, err := src.Latest(ctx)
if err != nil {
return nil, false, fmt.Errorf("checking latest release: %w", err)
}
if rel == nil {
return nil, false, nil
}
if !isNewer(rel.TagName, currentVersion) {
return nil, false, nil
}
return rel, true, nil
}
// isNewer reports whether tag is strictly newer than current under semver
// ordering. Both inputs are normalized to a leading "v". An invalid current
// version is treated as "older than anything" so dev builds always update.
func isNewer(tag, current string) bool {
tagV := ensureV(tag)
curV := ensureV(current)
if !semver.IsValid(tagV) {
// Remote tag is malformed; refuse to update.
return false
}
if !semver.IsValid(curV) {
// Current version is "dev" or otherwise invalid → always update.
return true
}
return semver.Compare(tagV, curV) > 0
}
func ensureV(s string) string {
if strings.HasPrefix(s, "v") {
return s
}
return "v" + s
}
// ---------------------------------------------------------------------------
// SHA256SUMS parsing
// ---------------------------------------------------------------------------
// parseSHA256Sums returns the lowercase-hex SHA-256 digest associated with
// targetName from a sha256sum-formatted SUMS file. The format is one line
// per file: "<64-hex> <name>" (text mode, two spaces) or "<64-hex> *<name>"
// (binary mode, space + asterisk). Both forms are accepted.
//
// Returns an error if the target is missing or the file is malformed.
func parseSHA256Sums(data []byte, targetName string) (string, error) {
scanner := bufio.NewScanner(bytes.NewReader(data))
// Allow long lines (default is 64 KiB which is plenty for sums files).
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimRight(scanner.Text(), "\r")
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Find the first space; everything before is hash, everything after
// (with optional leading '*') is the filename.
idx := strings.IndexByte(line, ' ')
if idx < 0 {
return "", fmt.Errorf("malformed SHA256SUMS line %d: missing separator", lineNum)
}
hashHex := line[:idx]
rest := strings.TrimLeft(line[idx+1:], " ")
rest = strings.TrimPrefix(rest, "*")
if len(hashHex) != 64 {
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is %d chars, want 64", lineNum, len(hashHex))
}
if _, err := hex.DecodeString(hashHex); err != nil {
return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is not valid hex: %w", lineNum, err)
}
if rest == targetName {
return strings.ToLower(hashHex), nil
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scanning SHA256SUMS: %w", err)
}
return "", fmt.Errorf("no SHA256SUMS entry for %q", targetName)
}
// ---------------------------------------------------------------------------
// Forgejo source
// ---------------------------------------------------------------------------
// ForgejoSource talks to a Forgejo instance's Releases API.
type ForgejoSource struct {
baseURL string // e.g. "https://git.okcu.io"
owner string
repo string
assetPattern string // substring matched against asset name
client *http.Client
userAgent string
}
// NewForgejoSource constructs a Source that talks to host (the Forgejo
// hostname, e.g. "git.okcu.io"). assetPattern is a substring matched against
// release asset names (e.g. "windows-amd64.exe").
func NewForgejoSource(host, owner, repo, assetPattern string) Source {
return &ForgejoSource{
baseURL: "https://" + strings.TrimRight(host, "/"),
owner: owner,
repo: repo,
assetPattern: assetPattern,
client: defaultHTTPClient(),
userAgent: "drover-go/" + version,
}
}
// defaultHTTPClient returns a context-aware http.Client with a 15s connect
// timeout but no overall request deadline (asset downloads may be large).
func defaultHTTPClient() *http.Client {
dialer := &net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{Transport: tr}
}
// forgejoRelease mirrors the JSON returned by /api/v1/repos/.../releases/latest.
type forgejoRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Draft bool `json:"draft"`
Prerelease bool `json:"prerelease"`
CreatedAt time.Time `json:"created_at"`
Assets []forgejoAsset `json:"assets"`
}
type forgejoAsset struct {
Name string `json:"name"`
Size int64 `json:"size"`
DownloadURL string `json:"browser_download_url"`
}
// Latest fetches the latest release. Returns (nil, nil) on 404 (no releases).
func (s *ForgejoSource) Latest(ctx context.Context) (*Release, error) {
endpoint, err := url.JoinPath(s.baseURL, "api", "v1", "repos", s.owner, s.repo, "releases", "latest")
if err != nil {
return nil, fmt.Errorf("building releases URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", s.userAgent)
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("GET %s: %w", endpoint, err)
}
defer resp.Body.Close()
switch {
case resp.StatusCode == http.StatusNotFound:
// No public releases exist.
return nil, nil
case resp.StatusCode < 200 || resp.StatusCode >= 300:
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", endpoint, resp.StatusCode, strings.TrimSpace(string(body)))
}
var raw forgejoRelease
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return nil, fmt.Errorf("decoding release JSON: %w", err)
}
if raw.Draft {
// Forgejo's /latest is supposed to skip drafts, but be defensive.
return nil, nil
}
assetURL, assetName := pickAsset(raw.Assets, s.assetPattern)
if assetURL == "" {
return nil, fmt.Errorf("release %s has no asset matching %q", raw.TagName, s.assetPattern)
}
sumsURL := pickAssetByExactName(raw.Assets, "SHA256SUMS.txt")
if sumsURL == "" {
return nil, fmt.Errorf("release %s is missing SHA256SUMS.txt asset", raw.TagName)
}
_ = assetName // currently unused outside the source; reserved for future logging
return &Release{
TagName: raw.TagName,
Name: raw.Name,
AssetURL: assetURL,
SumsURL: sumsURL,
Prerelease: raw.Prerelease,
CreatedAt: raw.CreatedAt,
}, nil
}
func pickAsset(assets []forgejoAsset, pattern string) (string, string) {
for _, a := range assets {
if strings.Contains(a.Name, pattern) {
return a.DownloadURL, a.Name
}
}
return "", ""
}
func pickAssetByExactName(assets []forgejoAsset, name string) string {
for _, a := range assets {
if a.Name == name {
return a.DownloadURL
}
}
return ""
}
// ---------------------------------------------------------------------------
// ApplyUpdate
// ---------------------------------------------------------------------------
// ApplyUpdate downloads the release's asset + SHA256SUMS.txt, verifies the
// asset against its sha256 entry, and atomically replaces the current
// executable using github.com/minio/selfupdate.
//
// progress, if non-nil, is called periodically with (downloaded, total)
// while the asset is being downloaded. total is set from the Content-Length
// header and may be 0 if the server omits it.
func ApplyUpdate(ctx context.Context, rel *Release, progress func(downloaded, total int64)) error {
if rel == nil {
return errors.New("apply update: release is nil")
}
if rel.AssetURL == "" || rel.SumsURL == "" {
return errors.New("apply update: release is missing AssetURL or SumsURL")
}
assetName := assetNameFromURL(rel.AssetURL)
if assetName == "" {
return fmt.Errorf("apply update: cannot derive asset filename from URL %q", rel.AssetURL)
}
client := defaultHTTPClient()
binary, err := downloadAndVerify(ctx, rel, assetName, client, "drover-go/"+version, progress)
if err != nil {
return fmt.Errorf("apply update: %w", err)
}
// selfupdate.Apply takes an io.Reader and atomically swaps the running
// binary. We pass a bytes.Reader since we already have the verified
// bytes in memory.
if err := selfupdate.Apply(bytes.NewReader(binary), selfupdate.Options{}); err != nil {
// On failure, selfupdate may leave a .old file; surface that fact
// to the caller via the wrapped error.
if rerr := selfupdate.RollbackError(err); rerr != nil {
return fmt.Errorf("apply update: replacing executable failed and rollback failed: %w (rollback: %v)", err, rerr)
}
return fmt.Errorf("apply update: replacing executable: %w", err)
}
return nil
}
// assetNameFromURL returns the last path segment of u (used to look the file
// up in SHA256SUMS.txt). Returns "" if u does not parse.
func assetNameFromURL(u string) string {
parsed, err := url.Parse(u)
if err != nil {
return ""
}
return path.Base(parsed.Path)
}
// downloadAndVerify is the I/O-heavy core of ApplyUpdate, factored out so
// it's testable without invoking selfupdate.Apply (which would replace the
// running test binary). Returns the verified asset bytes.
func downloadAndVerify(
ctx context.Context,
rel *Release,
assetName string,
client *http.Client,
userAgent string,
progress func(downloaded, total int64),
) ([]byte, error) {
// 1. Fetch SHA256SUMS.txt first — fail fast if it's missing/wrong before
// we spend time on a possibly-large binary.
sumsBody, err := httpGetAll(ctx, client, rel.SumsURL, userAgent)
if err != nil {
return nil, fmt.Errorf("downloading SHA256SUMS.txt: %w", err)
}
wantHashHex, err := parseSHA256Sums(sumsBody, assetName)
if err != nil {
return nil, fmt.Errorf("reading SHA256SUMS.txt: %w", err)
}
wantHash, err := hex.DecodeString(wantHashHex)
if err != nil {
return nil, fmt.Errorf("decoding expected sha256 hex: %w", err)
}
// 2. Fetch the asset, hashing as we go.
binBody, err := httpGetWithProgress(ctx, client, rel.AssetURL, userAgent, progress)
if err != nil {
return nil, fmt.Errorf("downloading asset %s: %w", assetName, err)
}
// 3. Verify.
gotHash := sha256.Sum256(binBody)
if !bytes.Equal(gotHash[:], wantHash) {
return nil, fmt.Errorf("sha256 checksum mismatch for %s: got %x, want %s", assetName, gotHash[:], wantHashHex)
}
return binBody, nil
}
func httpGetAll(ctx context.Context, client *http.Client, rawurl, userAgent string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
}
// Cap at 1 MiB — SHA256SUMS.txt is tiny; this is for the sums path only.
// Asset downloads use httpGetWithProgress.
return io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
}
func httpGetWithProgress(
ctx context.Context,
client *http.Client,
rawurl, userAgent string,
progress func(downloaded, total int64),
) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body)))
}
total := resp.ContentLength // -1 if unknown
var totalReport int64
if total > 0 {
totalReport = total
}
// Cap a reasonable absolute ceiling so a malicious server can't OOM us.
// 256 MiB is far above any realistic drover-go binary.
const maxBytes = 256 * 1024 * 1024
body := resp.Body
if total > 0 && total > maxBytes {
return nil, fmt.Errorf("asset Content-Length %d exceeds %d byte cap", total, maxBytes)
}
pr := &progressReader{
r: io.LimitReader(body, maxBytes),
total: totalReport,
callback: progress,
}
buf, err := io.ReadAll(pr)
if err != nil {
return nil, fmt.Errorf("reading body: %w", err)
}
// Final progress tick if total was unknown (so callers get a "done" signal).
if progress != nil && totalReport == 0 {
progress(int64(len(buf)), int64(len(buf)))
}
return buf, nil
}
// progressReader wraps an io.Reader and emits progress updates after each
// successful Read. It is allocation-light and never blocks; callbacks are
// invoked synchronously from Read so they should be cheap.
type progressReader struct {
r io.Reader
total int64
read int64
callback func(downloaded, total int64)
}
func (p *progressReader) Read(b []byte) (int, error) {
n, err := p.r.Read(b)
if n > 0 {
p.read += int64(n)
if p.callback != nil {
p.callback(p.read, p.total)
}
}
return n, err
}