feat(update): add forge drivers and checksum validation hooks

This commit is contained in:
thibaud-lclr 2026-04-14 14:11:43 +02:00
parent 42e1345962
commit bf8e1285d8
5 changed files with 1267 additions and 74 deletions

View file

@ -94,9 +94,14 @@ docs_url = "https://docs.example.com/my-mcp"
[update] [update]
source_name = "Gitea releases" source_name = "Gitea releases"
driver = "gitea"
repository = "org/repo"
base_url = "https://gitea.example.com" base_url = "https://gitea.example.com"
latest_release_url = "https://gitea.example.com/api/v1/repos/org/repo/releases/latest" asset_name_template = "{binary}-{os}-{arch}{ext}"
checksum_asset_name = "{asset}.sha256"
checksum_required = false
token_header = "Authorization" token_header = "Authorization"
token_prefix = "token"
token_env_names = ["GITEA_TOKEN"] token_env_names = ["GITEA_TOKEN"]
[environment] [environment]
@ -120,9 +125,15 @@ Champs supportés :
- `[update]` : source de release consommée par `update`. - `[update]` : source de release consommée par `update`.
- `source_name` : nom humain de la source de release, utilisé dans certains messages d'erreur. - `source_name` : nom humain de la source de release, utilisé dans certains messages d'erreur.
- `driver` : driver de forge (`gitea`, `gitlab`, `github`) pour déduire automatiquement l'endpoint latest.
- `repository` : dépôt cible (`org/repo` ou `group/subgroup/repo`).
- `base_url` : base de la forge ou du service de release. - `base_url` : base de la forge ou du service de release.
- `latest_release_url` : URL complète qui retourne la release la plus récente. - `latest_release_url` : URL complète qui retourne la release la plus récente (prioritaire sur le driver).
- `asset_name_template` : template de nom d'asset (`{binary}`, `{os}`, `{arch}`, `{ext}`).
- `checksum_asset_name` : nom d'asset checksum, avec placeholder optionnel `{asset}`.
- `checksum_required` : si `true`, l'update échoue quand l'asset checksum est absent.
- `token_header` : header HTTP à utiliser pour l'authentification. - `token_header` : header HTTP à utiliser pour l'authentification.
- `token_prefix` : préfixe appliqué devant le token (`Bearer`, `token`, ...).
- `token_env_names` : liste de variables d'environnement candidates pour retrouver le token. - `token_env_names` : liste de variables d'environnement candidates pour retrouver le token.
- `[environment].known` : variables d'environnement connues du projet. - `[environment].known` : variables d'environnement connues du projet.
- `[secret_store].backend_policy` : politique de secret store (`auto`, `kwallet-only`, `keyring-any`, `env-only`). - `[secret_store].backend_policy` : politique de secret store (`auto`, `kwallet-only`, `keyring-any`, `env-only`).
@ -462,9 +473,14 @@ if report.HasFailures() {
## Auto-Update ## Auto-Update
Le package `update` ne déduit pas la forge ni l'authentification. Le package `update` supporte les drivers `gitea`, `gitlab` et `github`.
L'application cliente fournit l'URL de release, le header d'auth éventuel et, Si `latest_release_url` est vide, l'URL latest est déduite depuis
si besoin, les variables d'environnement à consulter. `driver + repository (+ base_url)`.
Le parseur de release supporte :
- format `assets.links` (Gitea/GitLab)
- format `assets[]` avec `browser_download_url` (GitHub et Gitea API)
Le format attendu pour la réponse `latest release` est actuellement : Le format attendu pour la réponse `latest release` est actuellement :
@ -501,12 +517,12 @@ if err != nil {
} }
``` ```
Contraintes actuelles : Comportement :
- le `latest_release_url` doit être renseigné explicitement - le nom de l'asset est configurable (`asset_name_template`) et supporte tout couple `GOOS/GOARCH`
- les assets supportés sont `darwin/amd64`, `darwin/arm64`, `linux/amd64` et `windows/amd64` - si un asset `<asset>.sha256` (ou `checksum_asset_name`) existe, le binaire téléchargé est vérifié avant remplacement
- le remplacement du binaire n'est pas supporté sur Windows - un hook `ValidateDownloaded` permet d'ajouter une validation custom (signature, scan, etc.)
- le nom de l'asset est dérivé de `BinaryName`, `GOOS` et `GOARCH` - sur Windows, le remplacement in-place n'est pas fait par défaut ; fournir `Options.ReplaceExecutable` pour une stratégie dédiée
## Exemple Minimal ## Exemple Minimal

View file

