// Package updater performs self-update via the Forgejo Releases API. // // The package is split into: // // - Source: an interface that fetches "the latest release" from somewhere. // ForgejoSource is the production implementation; tests use a fake. // - CheckForUpdate: compares a release tag against a current version using // semver-aware ordering (golang.org/x/mod/semver). // - ApplyUpdate: downloads the asset + SHA256SUMS.txt, verifies the binary // against its hash, and atomically replaces the running executable via // github.com/minio/selfupdate. // // The package is GUI- and CLI-framework-agnostic; importers should not pull // in cobra/Wails as a transitive dependency. package updater import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "path" "strings" "time" "github.com/minio/selfupdate" "golang.org/x/mod/semver" ) // version is exposed for tests/User-Agent. Caller-overridable. var version = "dev" // SetVersion lets the calling binary inject its build version into the // User-Agent header used by HTTP requests. Optional; defaults to "dev". func SetVersion(v string) { version = v } // Release is the minimal subset of a Forgejo release we care about. type Release struct { TagName string // e.g. "v0.1.2" Name string // human-readable name (often equals TagName) AssetURL string // browser_download_url of the windows-amd64 exe asset SumsURL string // browser_download_url of SHA256SUMS.txt Prerelease bool CreatedAt time.Time } // Source abstracts "where do releases come from" so callers can swap real // HTTP sources for fakes in tests. type Source interface { // Latest returns the latest non-draft release, or (nil, nil) if there // are no releases yet (e.g. Forgejo returned 404 from /releases/latest). // Errors should be returned wrapped with %w for any genuine failure. Latest(ctx context.Context) (*Release, error) } // CheckForUpdate compares the latest release's tag name against // currentVersion. Returns (release, true, nil) if the latest is strictly // newer; (nil, false, nil) if the current is up-to-date or newer; or an // error wrapping the source's error. // // If currentVersion fails semver.IsValid (e.g. it is "dev"), CheckForUpdate // treats every available release as an update so dev builds get notified. func CheckForUpdate(ctx context.Context, src Source, currentVersion string) (*Release, bool, error) { rel, err := src.Latest(ctx) if err != nil { return nil, false, fmt.Errorf("checking latest release: %w", err) } if rel == nil { return nil, false, nil } if !isNewer(rel.TagName, currentVersion) { return nil, false, nil } return rel, true, nil } // isNewer reports whether tag is strictly newer than current under semver // ordering. Both inputs are normalized to a leading "v". An invalid current // version is treated as "older than anything" so dev builds always update. func isNewer(tag, current string) bool { tagV := ensureV(tag) curV := ensureV(current) if !semver.IsValid(tagV) { // Remote tag is malformed; refuse to update. return false } if !semver.IsValid(curV) { // Current version is "dev" or otherwise invalid → always update. return true } return semver.Compare(tagV, curV) > 0 } func ensureV(s string) string { if strings.HasPrefix(s, "v") { return s } return "v" + s } // --------------------------------------------------------------------------- // SHA256SUMS parsing // --------------------------------------------------------------------------- // parseSHA256Sums returns the lowercase-hex SHA-256 digest associated with // targetName from a sha256sum-formatted SUMS file. The format is one line // per file: "<64-hex> " (text mode, two spaces) or "<64-hex> *" // (binary mode, space + asterisk). Both forms are accepted. // // Returns an error if the target is missing or the file is malformed. func parseSHA256Sums(data []byte, targetName string) (string, error) { scanner := bufio.NewScanner(bytes.NewReader(data)) // Allow long lines (default is 64 KiB which is plenty for sums files). scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) lineNum := 0 for scanner.Scan() { lineNum++ line := strings.TrimRight(scanner.Text(), "\r") line = strings.TrimSpace(line) if line == "" { continue } // Find the first space; everything before is hash, everything after // (with optional leading '*') is the filename. idx := strings.IndexByte(line, ' ') if idx < 0 { return "", fmt.Errorf("malformed SHA256SUMS line %d: missing separator", lineNum) } hashHex := line[:idx] rest := strings.TrimLeft(line[idx+1:], " ") rest = strings.TrimPrefix(rest, "*") if len(hashHex) != 64 { return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is %d chars, want 64", lineNum, len(hashHex)) } if _, err := hex.DecodeString(hashHex); err != nil { return "", fmt.Errorf("malformed SHA256SUMS line %d: hash is not valid hex: %w", lineNum, err) } if rest == targetName { return strings.ToLower(hashHex), nil } } if err := scanner.Err(); err != nil { return "", fmt.Errorf("scanning SHA256SUMS: %w", err) } return "", fmt.Errorf("no SHA256SUMS entry for %q", targetName) } // --------------------------------------------------------------------------- // Forgejo source // --------------------------------------------------------------------------- // ForgejoSource talks to a Forgejo instance's Releases API. type ForgejoSource struct { baseURL string // e.g. "https://git.okcu.io" owner string repo string assetPattern string // substring matched against asset name client *http.Client userAgent string } // NewForgejoSource constructs a Source that talks to host (the Forgejo // hostname, e.g. "git.okcu.io"). assetPattern is a substring matched against // release asset names (e.g. "windows-amd64.exe"). func NewForgejoSource(host, owner, repo, assetPattern string) Source { return &ForgejoSource{ baseURL: "https://" + strings.TrimRight(host, "/"), owner: owner, repo: repo, assetPattern: assetPattern, client: defaultHTTPClient(), userAgent: "drover-go/" + version, } } // defaultHTTPClient returns a context-aware http.Client with a 15s connect // timeout but no overall request deadline (asset downloads may be large). func defaultHTTPClient() *http.Client { dialer := &net.Dialer{ Timeout: 15 * time.Second, KeepAlive: 30 * time.Second, } tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialer.DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 10, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 15 * time.Second, ExpectContinueTimeout: 1 * time.Second, } return &http.Client{Transport: tr} } // forgejoRelease mirrors the JSON returned by /api/v1/repos/.../releases/latest. type forgejoRelease struct { TagName string `json:"tag_name"` Name string `json:"name"` Draft bool `json:"draft"` Prerelease bool `json:"prerelease"` CreatedAt time.Time `json:"created_at"` Assets []forgejoAsset `json:"assets"` } type forgejoAsset struct { Name string `json:"name"` Size int64 `json:"size"` DownloadURL string `json:"browser_download_url"` } // Latest fetches the latest release. Returns (nil, nil) on 404 (no releases). func (s *ForgejoSource) Latest(ctx context.Context) (*Release, error) { endpoint, err := url.JoinPath(s.baseURL, "api", "v1", "repos", s.owner, s.repo, "releases", "latest") if err != nil { return nil, fmt.Errorf("building releases URL: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("building request: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", s.userAgent) resp, err := s.client.Do(req) if err != nil { return nil, fmt.Errorf("GET %s: %w", endpoint, err) } defer resp.Body.Close() switch { case resp.StatusCode == http.StatusNotFound: // No public releases exist. return nil, nil case resp.StatusCode < 200 || resp.StatusCode >= 300: body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return nil, fmt.Errorf("GET %s: HTTP %d: %s", endpoint, resp.StatusCode, strings.TrimSpace(string(body))) } var raw forgejoRelease if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { return nil, fmt.Errorf("decoding release JSON: %w", err) } if raw.Draft { // Forgejo's /latest is supposed to skip drafts, but be defensive. return nil, nil } assetURL, assetName := pickAsset(raw.Assets, s.assetPattern) if assetURL == "" { return nil, fmt.Errorf("release %s has no asset matching %q", raw.TagName, s.assetPattern) } sumsURL := pickAssetByExactName(raw.Assets, "SHA256SUMS.txt") if sumsURL == "" { return nil, fmt.Errorf("release %s is missing SHA256SUMS.txt asset", raw.TagName) } _ = assetName // currently unused outside the source; reserved for future logging return &Release{ TagName: raw.TagName, Name: raw.Name, AssetURL: assetURL, SumsURL: sumsURL, Prerelease: raw.Prerelease, CreatedAt: raw.CreatedAt, }, nil } func pickAsset(assets []forgejoAsset, pattern string) (string, string) { for _, a := range assets { if strings.Contains(a.Name, pattern) { return a.DownloadURL, a.Name } } return "", "" } func pickAssetByExactName(assets []forgejoAsset, name string) string { for _, a := range assets { if a.Name == name { return a.DownloadURL } } return "" } // --------------------------------------------------------------------------- // ApplyUpdate // --------------------------------------------------------------------------- // ApplyUpdate downloads the release's asset + SHA256SUMS.txt, verifies the // asset against its sha256 entry, and atomically replaces the current // executable using github.com/minio/selfupdate. // // progress, if non-nil, is called periodically with (downloaded, total) // while the asset is being downloaded. total is set from the Content-Length // header and may be 0 if the server omits it. func ApplyUpdate(ctx context.Context, rel *Release, progress func(downloaded, total int64)) error { if rel == nil { return errors.New("apply update: release is nil") } if rel.AssetURL == "" || rel.SumsURL == "" { return errors.New("apply update: release is missing AssetURL or SumsURL") } assetName := assetNameFromURL(rel.AssetURL) if assetName == "" { return fmt.Errorf("apply update: cannot derive asset filename from URL %q", rel.AssetURL) } client := defaultHTTPClient() binary, err := downloadAndVerify(ctx, rel, assetName, client, "drover-go/"+version, progress) if err != nil { return fmt.Errorf("apply update: %w", err) } // selfupdate.Apply takes an io.Reader and atomically swaps the running // binary. We pass a bytes.Reader since we already have the verified // bytes in memory. if err := selfupdate.Apply(bytes.NewReader(binary), selfupdate.Options{}); err != nil { // On failure, selfupdate may leave a .old file; surface that fact // to the caller via the wrapped error. if rerr := selfupdate.RollbackError(err); rerr != nil { return fmt.Errorf("apply update: replacing executable failed and rollback failed: %w (rollback: %v)", err, rerr) } return fmt.Errorf("apply update: replacing executable: %w", err) } return nil } // assetNameFromURL returns the last path segment of u (used to look the file // up in SHA256SUMS.txt). Returns "" if u does not parse. func assetNameFromURL(u string) string { parsed, err := url.Parse(u) if err != nil { return "" } return path.Base(parsed.Path) } // downloadAndVerify is the I/O-heavy core of ApplyUpdate, factored out so // it's testable without invoking selfupdate.Apply (which would replace the // running test binary). Returns the verified asset bytes. func downloadAndVerify( ctx context.Context, rel *Release, assetName string, client *http.Client, userAgent string, progress func(downloaded, total int64), ) ([]byte, error) { // 1. Fetch SHA256SUMS.txt first — fail fast if it's missing/wrong before // we spend time on a possibly-large binary. sumsBody, err := httpGetAll(ctx, client, rel.SumsURL, userAgent) if err != nil { return nil, fmt.Errorf("downloading SHA256SUMS.txt: %w", err) } wantHashHex, err := parseSHA256Sums(sumsBody, assetName) if err != nil { return nil, fmt.Errorf("reading SHA256SUMS.txt: %w", err) } wantHash, err := hex.DecodeString(wantHashHex) if err != nil { return nil, fmt.Errorf("decoding expected sha256 hex: %w", err) } // 2. Fetch the asset, hashing as we go. binBody, err := httpGetWithProgress(ctx, client, rel.AssetURL, userAgent, progress) if err != nil { return nil, fmt.Errorf("downloading asset %s: %w", assetName, err) } // 3. Verify. gotHash := sha256.Sum256(binBody) if !bytes.Equal(gotHash[:], wantHash) { return nil, fmt.Errorf("sha256 checksum mismatch for %s: got %x, want %s", assetName, gotHash[:], wantHashHex) } return binBody, nil } func httpGetAll(ctx context.Context, client *http.Client, rawurl, userAgent string) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil) if err != nil { return nil, fmt.Errorf("building request: %w", err) } req.Header.Set("User-Agent", userAgent) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body))) } // Cap at 1 MiB — SHA256SUMS.txt is tiny; this is for the sums path only. // Asset downloads use httpGetWithProgress. return io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) } func httpGetWithProgress( ctx context.Context, client *http.Client, rawurl, userAgent string, progress func(downloaded, total int64), ) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawurl, nil) if err != nil { return nil, fmt.Errorf("building request: %w", err) } req.Header.Set("User-Agent", userAgent) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return nil, fmt.Errorf("GET %s: HTTP %d: %s", rawurl, resp.StatusCode, strings.TrimSpace(string(body))) } total := resp.ContentLength // -1 if unknown var totalReport int64 if total > 0 { totalReport = total } // Cap a reasonable absolute ceiling so a malicious server can't OOM us. // 256 MiB is far above any realistic drover-go binary. const maxBytes = 256 * 1024 * 1024 body := resp.Body if total > 0 && total > maxBytes { return nil, fmt.Errorf("asset Content-Length %d exceeds %d byte cap", total, maxBytes) } pr := &progressReader{ r: io.LimitReader(body, maxBytes), total: totalReport, callback: progress, } buf, err := io.ReadAll(pr) if err != nil { return nil, fmt.Errorf("reading body: %w", err) } // Final progress tick if total was unknown (so callers get a "done" signal). if progress != nil && totalReport == 0 { progress(int64(len(buf)), int64(len(buf))) } return buf, nil } // progressReader wraps an io.Reader and emits progress updates after each // successful Read. It is allocation-light and never blocks; callbacks are // invoked synchronously from Read so they should be cheap. type progressReader struct { r io.Reader total int64 read int64 callback func(downloaded, total int64) } func (p *progressReader) Read(b []byte) (int, error) { n, err := p.r.Read(b) if n > 0 { p.read += int64(n) if p.callback != nil { p.callback(p.read, p.total) } } return n, err }