diff --git a/README.md b/README.md index 119084d..ce2a5be 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,14 @@ docs_url = "https://docs.example.com/my-mcp" [update] source_name = "Gitea releases" +driver = "gitea" +repository = "org/repo" base_url = "https://gitea.example.com" -latest_release_url = "https://gitea.example.com/api/v1/repos/org/repo/releases/latest" +asset_name_template = "{binary}-{os}-{arch}{ext}" +checksum_asset_name = "{asset}.sha256" +checksum_required = false token_header = "Authorization" +token_prefix = "token" token_env_names = ["GITEA_TOKEN"] [environment] @@ -120,9 +125,15 @@ Champs supportés : - `[update]` : source de release consommée par `update`. - `source_name` : nom humain de la source de release, utilisé dans certains messages d'erreur. +- `driver` : driver de forge (`gitea`, `gitlab`, `github`) pour déduire automatiquement l'endpoint latest. +- `repository` : dépôt cible (`org/repo` ou `group/subgroup/repo`). - `base_url` : base de la forge ou du service de release. -- `latest_release_url` : URL complète qui retourne la release la plus récente. +- `latest_release_url` : URL complète qui retourne la release la plus récente (prioritaire sur le driver). +- `asset_name_template` : template de nom d'asset (`{binary}`, `{os}`, `{arch}`, `{ext}`). +- `checksum_asset_name` : nom d'asset checksum, avec placeholder optionnel `{asset}`. +- `checksum_required` : si `true`, l'update échoue quand l'asset checksum est absent. - `token_header` : header HTTP à utiliser pour l'authentification. +- `token_prefix` : préfixe appliqué devant le token (`Bearer`, `token`, ...). - `token_env_names` : liste de variables d'environnement candidates pour retrouver le token. - `[environment].known` : variables d'environnement connues du projet. - `[secret_store].backend_policy` : politique de secret store (`auto`, `kwallet-only`, `keyring-any`, `env-only`). @@ -462,9 +473,14 @@ if report.HasFailures() { ## Auto-Update -Le package `update` ne déduit pas la forge ni l'authentification. -L'application cliente fournit l'URL de release, le header d'auth éventuel et, -si besoin, les variables d'environnement à consulter. +Le package `update` supporte les drivers `gitea`, `gitlab` et `github`. +Si `latest_release_url` est vide, l'URL latest est déduite depuis +`driver + repository (+ base_url)`. + +Le parseur de release supporte : + +- format `assets.links` (Gitea/GitLab) +- format `assets[]` avec `browser_download_url` (GitHub et Gitea API) Le format attendu pour la réponse `latest release` est actuellement : @@ -501,12 +517,12 @@ if err != nil { } ``` -Contraintes actuelles : +Comportement : -- le `latest_release_url` doit être renseigné explicitement -- les assets supportés sont `darwin/amd64`, `darwin/arm64`, `linux/amd64` et `windows/amd64` -- le remplacement du binaire n'est pas supporté sur Windows -- le nom de l'asset est dérivé de `BinaryName`, `GOOS` et `GOARCH` +- le nom de l'asset est configurable (`asset_name_template`) et supporte tout couple `GOOS/GOARCH` +- si un asset `.sha256` (ou `checksum_asset_name`) existe, le binaire téléchargé est vérifié avant remplacement +- un hook `ValidateDownloaded` permet d'ajouter une validation custom (signature, scan, etc.) +- sur Windows, le remplacement in-place n'est pas fait par défaut ; fournir `Options.ReplaceExecutable` pour une stratégie dédiée ## Exemple Minimal diff --git a/manifest/manifest.go b/manifest/manifest.go index 21c0bcd..b3c9414 100644 --- a/manifest/manifest.go +++ b/manifest/manifest.go @@ -25,11 +25,17 @@ type File struct { } type Update struct { - SourceName string `toml:"source_name"` - BaseURL string `toml:"base_url"` - LatestReleaseURL string `toml:"latest_release_url"` - TokenHeader string `toml:"token_header"` - TokenEnvNames []string `toml:"token_env_names"` + SourceName string `toml:"source_name"` + Driver string `toml:"driver"` + Repository string `toml:"repository"` + BaseURL string `toml:"base_url"` + LatestReleaseURL string `toml:"latest_release_url"` + AssetNameTemplate string `toml:"asset_name_template"` + ChecksumAssetName string `toml:"checksum_asset_name"` + ChecksumRequired bool `toml:"checksum_required"` + TokenHeader string `toml:"token_header"` + TokenPrefix string `toml:"token_prefix"` + TokenEnvNames []string `toml:"token_env_names"` } type Environment struct { @@ -143,9 +149,14 @@ func (f *File) normalize() { func (u *Update) normalize() { u.SourceName = strings.TrimSpace(u.SourceName) + u.Driver = strings.ToLower(strings.TrimSpace(u.Driver)) + u.Repository = strings.Trim(strings.TrimSpace(u.Repository), "/") u.BaseURL = strings.TrimRight(strings.TrimSpace(u.BaseURL), "/") u.LatestReleaseURL = strings.TrimSpace(u.LatestReleaseURL) + u.AssetNameTemplate = strings.TrimSpace(u.AssetNameTemplate) + u.ChecksumAssetName = strings.TrimSpace(u.ChecksumAssetName) u.TokenHeader = strings.TrimSpace(u.TokenHeader) + u.TokenPrefix = strings.TrimSpace(u.TokenPrefix) u.TokenEnvNames = normalizeStringList(u.TokenEnvNames) } @@ -170,11 +181,17 @@ func (u Update) ReleaseSource() update.ReleaseSource { u.normalize() return update.ReleaseSource{ - Name: u.SourceName, - BaseURL: u.BaseURL, - LatestReleaseURL: u.LatestReleaseURL, - TokenHeader: u.TokenHeader, - TokenEnvNames: append([]string(nil), u.TokenEnvNames...), + Name: u.SourceName, + Driver: u.Driver, + Repository: u.Repository, + BaseURL: u.BaseURL, + LatestReleaseURL: u.LatestReleaseURL, + AssetNameTemplate: u.AssetNameTemplate, + ChecksumAssetName: u.ChecksumAssetName, + ChecksumRequired: u.ChecksumRequired, + TokenHeader: u.TokenHeader, + TokenPrefix: u.TokenPrefix, + TokenEnvNames: append([]string(nil), u.TokenEnvNames...), } } diff --git a/manifest/manifest_test.go b/manifest/manifest_test.go index ebc9454..83ea7d5 100644 --- a/manifest/manifest_test.go +++ b/manifest/manifest_test.go @@ -46,9 +46,15 @@ func TestLoadParsesUpdateConfig(t *testing.T) { const content = ` [update] source_name = " Gitea releases " +driver = " Gitea " +repository = " org/repo " base_url = "https://gitea.example.com/" latest_release_url = "https://gitea.example.com/api/releases/latest" +asset_name_template = "{binary}_{os}_{arch}{ext}" +checksum_asset_name = "{asset}.sha256" +checksum_required = true token_header = " Authorization " +token_prefix = " token " token_env_names = [" GITEA_TOKEN ", "", "GITEA_RELEASE_TOKEN"] ` @@ -65,15 +71,33 @@ token_env_names = [" GITEA_TOKEN ", "", "GITEA_RELEASE_TOKEN"] if source.Name != "Gitea releases" { t.Fatalf("source name = %q", source.Name) } + if source.Driver != "gitea" { + t.Fatalf("driver = %q", source.Driver) + } + if source.Repository != "org/repo" { + t.Fatalf("repository = %q", source.Repository) + } if source.BaseURL != "https://gitea.example.com" { t.Fatalf("base URL = %q", source.BaseURL) } if source.LatestReleaseURL != "https://gitea.example.com/api/releases/latest" { t.Fatalf("latest release URL = %q", source.LatestReleaseURL) } + if source.AssetNameTemplate != "{binary}_{os}_{arch}{ext}" { + t.Fatalf("asset name template = %q", source.AssetNameTemplate) + } + if source.ChecksumAssetName != "{asset}.sha256" { + t.Fatalf("checksum asset name = %q", source.ChecksumAssetName) + } + if !source.ChecksumRequired { + t.Fatal("checksum required should be true") + } if source.TokenHeader != "Authorization" { t.Fatalf("token header = %q", source.TokenHeader) } + if source.TokenPrefix != "token" { + t.Fatalf("token prefix = %q", source.TokenPrefix) + } if len(source.TokenEnvNames) != 2 { t.Fatalf("token env names = %v", source.TokenEnvNames) } diff --git a/update/update.go b/update/update.go index 12593d8..7ba393a 100644 --- a/update/update.go +++ b/update/update.go @@ -2,6 +2,8 @@ package update import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -9,31 +11,57 @@ import ( "net/http" "net/url" "os" + "path" "path/filepath" "runtime" + "sort" "strings" "time" ) +const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}" + type Options struct { - Client *http.Client - CurrentVersion string - ExecutablePath string - LatestReleaseURL string - Stdout io.Writer - BinaryName string - ReleaseSource ReleaseSource - GOOS string - GOARCH string + Client *http.Client + CurrentVersion string + ExecutablePath string + LatestReleaseURL string + Stdout io.Writer + BinaryName string + AssetNameTemplate string + ReleaseSource ReleaseSource + GOOS string + GOARCH string + ValidateDownloaded ValidateDownloadedFunc + ReplaceExecutable ReplaceExecutableFunc +} + +type ReplaceExecutableFunc func(downloadPath, targetPath string) error + +type ValidateDownloadedFunc func(context.Context, ValidationInput) error + +type ValidationInput struct { + DownloadPath string + TargetPath string + AssetName string + ReleaseTag string + ReleaseURL string + Source ReleaseSource } type ReleaseSource struct { - Name string - BaseURL string - LatestReleaseURL string - Token string - TokenHeader string - TokenEnvNames []string + Name string + Driver string + Repository string + BaseURL string + LatestReleaseURL string + AssetNameTemplate string + ChecksumAssetName string + ChecksumRequired bool + Token string + TokenHeader string + TokenPrefix string + TokenEnvNames []string } type Auth struct { @@ -53,6 +81,33 @@ type ReleaseLink struct { URL string `json:"url"` } +type releasePayload struct { + TagName string `json:"tag_name"` + Assets json.RawMessage `json:"assets"` +} + +type releaseAssetsPayload struct { + Links []releaseLinkPayload `json:"links"` +} + +type releaseLinkPayload struct { + Name string `json:"name"` + URL string `json:"url"` + BrowserDownloadURL string `json:"browser_download_url"` + DirectAssetURL string `json:"direct_asset_url"` +} + +func (r *Release) UnmarshalJSON(data []byte) error { + var payload releasePayload + if err := json.Unmarshal(data, &payload); err != nil { + return err + } + + r.TagName = strings.TrimSpace(payload.TagName) + r.Assets.Links = parseReleaseLinks(payload.Assets) + return nil +} + func Run(ctx context.Context, opts Options) error { if opts.Stdout == nil { opts.Stdout = io.Discard @@ -73,12 +128,9 @@ func Run(ctx context.Context, opts Options) error { source := normalizeSource(opts.ReleaseSource) auth := ResolveAuth(source.Token, source) - releaseURL := opts.LatestReleaseURL - if strings.TrimSpace(releaseURL) == "" { - releaseURL = strings.TrimSpace(source.LatestReleaseURL) - } - if releaseURL == "" { - return errors.New("latest release URL must not be empty") + releaseURL, err := ResolveLatestReleaseURL(opts.LatestReleaseURL, source) + if err != nil { + return err } targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) @@ -86,7 +138,12 @@ func Run(ctx context.Context, opts Options) error { return err } - assetName, err := AssetName(opts.BinaryName, opts.GOOS, opts.GOARCH) + assetTemplate := strings.TrimSpace(opts.AssetNameTemplate) + if assetTemplate == "" { + assetTemplate = source.AssetNameTemplate + } + + assetName, err := AssetNameWithTemplate(opts.BinaryName, opts.GOOS, opts.GOARCH, assetTemplate) if err != nil { return err } @@ -111,7 +168,28 @@ func Run(ctx context.Context, opts Options) error { } defer os.Remove(downloadPath) - if err := ReplaceExecutable(downloadPath, targetPath); err != nil { + if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil { + return err + } + + if opts.ValidateDownloaded != nil { + if err := opts.ValidateDownloaded(ctx, ValidationInput{ + DownloadPath: downloadPath, + TargetPath: targetPath, + AssetName: assetName, + ReleaseTag: release.TagName, + ReleaseURL: releaseURL, + Source: source, + }); err != nil { + return fmt.Errorf("validate downloaded artifact: %w", err) + } + } + + replaceExecutable := opts.ReplaceExecutable + if replaceExecutable == nil { + replaceExecutable = ReplaceExecutable + } + if err := replaceExecutable(downloadPath, targetPath); err != nil { return err } @@ -119,16 +197,55 @@ func Run(ctx context.Context, opts Options) error { return nil } +func ResolveLatestReleaseURL(explicit string, source ReleaseSource) (string, error) { + if releaseURL := strings.TrimSpace(explicit); releaseURL != "" { + return releaseURL, nil + } + + source = normalizeSource(source) + if source.LatestReleaseURL != "" { + return source.LatestReleaseURL, nil + } + + if source.Driver == "" { + return "", errors.New("latest release URL must not be empty (set latest_release_url or configure driver+repository)") + } + if source.Repository == "" { + return "", fmt.Errorf("release source %q requires repository when driver is set", source.Driver) + } + + switch source.Driver { + case "gitea": + if source.BaseURL == "" { + return "", errors.New("release source gitea requires base_url") + } + return fmt.Sprintf("%s/api/v1/repos/%s/releases/latest", source.BaseURL, source.Repository), nil + case "gitlab": + projectPath := url.PathEscape(source.Repository) + return fmt.Sprintf("%s/api/v4/projects/%s/releases/permalink/latest", source.BaseURL, projectPath), nil + case "github": + return fmt.Sprintf("%s/repos/%s/releases/latest", source.BaseURL, source.Repository), nil + default: + return "", fmt.Errorf("unsupported release driver %q (expected gitea, gitlab or github)", source.Driver) + } +} + func ResolveAuth(explicitToken string, source ReleaseSource) Auth { source = normalizeSource(source) if token := strings.TrimSpace(explicitToken); token != "" { - return Auth{Header: source.TokenHeader, Token: token} + return Auth{ + Header: source.TokenHeader, + Token: withTokenPrefix(token, source.TokenPrefix), + } } for _, envName := range source.TokenEnvNames { if token := strings.TrimSpace(os.Getenv(envName)); token != "" { - return Auth{Header: source.TokenHeader, Token: token} + return Auth{ + Header: source.TokenHeader, + Token: withTokenPrefix(token, source.TokenPrefix), + } } } @@ -153,23 +270,46 @@ func ResolveUpdateTarget(explicitPath string) (string, error) { } func AssetName(binaryName, goos, goarch string) (string, error) { + return AssetNameWithTemplate(binaryName, goos, goarch, defaultAssetNameTemplate) +} + +func AssetNameWithTemplate(binaryName, goos, goarch, template string) (string, error) { name := strings.TrimSpace(binaryName) if name == "" { return "", errors.New("binary name must not be empty") } - switch { - case goos == "darwin" && goarch == "amd64": - return name + "-darwin-amd64", nil - case goos == "darwin" && goarch == "arm64": - return name + "-darwin-arm64", nil - case goos == "linux" && goarch == "amd64": - return name + "-linux-amd64", nil - case goos == "windows" && goarch == "amd64": - return name + "-windows-amd64.exe", nil - default: - return "", fmt.Errorf("no release artifact for %s/%s", goos, goarch) + osName := strings.ToLower(strings.TrimSpace(goos)) + archName := strings.ToLower(strings.TrimSpace(goarch)) + if osName == "" || archName == "" { + return "", errors.New("goos and goarch must not be empty") } + + assetTemplate := strings.TrimSpace(template) + if assetTemplate == "" { + assetTemplate = defaultAssetNameTemplate + } + + ext := "" + if osName == "windows" { + ext = ".exe" + } + + replaced := strings.NewReplacer( + "{binary}", name, + "{os}", osName, + "{arch}", archName, + "{ext}", ext, + ).Replace(assetTemplate) + replaced = strings.TrimSpace(replaced) + if replaced == "" { + return "", errors.New("asset name template resolved to an empty value") + } + if strings.ContainsRune(replaced, '/') || strings.ContainsRune(replaced, '\\') { + return "", fmt.Errorf("asset name %q must not contain path separators", replaced) + } + + return replaced, nil } func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source ReleaseSource) (Release, error) { @@ -232,7 +372,26 @@ func (r Release) AssetURL(assetName, releaseURL string) (string, error) { } } - return "", fmt.Errorf("latest release does not contain asset %q", assetName) + availableAssets := make([]string, 0, len(r.Assets.Links)) + for _, link := range r.Assets.Links { + if name := strings.TrimSpace(link.Name); name != "" { + availableAssets = append(availableAssets, name) + } + } + sort.Strings(availableAssets) + if len(availableAssets) == 0 { + return "", fmt.Errorf("latest release does not contain asset %q", assetName) + } + + preview := availableAssets + if len(preview) > 8 { + preview = preview[:8] + } + return "", fmt.Errorf( + "latest release does not contain asset %q (available: %s)", + assetName, + strings.Join(preview, ", "), + ) } func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource) (string, error) { @@ -292,9 +451,57 @@ func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, ta return tempPath, nil } +func VerifyReleaseAssetChecksum( + ctx context.Context, + client *http.Client, + release Release, + releaseURL string, + assetName string, + artifactPath string, + auth Auth, + source ReleaseSource, +) error { + source = normalizeSource(source) + + checksumAssetName := resolveChecksumAssetName(assetName, source.ChecksumAssetName) + checksumURL, err := release.AssetURL(checksumAssetName, releaseURL) + if err != nil { + if source.ChecksumRequired { + return fmt.Errorf("checksum verification: %w", err) + } + return nil + } + + checksumBody, err := downloadAssetBytes(ctx, client, checksumURL, auth, source) + if err != nil { + return fmt.Errorf("checksum verification: %w", err) + } + + expected, err := parseChecksum(string(checksumBody), assetName) + if err != nil { + return fmt.Errorf("checksum verification: %w", err) + } + + actual, err := fileSHA256(artifactPath) + if err != nil { + return fmt.Errorf("checksum verification: %w", err) + } + + if !strings.EqualFold(expected, actual) { + return fmt.Errorf( + "checksum mismatch for asset %q: expected %s, got %s", + assetName, + expected, + actual, + ) + } + + return nil +} + func ReplaceExecutable(downloadPath, targetPath string) error { if runtime.GOOS == "windows" { - return errors.New("self-update is not supported on windows") + return errors.New("self-update is not supported on windows without a custom ReplaceExecutable hook") } if err := os.Rename(downloadPath, targetPath); err != nil { return fmt.Errorf("replace executable %q: %w", targetPath, err) @@ -304,9 +511,69 @@ func ReplaceExecutable(downloadPath, targetPath string) error { func normalizeSource(source ReleaseSource) ReleaseSource { source.Name = strings.TrimSpace(source.Name) + source.Driver = strings.ToLower(strings.TrimSpace(source.Driver)) + source.Repository = strings.Trim(strings.TrimSpace(source.Repository), "/") source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL) + source.AssetNameTemplate = strings.TrimSpace(source.AssetNameTemplate) + source.ChecksumAssetName = strings.TrimSpace(source.ChecksumAssetName) + source.Token = strings.TrimSpace(source.Token) source.TokenHeader = strings.TrimSpace(source.TokenHeader) + source.TokenPrefix = strings.TrimSpace(source.TokenPrefix) + + envNames := source.TokenEnvNames[:0] + for _, envName := range source.TokenEnvNames { + if trimmed := strings.TrimSpace(envName); trimmed != "" { + envNames = append(envNames, trimmed) + } + } + source.TokenEnvNames = envNames + + switch source.Driver { + case "gitea": + if source.Name == "" { + source.Name = "Gitea releases" + } + if source.TokenHeader == "" { + source.TokenHeader = "Authorization" + } + if source.TokenPrefix == "" { + source.TokenPrefix = "token " + } + if len(source.TokenEnvNames) == 0 { + source.TokenEnvNames = []string{"GITEA_TOKEN"} + } + case "gitlab": + if source.Name == "" { + source.Name = "GitLab releases" + } + if source.BaseURL == "" { + source.BaseURL = "https://gitlab.com" + } + if source.TokenHeader == "" { + source.TokenHeader = "PRIVATE-TOKEN" + } + if len(source.TokenEnvNames) == 0 { + source.TokenEnvNames = []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"} + } + case "github": + if source.Name == "" { + source.Name = "GitHub releases" + } + if source.BaseURL == "" { + source.BaseURL = "https://api.github.com" + } + if source.TokenHeader == "" { + source.TokenHeader = "Authorization" + } + if source.TokenPrefix == "" { + source.TokenPrefix = "Bearer " + } + if len(source.TokenEnvNames) == 0 { + source.TokenEnvNames = []string{"GITHUB_TOKEN"} + } + } + return source } @@ -375,3 +642,211 @@ func (a Auth) maybeHint(statusCode int, body []byte, source ReleaseSource) error source.TokenEnvNames[1], ) } + +func parseReleaseLinks(raw json.RawMessage) []ReleaseLink { + if len(raw) == 0 || strings.TrimSpace(string(raw)) == "null" { + return nil + } + + parseLinks := func(payload []releaseLinkPayload) []ReleaseLink { + links := make([]ReleaseLink, 0, len(payload)) + for _, item := range payload { + name := strings.TrimSpace(item.Name) + assetURL := firstNonEmpty(item.DirectAssetURL, item.BrowserDownloadURL, item.URL) + if name == "" || strings.TrimSpace(assetURL) == "" { + continue + } + links = append(links, ReleaseLink{ + Name: name, + URL: strings.TrimSpace(assetURL), + }) + } + return links + } + + var asObject releaseAssetsPayload + if err := json.Unmarshal(raw, &asObject); err == nil && len(asObject.Links) > 0 { + return parseLinks(asObject.Links) + } + + var asArray []releaseLinkPayload + if err := json.Unmarshal(raw, &asArray); err == nil && len(asArray) > 0 { + return parseLinks(asArray) + } + + return nil +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func withTokenPrefix(token, prefix string) string { + trimmedToken := strings.TrimSpace(token) + if trimmedToken == "" { + return "" + } + + trimmedPrefix := strings.TrimSpace(prefix) + if trimmedPrefix == "" { + return trimmedToken + } + + lowerToken := strings.ToLower(trimmedToken) + lowerPrefix := strings.ToLower(trimmedPrefix) + if strings.HasPrefix(lowerToken, lowerPrefix) { + return trimmedToken + } + + return trimmedPrefix + " " + trimmedToken +} + +func resolveChecksumAssetName(assetName, configured string) string { + value := strings.TrimSpace(configured) + if value == "" { + return assetName + ".sha256" + } + return strings.ReplaceAll(value, "{asset}", assetName) +} + +func downloadAssetBytes(ctx context.Context, client *http.Client, assetURL string, auth Auth, source ReleaseSource) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) + if err != nil { + return nil, fmt.Errorf("build checksum download request: %w", err) + } + req.Header.Set("User-Agent", "mcp updater") + auth.apply(req) + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("download checksum asset: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if hint := auth.maybeHint(resp.StatusCode, body, source); hint != nil { + return nil, fmt.Errorf("download checksum asset: %w", hint) + } + return nil, fmt.Errorf( + "download checksum asset: unexpected status %d: %s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) + } + + content, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024)) + if err != nil { + return nil, fmt.Errorf("read checksum asset: %w", err) + } + return content, nil +} + +func parseChecksum(content, assetName string) (string, error) { + lines := strings.Split(content, "\n") + fallbackSingle := "" + + for _, raw := range lines { + line := strings.TrimSpace(raw) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + if strings.HasPrefix(strings.ToUpper(line), "SHA256 (") { + openIndex := strings.Index(line, "(") + closeIndex := strings.LastIndex(line, ")") + equalIndex := strings.LastIndex(line, "=") + if openIndex >= 0 && closeIndex > openIndex && equalIndex > closeIndex { + name := strings.TrimSpace(line[openIndex+1 : closeIndex]) + hash := strings.TrimSpace(line[equalIndex+1:]) + if isSHA256Hex(hash) && matchesAssetName(name, assetName) { + return strings.ToLower(hash), nil + } + } + continue + } + + fields := strings.Fields(line) + if len(fields) > 0 && isSHA256Hex(fields[0]) { + if len(fields) == 1 { + if fallbackSingle == "" { + fallbackSingle = strings.ToLower(fields[0]) + } + continue + } + + name := strings.TrimSpace(strings.TrimPrefix(fields[1], "*")) + if matchesAssetName(name, assetName) { + return strings.ToLower(fields[0]), nil + } + continue + } + + colonIndex := strings.Index(line, ":") + if colonIndex > 0 && colonIndex < len(line)-1 { + left := strings.TrimSpace(line[:colonIndex]) + right := strings.TrimSpace(line[colonIndex+1:]) + switch { + case isSHA256Hex(left) && matchesAssetName(right, assetName): + return strings.ToLower(left), nil + case isSHA256Hex(right) && matchesAssetName(left, assetName): + return strings.ToLower(right), nil + } + } + } + + if fallbackSingle != "" { + return fallbackSingle, nil + } + + return "", fmt.Errorf("checksum file does not contain a sha256 for asset %q", assetName) +} + +func matchesAssetName(candidate, assetName string) bool { + name := strings.TrimSpace(strings.TrimPrefix(candidate, "*")) + name = strings.TrimPrefix(name, "./") + if name == assetName { + return true + } + + name = strings.ReplaceAll(name, "\\", "/") + if path.Base(name) == assetName { + return true + } + return false +} + +func isSHA256Hex(value string) bool { + if len(value) != 64 { + return false + } + for _, r := range value { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'f': + case r >= 'A' && r <= 'F': + default: + return false + } + } + return true +} + +func fileSHA256(path string) (string, error) { + file, err := os.Open(path) + if err != nil { + return "", fmt.Errorf("open downloaded artifact: %w", err) + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", fmt.Errorf("hash downloaded artifact: %w", err) + } + return hex.EncodeToString(hash.Sum(nil)), nil +} diff --git a/update/update_test.go b/update/update_test.go index 047ac74..ab427e3 100644 --- a/update/update_test.go +++ b/update/update_test.go @@ -3,7 +3,10 @@ package update import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" + "errors" "io" "net/http" "os" @@ -24,13 +27,20 @@ func TestAssetName(t *testing.T) { {name: "darwin amd64", goos: "darwin", goarch: "amd64", want: "graylog-mcp-darwin-amd64"}, {name: "darwin arm64", goos: "darwin", goarch: "arm64", want: "graylog-mcp-darwin-arm64"}, {name: "linux amd64", goos: "linux", goarch: "amd64", want: "graylog-mcp-linux-amd64"}, - {name: "windows amd64", goos: "windows", goarch: "amd64", want: "graylog-mcp-windows-amd64.exe"}, - {name: "unsupported", goos: "linux", goarch: "arm64", wantErr: "no release artifact"}, + {name: "linux arm64", goos: "linux", goarch: "arm64", want: "graylog-mcp-linux-arm64"}, + {name: "windows arm64", goos: "windows", goarch: "arm64", want: "graylog-mcp-windows-arm64.exe"}, + {name: "missing binary", goos: "linux", goarch: "amd64", wantErr: "binary name must not be empty"}, + {name: "missing platform", goos: "", goarch: "amd64", wantErr: "goos and goarch must not be empty"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := AssetName("graylog-mcp", tt.goos, tt.goarch) + binaryName := "graylog-mcp" + if tt.name == "missing binary" { + binaryName = " " + } + + got, err := AssetName(binaryName, tt.goos, tt.goarch) if tt.wantErr != "" { if err == nil || !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("error = %v, want substring %q", err, tt.wantErr) @@ -47,6 +57,91 @@ func TestAssetName(t *testing.T) { } } +func TestAssetNameWithTemplate(t *testing.T) { + got, err := AssetNameWithTemplate("graylog-mcp", "linux", "amd64", "{binary}_{os}_{arch}") + if err != nil { + t.Fatalf("AssetNameWithTemplate: %v", err) + } + if got != "graylog-mcp_linux_amd64" { + t.Fatalf("got %q", got) + } + + _, err = AssetNameWithTemplate("graylog-mcp", "linux", "amd64", "{binary}/{os}/{arch}") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "must not contain path separators") { + t.Fatalf("error = %v", err) + } +} + +func TestResolveLatestReleaseURL(t *testing.T) { + got, err := ResolveLatestReleaseURL("https://custom/latest", ReleaseSource{ + Driver: "gitea", + Repository: "org/repo", + BaseURL: "https://gitea.example.com", + }) + if err != nil { + t.Fatalf("ResolveLatestReleaseURL explicit: %v", err) + } + if got != "https://custom/latest" { + t.Fatalf("release url = %q", got) + } + + got, err = ResolveLatestReleaseURL("", ReleaseSource{ + LatestReleaseURL: "https://manifest/latest", + }) + if err != nil { + t.Fatalf("ResolveLatestReleaseURL from source: %v", err) + } + if got != "https://manifest/latest" { + t.Fatalf("release url = %q", got) + } + + got, err = ResolveLatestReleaseURL("", ReleaseSource{ + Driver: "gitea", + Repository: "org/repo", + BaseURL: "https://gitea.example.com", + }) + if err != nil { + t.Fatalf("ResolveLatestReleaseURL gitea: %v", err) + } + if got != "https://gitea.example.com/api/v1/repos/org/repo/releases/latest" { + t.Fatalf("release url = %q", got) + } + + got, err = ResolveLatestReleaseURL("", ReleaseSource{ + Driver: "gitlab", + Repository: "group/sub/repo", + BaseURL: "https://gitlab.example.com", + }) + if err != nil { + t.Fatalf("ResolveLatestReleaseURL gitlab: %v", err) + } + if got != "https://gitlab.example.com/api/v4/projects/group%2Fsub%2Frepo/releases/permalink/latest" { + t.Fatalf("release url = %q", got) + } + + got, err = ResolveLatestReleaseURL("", ReleaseSource{ + Driver: "github", + Repository: "org/repo", + }) + if err != nil { + t.Fatalf("ResolveLatestReleaseURL github: %v", err) + } + if got != "https://api.github.com/repos/org/repo/releases/latest" { + t.Fatalf("release url = %q", got) + } + + _, err = ResolveLatestReleaseURL("", ReleaseSource{Driver: "gitea"}) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "requires repository") { + t.Fatalf("error = %v", err) + } +} + func TestResolveUpdateTargetFollowsSymlink(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("symlink behavior differs on windows") @@ -87,20 +182,37 @@ func TestReleaseAssetURLResolvesRelativeLinks(t *testing.T) { } } +func TestReleaseAssetURLErrorIncludesAvailableAssets(t *testing.T) { + release := Release{} + release.Assets.Links = []ReleaseLink{ + {Name: "my-mcp-linux-amd64", URL: "/downloads/my-mcp-linux-amd64"}, + {Name: "my-mcp-darwin-arm64", URL: "/downloads/my-mcp-darwin-arm64"}, + } + + _, err := release.AssetURL("my-mcp-linux-arm64", "https://releases.example.com/latest") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "available: my-mcp-darwin-arm64, my-mcp-linux-amd64") { + t.Fatalf("error = %v", err) + } +} + func TestResolveAuthPrefersExplicitToken(t *testing.T) { t.Setenv("RELEASE_TOKEN", "env-token") auth := ResolveAuth("explicit-token", ReleaseSource{ - Name: "release endpoint", - BaseURL: "https://releases.example.com", - TokenHeader: "X-Release-Token", - TokenEnvNames: []string{"RELEASE_TOKEN", "RELEASE_PRIVATE_TOKEN"}, + TokenHeader: "Authorization", + TokenPrefix: "Bearer", + TokenEnvNames: []string{ + "RELEASE_TOKEN", + }, }) - if auth.Header != "X-Release-Token" { - t.Fatalf("header = %q, want X-Release-Token", auth.Header) + if auth.Header != "Authorization" { + t.Fatalf("header = %q, want Authorization", auth.Header) } - if auth.Token != "explicit-token" { - t.Fatalf("token = %q, want explicit token", auth.Token) + if auth.Token != "Bearer explicit-token" { + t.Fatalf("token = %q, want prefixed explicit token", auth.Token) } } @@ -108,10 +220,11 @@ func TestResolveAuthReadsEnvironment(t *testing.T) { t.Setenv("RELEASE_PRIVATE_TOKEN", "env-token") auth := ResolveAuth("", ReleaseSource{ - Name: "release endpoint", - BaseURL: "https://releases.example.com", - TokenHeader: "X-Release-Token", - TokenEnvNames: []string{"RELEASE_TOKEN", "RELEASE_PRIVATE_TOKEN"}, + TokenHeader: "X-Release-Token", + TokenEnvNames: []string{ + "RELEASE_TOKEN", + "RELEASE_PRIVATE_TOKEN", + }, }) if auth.Header != "X-Release-Token" { t.Fatalf("header = %q, want X-Release-Token", auth.Header) @@ -155,6 +268,86 @@ func TestFetchLatestReleaseAddsConfiguredAuthHeader(t *testing.T) { } } +func TestFetchLatestReleaseSupportsGitHubAssets(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + body := `{ + "tag_name":"v1.2.3", + "assets":[ + {"name":"my-mcp-linux-amd64","browser_download_url":"https://example.com/my-mcp-linux-amd64"} + ] + }` + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + }, nil + }), + } + + release, err := FetchLatestRelease( + context.Background(), + client, + "https://api.github.com/repos/org/repo/releases/latest", + Auth{}, + ReleaseSource{Name: "GitHub releases"}, + ) + if err != nil { + t.Fatalf("FetchLatestRelease: %v", err) + } + + assetURL, err := release.AssetURL("my-mcp-linux-amd64", "https://api.github.com/repos/org/repo/releases/latest") + if err != nil { + t.Fatalf("AssetURL: %v", err) + } + if assetURL != "https://example.com/my-mcp-linux-amd64" { + t.Fatalf("assetURL = %q", assetURL) + } +} + +func TestFetchLatestReleaseSupportsGitLabDirectAssetURL(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + body := `{ + "tag_name":"v1.2.3", + "assets":{ + "links":[ + { + "name":"my-mcp-linux-amd64", + "url":"https://gitlab.example.com/fallback", + "direct_asset_url":"https://gitlab.example.com/direct/my-mcp-linux-amd64" + } + ] + } + }` + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + }, nil + }), + } + + release, err := FetchLatestRelease( + context.Background(), + client, + "https://gitlab.example.com/api/v4/projects/org%2Frepo/releases/permalink/latest", + Auth{}, + ReleaseSource{Name: "GitLab releases"}, + ) + if err != nil { + t.Fatalf("FetchLatestRelease: %v", err) + } + + assetURL, err := release.AssetURL("my-mcp-linux-amd64", "https://gitlab.example.com/api/v4/projects/org%2Frepo/releases/permalink/latest") + if err != nil { + t.Fatalf("AssetURL: %v", err) + } + if assetURL != "https://gitlab.example.com/direct/my-mcp-linux-amd64" { + t.Fatalf("assetURL = %q", assetURL) + } +} + func TestFetchLatestReleaseHintsWhenAuthIsMissing(t *testing.T) { client := &http.Client{ Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { @@ -282,6 +475,474 @@ func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) { } } +func TestRunUsesDriverWithoutExplicitLatestReleaseURL(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + latestURL := "https://gitea.example.com/api/v1/repos/org/graylog-mcp/releases/latest" + replaceCalled := false + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case latestURL: + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://gitea.example.com/downloads/artifact"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://gitea.example.com/downloads/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("new-binary")), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + BinaryName: "graylog-mcp", + ReleaseSource: ReleaseSource{ + Driver: "gitea", + Repository: "org/graylog-mcp", + BaseURL: "https://gitea.example.com", + }, + ReplaceExecutable: func(downloadPath, targetPath string) error { + replaceCalled = true + + data, err := os.ReadFile(downloadPath) + if err != nil { + return err + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if !replaceCalled { + t.Fatal("custom replace hook was not called") + } +} + +func TestRunVerifiesChecksumWhenSidecarAvailable(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + const newBinary = "new-binary" + hash := sha256.Sum256([]byte(newBinary)) + checksum := hex.EncodeToString(hash[:]) + " " + assetName + "\n" + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://releases.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://releases.example.com/artifact"}, + {Name: assetName + ".sha256", URL: "https://releases.example.com/artifact.sha256"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://releases.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(newBinary)), + }, nil + case "https://releases.example.com/artifact.sha256": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(checksum)), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + LatestReleaseURL: "https://releases.example.com/latest", + BinaryName: "graylog-mcp", + ReplaceExecutable: func(downloadPath, targetPath string) error { + data, err := os.ReadFile(downloadPath) + if err != nil { + return err + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err != nil { + t.Fatalf("Run: %v", err) + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile target: %v", err) + } + if string(got) != newBinary { + t.Fatalf("target content = %q, want %q", string(got), newBinary) + } +} + +func TestRunFailsOnChecksumMismatch(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://releases.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://releases.example.com/artifact"}, + {Name: assetName + ".sha256", URL: "https://releases.example.com/artifact.sha256"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://releases.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("new-binary")), + }, nil + case "https://releases.example.com/artifact.sha256": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa " + assetName)), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + LatestReleaseURL: "https://releases.example.com/latest", + BinaryName: "graylog-mcp", + ReplaceExecutable: func(downloadPath, targetPath string) error { + data, readErr := os.ReadFile(downloadPath) + if readErr != nil { + return readErr + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Fatalf("error = %v", err) + } +} + +func TestRunFailsWhenChecksumRequiredAndMissing(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://releases.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://releases.example.com/artifact"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://releases.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("new-binary")), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + LatestReleaseURL: "https://releases.example.com/latest", + BinaryName: "graylog-mcp", + ReleaseSource: ReleaseSource{ + ChecksumRequired: true, + }, + ReplaceExecutable: func(downloadPath, targetPath string) error { + data, readErr := os.ReadFile(downloadPath) + if readErr != nil { + return readErr + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "checksum verification") { + t.Fatalf("error = %v", err) + } +} + +func TestRunInvokesValidationHook(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + validateCalled := false + replaceCalled := false + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://releases.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://releases.example.com/artifact"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://releases.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("new-binary")), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + LatestReleaseURL: "https://releases.example.com/latest", + BinaryName: "graylog-mcp", + ValidateDownloaded: func(_ context.Context, input ValidationInput) error { + validateCalled = true + if input.AssetName != assetName { + t.Fatalf("asset name = %q", input.AssetName) + } + data, readErr := os.ReadFile(input.DownloadPath) + if readErr != nil { + return readErr + } + if string(data) != "new-binary" { + t.Fatalf("downloaded content = %q", string(data)) + } + return nil + }, + ReplaceExecutable: func(downloadPath, targetPath string) error { + replaceCalled = true + data, readErr := os.ReadFile(downloadPath) + if readErr != nil { + return readErr + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if !validateCalled { + t.Fatal("validation hook was not called") + } + if !replaceCalled { + t.Fatal("replace hook was not called") + } +} + +func TestRunStopsWhenValidationHookFails(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://releases.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://releases.example.com/artifact"}, + } + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://releases.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("new-binary")), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + replaceCalled := false + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: target, + LatestReleaseURL: "https://releases.example.com/latest", + BinaryName: "graylog-mcp", + ValidateDownloaded: func(context.Context, ValidationInput) error { + return errors.New("signature invalid") + }, + ReplaceExecutable: func(downloadPath, targetPath string) error { + replaceCalled = true + data, readErr := os.ReadFile(downloadPath) + if readErr != nil { + return readErr + } + return os.WriteFile(targetPath, data, 0o755) + }, + }) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "validate downloaded artifact") { + t.Fatalf("error = %v", err) + } + if replaceCalled { + t.Fatal("replace hook should not have been called") + } +} + func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) { assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) if err != nil { @@ -364,7 +1025,7 @@ func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) { } } -func TestRunRequiresLatestReleaseURL(t *testing.T) { +func TestRunRequiresLatestReleaseURLOrDriver(t *testing.T) { target := filepath.Join(t.TempDir(), "graylog-mcp") if err := os.WriteFile(target, []byte("current-binary"), 0o755); err != nil { t.Fatalf("WriteFile target: %v", err)