@ -25,11 +25,17 @@ type File struct {
} }
type Update struct { type Update struct {
SourceName string `toml:"source_name"` SourceName string `toml:"source_name"`
BaseURL string `toml:"base_url"` Driver string `toml:"driver"`
LatestReleaseURL string `toml:"latest_release_url"` Repository string `toml:"repository"`
TokenHeader string `toml:"token_header"` BaseURL string `toml:"base_url"`
TokenEnvNames []string `toml:"token_env_names"` LatestReleaseURL string `toml:"latest_release_url"`
AssetNameTemplate string `toml:"asset_name_template"`
ChecksumAssetName string `toml:"checksum_asset_name"`
ChecksumRequired bool `toml:"checksum_required"`
TokenHeader string `toml:"token_header"`
TokenPrefix string `toml:"token_prefix"`
TokenEnvNames []string `toml:"token_env_names"`
} }
type Environment struct { type Environment struct {
@ -143,9 +149,14 @@ func (f *File) normalize() {
func (u *Update) normalize() { func (u *Update) normalize() {
u.SourceName = strings.TrimSpace(u.SourceName) u.SourceName = strings.TrimSpace(u.SourceName)
u.Driver = strings.ToLower(strings.TrimSpace(u.Driver))
u.Repository = strings.Trim(strings.TrimSpace(u.Repository), "/")
u.BaseURL = strings.TrimRight(strings.TrimSpace(u.BaseURL), "/") u.BaseURL = strings.TrimRight(strings.TrimSpace(u.BaseURL), "/")
u.LatestReleaseURL = strings.TrimSpace(u.LatestReleaseURL) u.LatestReleaseURL = strings.TrimSpace(u.LatestReleaseURL)
u.AssetNameTemplate = strings.TrimSpace(u.AssetNameTemplate)
u.ChecksumAssetName = strings.TrimSpace(u.ChecksumAssetName)
u.TokenHeader = strings.TrimSpace(u.TokenHeader) u.TokenHeader = strings.TrimSpace(u.TokenHeader)
u.TokenPrefix = strings.TrimSpace(u.TokenPrefix)
u.TokenEnvNames = normalizeStringList(u.TokenEnvNames) u.TokenEnvNames = normalizeStringList(u.TokenEnvNames)
} }
@ -170,11 +181,17 @@ func (u Update) ReleaseSource() update.ReleaseSource {
u.normalize() u.normalize()
return update.ReleaseSource{ return update.ReleaseSource{
Name: u.SourceName, Name: u.SourceName,
BaseURL: u.BaseURL, Driver: u.Driver,
LatestReleaseURL: u.LatestReleaseURL, Repository: u.Repository,
TokenHeader: u.TokenHeader, BaseURL: u.BaseURL,
TokenEnvNames: append([]string(nil), u.TokenEnvNames...), LatestReleaseURL: u.LatestReleaseURL,
AssetNameTemplate: u.AssetNameTemplate,
ChecksumAssetName: u.ChecksumAssetName,
ChecksumRequired: u.ChecksumRequired,
TokenHeader: u.TokenHeader,
TokenPrefix: u.TokenPrefix,
TokenEnvNames: append([]string(nil), u.TokenEnvNames...),
} }
} }

View file

@ -46,9 +46,15 @@ func TestLoadParsesUpdateConfig(t *testing.T) {
const content = ` const content = `
[update] [update]
source_name = " Gitea releases " source_name = " Gitea releases "
driver = " Gitea "
repository = " org/repo "
base_url = "https://gitea.example.com/" base_url = "https://gitea.example.com/"
latest_release_url = "https://gitea.example.com/api/releases/latest" latest_release_url = "https://gitea.example.com/api/releases/latest"
asset_name_template = "{binary}_{os}_{arch}{ext}"
checksum_asset_name = "{asset}.sha256"
checksum_required = true
token_header = " Authorization " token_header = " Authorization "
token_prefix = " token "
token_env_names = [" GITEA_TOKEN ", "", "GITEA_RELEASE_TOKEN"] token_env_names = [" GITEA_TOKEN ", "", "GITEA_RELEASE_TOKEN"]
` `
@ -65,15 +71,33 @@ token_env_names = [" GITEA_TOKEN ", "", "GITEA_RELEASE_TOKEN"]
if source.Name != "Gitea releases" { if source.Name != "Gitea releases" {
t.Fatalf("source name = %q", source.Name) t.Fatalf("source name = %q", source.Name)
} }
if source.Driver != "gitea" {
t.Fatalf("driver = %q", source.Driver)
}
if source.Repository != "org/repo" {
t.Fatalf("repository = %q", source.Repository)
}
if source.BaseURL != "https://gitea.example.com" { if source.BaseURL != "https://gitea.example.com" {
t.Fatalf("base URL = %q", source.BaseURL) t.Fatalf("base URL = %q", source.BaseURL)
} }
if source.LatestReleaseURL != "https://gitea.example.com/api/releases/latest" { if source.LatestReleaseURL != "https://gitea.example.com/api/releases/latest" {
t.Fatalf("latest release URL = %q", source.LatestReleaseURL) t.Fatalf("latest release URL = %q", source.LatestReleaseURL)
} }
if source.AssetNameTemplate != "{binary}_{os}_{arch}{ext}" {
t.Fatalf("asset name template = %q", source.AssetNameTemplate)
}
if source.ChecksumAssetName != "{asset}.sha256" {
t.Fatalf("checksum asset name = %q", source.ChecksumAssetName)
}
if !source.ChecksumRequired {
t.Fatal("checksum required should be true")
}
if source.TokenHeader != "Authorization" { if source.TokenHeader != "Authorization" {
t.Fatalf("token header = %q", source.TokenHeader) t.Fatalf("token header = %q", source.TokenHeader)
} }
if source.TokenPrefix != "token" {
t.Fatalf("token prefix = %q", source.TokenPrefix)
}
if len(source.TokenEnvNames) != 2 { if len(source.TokenEnvNames) != 2 {
t.Fatalf("token env names = %v", source.TokenEnvNames) t.Fatalf("token env names = %v", source.TokenEnvNames)
} }

View file

@ -2,6 +2,8 @@ package update
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -9,31 +11,57 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort"
"strings" "strings"
"time" "time"
) )
const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}"
type Options struct { type Options struct {
Client *http.Client Client *http.Client
CurrentVersion string CurrentVersion string
ExecutablePath string ExecutablePath string
LatestReleaseURL string LatestReleaseURL string
Stdout io.Writer Stdout io.Writer
BinaryName string BinaryName string
ReleaseSource ReleaseSource AssetNameTemplate string
GOOS string ReleaseSource ReleaseSource
GOARCH string GOOS string
GOARCH string
ValidateDownloaded ValidateDownloadedFunc
ReplaceExecutable ReplaceExecutableFunc
}
type ReplaceExecutableFunc func(downloadPath, targetPath string) error
type ValidateDownloadedFunc func(context.Context, ValidationInput) error
type ValidationInput struct {
DownloadPath string
TargetPath string
AssetName string
ReleaseTag string
ReleaseURL string
Source ReleaseSource
} }
type ReleaseSource struct { type ReleaseSource struct {
Name string Name string
BaseURL string Driver string
LatestReleaseURL string Repository string
Token string BaseURL string
TokenHeader string LatestReleaseURL string
TokenEnvNames []string AssetNameTemplate string
ChecksumAssetName string
ChecksumRequired bool
Token string
TokenHeader string
TokenPrefix string
TokenEnvNames []string
} }
type Auth struct { type Auth struct {
@ -53,6 +81,33 @@ type ReleaseLink struct {
URL string `json:"url"` URL string `json:"url"`
} }
type releasePayload struct {
TagName string `json:"tag_name"`
Assets json.RawMessage `json:"assets"`
}
type releaseAssetsPayload struct {
Links []releaseLinkPayload `json:"links"`
}
type releaseLinkPayload struct {
Name string `json:"name"`
URL string `json:"url"`
BrowserDownloadURL string `json:"browser_download_url"`
DirectAssetURL string `json:"direct_asset_url"`
}
func (r *Release) UnmarshalJSON(data []byte) error {
var payload releasePayload
if err := json.Unmarshal(data, &payload); err != nil {
return err
}
r.TagName = strings.TrimSpace(payload.TagName)
r.Assets.Links = parseReleaseLinks(payload.Assets)
return nil
}
func Run(ctx context.Context, opts Options) error { func Run(ctx context.Context, opts Options) error {
if opts.Stdout == nil { if opts.Stdout == nil {
opts.Stdout = io.Discard opts.Stdout = io.Discard
@ -73,12 +128,9 @@ func Run(ctx context.Context, opts Options) error {
source := normalizeSource(opts.ReleaseSource) source := normalizeSource(opts.ReleaseSource)
auth := ResolveAuth(source.Token, source) auth := ResolveAuth(source.Token, source)
releaseURL := opts.LatestReleaseURL releaseURL, err := ResolveLatestReleaseURL(opts.LatestReleaseURL, source)
if strings.TrimSpace(releaseURL) == "" { if err != nil {
releaseURL = strings.TrimSpace(source.LatestReleaseURL) return err
}
if releaseURL == "" {
return errors.New("latest release URL must not be empty")
} }
targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) targetPath, err := ResolveUpdateTarget(opts.ExecutablePath)
@ -86,7 +138,12 @@ func Run(ctx context.Context, opts Options) error {
return err return err
} }
assetName, err := AssetName(opts.BinaryName, opts.GOOS, opts.GOARCH) assetTemplate := strings.TrimSpace(opts.AssetNameTemplate)
if assetTemplate == "" {
assetTemplate = source.AssetNameTemplate
}
assetName, err := AssetNameWithTemplate(opts.BinaryName, opts.GOOS, opts.GOARCH, assetTemplate)
if err != nil { if err != nil {
return err return err
} }
@ -111,7 +168,28 @@ func Run(ctx context.Context, opts Options) error {
} }
defer os.Remove(downloadPath) defer os.Remove(downloadPath)
if err := ReplaceExecutable(downloadPath, targetPath); err != nil { if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil {
return err
}
if opts.ValidateDownloaded != nil {
if err := opts.ValidateDownloaded(ctx, ValidationInput{
DownloadPath: downloadPath,
TargetPath: targetPath,
AssetName: assetName,
ReleaseTag: release.TagName,
ReleaseURL: releaseURL,
Source: source,
}); err != nil {
return fmt.Errorf("validate downloaded artifact: %w", err)
}
}
replaceExecutable := opts.ReplaceExecutable
if replaceExecutable == nil {
replaceExecutable = ReplaceExecutable
}
if err := replaceExecutable(downloadPath, targetPath); err != nil {
return err return err
} }
@ -119,16 +197,55 @@ func Run(ctx context.Context, opts Options) error {
return nil return nil
} }
func ResolveLatestReleaseURL(explicit string, source ReleaseSource) (string, error) {
if releaseURL := strings.TrimSpace(explicit); releaseURL != "" {
return releaseURL, nil
}
source = normalizeSource(source)
if source.LatestReleaseURL != "" {
return source.LatestReleaseURL, nil
}
if source.Driver == "" {
return "", errors.New("latest release URL must not be empty (set latest_release_url or configure driver+repository)")
}
if source.Repository == "" {
return "", fmt.Errorf("release source %q requires repository when driver is set", source.Driver)
}
switch source.Driver {
case "gitea":
if source.BaseURL == "" {
return "", errors.New("release source gitea requires base_url")
}
return fmt.Sprintf("%s/api/v1/repos/%s/releases/latest", source.BaseURL, source.Repository), nil
case "gitlab":
projectPath := url.PathEscape(source.Repository)
return fmt.Sprintf("%s/api/v4/projects/%s/releases/permalink/latest", source.BaseURL, projectPath), nil
case "github":
return fmt.Sprintf("%s/repos/%s/releases/latest", source.BaseURL, source.Repository), nil
default:
return "", fmt.Errorf("unsupported release driver %q (expected gitea, gitlab or github)", source.Driver)
}
}
func ResolveAuth(explicitToken string, source ReleaseSource) Auth { func ResolveAuth(explicitToken string, source ReleaseSource) Auth {
source = normalizeSource(source) source = normalizeSource(source)
if token := strings.TrimSpace(explicitToken); token != "" { if token := strings.TrimSpace(explicitToken); token != "" {
return Auth{Header: source.TokenHeader, Token: token} return Auth{
Header: source.TokenHeader,
Token: withTokenPrefix(token, source.TokenPrefix),
}
} }
for _, envName := range source.TokenEnvNames { for _, envName := range source.TokenEnvNames {
if token := strings.TrimSpace(os.Getenv(envName)); token != "" { if token := strings.TrimSpace(os.Getenv(envName)); token != "" {
return Auth{Header: source.TokenHeader, Token: token} return Auth{
Header: source.TokenHeader,
Token: withTokenPrefix(token, source.TokenPrefix),
}
} }
} }
@ -153,23 +270,46 @@ func ResolveUpdateTarget(explicitPath string) (string, error) {
} }
func AssetName(binaryName, goos, goarch string) (string, error) { func AssetName(binaryName, goos, goarch string) (string, error) {
return AssetNameWithTemplate(binaryName, goos, goarch, defaultAssetNameTemplate)
}
func AssetNameWithTemplate(binaryName, goos, goarch, template string) (string, error) {
name := strings.TrimSpace(binaryName) name := strings.TrimSpace(binaryName)
if name == "" { if name == "" {
return "", errors.New("binary name must not be empty") return "", errors.New("binary name must not be empty")
} }
switch { osName := strings.ToLower(strings.TrimSpace(goos))
case goos == "darwin" && goarch == "amd64": archName := strings.ToLower(strings.TrimSpace(goarch))
return name + "-darwin-amd64", nil if osName == "" || archName == "" {
case goos == "darwin" && goarch == "arm64": return "", errors.New("goos and goarch must not be empty")
return name + "-darwin-arm64", nil
case goos == "linux" && goarch == "amd64":
return name + "-linux-amd64", nil
case goos == "windows" && goarch == "amd64":
return name + "-windows-amd64.exe", nil
default:
return "", fmt.Errorf("no release artifact for %s/%s", goos, goarch)
} }
assetTemplate := strings.TrimSpace(template)
if assetTemplate == "" {
assetTemplate = defaultAssetNameTemplate
}
ext := ""
if osName == "windows" {
ext = ".exe"
}
replaced := strings.NewReplacer(
"{binary}", name,
"{os}", osName,
"{arch}", archName,
"{ext}", ext,
).Replace(assetTemplate)
replaced = strings.TrimSpace(replaced)
if replaced == "" {
return "", errors.New("asset name template resolved to an empty value")
}
if strings.ContainsRune(replaced, '/') || strings.ContainsRune(replaced, '\\') {
return "", fmt.Errorf("asset name %q must not contain path separators", replaced)
}
return replaced, nil
} }
func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source ReleaseSource) (Release, error) { func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source ReleaseSource) (Release, error) {
@ -232,7 +372,26 @@ func (r Release) AssetURL(assetName, releaseURL string) (string, error) {
} }
} }
return "", fmt.Errorf("latest release does not contain asset %q", assetName) availableAssets := make([]string, 0, len(r.Assets.Links))
for _, link := range r.Assets.Links {
if name := strings.TrimSpace(link.Name); name != "" {
availableAssets = append(availableAssets, name)
}
}
sort.Strings(availableAssets)
if len(availableAssets) == 0 {
return "", fmt.Errorf("latest release does not contain asset %q", assetName)
}
preview := availableAssets
if len(preview) > 8 {
preview = preview[:8]
}
return "", fmt.Errorf(
"latest release does not contain asset %q (available: %s)",
assetName,
strings.Join(preview, ", "),
)
} }
func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource) (string, error) { func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource) (string, error) {
@ -292,9 +451,57 @@ func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, ta
return tempPath, nil return tempPath, nil
} }
func VerifyReleaseAssetChecksum(
ctx context.Context,
client *http.Client,
release Release,
releaseURL string,
assetName string,
artifactPath string,
auth Auth,
source ReleaseSource,
) error {
source = normalizeSource(source)
checksumAssetName := resolveChecksumAssetName(assetName, source.ChecksumAssetName)
checksumURL, err := release.AssetURL(checksumAssetName, releaseURL)
if err != nil {
if source.ChecksumRequired {
return fmt.Errorf("checksum verification: %w", err)
}
return nil
}
checksumBody, err := downloadAssetBytes(ctx, client, checksumURL, auth, source)
if err != nil {
return fmt.Errorf("checksum verification: %w", err)
}
expected, err := parseChecksum(string(checksumBody), assetName)
if err != nil {
return fmt.Errorf("checksum verification: %w", err)
}
actual, err := fileSHA256(artifactPath)
if err != nil {
return fmt.Errorf("checksum verification: %w", err)
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf(
"checksum mismatch for asset %q: expected %s, got %s",
assetName,
expected,
actual,
)
}
return nil
}
func ReplaceExecutable(downloadPath, targetPath string) error { func ReplaceExecutable(downloadPath, targetPath string) error {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return errors.New("self-update is not supported on windows") return errors.New("self-update is not supported on windows without a custom ReplaceExecutable hook")
} }
if err := os.Rename(downloadPath, targetPath); err != nil { if err := os.Rename(downloadPath, targetPath); err != nil {
return fmt.Errorf("replace executable %q: %w", targetPath, err) return fmt.Errorf("replace executable %q: %w", targetPath, err)
@ -304,9 +511,69 @@ func ReplaceExecutable(downloadPath, targetPath string) error {
func normalizeSource(source ReleaseSource) ReleaseSource { func normalizeSource(source ReleaseSource) ReleaseSource {
source.Name = strings.TrimSpace(source.Name) source.Name = strings.TrimSpace(source.Name)
source.Driver = strings.ToLower(strings.TrimSpace(source.Driver))
source.Repository = strings.Trim(strings.TrimSpace(source.Repository), "/")
source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/")
source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL) source.LatestReleaseURL = strings.TrimSpace(source.LatestReleaseURL)
source.AssetNameTemplate = strings.TrimSpace(source.AssetNameTemplate)
source.ChecksumAssetName = strings.TrimSpace(source.ChecksumAssetName)
source.Token = strings.TrimSpace(source.Token)
source.TokenHeader = strings.TrimSpace(source.TokenHeader) source.TokenHeader = strings.TrimSpace(source.TokenHeader)
source.TokenPrefix = strings.TrimSpace(source.TokenPrefix)
envNames := source.TokenEnvNames[:0]
for _, envName := range source.TokenEnvNames {
if trimmed := strings.TrimSpace(envName); trimmed != "" {
envNames = append(envNames, trimmed)
}
}
source.TokenEnvNames = envNames
switch source.Driver {
case "gitea":
if source.Name == "" {
source.Name = "Gitea releases"
}
if source.TokenHeader == "" {
source.TokenHeader = "Authorization"
}
if source.TokenPrefix == "" {
source.TokenPrefix = "token "
}
if len(source.TokenEnvNames) == 0 {
source.TokenEnvNames = []string{"GITEA_TOKEN"}
}
case "gitlab":
if source.Name == "" {
source.Name = "GitLab releases"
}
if source.BaseURL == "" {
source.BaseURL = "https://gitlab.com"
}
if source.TokenHeader == "" {
source.TokenHeader = "PRIVATE-TOKEN"
}
if len(source.TokenEnvNames) == 0 {
source.TokenEnvNames = []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}
}
case "github":
if source.Name == "" {
source.Name = "GitHub releases"
}
if source.BaseURL == "" {
source.BaseURL = "https://api.github.com"
}
if source.TokenHeader == "" {
source.TokenHeader = "Authorization"
}
if source.TokenPrefix == "" {
source.TokenPrefix = "Bearer "
}
if len(source.TokenEnvNames) == 0 {
source.TokenEnvNames = []string{"GITHUB_TOKEN"}
}
}
return source return source
} }
@ -375,3 +642,211 @@ func (a Auth) maybeHint(statusCode int, body []byte, source ReleaseSource) error
source.TokenEnvNames[1], source.TokenEnvNames[1],
) )
} }
func parseReleaseLinks(raw json.RawMessage) []ReleaseLink {
if len(raw) == 0 || strings.TrimSpace(string(raw)) == "null" {
return nil
}
parseLinks := func(payload []releaseLinkPayload) []ReleaseLink {
links := make([]ReleaseLink, 0, len(payload))
for _, item := range payload {
name := strings.TrimSpace(item.Name)
assetURL := firstNonEmpty(item.DirectAssetURL, item.BrowserDownloadURL, item.URL)
if name == "" || strings.TrimSpace(assetURL) == "" {
continue
}
links = append(links, ReleaseLink{
Name: name,
URL: strings.TrimSpace(assetURL),
})
}
return links
}
var asObject releaseAssetsPayload
if err := json.Unmarshal(raw, &asObject); err == nil && len(asObject.Links) > 0 {
return parseLinks(asObject.Links)
}
var asArray []releaseLinkPayload
if err := json.Unmarshal(raw, &asArray); err == nil && len(asArray) > 0 {
return parseLinks(asArray)
}
return nil
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func withTokenPrefix(token, prefix string) string {
trimmedToken := strings.TrimSpace(token)
if trimmedToken == "" {
return ""
}
trimmedPrefix := strings.TrimSpace(prefix)
if trimmedPrefix == "" {
return trimmedToken
}
lowerToken := strings.ToLower(trimmedToken)
lowerPrefix := strings.ToLower(trimmedPrefix)
if strings.HasPrefix(lowerToken, lowerPrefix) {
return trimmedToken
}
return trimmedPrefix + " " + trimmedToken
}
func resolveChecksumAssetName(assetName, configured string) string {
value := strings.TrimSpace(configured)
if value == "" {
return assetName + ".sha256"
}
return strings.ReplaceAll(value, "{asset}", assetName)
}
func downloadAssetBytes(ctx context.Context, client *http.Client, assetURL string, auth Auth, source ReleaseSource) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil)
if err != nil {
return nil, fmt.Errorf("build checksum download request: %w", err)
}
req.Header.Set("User-Agent", "mcp updater")
auth.apply(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("download checksum asset: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
if hint := auth.maybeHint(resp.StatusCode, body, source); hint != nil {
return nil, fmt.Errorf("download checksum asset: %w", hint)
}
return nil, fmt.Errorf(
"download checksum asset: unexpected status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
content, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024))
if err != nil {
return nil, fmt.Errorf("read checksum asset: %w", err)
}
return content, nil
}
func parseChecksum(content, assetName string) (string, error) {
lines := strings.Split(content, "\n")
fallbackSingle := ""
for _, raw := range lines {
line := strings.TrimSpace(raw)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(strings.ToUpper(line), "SHA256 (") {
openIndex := strings.Index(line, "(")
closeIndex := strings.LastIndex(line, ")")
equalIndex := strings.LastIndex(line, "=")
if openIndex >= 0 && closeIndex > openIndex && equalIndex > closeIndex {
name := strings.TrimSpace(line[openIndex+1 : closeIndex])
hash := strings.TrimSpace(line[equalIndex+1:])
if isSHA256Hex(hash) && matchesAssetName(name, assetName) {
return strings.ToLower(hash), nil
}
}
continue
}
fields := strings.Fields(line)
if len(fields) > 0 && isSHA256Hex(fields[0]) {
if len(fields) == 1 {
if fallbackSingle == "" {
fallbackSingle = strings.ToLower(fields[0])
}
continue
}
name := strings.TrimSpace(strings.TrimPrefix(fields[1], "*"))
if matchesAssetName(name, assetName) {
return strings.ToLower(fields[0]), nil
}
continue
}
colonIndex := strings.Index(line, ":")
if colonIndex > 0 && colonIndex < len(line)-1 {
left := strings.TrimSpace(line[:colonIndex])
right := strings.TrimSpace(line[colonIndex+1:])
switch {
case isSHA256Hex(left) && matchesAssetName(right, assetName):
return strings.ToLower(left), nil
case isSHA256Hex(right) && matchesAssetName(left, assetName):
return strings.ToLower(right), nil
}
}
}
if fallbackSingle != "" {
return fallbackSingle, nil
}
return "", fmt.Errorf("checksum file does not contain a sha256 for asset %q", assetName)
}
func matchesAssetName(candidate, assetName string) bool {
name := strings.TrimSpace(strings.TrimPrefix(candidate, "*"))
name = strings.TrimPrefix(name, "./")
if name == assetName {
return true
}
name = strings.ReplaceAll(name, "\\", "/")
if path.Base(name) == assetName {
return true
}
return false
}
func isSHA256Hex(value string) bool {
if len(value) != 64 {
return false
}
for _, r := range value {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'f':
case r >= 'A' && r <= 'F':
default:
return false
}
}
return true
}
func fileSHA256(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", fmt.Errorf("open downloaded artifact: %w", err)
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", fmt.Errorf("hash downloaded artifact: %w", err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}

View file

@ -3,7 +3,10 @@ package update
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -24,13 +27,20 @@ func TestAssetName(t *testing.T) {
{name: "darwin amd64", goos: "darwin", goarch: "amd64", want: "graylog-mcp-darwin-amd64"}, {name: "darwin amd64", goos: "darwin", goarch: "amd64", want: "graylog-mcp-darwin-amd64"},
{name: "darwin arm64", goos: "darwin", goarch: "arm64", want: "graylog-mcp-darwin-arm64"}, {name: "darwin arm64", goos: "darwin", goarch: "arm64", want: "graylog-mcp-darwin-arm64"},
{name: "linux amd64", goos: "linux", goarch: "amd64", want: "graylog-mcp-linux-amd64"}, {name: "linux amd64", goos: "linux", goarch: "amd64", want: "graylog-mcp-linux-amd64"},
{name: "windows amd64", goos: "windows", goarch: "amd64", want: "graylog-mcp-windows-amd64.exe"}, {name: "linux arm64", goos: "linux", goarch: "arm64", want: "graylog-mcp-linux-arm64"},
{name: "unsupported", goos: "linux", goarch: "arm64", wantErr: "no release artifact"}, {name: "windows arm64", goos: "windows", goarch: "arm64", want: "graylog-mcp-windows-arm64.exe"},
{name: "missing binary", goos: "linux", goarch: "amd64", wantErr: "binary name must not be empty"},
{name: "missing platform", goos: "", goarch: "amd64", wantErr: "goos and goarch must not be empty"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := AssetName("graylog-mcp", tt.goos, tt.goarch) binaryName := "graylog-mcp"
if tt.name == "missing binary" {
binaryName = " "
}
got, err := AssetName(binaryName, tt.goos, tt.goarch)
if tt.wantErr != "" { if tt.wantErr != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErr) { if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error = %v, want substring %q", err, tt.wantErr) t.Fatalf("error = %v, want substring %q", err, tt.wantErr)
@ -47,6 +57,91 @@ func TestAssetName(t *testing.T) {
} }
} }
func TestAssetNameWithTemplate(t *testing.T) {
got, err := AssetNameWithTemplate("graylog-mcp", "linux", "amd64", "{binary}_{os}_{arch}")
if err != nil {
t.Fatalf("AssetNameWithTemplate: %v", err)
}
if got != "graylog-mcp_linux_amd64" {
t.Fatalf("got %q", got)
}
_, err = AssetNameWithTemplate("graylog-mcp", "linux", "amd64", "{binary}/{os}/{arch}")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "must not contain path separators") {
t.Fatalf("error = %v", err)
}
}
func TestResolveLatestReleaseURL(t *testing.T) {
got, err := ResolveLatestReleaseURL("https://custom/latest", ReleaseSource{
Driver: "gitea",
Repository: "org/repo",
BaseURL: "https://gitea.example.com",
})
if err != nil {
t.Fatalf("ResolveLatestReleaseURL explicit: %v", err)
}
if got != "https://custom/latest" {
t.Fatalf("release url = %q", got)
}
got, err = ResolveLatestReleaseURL("", ReleaseSource{
LatestReleaseURL: "https://manifest/latest",
})
if err != nil {
t.Fatalf("ResolveLatestReleaseURL from source: %v", err)
}
if got != "https://manifest/latest" {
t.Fatalf("release url = %q", got)
}
got, err = ResolveLatestReleaseURL("", ReleaseSource{
Driver: "gitea",
Repository: "org/repo",
BaseURL: "https://gitea.example.com",
})
if err != nil {
t.Fatalf("ResolveLatestReleaseURL gitea: %v", err)
}
if got != "https://gitea.example.com/api/v1/repos/org/repo/releases/latest" {
t.Fatalf("release url = %q", got)
}
got, err = ResolveLatestReleaseURL("", ReleaseSource{
Driver: "gitlab",
Repository: "group/sub/repo",
BaseURL: "https://gitlab.example.com",
})
if err != nil {
t.Fatalf("ResolveLatestReleaseURL gitlab: %v", err)
}
if got != "https://gitlab.example.com/api/v4/projects/group%2Fsub%2Frepo/releases/permalink/latest" {
t.Fatalf("release url = %q", got)
}
got, err = ResolveLatestReleaseURL("", ReleaseSource{
Driver: "github",
Repository: "org/repo",
})
if err != nil {
t.Fatalf("ResolveLatestReleaseURL github: %v", err)
}
if got != "https://api.github.com/repos/org/repo/releases/latest" {
t.Fatalf("release url = %q", got)
}
_, err = ResolveLatestReleaseURL("", ReleaseSource{Driver: "gitea"})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "requires repository") {
t.Fatalf("error = %v", err)
}
}
func TestResolveUpdateTargetFollowsSymlink(t *testing.T) { func TestResolveUpdateTargetFollowsSymlink(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("symlink behavior differs on windows") t.Skip("symlink behavior differs on windows")
@ -87,20 +182,37 @@ func TestReleaseAssetURLResolvesRelativeLinks(t *testing.T) {
} }
} }
func TestReleaseAssetURLErrorIncludesAvailableAssets(t *testing.T) {
release := Release{}
release.Assets.Links = []ReleaseLink{
{Name: "my-mcp-linux-amd64", URL: "/downloads/my-mcp-linux-amd64"},
{Name: "my-mcp-darwin-arm64", URL: "/downloads/my-mcp-darwin-arm64"},
}
_, err := release.AssetURL("my-mcp-linux-arm64", "https://releases.example.com/latest")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "available: my-mcp-darwin-arm64, my-mcp-linux-amd64") {
t.Fatalf("error = %v", err)
}
}
func TestResolveAuthPrefersExplicitToken(t *testing.T) { func TestResolveAuthPrefersExplicitToken(t *testing.T) {
t.Setenv("RELEASE_TOKEN", "env-token") t.Setenv("RELEASE_TOKEN", "env-token")
auth := ResolveAuth("explicit-token", ReleaseSource{ auth := ResolveAuth("explicit-token", ReleaseSource{
Name: "release endpoint", TokenHeader: "Authorization",
BaseURL: "https://releases.example.com", TokenPrefix: "Bearer",
TokenHeader: "X-Release-Token", TokenEnvNames: []string{
TokenEnvNames: []string{"RELEASE_TOKEN", "RELEASE_PRIVATE_TOKEN"}, "RELEASE_TOKEN",
},
}) })
if auth.Header != "X-Release-Token" { if auth.Header != "Authorization" {
t.Fatalf("header = %q, want X-Release-Token", auth.Header) t.Fatalf("header = %q, want Authorization", auth.Header)
} }
if auth.Token != "explicit-token" { if auth.Token != "Bearer explicit-token" {
t.Fatalf("token = %q, want explicit token", auth.Token) t.Fatalf("token = %q, want prefixed explicit token", auth.Token)
} }
} }
@ -108,10 +220,11 @@ func TestResolveAuthReadsEnvironment(t *testing.T) {
t.Setenv("RELEASE_PRIVATE_TOKEN", "env-token") t.Setenv("RELEASE_PRIVATE_TOKEN", "env-token")
auth := ResolveAuth("", ReleaseSource{ auth := ResolveAuth("", ReleaseSource{
Name: "release endpoint", TokenHeader: "X-Release-Token",
BaseURL: "https://releases.example.com", TokenEnvNames: []string{
TokenHeader: "X-Release-Token", "RELEASE_TOKEN",
TokenEnvNames: []string{"RELEASE_TOKEN", "RELEASE_PRIVATE_TOKEN"}, "RELEASE_PRIVATE_TOKEN",
},
}) })
if auth.Header != "X-Release-Token" { if auth.Header != "X-Release-Token" {
t.Fatalf("header = %q, want X-Release-Token", auth.Header) t.Fatalf("header = %q, want X-Release-Token", auth.Header)
@ -155,6 +268,86 @@ func TestFetchLatestReleaseAddsConfiguredAuthHeader(t *testing.T) {
} }
} }
func TestFetchLatestReleaseSupportsGitHubAssets(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
body := `{
"tag_name":"v1.2.3",
"assets":[
{"name":"my-mcp-linux-amd64","browser_download_url":"https://example.com/my-mcp-linux-amd64"}
]
}`
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
}),
}
release, err := FetchLatestRelease(
context.Background(),
client,
"https://api.github.com/repos/org/repo/releases/latest",
Auth{},
ReleaseSource{Name: "GitHub releases"},
)
if err != nil {
t.Fatalf("FetchLatestRelease: %v", err)
}
assetURL, err := release.AssetURL("my-mcp-linux-amd64", "https://api.github.com/repos/org/repo/releases/latest")
if err != nil {
t.Fatalf("AssetURL: %v", err)
}
if assetURL != "https://example.com/my-mcp-linux-amd64" {
t.Fatalf("assetURL = %q", assetURL)
}
}
func TestFetchLatestReleaseSupportsGitLabDirectAssetURL(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
body := `{
"tag_name":"v1.2.3",
"assets":{
"links":[
{
"name":"my-mcp-linux-amd64",
"url":"https://gitlab.example.com/fallback",
"direct_asset_url":"https://gitlab.example.com/direct/my-mcp-linux-amd64"
}
]
}
}`
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
}),
}
release, err := FetchLatestRelease(
context.Background(),
client,
"https://gitlab.example.com/api/v4/projects/org%2Frepo/releases/permalink/latest",
Auth{},
ReleaseSource{Name: "GitLab releases"},
)
if err != nil {
t.Fatalf("FetchLatestRelease: %v", err)
}
assetURL, err := release.AssetURL("my-mcp-linux-amd64", "https://gitlab.example.com/api/v4/projects/org%2Frepo/releases/permalink/latest")
if err != nil {
t.Fatalf("AssetURL: %v", err)
}
if assetURL != "https://gitlab.example.com/direct/my-mcp-linux-amd64" {
t.Fatalf("assetURL = %q", assetURL)
}
}
func TestFetchLatestReleaseHintsWhenAuthIsMissing(t *testing.T) { func TestFetchLatestReleaseHintsWhenAuthIsMissing(t *testing.T) {
client := &http.Client{ client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
@ -282,6 +475,474 @@ func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
} }
} }
func TestRunUsesDriverWithoutExplicitLatestReleaseURL(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
latestURL := "https://gitea.example.com/api/v1/repos/org/graylog-mcp/releases/latest"
replaceCalled := false
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case latestURL:
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://gitea.example.com/downloads/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://gitea.example.com/downloads/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("new-binary")),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
BinaryName: "graylog-mcp",
ReleaseSource: ReleaseSource{
Driver: "gitea",
Repository: "org/graylog-mcp",
BaseURL: "https://gitea.example.com",
},
ReplaceExecutable: func(downloadPath, targetPath string) error {
replaceCalled = true
data, err := os.ReadFile(downloadPath)
if err != nil {
return err
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err != nil {
t.Fatalf("Run: %v", err)
}
if !replaceCalled {
t.Fatal("custom replace hook was not called")
}
}
func TestRunVerifiesChecksumWhenSidecarAvailable(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
const newBinary = "new-binary"
hash := sha256.Sum256([]byte(newBinary))
checksum := hex.EncodeToString(hash[:]) + " " + assetName + "\n"
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
{Name: assetName + ".sha256", URL: "https://releases.example.com/artifact.sha256"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(newBinary)),
}, nil
case "https://releases.example.com/artifact.sha256":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(checksum)),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ReplaceExecutable: func(downloadPath, targetPath string) error {
data, err := os.ReadFile(downloadPath)
if err != nil {
return err
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err != nil {
t.Fatalf("Run: %v", err)
}
got, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile target: %v", err)
}
if string(got) != newBinary {
t.Fatalf("target content = %q, want %q", string(got), newBinary)
}
}
func TestRunFailsOnChecksumMismatch(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
{Name: assetName + ".sha256", URL: "https://releases.example.com/artifact.sha256"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("new-binary")),
}, nil
case "https://releases.example.com/artifact.sha256":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa " + assetName)),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ReplaceExecutable: func(downloadPath, targetPath string) error {
data, readErr := os.ReadFile(downloadPath)
if readErr != nil {
return readErr
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "checksum mismatch") {
t.Fatalf("error = %v", err)
}
}
func TestRunFailsWhenChecksumRequiredAndMissing(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("new-binary")),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ReleaseSource: ReleaseSource{
ChecksumRequired: true,
},
ReplaceExecutable: func(downloadPath, targetPath string) error {
data, readErr := os.ReadFile(downloadPath)
if readErr != nil {
return readErr
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "checksum verification") {
t.Fatalf("error = %v", err)
}
}
func TestRunInvokesValidationHook(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
validateCalled := false
replaceCalled := false
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("new-binary")),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ValidateDownloaded: func(_ context.Context, input ValidationInput) error {
validateCalled = true
if input.AssetName != assetName {
t.Fatalf("asset name = %q", input.AssetName)
}
data, readErr := os.ReadFile(input.DownloadPath)
if readErr != nil {
return readErr
}
if string(data) != "new-binary" {
t.Fatalf("downloaded content = %q", string(data))
}
return nil
},
ReplaceExecutable: func(downloadPath, targetPath string) error {
replaceCalled = true
data, readErr := os.ReadFile(downloadPath)
if readErr != nil {
return readErr
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err != nil {
t.Fatalf("Run: %v", err)
}
if !validateCalled {
t.Fatal("validation hook was not called")
}
if !replaceCalled {
t.Fatal("replace hook was not called")
}
}
func TestRunStopsWhenValidationHookFails(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("new-binary")),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
replaceCalled := false
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ValidateDownloaded: func(context.Context, ValidationInput) error {
return errors.New("signature invalid")
},
ReplaceExecutable: func(downloadPath, targetPath string) error {
replaceCalled = true
data, readErr := os.ReadFile(downloadPath)
if readErr != nil {
return readErr
}
return os.WriteFile(targetPath, data, 0o755)
},
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "validate downloaded artifact") {
t.Fatalf("error = %v", err)
}
if replaceCalled {
t.Fatal("replace hook should not have been called")
}
}
func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) { func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil { if err != nil {
@ -364,7 +1025,7 @@ func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) {
} }
} }
func TestRunRequiresLatestReleaseURL(t *testing.T) { func TestRunRequiresLatestReleaseURLOrDriver(t *testing.T) {
target := filepath.Join(t.TempDir(), "graylog-mcp") target := filepath.Join(t.TempDir(), "graylog-mcp")
if err := os.WriteFile(target, []byte("current-binary"), 0o755); err != nil { if err := os.WriteFile(target, []byte("current-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err) t.Fatalf("WriteFile target: %v", err)