package update import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "sort" "strings" "time" ) const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}" const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024 type Options struct { Client *http.Client CurrentVersion string ExecutablePath string LatestReleaseURL string Stdout io.Writer BinaryName string AssetNameTemplate string ReleaseSource ReleaseSource GOOS string GOARCH string MaxDownloadBytes int64 ValidateDownloaded ValidateDownloadedFunc ReplaceExecutable ReplaceExecutableFunc } type ReplaceExecutableFunc func(downloadPath, targetPath string) error type ValidateDownloadedFunc func(context.Context, ValidationInput) error type ValidationInput struct { DownloadPath string TargetPath string AssetName string ReleaseTag string ReleaseURL string Source ReleaseSource } type ReleaseSource struct { Name string Driver string Repository string BaseURL string LatestReleaseURL string AssetNameTemplate string ChecksumAssetName string ChecksumRequired bool Token string TokenHeader string TokenPrefix string TokenEnvNames []string } type Auth struct { Header string Token string } type Release struct { TagName string `json:"tag_name"` Assets struct { Links []ReleaseLink `json:"links"` } `json:"assets"` } type ReleaseLink struct { Name string `json:"name"` URL string `json:"url"` } type releasePayload struct { TagName string `json:"tag_name"` Assets json.RawMessage `json:"assets"` } type releaseAssetsPayload struct { Links []releaseLinkPayload `json:"links"` } type releaseLinkPayload struct { Name string `json:"name"` URL string `json:"url"` BrowserDownloadURL string `json:"browser_download_url"` DirectAssetURL string `json:"direct_asset_url"` } func (r *Release) UnmarshalJSON(data []byte) error { var payload releasePayload if err := json.Unmarshal(data, &payload); err != nil { return err } r.TagName = strings.TrimSpace(payload.TagName) r.Assets.Links = parseReleaseLinks(payload.Assets) return nil } func Run(ctx context.Context, opts Options) error { if opts.Stdout == nil { opts.Stdout = io.Discard } if opts.Client == nil { opts.Client = &http.Client{Timeout: 60 * time.Second} } if strings.TrimSpace(opts.CurrentVersion) == "" { opts.CurrentVersion = "dev" } if strings.TrimSpace(opts.GOOS) == "" { opts.GOOS = runtime.GOOS } if strings.TrimSpace(opts.GOARCH) == "" { opts.GOARCH = runtime.GOARCH } if opts.MaxDownloadBytes <= 0 { opts.MaxDownloadBytes = defaultMaxDownloadBytes } source := normalizeSource(opts.ReleaseSource) auth := ResolveAuth(source.Token, source) releaseURL, err := ResolveLatestReleaseURL(opts.LatestReleaseURL, source) if err != nil { return err } targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) if err != nil { return err } assetTemplate := strings.TrimSpace(opts.AssetNameTemplate) if assetTemplate == "" { assetTemplate = source.AssetNameTemplate } assetName, err := AssetNameWithTemplate(opts.BinaryName, opts.GOOS, opts.GOARCH, assetTemplate) if err != nil { return err } release, err := FetchLatestRelease(ctx, opts.Client, releaseURL, auth, source) if err != nil { return err } if isCurrentRelease(opts.CurrentVersion, release.TagName) { fmt.Fprintf(opts.Stdout, "Already up to date (%s)\n", release.TagName) return nil } assetURL, err := release.AssetURL(assetName, releaseURL) if err != nil { return err } downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source, opts.MaxDownloadBytes) if err != nil { return err } defer os.Remove(downloadPath) if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil { return err } if opts.ValidateDownloaded != nil { if err := opts.ValidateDownloaded(ctx, ValidationInput{ DownloadPath: downloadPath, TargetPath: targetPath, AssetName: assetName, ReleaseTag: release.TagName, ReleaseURL: releaseURL, Source: source, }); err != nil { return fmt.Errorf("validate downloaded artifact: %w", err) } } replaceExecutable := opts.ReplaceExecutable if replaceExecutable == nil { replaceExecutable = ReplaceExecutable } if err := replaceExecutable(downloadPath, targetPath); err != nil { return err } fmt.Fprintf(opts.Stdout, "Updated %s to %s\n", targetPath, release.TagName) return nil } func ResolveLatestReleaseURL(explicit string, source ReleaseSource) (string, error) { if releaseURL := strings.TrimSpace(explicit); releaseURL != "" { return releaseURL, nil } source = normalizeSource(source) if source.LatestReleaseURL != "" { return source.LatestReleaseURL, nil } if source.Driver == "" { return "", errors.New("latest release URL must not be empty (set latest_release_url or configure driver+repository)") } if source.Repository == "" { return "", fmt.Errorf("release source %q requires repository when driver is set", source.Driver) } switch source.Driver { case "gitea": if source.BaseURL == "" { return "", errors.New("release source gitea requires base_url") } return fmt.Sprintf("%s/api/v1/repos/%s/releases/latest", source.BaseURL, source.Repository), nil case "gitlab": projectPath := url.PathEscape(source.Repository) return fmt.Sprintf("%s/api/v4/projects/%s/releases/permalink/latest", source.BaseURL, projectPath), nil case "github": return fmt.Sprintf("%s/repos/%s/releases/latest", source.BaseURL, source.Repository), nil default: return "", fmt.Errorf("unsupported release driver %q (expected gitea, gitlab or github)", source.Driver) } } func ResolveAuth(explicitToken string, source ReleaseSource) Auth { source = normalizeSource(source) if token := strings.TrimSpace(explicitToken); token != "" { return Auth{ Header: source.TokenHeader, Token: withTokenPrefix(token, source.TokenPrefix), } } for _, envName := range source.TokenEnvNames { if token := strings.TrimSpace(os.Getenv(envName)); token != "" { return Auth{ Header: source.TokenHeader, Token: withTokenPrefix(token, source.TokenPrefix), } } } return Auth{} } func ResolveUpdateTarget(explicitPath string) (string, error) { targetPath := strings.TrimSpace(explicitPath) if targetPath == "" { var err error targetPath, err = os.Executable() if err != nil { return "", fmt.Errorf("resolve executable path: %w", err) } } resolvedPath, err := filepath.EvalSymlinks(targetPath) if err != nil { return "", fmt.Errorf("resolve executable symlink %q: %w", targetPath, err) } return resolvedPath, nil } func AssetName(binaryName, goos, goarch string) (string, error) { return AssetNameWithTemplate(binaryName, goos, goarch, defaultAssetNameTemplate) } func AssetNameWithTemplate(binaryName, goos, goarch, template string) (string, error) { name := strings.TrimSpace(binaryName) if name == "" { return "", errors.New("binary name must not be empty") } osName := strings.ToLower(strings.TrimSpace(goos)) archName := strings.ToLower(strings.TrimSpace(goarch)) if osName == "" || archName == "" { return "", errors.New("goos and goarch must not be empty") } assetTemplate := strings.TrimSpace(template) if assetTemplate == "" { assetTemplate = defaultAssetNameTemplate } ext := "" if osName == "windows" { ext = ".exe" } replaced := strings.NewReplacer( "{binary}", name, "{os}", osName, "{arch}", archName, "{ext}", ext, ).Replace(assetTemplate) replaced = strings.TrimSpace(replaced) if replaced == "" { return "", errors.New("asset name template resolved to an empty value") } if strings.ContainsRune(replaced, '/') || strings.ContainsRune(replaced, '\\') { return "", fmt.Errorf("asset name %q must not contain path separators", replaced) } return replaced, nil } func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source ReleaseSource) (Release, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil) if err != nil { return Release{}, fmt.Errorf("build latest release request: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) } return Release{}, fmt.Errorf( "fetch latest release metadata: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } var release Release if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { return Release{}, fmt.Errorf("decode latest release metadata: %w", err) } if strings.TrimSpace(release.TagName) == "" { return Release{}, errors.New("latest release metadata is missing tag_name") } return release, nil } func (r Release) AssetURL(assetName, releaseURL string) (string, error) { for _, link := range r.Assets.Links { if link.Name == assetName { if strings.TrimSpace(link.URL) == "" { return "", fmt.Errorf("release asset %q has no URL", assetName) } parsed, err := url.Parse(link.URL) if err != nil { return "", fmt.Errorf("parse release asset URL %q: %w", link.URL, err) } if parsed.IsAbs() { return parsed.String(), nil } baseURL, err := url.Parse(releaseURL) if err != nil { return "", fmt.Errorf("parse latest release URL %q: %w", releaseURL, err) } return baseURL.ResolveReference(parsed).String(), nil } } availableAssets := make([]string, 0, len(r.Assets.Links)) for _, link := range r.Assets.Links { if name := strings.TrimSpace(link.Name); name != "" { availableAssets = append(availableAssets, name) } } sort.Strings(availableAssets) if len(availableAssets) == 0 { return "", fmt.Errorf("latest release does not contain asset %q", assetName) } preview := availableAssets if len(preview) > 8 { preview = preview[:8] } return "", fmt.Errorf( "latest release does not contain asset %q (available: %s)", assetName, strings.Join(preview, ", "), ) } func DownloadReleaseAsset( ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource, maxDownloadBytes int64, ) (string, error) { if maxDownloadBytes <= 0 { maxDownloadBytes = defaultMaxDownloadBytes } req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) if err != nil { return "", fmt.Errorf("build artifact download request: %w", err) } req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("download release artifact: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { return "", fmt.Errorf("download release artifact: %w", err) } return "", fmt.Errorf( "download release artifact: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } if resp.ContentLength > 0 && resp.ContentLength > maxDownloadBytes { return "", fmt.Errorf( "download release artifact: content length %d exceeds limit %d bytes", resp.ContentLength, maxDownloadBytes, ) } existingInfo, err := os.Stat(targetPath) if err != nil { return "", fmt.Errorf("stat executable %q: %w", targetPath, err) } tempFile, err := os.CreateTemp(filepath.Dir(targetPath), filepath.Base(targetPath)+".download-*") if err != nil { return "", fmt.Errorf("create temporary file: %w", err) } tempPath := tempFile.Name() cleanup := func(copyErr error) (string, error) { tempFile.Close() os.Remove(tempPath) return "", copyErr } limited := &io.LimitedReader{R: resp.Body, N: maxDownloadBytes + 1} written, err := io.Copy(tempFile, limited) if err != nil { return cleanup(fmt.Errorf("write downloaded artifact: %w", err)) } if written > maxDownloadBytes { return cleanup(fmt.Errorf("write downloaded artifact: size exceeds limit %d bytes", maxDownloadBytes)) } if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil { return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err)) } if err := tempFile.Close(); err != nil { os.Remove(tempPath) return "", fmt.Errorf("close downloaded artifact: %w", err) } return tempPath, nil } func VerifyReleaseAssetChecksum( ctx context.Context, client *http.Client, release Release, releaseURL string, assetName string, artifactPath string, auth Auth, source ReleaseSource, ) error { source = normalizeSource(source) checksumAssetName := resolveChecksumAssetName(assetName, source.ChecksumAssetName) checksumURL, err := release.AssetURL(checksumAssetName, releaseURL) if err != nil { if source.ChecksumRequired { return fmt.Errorf("checksum verification: %w", err) } return nil } checksumBody, err := downloadAssetBytes(ctx, client, checksumURL, auth, source) if err != nil { return fmt.Errorf("checksum verification: %w", err) } expected, err := parseChecksum(string(checksumBody), assetName) if err != nil { return fmt.Errorf("checksum verification: %w", err) } actual, err := fileSHA256(artifactPath) if err != nil { return fmt.Errorf("checksum verification: %w", err) } if !strings.EqualFold(expected, actual) { return fmt.Errorf( "checksum mismatch for asset %q: expected %s, got %s", assetName, expected, actual, ) } return nil } func ReplaceExecutable(downloadPath, targetPath string) error { if runtime.GOOS == "windows" { return errors.New("self-update is not supported on windows without a custom ReplaceExecutable hook") } if err := os.Rename(downloadPath, targetPath); err != nil { return fmt.Errorf("replace executable %q: %w", targetPath, err) } return nil } func normalizeSource(source ReleaseSource) ReleaseSource { source.Name = strings.TrimSpace(source.Name) source.Driver = strings.ToLower(strings.TrimSpace(source.Driver)) source.Repository = strings.Trim(strings.TrimSpace(source.Repository), "/") source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL) source.AssetNameTemplate = strings.TrimSpace(source.AssetNameTemplate) source.ChecksumAssetName = strings.TrimSpace(source.ChecksumAssetName) source.Token = strings.TrimSpace(source.Token) source.TokenHeader = strings.TrimSpace(source.TokenHeader) source.TokenPrefix = strings.TrimSpace(source.TokenPrefix) envNames := source.TokenEnvNames[:0] for _, envName := range source.TokenEnvNames { if trimmed := strings.TrimSpace(envName); trimmed != "" { envNames = append(envNames, trimmed) } } source.TokenEnvNames = envNames switch source.Driver { case "gitea": if source.Name == "" { source.Name = "Gitea releases" } if source.TokenHeader == "" { source.TokenHeader = "Authorization" } if source.TokenPrefix == "" { source.TokenPrefix = "token " } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITEA_TOKEN"} } case "gitlab": if source.Name == "" { source.Name = "GitLab releases" } if source.BaseURL == "" { source.BaseURL = "https://gitlab.com" } if source.TokenHeader == "" { source.TokenHeader = "PRIVATE-TOKEN" } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"} } case "github": if source.Name == "" { source.Name = "GitHub releases" } if source.BaseURL == "" { source.BaseURL = "https://api.github.com" } if source.TokenHeader == "" { source.TokenHeader = "Authorization" } if source.TokenPrefix == "" { source.TokenPrefix = "Bearer " } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITHUB_TOKEN"} } } return source } func isCurrentRelease(currentVersion, latestTag string) bool { current := strings.TrimSpace(currentVersion) latest := strings.TrimSpace(latestTag) if latest == "" { return false } if current == "" || current == "dev" { return false } return current == latest } func (a Auth) apply(req *http.Request) { if strings.TrimSpace(a.Header) == "" || strings.TrimSpace(a.Token) == "" { return } req.Header.Set(a.Header, a.Token) } func (a Auth) maybeHint(statusCode int, body []byte, source ReleaseSource) error { source = normalizeSource(source) if strings.TrimSpace(a.Token) != "" || len(source.TokenEnvNames) == 0 { return nil } switch statusCode { case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: default: return nil } message := strings.ToLower(strings.TrimSpace(string(body))) if !strings.Contains(message, "project not found") && !strings.Contains(message, "not found") && !strings.Contains(message, "unauthorized") && !strings.Contains(message, "forbidden") { return nil } target := source.BaseURL if target == "" { target = "release endpoint" } name := source.Name if name == "" { name = "release" } if len(source.TokenEnvNames) == 1 { return fmt.Errorf( "%s access requires authentication on %s; set %s and retry", name, target, source.TokenEnvNames[0], ) } return fmt.Errorf( "%s access requires authentication on %s; set %s (or %s) and retry", name, target, source.TokenEnvNames[0], source.TokenEnvNames[1], ) } func parseReleaseLinks(raw json.RawMessage) []ReleaseLink { if len(raw) == 0 || strings.TrimSpace(string(raw)) == "null" { return nil } parseLinks := func(payload []releaseLinkPayload) []ReleaseLink { links := make([]ReleaseLink, 0, len(payload)) for _, item := range payload { name := strings.TrimSpace(item.Name) assetURL := firstNonEmpty(item.DirectAssetURL, item.BrowserDownloadURL, item.URL) if name == "" || strings.TrimSpace(assetURL) == "" { continue } links = append(links, ReleaseLink{ Name: name, URL: strings.TrimSpace(assetURL), }) } return links } var asObject releaseAssetsPayload if err := json.Unmarshal(raw, &asObject); err == nil && len(asObject.Links) > 0 { return parseLinks(asObject.Links) } var asArray []releaseLinkPayload if err := json.Unmarshal(raw, &asArray); err == nil && len(asArray) > 0 { return parseLinks(asArray) } return nil } func firstNonEmpty(values ...string) string { for _, value := range values { if trimmed := strings.TrimSpace(value); trimmed != "" { return trimmed } } return "" } func withTokenPrefix(token, prefix string) string { trimmedToken := strings.TrimSpace(token) if trimmedToken == "" { return "" } trimmedPrefix := strings.TrimSpace(prefix) if trimmedPrefix == "" { return trimmedToken } lowerToken := strings.ToLower(trimmedToken) lowerPrefix := strings.ToLower(trimmedPrefix) if strings.HasPrefix(lowerToken, lowerPrefix) { return trimmedToken } return trimmedPrefix + " " + trimmedToken } func resolveChecksumAssetName(assetName, configured string) string { value := strings.TrimSpace(configured) if value == "" { return assetName + ".sha256" } return strings.ReplaceAll(value, "{asset}", assetName) } func downloadAssetBytes(ctx context.Context, client *http.Client, assetURL string, auth Auth, source ReleaseSource) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) if err != nil { return nil, fmt.Errorf("build checksum download request: %w", err) } req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("download checksum asset: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if hint := auth.maybeHint(resp.StatusCode, body, source); hint != nil { return nil, fmt.Errorf("download checksum asset: %w", hint) } return nil, fmt.Errorf( "download checksum asset: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } content, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024)) if err != nil { return nil, fmt.Errorf("read checksum asset: %w", err) } return content, nil } func parseChecksum(content, assetName string) (string, error) { lines := strings.Split(content, "\n") fallbackSingle := "" for _, raw := range lines { line := strings.TrimSpace(raw) if line == "" || strings.HasPrefix(line, "#") { continue } if strings.HasPrefix(strings.ToUpper(line), "SHA256 (") { openIndex := strings.Index(line, "(") closeIndex := strings.LastIndex(line, ")") equalIndex := strings.LastIndex(line, "=") if openIndex >= 0 && closeIndex > openIndex && equalIndex > closeIndex { name := strings.TrimSpace(line[openIndex+1 : closeIndex]) hash := strings.TrimSpace(line[equalIndex+1:]) if isSHA256Hex(hash) && matchesAssetName(name, assetName) { return strings.ToLower(hash), nil } } continue } fields := strings.Fields(line) if len(fields) > 0 && isSHA256Hex(fields[0]) { if len(fields) == 1 { if fallbackSingle == "" { fallbackSingle = strings.ToLower(fields[0]) } continue } name := strings.TrimSpace(strings.TrimPrefix(fields[1], "*")) if matchesAssetName(name, assetName) { return strings.ToLower(fields[0]), nil } continue } colonIndex := strings.Index(line, ":") if colonIndex > 0 && colonIndex < len(line)-1 { left := strings.TrimSpace(line[:colonIndex]) right := strings.TrimSpace(line[colonIndex+1:]) switch { case isSHA256Hex(left) && matchesAssetName(right, assetName): return strings.ToLower(left), nil case isSHA256Hex(right) && matchesAssetName(left, assetName): return strings.ToLower(right), nil } } } if fallbackSingle != "" { return fallbackSingle, nil } return "", fmt.Errorf("checksum file does not contain a sha256 for asset %q", assetName) } func matchesAssetName(candidate, assetName string) bool { name := strings.TrimSpace(strings.TrimPrefix(candidate, "*")) name = strings.TrimPrefix(name, "./") if name == assetName { return true } name = strings.ReplaceAll(name, "\\", "/") if path.Base(name) == assetName { return true } return false } func isSHA256Hex(value string) bool { if len(value) != 64 { return false } for _, r := range value { switch { case r >= '0' && r <= '9': case r >= 'a' && r <= 'f': case r >= 'A' && r <= 'F': default: return false } } return true } func fileSHA256(path string) (string, error) { file, err := os.Open(path) if err != nil { return "", fmt.Errorf("open downloaded artifact: %w", err) } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return "", fmt.Errorf("hash downloaded artifact: %w", err) } return hex.EncodeToString(hash.Sum(nil)), nil }