mcp-framework/update/update_test.go
2026-04-13 15:33:48 +02:00

363 lines
10 KiB
Go

package update
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func TestAssetName(t *testing.T) {
tests := []struct {
name string
goos string
goarch string
want string
wantErr string
}{
{name: "darwin amd64", goos: "darwin", goarch: "amd64", want: "graylog-mcp-darwin-amd64"},
{name: "darwin arm64", goos: "darwin", goarch: "arm64", want: "graylog-mcp-darwin-arm64"},
{name: "linux amd64", goos: "linux", goarch: "amd64", want: "graylog-mcp-linux-amd64"},
{name: "windows amd64", goos: "windows", goarch: "amd64", want: "graylog-mcp-windows-amd64.exe"},
{name: "unsupported", goos: "linux", goarch: "arm64", wantErr: "no release artifact"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := AssetName("graylog-mcp", tt.goos, tt.goarch)
if tt.wantErr != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error = %v, want substring %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Fatalf("got %q, want %q", got, tt.want)
}
})
}
}
func TestResolveUpdateTargetFollowsSymlink(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink behavior differs on windows")
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
link := filepath.Join(tempDir, "graylog-mcp-link")
if err := os.WriteFile(target, []byte("old"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
if err := os.Symlink(target, link); err != nil {
t.Fatalf("Symlink: %v", err)
}
resolved, err := ResolveUpdateTarget(link)
if err != nil {
t.Fatalf("ResolveUpdateTarget: %v", err)
}
if resolved != target {
t.Fatalf("resolved = %q, want %q", resolved, target)
}
}
func TestReleaseAssetURLResolvesRelativeLinks(t *testing.T) {
release := Release{}
release.Assets.Links = []ReleaseLink{
{Name: "graylog-mcp-linux-amd64", URL: "/downloads/graylog-mcp-linux-amd64"},
}
got, err := release.AssetURL("graylog-mcp-linux-amd64", "https://gitlab.example.com/api/v4/projects/1/releases/permalink/latest")
if err != nil {
t.Fatalf("AssetURL: %v", err)
}
if got != "https://gitlab.example.com/downloads/graylog-mcp-linux-amd64" {
t.Fatalf("got %q", got)
}
}
func TestResolveGitLabAuthPrefersExplicitToken(t *testing.T) {
t.Setenv("GITLAB_TOKEN", "env-token")
auth := ResolveGitLabAuth("explicit-token", GitLabSource{
BaseURL: "https://gitlab.example.com",
ProjectPath: "group/project",
TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"},
})
if auth.Header != "PRIVATE-TOKEN" {
t.Fatalf("header = %q, want PRIVATE-TOKEN", auth.Header)
}
if auth.Token != "explicit-token" {
t.Fatalf("token = %q, want explicit token", auth.Token)
}
}
func TestResolveGitLabAuthReadsEnvironment(t *testing.T) {
t.Setenv("GITLAB_PRIVATE_TOKEN", "env-token")
auth := ResolveGitLabAuth("", GitLabSource{
BaseURL: "https://gitlab.example.com",
ProjectPath: "group/project",
TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"},
})
if auth.Header != "PRIVATE-TOKEN" {
t.Fatalf("header = %q, want PRIVATE-TOKEN", auth.Header)
}
if auth.Token != "env-token" {
t.Fatalf("token = %q, want env token", auth.Token)
}
}
func TestFetchLatestReleaseAddsGitLabAuthHeader(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if got := r.Header.Get("PRIVATE-TOKEN"); got != "secret-token" {
t.Fatalf("PRIVATE-TOKEN = %q, want secret-token", got)
}
payload, err := json.Marshal(Release{TagName: "v1.2.3"})
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
}),
}
release, err := FetchLatestRelease(
context.Background(),
client,
"https://gitlab.example.com/latest",
Auth{Header: "PRIVATE-TOKEN", Token: "secret-token"},
GitLabSource{BaseURL: "https://gitlab.example.com", ProjectPath: "group/project"},
)
if err != nil {
t.Fatalf("FetchLatestRelease: %v", err)
}
if release.TagName != "v1.2.3" {
t.Fatalf("tag = %q, want v1.2.3", release.TagName)
}
}
func TestFetchLatestReleaseHintsWhenGitLabAuthIsMissing(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(`{"message":"404 Project Not Found"}`)),
}, nil
}),
}
_, err := FetchLatestRelease(
context.Background(),
client,
"https://gitlab.example.com/latest",
Auth{},
GitLabSource{
BaseURL: "https://gitlab.example.com",
ProjectPath: "group/project",
TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"},
},
)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "GITLAB_TOKEN") {
t.Fatalf("error = %v, want token hint", err)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("self-replace is not supported on windows")
}
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
const newBinary = "new-binary"
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://gitlab.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://gitlab.example.com/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://gitlab.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(newBinary)),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
link := filepath.Join(tempDir, "graylog-mcp-link")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
if err := os.Symlink(target, link); err != nil {
t.Fatalf("Symlink: %v", err)
}
var stdout strings.Builder
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: link,
LatestReleaseURL: "https://gitlab.example.com/latest",
Stdout: &stdout,
BinaryName: "graylog-mcp",
ReleaseSource: GitLabSource{
BaseURL: "https://gitlab.example.com",
ProjectPath: "group/project",
TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"},
},
})
if err != nil {
t.Fatalf("Run: %v", err)
}
got, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile target: %v", err)
}
if string(got) != newBinary {
t.Fatalf("target content = %q, want %q", string(got), newBinary)
}
if info, err := os.Lstat(link); err != nil {
t.Fatalf("Lstat link: %v", err)
} else if info.Mode()&os.ModeSymlink == 0 {
t.Fatalf("link %q is no longer a symlink", link)
}
if !strings.Contains(stdout.String(), "v1.2.3") {
t.Fatalf("stdout = %q, want release tag", stdout.String())
}
}
func TestRunSkipsWhenAlreadyOnLatestRelease(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
downloaded := false
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://gitlab.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://gitlab.example.com/artifact"},
}
payload, err := json.Marshal(release)
if err != nil {
t.Fatalf("Marshal release: %v", err)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://gitlab.example.com/artifact":
downloaded = true
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("should-not-download")),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("current-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
var stdout strings.Builder
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.3",
ExecutablePath: target,
LatestReleaseURL: "https://gitlab.example.com/latest",
Stdout: &stdout,
BinaryName: "graylog-mcp",
ReleaseSource: GitLabSource{
BaseURL: "https://gitlab.example.com",
ProjectPath: "group/project",
TokenEnvNames: []string{"GITLAB_TOKEN", "GITLAB_PRIVATE_TOKEN"},
},
})
if err != nil {
t.Fatalf("Run: %v", err)
}
if downloaded {
t.Fatal("artifact should not have been downloaded")
}
got, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile target: %v", err)
}
if string(got) != "current-binary" {
t.Fatalf("target content = %q, want unchanged binary", string(got))
}
if !strings.Contains(stdout.String(), "Already up to date") {
t.Fatalf("stdout = %q, want up-to-date message", stdout.String())
}
}