From 5642581b9b6566a6b8253ba2a89080fafff1e35e Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Mon, 13 Apr 2026 15:33:48 +0200 Subject: [PATCH] feat: add reusable mcp framework --- .gitignore | 3 + README.md | 15 ++ cli/cli.go | 101 ++++++++++++ config/config.go | 161 ++++++++++++++++++ config/config_test.go | 78 +++++++++ go.mod | 18 ++ go.sum | 40 +++++ secretstore/store.go | 90 ++++++++++ update/update.go | 371 ++++++++++++++++++++++++++++++++++++++++++ update/update_test.go | 363 +++++++++++++++++++++++++++++++++++++++++ 10 files changed, 1240 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 cli/cli.go create mode 100644 config/config.go create mode 100644 config/config_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 secretstore/store.go create mode 100644 update/update.go create mode 100644 update/update_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c38c6f --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/bin/ +/build/ +.coverprofile diff --git a/README.md b/README.md new file mode 100644 index 0000000..6a0e7eb --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# mcp-framework + +Bibliotheque Go pour construire des binaires MCP avec : + +- resolution de profils CLI +- stockage JSON de configuration dans `os.UserConfigDir()` +- stockage de secrets dans le wallet natif selon l'OS +- pipeline d'auto-update via GitLab Releases + +Packages exposes : + +- `cli` +- `config` +- `secretstore` +- `update` diff --git a/cli/cli.go b/cli/cli.go new file mode 100644 index 0000000..bfae943 --- /dev/null +++ b/cli/cli.go @@ -0,0 +1,101 @@ +package cli + +import ( + "bufio" + "fmt" + "io" + "net/url" + "os" + "strings" + + "golang.org/x/term" +) + +type Candidate struct { + Value string + Source string +} + +func ResolveProfileName(flagProfile, envProfile, currentProfile string) string { + for _, candidate := range []string{flagProfile, envProfile, currentProfile} { + if value := strings.TrimSpace(candidate); value != "" { + return value + } + } + + return "default" +} + +func FirstNonEmpty(candidates ...Candidate) (string, string) { + for _, candidate := range candidates { + if value := strings.TrimSpace(candidate.Value); value != "" { + return value, candidate.Source + } + } + + return "", "" +} + +func ValidateBaseURL(raw string) error { + parsed, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("parse URL: %w", err) + } + if parsed.Scheme == "" || parsed.Host == "" { + return fmt.Errorf("must include scheme and host") + } + return nil +} + +func PromptLine(reader *bufio.Reader, w io.Writer, label, defaultValue string) (string, error) { + if defaultValue != "" { + fmt.Fprintf(w, "%s [%s]: ", label, defaultValue) + } else { + fmt.Fprintf(w, "%s: ", label) + } + + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return "", err + } + + line = strings.TrimSpace(line) + if line == "" { + return defaultValue, nil + } + return line, nil +} + +func PromptSecret(stdin *os.File, w io.Writer, label string, hasStoredSecret bool, storedSecret string) (string, error) { + if hasStoredSecret { + fmt.Fprintf(w, "%s [stored, leave blank to keep]: ", label) + } else { + fmt.Fprintf(w, "%s: ", label) + } + + if term.IsTerminal(int(stdin.Fd())) { + secret, err := term.ReadPassword(int(stdin.Fd())) + fmt.Fprintln(w) + if err != nil { + return "", err + } + + value := strings.TrimSpace(string(secret)) + if value == "" && hasStoredSecret { + return storedSecret, nil + } + return value, nil + } + + reader := bufio.NewReader(stdin) + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return "", err + } + + value := strings.TrimSpace(line) + if value == "" && hasStoredSecret { + return storedSecret, nil + } + return value, nil +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..2603d5e --- /dev/null +++ b/config/config.go @@ -0,0 +1,161 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" +) + +const ( + CurrentVersion = 1 + DefaultFile = "config.json" +) + +type FileConfig[T any] struct { + Version int `json:"version"` + CurrentProfile string `json:"current_profile"` + Profiles map[string]T `json:"profiles"` +} + +type Store[T any] struct { + dirName string + fileName string +} + +func NewStore[T any](dirName string) Store[T] { + return NewStoreWithFile[T](dirName, DefaultFile) +} + +func NewStoreWithFile[T any](dirName, fileName string) Store[T] { + return Store[T]{ + dirName: dirName, + fileName: fileName, + } +} + +func (s Store[T]) Default() FileConfig[T] { + return FileConfig[T]{ + Version: CurrentVersion, + Profiles: map[string]T{}, + } +} + +func (s Store[T]) ConfigPath() (string, error) { + userConfigDir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("resolve user config dir: %w", err) + } + + return filepath.Join(userConfigDir, s.dirName, s.fileName), nil +} + +func (s Store[T]) Load(path string) (FileConfig[T], error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return s.Default(), nil + } + return FileConfig[T]{}, fmt.Errorf("read config %s: %w", path, err) + } + + if len(data) == 0 { + return s.Default(), nil + } + + cfg := s.Default() + if err := json.Unmarshal(data, &cfg); err != nil { + return FileConfig[T]{}, fmt.Errorf("parse config %s: %w", path, err) + } + + s.normalize(&cfg) + return cfg, nil +} + +func (s Store[T]) LoadDefault() (FileConfig[T], string, error) { + path, err := s.ConfigPath() + if err != nil { + return FileConfig[T]{}, "", err + } + + cfg, err := s.Load(path) + if err != nil { + return FileConfig[T]{}, "", err + } + + return cfg, path, nil +} + +func (s Store[T]) Save(path string, cfg FileConfig[T]) error { + s.normalize(&cfg) + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("create config dir %s: %w", dir, err) + } + if err := os.Chmod(dir, 0o700); err != nil { + return fmt.Errorf("set config dir permissions %s: %w", dir, err) + } + + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("encode config: %w", err) + } + data = append(data, '\n') + + tmpFile, err := os.CreateTemp(dir, "config-*.json") + if err != nil { + return fmt.Errorf("create temp config in %s: %w", dir, err) + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + _ = tmpFile.Close() + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if err := tmpFile.Chmod(0o600); err != nil { + return fmt.Errorf("set temp config permissions %s: %w", tmpPath, err) + } + if _, err := tmpFile.Write(data); err != nil { + return fmt.Errorf("write temp config %s: %w", tmpPath, err) + } + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("close temp config %s: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("replace config %s: %w", path, err) + } + if err := os.Chmod(path, 0o600); err != nil { + return fmt.Errorf("set config permissions %s: %w", path, err) + } + + cleanup = false + return nil +} + +func (s Store[T]) SaveDefault(cfg FileConfig[T]) (string, error) { + path, err := s.ConfigPath() + if err != nil { + return "", err + } + + if err := s.Save(path, cfg); err != nil { + return "", err + } + + return path, nil +} + +func (s Store[T]) normalize(cfg *FileConfig[T]) { + if cfg.Version == 0 { + cfg.Version = CurrentVersion + } + if cfg.Profiles == nil { + cfg.Profiles = map[string]T{} + } +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..a2ba592 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,78 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +type testProfile struct { + BaseURL string `json:"base_url"` + StreamID string `json:"stream_id"` +} + +func TestLoadMissingReturnsDefault(t *testing.T) { + store := NewStore[testProfile]("mcp-framework-test") + path := filepath.Join(t.TempDir(), "missing.json") + + cfg, err := store.Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + if cfg.Version != CurrentVersion { + t.Fatalf("Version = %d, want %d", cfg.Version, CurrentVersion) + } + if len(cfg.Profiles) != 0 { + t.Fatalf("Profiles = %v, want empty", cfg.Profiles) + } +} + +func TestSaveAndLoadRoundTrip(t *testing.T) { + store := NewStore[testProfile]("mcp-framework-test") + dir := t.TempDir() + path := filepath.Join(dir, "mcp-framework", "config.json") + + input := FileConfig[testProfile]{ + Version: CurrentVersion, + CurrentProfile: "prod", + Profiles: map[string]testProfile{ + "prod": { + BaseURL: "https://graylog.example.com", + StreamID: "stream-1", + }, + }, + } + + if err := store.Save(path, input); err != nil { + t.Fatalf("Save returned error: %v", err) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat returned error: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Fatalf("file mode = %o, want 600", info.Mode().Perm()) + } + + dirInfo, err := os.Stat(filepath.Dir(path)) + if err != nil { + t.Fatalf("Stat dir returned error: %v", err) + } + if dirInfo.Mode().Perm() != 0o700 { + t.Fatalf("dir mode = %o, want 700", dirInfo.Mode().Perm()) + } + + cfg, err := store.Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + if cfg.CurrentProfile != "prod" { + t.Fatalf("CurrentProfile = %q, want prod", cfg.CurrentProfile) + } + if cfg.Profiles["prod"].BaseURL != "https://graylog.example.com" { + t.Fatalf("BaseURL = %q", cfg.Profiles["prod"].BaseURL) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d6a0f56 --- /dev/null +++ b/go.mod @@ -0,0 +1,18 @@ +module gitlab.lundimatin.app/artificial-intelligence-ia/claude/mcp-framework + +go 1.25.0 + +require ( + github.com/99designs/keyring v1.2.2 + golang.org/x/term v0.40.0 +) + +require ( + github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect + github.com/danieljoos/wincred v1.1.2 // indirect + github.com/dvsekhvalnov/jose2go v1.5.0 // indirect + github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect + github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect + github.com/mtibben/percent v0.2.1 // indirect + golang.org/x/sys v0.41.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..bdfc1a5 --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= +github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= +github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= +github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= +github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= +github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= +github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= +github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/secretstore/store.go b/secretstore/store.go new file mode 100644 index 0000000..09c45a4 --- /dev/null +++ b/secretstore/store.go @@ -0,0 +1,90 @@ +package secretstore + +import ( + "errors" + "fmt" + "runtime" + "strings" + + "github.com/99designs/keyring" +) + +var ErrNotFound = errors.New("secret not found") + +type Options struct { + ServiceName string +} + +type Store interface { + SetSecret(name, label, secret string) error + GetSecret(name string) (string, error) + DeleteSecret(name string) error +} + +type keyringStore struct { + ring keyring.Keyring + serviceName string +} + +func Open(options Options) (Store, error) { + serviceName := strings.TrimSpace(options.ServiceName) + if serviceName == "" { + return nil, errors.New("service name must not be empty") + } + + ring, err := keyring.Open(keyring.Config{ + ServiceName: serviceName, + }) + if err != nil { + return nil, fmt.Errorf("open OS wallet backend %q for service %q: %w", BackendName(), serviceName, err) + } + + return &keyringStore{ + ring: ring, + serviceName: serviceName, + }, nil +} + +func BackendName() string { + switch runtime.GOOS { + case "darwin": + return "macOS Keychain" + case "windows": + return "Windows Credential Manager" + case "linux": + return "Linux Secret Service or KWallet" + default: + return "system wallet" + } +} + +func (s *keyringStore) SetSecret(name, label, secret string) error { + if err := s.ring.Set(keyring.Item{ + Key: name, + Label: label, + Data: []byte(secret), + }); err != nil { + return fmt.Errorf("save secret %q in OS wallet for service %q: %w", name, s.serviceName, err) + } + + return nil +} + +func (s *keyringStore) GetSecret(name string) (string, error) { + item, err := s.ring.Get(name) + if err != nil { + if errors.Is(err, keyring.ErrKeyNotFound) { + return "", ErrNotFound + } + return "", fmt.Errorf("read secret %q from OS wallet for service %q: %w", name, s.serviceName, err) + } + + return string(item.Data), nil +} + +func (s *keyringStore) DeleteSecret(name string) error { + if err := s.ring.Remove(name); err != nil && !errors.Is(err, keyring.ErrKeyNotFound) { + return fmt.Errorf("delete secret %q from OS wallet for service %q: %w", name, s.serviceName, err) + } + return nil +} diff --git a/update/update.go b/update/update.go new file mode 100644 index 0000000..3fcdf59 --- /dev/null +++ b/update/update.go @@ -0,0 +1,371 @@ +package update + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +type Options struct { + Client *http.Client + CurrentVersion string + ExecutablePath string + LatestReleaseURL string + Stdout io.Writer + BinaryName string + ReleaseSource GitLabSource + GOOS string + GOARCH string +} + +type GitLabSource struct { + BaseURL string + ProjectPath string + Token string + TokenHeader string + TokenEnvNames []string +} + +type Auth struct { + Header string + Token string +} + +type Release struct { + TagName string `json:"tag_name"` + Assets struct { + Links []ReleaseLink `json:"links"` + } `json:"assets"` +} + +type ReleaseLink struct { + Name string `json:"name"` + URL string `json:"url"` +} + +func Run(ctx context.Context, opts Options) error { + if opts.Stdout == nil { + opts.Stdout = io.Discard + } + if opts.Client == nil { + opts.Client = &http.Client{Timeout: 60 * time.Second} + } + if strings.TrimSpace(opts.CurrentVersion) == "" { + opts.CurrentVersion = "dev" + } + if strings.TrimSpace(opts.GOOS) == "" { + opts.GOOS = runtime.GOOS + } + if strings.TrimSpace(opts.GOARCH) == "" { + opts.GOARCH = runtime.GOARCH + } + + source := normalizeSource(opts.ReleaseSource) + auth := ResolveGitLabAuth(source.Token, source) + + targetPath, err := ResolveUpdateTarget(opts.ExecutablePath) + if err != nil { + return err + } + + assetName, err := AssetName(opts.BinaryName, opts.GOOS, opts.GOARCH) + if err != nil { + return err + } + + releaseURL := opts.LatestReleaseURL + if strings.TrimSpace(releaseURL) == "" { + releaseURL = LatestReleaseAPIURL(source) + } + + release, err := FetchLatestRelease(ctx, opts.Client, releaseURL, auth, source) + if err != nil { + return err + } + if isCurrentRelease(opts.CurrentVersion, release.TagName) { + fmt.Fprintf(opts.Stdout, "Already up to date (%s)\n", release.TagName) + return nil + } + + assetURL, err := release.AssetURL(assetName, releaseURL) + if err != nil { + return err + } + + downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source) + if err != nil { + return err + } + defer os.Remove(downloadPath) + + if err := ReplaceExecutable(downloadPath, targetPath); err != nil { + return err + } + + fmt.Fprintf(opts.Stdout, "Updated %s to %s\n", targetPath, release.TagName) + return nil +} + +func ResolveGitLabAuth(explicitToken string, source GitLabSource) Auth { + source = normalizeSource(source) + + if token := strings.TrimSpace(explicitToken); token != "" { + return Auth{Header: source.TokenHeader, Token: token} + } + + for _, envName := range source.TokenEnvNames { + if token := strings.TrimSpace(os.Getenv(envName)); token != "" { + return Auth{Header: source.TokenHeader, Token: token} + } + } + + return Auth{} +} + +func LatestReleaseAPIURL(source GitLabSource) string { + source = normalizeSource(source) + return fmt.Sprintf( + "%s/api/v4/projects/%s/releases/permalink/latest", + source.BaseURL, + url.PathEscape(source.ProjectPath), + ) +} + +func ResolveUpdateTarget(explicitPath string) (string, error) { + targetPath := strings.TrimSpace(explicitPath) + if targetPath == "" { + var err error + targetPath, err = os.Executable() + if err != nil { + return "", fmt.Errorf("resolve executable path: %w", err) + } + } + + resolvedPath, err := filepath.EvalSymlinks(targetPath) + if err != nil { + return "", fmt.Errorf("resolve executable symlink %q: %w", targetPath, err) + } + return resolvedPath, nil +} + +func AssetName(binaryName, goos, goarch string) (string, error) { + name := strings.TrimSpace(binaryName) + if name == "" { + return "", errors.New("binary name must not be empty") + } + + switch { + case goos == "darwin" && goarch == "amd64": + return name + "-darwin-amd64", nil + case goos == "darwin" && goarch == "arm64": + return name + "-darwin-arm64", nil + case goos == "linux" && goarch == "amd64": + return name + "-linux-amd64", nil + case goos == "windows" && goarch == "amd64": + return name + "-windows-amd64.exe", nil + default: + return "", fmt.Errorf("no release artifact for %s/%s", goos, goarch) + } +} + +func FetchLatestRelease(ctx context.Context, client *http.Client, releaseURL string, auth Auth, source GitLabSource) (Release, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil) + if err != nil { + return Release{}, fmt.Errorf("build latest release request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "mcp updater") + auth.apply(req) + + resp, err := client.Do(req) + if err != nil { + return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { + return Release{}, fmt.Errorf("fetch latest release metadata: %w", err) + } + return Release{}, fmt.Errorf( + "fetch latest release metadata: unexpected status %d: %s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return Release{}, fmt.Errorf("decode latest release metadata: %w", err) + } + if strings.TrimSpace(release.TagName) == "" { + return Release{}, errors.New("latest release metadata is missing tag_name") + } + return release, nil +} + +func (r Release) AssetURL(assetName, releaseURL string) (string, error) { + for _, link := range r.Assets.Links { + if link.Name == assetName { + if strings.TrimSpace(link.URL) == "" { + return "", fmt.Errorf("release asset %q has no URL", assetName) + } + + parsed, err := url.Parse(link.URL) + if err != nil { + return "", fmt.Errorf("parse release asset URL %q: %w", link.URL, err) + } + if parsed.IsAbs() { + return parsed.String(), nil + } + + baseURL, err := url.Parse(releaseURL) + if err != nil { + return "", fmt.Errorf("parse latest release URL %q: %w", releaseURL, err) + } + return baseURL.ResolveReference(parsed).String(), nil + } + } + + return "", fmt.Errorf("latest release does not contain asset %q", assetName) +} + +func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source GitLabSource) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) + if err != nil { + return "", fmt.Errorf("build artifact download request: %w", err) + } + req.Header.Set("User-Agent", "mcp updater") + auth.apply(req) + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("download release artifact: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err := auth.maybeHint(resp.StatusCode, body, source); err != nil { + return "", fmt.Errorf("download release artifact: %w", err) + } + return "", fmt.Errorf( + "download release artifact: unexpected status %d: %s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) + } + + existingInfo, err := os.Stat(targetPath) + if err != nil { + return "", fmt.Errorf("stat executable %q: %w", targetPath, err) + } + + tempFile, err := os.CreateTemp(filepath.Dir(targetPath), filepath.Base(targetPath)+".download-*") + if err != nil { + return "", fmt.Errorf("create temporary file: %w", err) + } + + tempPath := tempFile.Name() + cleanup := func(copyErr error) (string, error) { + tempFile.Close() + os.Remove(tempPath) + return "", copyErr + } + + if _, err := io.Copy(tempFile, resp.Body); err != nil { + return cleanup(fmt.Errorf("write downloaded artifact: %w", err)) + } + if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil { + return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err)) + } + if err := tempFile.Close(); err != nil { + os.Remove(tempPath) + return "", fmt.Errorf("close downloaded artifact: %w", err) + } + + return tempPath, nil +} + +func ReplaceExecutable(downloadPath, targetPath string) error { + if runtime.GOOS == "windows" { + return errors.New("self-update is not supported on windows") + } + if err := os.Rename(downloadPath, targetPath); err != nil { + return fmt.Errorf("replace executable %q: %w", targetPath, err) + } + return nil +} + +func normalizeSource(source GitLabSource) GitLabSource { + if source.TokenHeader == "" { + source.TokenHeader = "PRIVATE-TOKEN" + } + source.BaseURL = strings.TrimRight(strings.TrimSpace(source.BaseURL), "/") + source.ProjectPath = strings.TrimSpace(source.ProjectPath) + return source +} + +func isCurrentRelease(currentVersion, latestTag string) bool { + current := strings.TrimSpace(currentVersion) + latest := strings.TrimSpace(latestTag) + if latest == "" { + return false + } + if current == "" || current == "dev" { + return false + } + return current == latest +} + +func (a Auth) apply(req *http.Request) { + if strings.TrimSpace(a.Header) == "" || strings.TrimSpace(a.Token) == "" { + return + } + req.Header.Set(a.Header, a.Token) +} + +func (a Auth) maybeHint(statusCode int, body []byte, source GitLabSource) error { + source = normalizeSource(source) + if strings.TrimSpace(a.Token) != "" || len(source.TokenEnvNames) == 0 { + return nil + } + + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: + default: + return nil + } + + message := strings.ToLower(strings.TrimSpace(string(body))) + if !strings.Contains(message, "project not found") && + !strings.Contains(message, "unauthorized") && + !strings.Contains(message, "forbidden") { + return nil + } + + if len(source.TokenEnvNames) == 1 { + return fmt.Errorf( + "GitLab release access requires authentication on %s; set %s and retry", + source.BaseURL, + source.TokenEnvNames[0], + ) + } + + return fmt.Errorf( + "GitLab release access requires authentication on %s; set %s (or %s) and retry", + source.BaseURL, + source.TokenEnvNames[0], + source.TokenEnvNames[1], + ) +} diff --git a/update/update_test.go b/update/update_test.go new file mode 100644 index 0000000..b49904b --- /dev/null +++ b/update/update_test.go @@ -0,0 +1,363 @@ +package update + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestAssetName(t *testing.T) { + tests := []struct { + name string + goos string + goarch string + want string + wantErr string + }{ + {name: "darwin amd64", goos: "darwin", goarch: "amd64", want: "graylog-mcp-darwin-amd64"}, + {name: "darwin arm64", goos: "darwin", goarch: "arm64", want: "graylog-mcp-darwin-arm64"}, + {name: "linux amd64", goos: "linux", goarch: "amd64", want: "graylog-mcp-linux-amd64"}, + {name: "windows amd64", goos: "windows", goarch: "amd64", want: "graylog-mcp-windows-amd64.exe"}, + {name: "unsupported", goos: "linux", goarch: "arm64", wantErr: "no release artifact"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := AssetName("graylog-mcp", tt.goos, tt.goarch) + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error = %v, want substring %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveUpdateTargetFollowsSymlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink behavior differs on windows") + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + link := filepath.Join(tempDir, "graylog-mcp-link") + + if err := os.WriteFile(target, []byte("old"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + if err := os.Symlink(target, link); err != nil { + t.Fatalf("Symlink: %v", err) + } + + resolved, err := ResolveUpdateTarget(link) + if err != nil { + t.Fatalf("ResolveUpdateTarget: %v", err) + } + if resolved != target { + t.Fatalf("resolved = %q, want %q", resolved, target) + } +} + +func TestReleaseAssetURLResolvesRelativeLinks(t *testing.T) { + release := Release{} + release.Assets.Links = []ReleaseLink{ + {Name: "graylog-mcp-linux-amd64", URL: "/downloads/graylog-mcp-linux-amd64"}, + } + + got, err := release.AssetURL("graylog-mcp-linux-amd64", "https://gitlab.example.com/api/v4/projects/1/releases/permalink/latest") + if err != nil { + t.Fatalf("AssetURL: %v", err) + } + if got != "https://gitlab.example.com/downloads/graylog-mcp-linux-amd64" { + t.Fatalf("got %q", got) + } +} + +func TestResolveGitLabAuthPrefersExplicitToken(t *testing.T) { + t.Setenv("GITLAB_TOKEN", "env-token") + + auth := ResolveGitLabAuth("explicit-token", GitLabSource{ + BaseURL: "https://gitlab.example.com", + ProjectPath: "group/project", + TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}, + }) + if auth.Header != "PRIVATE-TOKEN" { + t.Fatalf("header = %q, want PRIVATE-TOKEN", auth.Header) + } + if auth.Token != "explicit-token" { + t.Fatalf("token = %q, want explicit token", auth.Token) + } +} + +func TestResolveGitLabAuthReadsEnvironment(t *testing.T) { + t.Setenv("GITLAB_PRIVATE_TOKEN", "env-token") + + auth := ResolveGitLabAuth("", GitLabSource{ + BaseURL: "https://gitlab.example.com", + ProjectPath: "group/project", + TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}, + }) + if auth.Header != "PRIVATE-TOKEN" { + t.Fatalf("header = %q, want PRIVATE-TOKEN", auth.Header) + } + if auth.Token != "env-token" { + t.Fatalf("token = %q, want env token", auth.Token) + } +} + +func TestFetchLatestReleaseAddsGitLabAuthHeader(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if got := r.Header.Get("PRIVATE-TOKEN"); got != "secret-token" { + t.Fatalf("PRIVATE-TOKEN = %q, want secret-token", got) + } + + payload, err := json.Marshal(Release{TagName: "v1.2.3"}) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + }), + } + + release, err := FetchLatestRelease( + context.Background(), + client, + "https://gitlab.example.com/latest", + Auth{Header: "PRIVATE-TOKEN", Token: "secret-token"}, + GitLabSource{BaseURL: "https://gitlab.example.com", ProjectPath: "group/project"}, + ) + if err != nil { + t.Fatalf("FetchLatestRelease: %v", err) + } + if release.TagName != "v1.2.3" { + t.Fatalf("tag = %q, want v1.2.3", release.TagName) + } +} + +func TestFetchLatestReleaseHintsWhenGitLabAuthIsMissing(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"message":"404 Project Not Found"}`)), + }, nil + }), + } + + _, err := FetchLatestRelease( + context.Background(), + client, + "https://gitlab.example.com/latest", + Auth{}, + GitLabSource{ + BaseURL: "https://gitlab.example.com", + ProjectPath: "group/project", + TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}, + }, + ) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "GITLAB_TOKEN") { + t.Fatalf("error = %v, want token hint", err) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("self-replace is not supported on windows") + } + + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + const newBinary = "new-binary" + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://gitlab.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://gitlab.example.com/artifact"}, + } + + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://gitlab.example.com/artifact": + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(newBinary)), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + link := filepath.Join(tempDir, "graylog-mcp-link") + + if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + if err := os.Symlink(target, link); err != nil { + t.Fatalf("Symlink: %v", err) + } + + var stdout strings.Builder + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.2", + ExecutablePath: link, + LatestReleaseURL: "https://gitlab.example.com/latest", + Stdout: &stdout, + BinaryName: "graylog-mcp", + ReleaseSource: GitLabSource{ + BaseURL: "https://gitlab.example.com", + ProjectPath: "group/project", + TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}, + }, + }) + if err != nil { + t.Fatalf("Run: %v", err) + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile target: %v", err) + } + if string(got) != newBinary { + t.Fatalf("target content = %q, want %q", string(got), newBinary) + } + if info, err := os.Lstat(link); err != nil { + t.Fatalf("Lstat link: %v", err) + } else if info.Mode()&os.ModeSymlink == 0 { + t.Fatalf("link %q is no longer a symlink", link) + } + if !strings.Contains(stdout.String(), "v1.2.3") { + t.Fatalf("stdout = %q, want release tag", stdout.String()) + } +} + +func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) { + assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Skipf("unsupported test platform: %v", err) + } + + downloaded := false + client := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch r.URL.String() { + case "https://gitlab.example.com/latest": + release := Release{TagName: "v1.2.3"} + release.Assets.Links = []ReleaseLink{ + {Name: assetName, URL: "https://gitlab.example.com/artifact"}, + } + + payload, err := json.Marshal(release) + if err != nil { + t.Fatalf("Marshal release: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(payload)), + }, nil + case "https://gitlab.example.com/artifact": + downloaded = true + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("should-not-download")), + }, nil + default: + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + } + }), + } + + tempDir := t.TempDir() + target := filepath.Join(tempDir, "graylog-mcp") + if err := os.WriteFile(target, []byte("current-binary"), 0o755); err != nil { + t.Fatalf("WriteFile target: %v", err) + } + + var stdout strings.Builder + err = Run(context.Background(), Options{ + Client: client, + CurrentVersion: "v1.2.3", + ExecutablePath: target, + LatestReleaseURL: "https://gitlab.example.com/latest", + Stdout: &stdout, + BinaryName: "graylog-mcp", + ReleaseSource: GitLabSource{ + BaseURL: "https://gitlab.example.com", + ProjectPath: "group/project", + TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"}, + }, + }) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if downloaded { + t.Fatal("artifact should not have been downloaded") + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile target: %v", err) + } + if string(got) != "current-binary" { + t.Fatalf("target content = %q, want unchanged binary", string(got)) + } + if !strings.Contains(stdout.String(), "Already up to date") { + t.Fatalf("stdout = %q, want up-to-date message", stdout.String()) + } +}