diff --git a/update/update.go b/update/update.go
index 21aa789..31c60f5 100644
--- a/update/update.go
+++ b/update/update.go
@@ -1,6 +1,7 @@
package update
import (
+ "bytes"
"context"
"crypto/ed25519"
"crypto/sha256"
@@ -23,6 +24,7 @@ import (
const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}"
const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024
+const downloadedArtifactSniffBytes = 4096
type Options struct {
Client *http.Client
@@ -179,6 +181,10 @@ func Run(ctx context.Context, opts Options) error {
}
defer os.Remove(downloadPath)
+ if err := validateDownloadedArtifact(downloadPath, assetName); err != nil {
+ return err
+ }
+
if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil {
return err
}
@@ -211,6 +217,52 @@ func Run(ctx context.Context, opts Options) error {
return nil
}
+func validateDownloadedArtifact(path, assetName string) error {
+ file, err := os.Open(path)
+ if err != nil {
+ return fmt.Errorf("validate downloaded artifact %q: %w", path, err)
+ }
+ defer file.Close()
+
+ head := make([]byte, downloadedArtifactSniffBytes)
+ n, readErr := file.Read(head)
+ if readErr != nil && !errors.Is(readErr, io.EOF) {
+ return fmt.Errorf("validate downloaded artifact %q: %w", path, readErr)
+ }
+ if n == 0 {
+ return fmt.Errorf("downloaded artifact %q is empty", assetName)
+ }
+
+ if looksLikeHTMLDocument(head[:n]) {
+ return fmt.Errorf(
+ "downloaded artifact %q looks like an HTML page (possible auth/forbidden response)",
+ assetName,
+ )
+ }
+
+ return nil
+}
+
+func looksLikeHTMLDocument(content []byte) bool {
+ if len(content) == 0 {
+ return false
+ }
+
+ detectedContentType := strings.ToLower(http.DetectContentType(content))
+ if strings.HasPrefix(detectedContentType, "text/html") {
+ return true
+ }
+
+ trimmed := bytes.TrimSpace(content)
+ trimmed = bytes.TrimPrefix(trimmed, []byte{0xEF, 0xBB, 0xBF}) // UTF-8 BOM
+ if len(trimmed) == 0 {
+ return false
+ }
+
+ lower := strings.ToLower(string(trimmed))
+ return strings.HasPrefix(lower, "
ForbiddenAccess denied"
+ if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+
+ err := validateDownloadedArtifact(path, "graylog-mcp-linux-amd64")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "looks like an HTML page") {
+ t.Fatalf("error = %v", err)
+ }
+}
+
+func TestValidateDownloadedArtifactAcceptsShebangScript(t *testing.T) {
+ path := filepath.Join(t.TempDir(), "downloaded")
+ content := "#!/usr/bin/env sh\necho ok\n"
+ if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+
+ err := validateDownloadedArtifact(path, "graylog-mcp-linux-amd64")
+ if err != nil {
+ t.Fatalf("validateDownloadedArtifact: %v", err)
+ }
+}
+
func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("self-replace is not supported on windows")
@@ -516,6 +545,85 @@ func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
}
}
+func TestRunStopsWhenArtifactLooksLikeHTML(t *testing.T) {
+ assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
+ if err != nil {
+ t.Skipf("unsupported test platform: %v", err)
+ }
+
+ client := &http.Client{
+ Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
+ switch r.URL.String() {
+ case "https://releases.example.com/latest":
+ release := Release{TagName: "v1.2.3"}
+ release.Assets.Links = []ReleaseLink{
+ {Name: assetName, URL: "https://releases.example.com/artifact"},
+ }
+
+ payload, marshalErr := json.Marshal(release)
+ if marshalErr != nil {
+ t.Fatalf("Marshal release: %v", marshalErr)
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: io.NopCloser(bytes.NewReader(payload)),
+ }, nil
+ case "https://releases.example.com/artifact":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(
+ "Access denied",
+ )),
+ }, nil
+ default:
+ return &http.Response{
+ StatusCode: http.StatusNotFound,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader("not found")),
+ }, nil
+ }
+ }),
+ }
+
+ tempDir := t.TempDir()
+ target := filepath.Join(tempDir, "graylog-mcp")
+ if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
+ t.Fatalf("WriteFile target: %v", err)
+ }
+
+ replaceCalled := false
+ err = Run(context.Background(), Options{
+ Client: client,
+ CurrentVersion: "v1.2.2",
+ ExecutablePath: target,
+ LatestReleaseURL: "https://releases.example.com/latest",
+ BinaryName: "graylog-mcp",
+ ReplaceExecutable: func(downloadPath, targetPath string) error {
+ replaceCalled = true
+ return os.WriteFile(targetPath, []byte("unexpected"), 0o755)
+ },
+ })
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "looks like an HTML page") {
+ t.Fatalf("error = %v", err)
+ }
+ if replaceCalled {
+ t.Fatal("replace hook should not have been called")
+ }
+
+ got, readErr := os.ReadFile(target)
+ if readErr != nil {
+ t.Fatalf("ReadFile target: %v", readErr)
+ }
+ if string(got) != "old-binary" {
+ t.Fatalf("target content = %q, want unchanged binary", string(got))
+ }
+}
+
func TestRunUsesDriverWithoutExplicitLatestReleaseURL(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {