package update import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "runtime" "strings" "time" ) type Options struct { Client *http.Client CurrentVersion string ExecutablePath string LatestReleaseURL string Stdout io.Writer BinaryName string ReleaseSource ReleaseSource GOOS string GOARCH string } type ReleaseSource struct { Name string BaseURL string LatestReleaseURL string Token string TokenHeader string TokenEnvNames []string } type Auth struct { Header string Token string } type Release struct { TagName string `json:"tag_name"` Assets struct { Links []ReleaseLink `json:"links"` } `json:"assets"` } type ReleaseLink struct { Name string `json:"name"` URL string `json:"url"` } func Run(ctx context.Context, opts Options) error { if opts.Stdout == nil { opts.Stdout = io.Discard } if opts.Client == nil { opts.Client = &http.Client{Timeout: 60 * time.Second} } if strings.TrimSpace(opts.CurrentVersion) == "" { opts.CurrentVersion = "dev" } if strings.TrimSpace(opts.GOOS) == "" { opts.GOOS = runtime.GOOS } if strings.TrimSpace(opts.GOARCH) == "" { opts.GOARCH = runtime.GOARCH } source := normalizeSource(opts.ReleaseSource) auth := ResolveAuth(source.Token, source) releaseURL := opts.LatestReleaseURL if strings.TrimSpace(releaseURL) == "" { releaseURL = strings.TrimSpace(source.LatestReleaseURL) } if releaseURL == "" { return errors.New("latest release URL must not be empty") } targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) if err != nil { return err } assetName, err := AssetName(opts.BinaryName, opts.GOOS, opts.GOARCH) if err != nil { return err } release, err := FetchLatestRelease(ctx, opts.Client, releaseURL, auth, source) if err != nil { return err } if isCurrentRelease(opts.CurrentVersion, release.TagName) { fmt.Fprintf(opts.Stdout, "Already up to date (%s)\n", release.TagName) return nil } assetURL, err := release.AssetURL(assetName, releaseURL) if err != nil { return err } downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source) if err != nil { return err } defer os.Remove(downloadPath) if err := ReplaceExecutable(downloadPath, targetPath); err != nil { return err } fmt.Fprintf(opts.Stdout, "Updated %s to %s\n", targetPath, release.TagName) return nil } func ResolveAuth(explicitToken string, source ReleaseSource) Auth { source = normalizeSource(source) if token := strings.TrimSpace(explicitToken); token != "" { return Auth{Header: source.TokenHeader, Token: token} } for _, envName := range source.TokenEnvNames { if token := strings.TrimSpace(os.Getenv(envName)); token != "" { return Auth{Header: source.TokenHeader, Token: token} } } return Auth{} } func ResolveUpdateTarget(explicitPath string) (string, error) { targetPath := strings.TrimSpace(explicitPath) if targetPath == "" { var err error targetPath, err = os.Executable() if err != nil { return "", fmt.Errorf("resolve executable path: %w", err) } } resolvedPath, err := filepath.EvalSymlinks(targetPath) if err != nil { return "", fmt.Errorf("resolve executable symlink %q: %w", targetPath, err) } return resolvedPath, nil } func AssetName(binaryName, goos, goarch string) (string, error) { name := strings.TrimSpace(binaryName) if name == "" { return "", errors.New("binary name must not be empty") } switch { case goos == "darwin" && goarch == "amd64": return name + "-darwin-amd64", nil case goos == "darwin" && goarch == "arm64": return name + "-darwin-arm64", nil case goos == "linux" && goarch == "amd64": return name + "-linux-amd64", nil case goos == "windows" && goarch == "amd64": return name + "-windows-amd64.exe", nil default: return "", fmt.Errorf("no release artifact for %s/%s", goos, goarch) } } func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source ReleaseSource) (Release, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil) if err != nil { return Release{}, fmt.Errorf("build latest release request: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) } return Release{}, fmt.Errorf( "fetch latest release metadata: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } var release Release if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { return Release{}, fmt.Errorf("decode latest release metadata: %w", err) } if strings.TrimSpace(release.TagName) == "" { return Release{}, errors.New("latest release metadata is missing tag_name") } return release, nil } func (r Release) AssetURL(assetName, releaseURL string) (string, error) { for _, link := range r.Assets.Links { if link.Name == assetName { if strings.TrimSpace(link.URL) == "" { return "", fmt.Errorf("release asset %q has no URL", assetName) } parsed, err := url.Parse(link.URL) if err != nil { return "", fmt.Errorf("parse release asset URL %q: %w", link.URL, err) } if parsed.IsAbs() { return parsed.String(), nil } baseURL, err := url.Parse(releaseURL) if err != nil { return "", fmt.Errorf("parse latest release URL %q: %w", releaseURL, err) } return baseURL.ResolveReference(parsed).String(), nil } } return "", fmt.Errorf("latest release does not contain asset %q", assetName) } func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) if err != nil { return "", fmt.Errorf("build artifact download request: %w", err) } req.Header.Set("User-Agent", "mcp updater") auth.apply(req) resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("download release artifact: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { return "", fmt.Errorf("download release artifact: %w", err) } return "", fmt.Errorf( "download release artifact: unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)), ) } existingInfo, err := os.Stat(targetPath) if err != nil { return "", fmt.Errorf("stat executable %q: %w", targetPath, err) } tempFile, err := os.CreateTemp(filepath.Dir(targetPath), filepath.Base(targetPath)+".download-*") if err != nil { return "", fmt.Errorf("create temporary file: %w", err) } tempPath := tempFile.Name() cleanup := func(copyErr error) (string, error) { tempFile.Close() os.Remove(tempPath) return "", copyErr } if _, err := io.Copy(tempFile, resp.Body); err != nil { return cleanup(fmt.Errorf("write downloaded artifact: %w", err)) } if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil { return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err)) } if err := tempFile.Close(); err != nil { os.Remove(tempPath) return "", fmt.Errorf("close downloaded artifact: %w", err) } return tempPath, nil } func ReplaceExecutable(downloadPath, targetPath string) error { if runtime.GOOS == "windows" { return errors.New("self-update is not supported on windows") } if err := os.Rename(downloadPath, targetPath); err != nil { return fmt.Errorf("replace executable %q: %w", targetPath, err) } return nil } func normalizeSource(source ReleaseSource) ReleaseSource { source.Name = strings.TrimSpace(source.Name) source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL) source.TokenHeader = strings.TrimSpace(source.TokenHeader) return source } func isCurrentRelease(currentVersion, latestTag string) bool { current := strings.TrimSpace(currentVersion) latest := strings.TrimSpace(latestTag) if latest == "" { return false } if current == "" || current == "dev" { return false } return current == latest } func (a Auth) apply(req *http.Request) { if strings.TrimSpace(a.Header) == "" || strings.TrimSpace(a.Token) == "" { return } req.Header.Set(a.Header, a.Token) } func (a Auth) maybeHint(statusCode int, body []byte, source ReleaseSource) error { source = normalizeSource(source) if strings.TrimSpace(a.Token) != "" || len(source.TokenEnvNames) == 0 { return nil } switch statusCode { case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: default: return nil } message := strings.ToLower(strings.TrimSpace(string(body))) if !strings.Contains(message, "project not found") && !strings.Contains(message, "not found") && !strings.Contains(message, "unauthorized") && !strings.Contains(message, "forbidden") { return nil } target := source.BaseURL if target == "" { target = "release endpoint" } name := source.Name if name == "" { name = "release" } if len(source.TokenEnvNames) == 1 { return fmt.Errorf( "%s access requires authentication on %s; set %s and retry", name, target, source.TokenEnvNames[0], ) } return fmt.Errorf( "%s access requires authentication on %s; set %s (or %s) and retry", name, target, source.TokenEnvNames[0], source.TokenEnvNames[1], ) }