package update import ( "bytes" "context" "crypto/ed25519" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "sort" "strings" "time" ) const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}" const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024 const downloadedArtifactSniffBytes = 4096 type Options struct { Client *http.Client CurrentVersion string ExecutablePath string LatestReleaseURL string Stdout io.Writer BinaryName string AssetNameTemplate string ReleaseSource ReleaseSource GOOS string GOARCH string MaxDownloadBytes int64 ValidateDownloaded ValidateDownloadedFunc ReplaceExecutable ReplaceExecutableFunc } type ReplaceExecutableFunc func(downloadPath, targetPath string) error type ValidateDownloadedFunc func(context.Context, ValidationInput) error type ValidationInput struct { DownloadPath string TargetPath string AssetName string ReleaseTag string ReleaseURL string Source ReleaseSource } type ReleaseSource struct { Name string Driver string Repository string BaseURL string LatestReleaseURL string AssetNameTemplate string ChecksumAssetName string ChecksumRequired bool SignatureAssetName string SignatureRequired bool SignaturePublicKey string SignaturePublicKeyEnvNames []string Token string TokenHeader string TokenPrefix string TokenEnvNames []string } type Auth struct { Header string Token string } type Release struct { TagName string `json:"tag_name"` Assets struct { Links []ReleaseLink `json:"links"` } `json:"assets"` } type ReleaseLink struct { Name string `json:"name"` URL string `json:"url"` } type releasePayload struct { TagName string `json:"tag_name"` Assets json.RawMessage `json:"assets"` } type releaseAssetsPayload struct { Links []releaseLinkPayload `json:"links"` } type releaseLinkPayload struct { Name string `json:"name"` URL string `json:"url"` BrowserDownloadURL string `json:"browser_download_url"` DirectAssetURL string `json:"direct_asset_url"` } func (r *Release) UnmarshalJSON(data []byte) error { var payload releasePayload if err := json.Unmarshal(data, &payload); err != nil { return err } r.TagName = strings.TrimSpace(payload.TagName) r.Assets.Links = parseReleaseLinks(payload.Assets) return nil } func Run(ctx context.Context, opts Options) error { if opts.Stdout == nil { opts.Stdout = io.Discard } if opts.Client == nil { opts.Client = &http.Client{Timeout: 60 * time.Second} } if strings.TrimSpace(opts.CurrentVersion) == "" { opts.CurrentVersion = "dev" } if strings.TrimSpace(opts.GOOS) == "" { opts.GOOS = runtime.GOOS } if strings.TrimSpace(opts.GOARCH) == "" { opts.GOARCH = runtime.GOARCH } if opts.MaxDownloadBytes <= 0 { opts.MaxDownloadBytes = defaultMaxDownloadBytes } source := normalizeSource(opts.ReleaseSource) auth := ResolveAuth(source.Token, source) releaseURL, err := ResolveLatestReleaseURL(opts.LatestReleaseURL, source) if err != nil { return err } targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) if err != nil { return err } assetTemplate := strings.TrimSpace(opts.AssetNameTemplate) if assetTemplate == "" { assetTemplate = source.AssetNameTemplate } assetName, err := AssetNameWithTemplate(opts.BinaryName, opts.GOOS, opts.GOARCH, assetTemplate) if err != nil { return err } release, err := FetchLatestRelease(ctx, opts.Client, releaseURL, auth, source) if err != nil { return err } if isCurrentRelease(opts.CurrentVersion, release.TagName) { fmt.Fprintf(opts.Stdout, "Already up to date (%s)\n", release.TagName) return nil } assetURL, err := release.AssetURL(assetName, releaseURL) if err != nil { return err } downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source, opts.MaxDownloadBytes) if err != nil { return err } defer os.Remove(downloadPath) if err := validateDownloadedArtifact(downloadPath, assetName); err != nil { return err } if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil { return err } if err := VerifyReleaseAssetSignature(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil { return err } if opts.ValidateDownloaded != nil { if err := opts.ValidateDownloaded(ctx, ValidationInput{ DownloadPath: downloadPath, TargetPath: targetPath, AssetName: assetName, ReleaseTag: release.TagName, ReleaseURL: releaseURL, Source: source, }); err != nil { return fmt.Errorf("validate downloaded artifact: %w", err) } } replaceExecutable := opts.ReplaceExecutable if replaceExecutable == nil { replaceExecutable = ReplaceExecutable } if err := replaceExecutable(downloadPath, targetPath); err != nil { return err } fmt.Fprintf(opts.Stdout, "Updated %s to %s\n", targetPath, release.TagName) return nil } func validateDownloadedArtifact(path, assetName string) error { file, err := os.Open(path) if err != nil { return fmt.Errorf("validate downloaded artifact %q: %w", path, err) } defer file.Close() head := make([]byte, downloadedArtifactSniffBytes) n, readErr := file.Read(head) if readErr != nil && !errors.Is(readErr, io.EOF) { return fmt.Errorf("validate downloaded artifact %q: %w", path, readErr) } if n == 0 { return fmt.Errorf("downloaded artifact %q is empty", assetName) } if looksLikeHTMLDocument(head[:n]) { return fmt.Errorf( "downloaded artifact %q looks like an HTML page (possible auth/forbidden response)", assetName, ) } return nil } func looksLikeHTMLDocument(content []byte) bool { if len(content) == 0 { return false } detectedContentType := strings.ToLower(http.DetectContentType(content)) if strings.HasPrefix(detectedContentType, "text/html") { return true } trimmed := bytes.TrimSpace(content) trimmed = bytes.TrimPrefix(trimmed, []byte{0xEF, 0xBB, 0xBF}) // UTF-8 BOM if len(trimmed) == 0 { return false } lower := strings.ToLower(string(trimmed)) return strings.HasPrefix(lower, " 8 { preview = preview[:8] } return "", fmt.Errorf( "latest release does not contain asset %q (available: %s)", assetName, strings.Join(preview, ", "), ) } func DownloadReleaseAsset( ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource, maxDownloadBytes int64, ) (string, error) { if maxDownloadBytes <= 0 { maxDownloadBytes = defaultMaxDownloadBytes } req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) if err != nil { return "", fmt.Errorf("build artifact download request: %w", err) } req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("download release artifact: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { return "", fmt.Errorf("download release artifact: %w", err) } return "", fmt.Errorf( "download release artifact: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } if resp.ContentLength > 0 && resp.ContentLength > maxDownloadBytes { return "", fmt.Errorf( "download release artifact: content length %d exceeds limit %d bytes", resp.ContentLength, maxDownloadBytes, ) } existingInfo, err := os.Stat(targetPath) if err != nil { return "", fmt.Errorf("stat executable %q: %w", targetPath, err) } tempFile, err := os.CreateTemp(filepath.Dir(targetPath), filepath.Base(targetPath)+".download-*") if err != nil { return "", fmt.Errorf("create temporary file: %w", err) } tempPath := tempFile.Name() cleanup := func(copyErr error) (string, error) { tempFile.Close() os.Remove(tempPath) return "", copyErr } limited := &io.LimitedReader{R: resp.Body, N: maxDownloadBytes + 1} written, err := io.Copy(tempFile, limited) if err != nil { return cleanup(fmt.Errorf("write downloaded artifact: %w", err)) } if written > maxDownloadBytes { return cleanup(fmt.Errorf("write downloaded artifact: size exceeds limit %d bytes", maxDownloadBytes)) } if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil { return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err)) } if err := tempFile.Close(); err != nil { os.Remove(tempPath) return "", fmt.Errorf("close downloaded artifact: %w", err) } return tempPath, nil } func VerifyReleaseAssetChecksum( ctx context.Context, client *http.Client, release Release, releaseURL string, assetName string, artifactPath string, auth Auth, source ReleaseSource, ) error { source = normalizeSource(source) checksumAssetName := resolveChecksumAssetName(assetName, source.ChecksumAssetName) checksumURL, err := release.AssetURL(checksumAssetName, releaseURL) if err != nil { if source.ChecksumRequired { return fmt.Errorf("checksum verification: %w", err) } return nil } checksumBody, err := downloadAssetBytes(ctx, client, checksumURL, auth, source) if err != nil { return fmt.Errorf("checksum verification: %w", err) } expected, err := parseChecksum(string(checksumBody), assetName) if err != nil { return fmt.Errorf("checksum verification: %w", err) } actual, err := fileSHA256(artifactPath) if err != nil { return fmt.Errorf("checksum verification: %w", err) } if !strings.EqualFold(expected, actual) { return fmt.Errorf( "checksum mismatch for asset %q: expected %s, got %s", assetName, expected, actual, ) } return nil } func VerifyReleaseAssetSignature( ctx context.Context, client *http.Client, release Release, releaseURL string, assetName string, artifactPath string, auth Auth, source ReleaseSource, ) error { source = normalizeSource(source) publicKey, hasPublicKey, err := resolveEd25519PublicKey(source.SignaturePublicKey, source.SignaturePublicKeyEnvNames) if err != nil { return fmt.Errorf("signature verification: %w", err) } if !hasPublicKey { if source.SignatureRequired { if len(source.SignaturePublicKeyEnvNames) > 0 { return fmt.Errorf( "signature verification: no Ed25519 public key configured (set %s)", strings.Join(source.SignaturePublicKeyEnvNames, " or "), ) } return errors.New("signature verification: no Ed25519 public key configured") } return nil } signatureAssetName := resolveSignatureAssetName(assetName, source.SignatureAssetName) signatureURL, err := release.AssetURL(signatureAssetName, releaseURL) if err != nil { if source.SignatureRequired { return fmt.Errorf("signature verification: %w", err) } return nil } signatureBody, err := downloadAssetBytes(ctx, client, signatureURL, auth, source) if err != nil { return fmt.Errorf("signature verification: %w", err) } signature, err := parseEd25519Signature(string(signatureBody), assetName) if err != nil { return fmt.Errorf("signature verification: %w", err) } digestHex, err := fileSHA256(artifactPath) if err != nil { return fmt.Errorf("signature verification: %w", err) } digest, err := hex.DecodeString(digestHex) if err != nil { return fmt.Errorf("signature verification: decode local artifact digest: %w", err) } if !ed25519.Verify(publicKey, digest, signature) { return fmt.Errorf("signature mismatch for asset %q", assetName) } return nil } func ReplaceExecutable(downloadPath, targetPath string) error { if runtime.GOOS == "windows" { return errors.New("self-update is not supported on windows without a custom ReplaceExecutable hook") } if err := os.Rename(downloadPath, targetPath); err != nil { return fmt.Errorf("replace executable %q: %w", targetPath, err) } return nil } func normalizeSource(source ReleaseSource) ReleaseSource { source.Name = strings.TrimSpace(source.Name) source.Driver = strings.ToLower(strings.TrimSpace(source.Driver)) source.Repository = strings.Trim(strings.TrimSpace(source.Repository), "/") source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL) source.AssetNameTemplate = strings.TrimSpace(source.AssetNameTemplate) source.ChecksumAssetName = strings.TrimSpace(source.ChecksumAssetName) source.SignatureAssetName = strings.TrimSpace(source.SignatureAssetName) source.SignaturePublicKey = strings.TrimSpace(source.SignaturePublicKey) source.Token = strings.TrimSpace(source.Token) source.TokenHeader = strings.TrimSpace(source.TokenHeader) source.TokenPrefix = strings.TrimSpace(source.TokenPrefix) envNames := source.TokenEnvNames[:0] for _, envName := range source.TokenEnvNames { if trimmed := strings.TrimSpace(envName); trimmed != "" { envNames = append(envNames, trimmed) } } source.TokenEnvNames = envNames publicKeyEnvNames := source.SignaturePublicKeyEnvNames[:0] for _, envName := range source.SignaturePublicKeyEnvNames { if trimmed := strings.TrimSpace(envName); trimmed != "" { publicKeyEnvNames = append(publicKeyEnvNames, trimmed) } } source.SignaturePublicKeyEnvNames = publicKeyEnvNames switch source.Driver { case "gitea": if source.Name == "" { source.Name = "Gitea releases" } if source.TokenHeader == "" { source.TokenHeader = "Authorization" } if source.TokenPrefix == "" { source.TokenPrefix = "token " } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITEA_TOKEN"} } case "gitlab": if source.Name == "" { source.Name = "GitLab releases" } if source.BaseURL == "" { source.BaseURL = "https://gitlab.com" } if source.TokenHeader == "" { source.TokenHeader = "PRIVATE-TOKEN" } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"} } case "github": if source.Name == "" { source.Name = "GitHub releases" } if source.BaseURL == "" { source.BaseURL = "https://api.github.com" } if source.TokenHeader == "" { source.TokenHeader = "Authorization" } if source.TokenPrefix == "" { source.TokenPrefix = "Bearer " } if len(source.TokenEnvNames) == 0 { source.TokenEnvNames = []string{"GITHUB_TOKEN"} } } return source } func isCurrentRelease(currentVersion, latestTag string) bool { current := strings.TrimSpace(currentVersion) latest := strings.TrimSpace(latestTag) if latest == "" { return false } if current == "" || current == "dev" { return false } return current == latest } func (a Auth) apply(req *http.Request) { if strings.TrimSpace(a.Header) == "" || strings.TrimSpace(a.Token) == "" { return } req.Header.Set(a.Header, a.Token) } func (a Auth) maybeHint(statusCode int, body []byte, source ReleaseSource) error { source = normalizeSource(source) if strings.TrimSpace(a.Token) != "" || len(source.TokenEnvNames) == 0 { return nil } switch statusCode { case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: default: return nil } message := strings.ToLower(strings.TrimSpace(string(body))) if !strings.Contains(message, "project not found") && !strings.Contains(message, "not found") && !strings.Contains(message, "unauthorized") && !strings.Contains(message, "forbidden") { return nil } target := source.BaseURL if target == "" { target = "release endpoint" } name := source.Name if name == "" { name = "release" } if len(source.TokenEnvNames) == 1 { return fmt.Errorf( "%s access requires authentication on %s; set %s and retry", name, target, source.TokenEnvNames[0], ) } return fmt.Errorf( "%s access requires authentication on %s; set %s (or %s) and retry", name, target, source.TokenEnvNames[0], source.TokenEnvNames[1], ) } func parseReleaseLinks(raw json.RawMessage) []ReleaseLink { if len(raw) == 0 || strings.TrimSpace(string(raw)) == "null" { return nil } parseLinks := func(payload []releaseLinkPayload) []ReleaseLink { links := make([]ReleaseLink, 0, len(payload)) for _, item := range payload { name := strings.TrimSpace(item.Name) assetURL := firstNonEmpty(item.DirectAssetURL, item.BrowserDownloadURL, item.URL) if name == "" || strings.TrimSpace(assetURL) == "" { continue } links = append(links, ReleaseLink{ Name: name, URL: strings.TrimSpace(assetURL), }) } return links } var asObject releaseAssetsPayload if err := json.Unmarshal(raw, &asObject); err == nil && len(asObject.Links) > 0 { return parseLinks(asObject.Links) } var asArray []releaseLinkPayload if err := json.Unmarshal(raw, &asArray); err == nil && len(asArray) > 0 { return parseLinks(asArray) } return nil } func firstNonEmpty(values ...string) string { for _, value := range values { if trimmed := strings.TrimSpace(value); trimmed != "" { return trimmed } } return "" } func withTokenPrefix(token, prefix string) string { trimmedToken := strings.TrimSpace(token) if trimmedToken == "" { return "" } trimmedPrefix := strings.TrimSpace(prefix) if trimmedPrefix == "" { return trimmedToken } lowerToken := strings.ToLower(trimmedToken) lowerPrefix := strings.ToLower(trimmedPrefix) if strings.HasPrefix(lowerToken, lowerPrefix) { return trimmedToken } return trimmedPrefix + " " + trimmedToken } func resolveChecksumAssetName(assetName, configured string) string { value := strings.TrimSpace(configured) if value == "" { return assetName + ".sha256" } return strings.ReplaceAll(value, "{asset}", assetName) } func resolveSignatureAssetName(assetName, configured string) string { value := strings.TrimSpace(configured) if value == "" { return assetName + ".sig" } return strings.ReplaceAll(value, "{asset}", assetName) } func downloadAssetBytes(ctx context.Context, client *http.Client, assetURL string, auth Auth, source ReleaseSource) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) if err != nil { return nil, fmt.Errorf("build checksum download request: %w", err) } req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("download checksum asset: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if hint := auth.maybeHint(resp.StatusCode, body, source); hint != nil { return nil, fmt.Errorf("download checksum asset: %w", hint) } return nil, fmt.Errorf( "download checksum asset: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } content, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024)) if err != nil { return nil, fmt.Errorf("read checksum asset: %w", err) } return content, nil } func parseChecksum(content, assetName string) (string, error) { lines := strings.Split(content, "\n") fallbackSingle := "" for _, raw := range lines { line := strings.TrimSpace(raw) if line == "" || strings.HasPrefix(line, "#") { continue } if strings.HasPrefix(strings.ToUpper(line), "SHA256 (") { openIndex := strings.Index(line, "(") closeIndex := strings.LastIndex(line, ")") equalIndex := strings.LastIndex(line, "=") if openIndex >= 0 && closeIndex > openIndex && equalIndex > closeIndex { name := strings.TrimSpace(line[openIndex+1 : closeIndex]) hash := strings.TrimSpace(line[equalIndex+1:]) if isSHA256Hex(hash) && matchesAssetName(name, assetName) { return strings.ToLower(hash), nil } } continue } fields := strings.Fields(line) if len(fields) > 0 && isSHA256Hex(fields[0]) { if len(fields) == 1 { if fallbackSingle == "" { fallbackSingle = strings.ToLower(fields[0]) } continue } name := strings.TrimSpace(strings.TrimPrefix(fields[1], "*")) if matchesAssetName(name, assetName) { return strings.ToLower(fields[0]), nil } continue } colonIndex := strings.Index(line, ":") if colonIndex > 0 && colonIndex < len(line)-1 { left := strings.TrimSpace(line[:colonIndex]) right := strings.TrimSpace(line[colonIndex+1:]) switch { case isSHA256Hex(left) && matchesAssetName(right, assetName): return strings.ToLower(left), nil case isSHA256Hex(right) && matchesAssetName(left, assetName): return strings.ToLower(right), nil } } } if fallbackSingle != "" { return fallbackSingle, nil } return "", fmt.Errorf("checksum file does not contain a sha256 for asset %q", assetName) } func parseEd25519Signature(content, assetName string) ([]byte, error) { lines := strings.Split(content, "\n") var fallbackSingle []byte for _, raw := range lines { line := strings.TrimSpace(raw) if line == "" || strings.HasPrefix(line, "#") { continue } fields := strings.Fields(line) if len(fields) > 0 { if signature, ok := parseEd25519SignatureToken(fields[0]); ok { if len(fields) == 1 { if fallbackSingle == nil { fallbackSingle = signature } continue } name := strings.TrimSpace(strings.TrimPrefix(fields[1], "*")) if matchesAssetName(name, assetName) { return signature, nil } } } colonIndex := strings.Index(line, ":") if colonIndex > 0 && colonIndex < len(line)-1 { left := strings.TrimSpace(line[:colonIndex]) right := strings.TrimSpace(line[colonIndex+1:]) if signature, ok := parseEd25519SignatureToken(left); ok && matchesAssetName(right, assetName) { return signature, nil } if signature, ok := parseEd25519SignatureToken(right); ok && matchesAssetName(left, assetName) { return signature, nil } } } if fallbackSingle != nil { return fallbackSingle, nil } return nil, fmt.Errorf("signature file does not contain a valid Ed25519 signature for asset %q", assetName) } func parseEd25519SignatureToken(value string) ([]byte, bool) { decoded, err := decodeBinaryValue(value, ed25519.SignatureSize) if err != nil { return nil, false } return decoded, true } func resolveEd25519PublicKey(explicit string, envNames []string) (ed25519.PublicKey, bool, error) { key := strings.TrimSpace(explicit) if key != "" { publicKey, err := parseEd25519PublicKey(key) if err != nil { return nil, false, fmt.Errorf("parse ed25519 public key: %w", err) } return publicKey, true, nil } for _, envName := range envNames { if value := strings.TrimSpace(os.Getenv(envName)); value != "" { publicKey, err := parseEd25519PublicKey(value) if err != nil { return nil, false, fmt.Errorf("parse ed25519 public key from %s: %w", envName, err) } return publicKey, true, nil } } return nil, false, nil } func parseEd25519PublicKey(value string) (ed25519.PublicKey, error) { decoded, err := decodeBinaryValue(value, ed25519.PublicKeySize) if err != nil { return nil, err } return ed25519.PublicKey(decoded), nil } func decodeBinaryValue(value string, expectedLength int) ([]byte, error) { trimmed := strings.TrimSpace(value) if trimmed == "" { return nil, errors.New("value must not be empty") } decoders := []func(string) ([]byte, error){ hex.DecodeString, base64.StdEncoding.DecodeString, base64.RawStdEncoding.DecodeString, base64.URLEncoding.DecodeString, base64.RawURLEncoding.DecodeString, } lengthMismatch := false for _, decode := range decoders { decoded, err := decode(trimmed) if err != nil { continue } if len(decoded) == expectedLength { return decoded, nil } lengthMismatch = true } if lengthMismatch { return nil, fmt.Errorf("decoded value has invalid length (expected %d bytes)", expectedLength) } return nil, errors.New("value must be hex or base64 encoded") } func matchesAssetName(candidate, assetName string) bool { name := strings.TrimSpace(strings.TrimPrefix(candidate, "*")) name = strings.TrimPrefix(name, "./") if name == assetName { return true } name = strings.ReplaceAll(name, "\\", "/") if path.Base(name) == assetName { return true } return false } func isSHA256Hex(value string) bool { if len(value) != 64 { return false } for _, r := range value { switch { case r >= '0' && r <= '9': case r >= 'a' && r <= 'f': case r >= 'A' && r <= 'F': default: return false } } return true } func fileSHA256(path string) (string, error) { file, err := os.Open(path) if err != nil { return "", fmt.Errorf("open downloaded artifact: %w", err) } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return "", fmt.Errorf("hash downloaded artifact: %w", err) } return hex.EncodeToString(hash.Sum(nil)), nil }