fix(update): reject HTML artifacts during self-update

This commit is contained in:
thibaud-lclr 2026-04-15 14:23:15 +02:00
parent f0e2e9304b
commit 01c0c7e1bc
2 changed files with 160 additions and 0 deletions

View file

@ -1,6 +1,7 @@
package update package update
import ( import (
"bytes"
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/sha256" "crypto/sha256"
@ -23,6 +24,7 @@ import (
const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}" const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}"
const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024 const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024
const downloadedArtifactSniffBytes = 4096
type Options struct { type Options struct {
Client *http.Client Client *http.Client
@ -179,6 +181,10 @@ func Run(ctx context.Context, opts Options) error {
} }
defer os.Remove(downloadPath) defer os.Remove(downloadPath)
if err := validateDownloadedArtifact(downloadPath, assetName); err != nil {
return err
}
if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil { if err := VerifyReleaseAssetChecksum(ctx, opts.Client, release, releaseURL, assetName, downloadPath, auth, source); err != nil {
return err return err
} }
@ -211,6 +217,52 @@ func Run(ctx context.Context, opts Options) error {
return nil return nil
} }
func validateDownloadedArtifact(path, assetName string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("validate downloaded artifact %q: %w", path, err)
}
defer file.Close()
head := make([]byte, downloadedArtifactSniffBytes)
n, readErr := file.Read(head)
if readErr != nil && !errors.Is(readErr, io.EOF) {
return fmt.Errorf("validate downloaded artifact %q: %w", path, readErr)
}
if n == 0 {
return fmt.Errorf("downloaded artifact %q is empty", assetName)
}
if looksLikeHTMLDocument(head[:n]) {
return fmt.Errorf(
"downloaded artifact %q looks like an HTML page (possible auth/forbidden response)",
assetName,
)
}
return nil
}
func looksLikeHTMLDocument(content []byte) bool {
if len(content) == 0 {
return false
}
detectedContentType := strings.ToLower(http.DetectContentType(content))
if strings.HasPrefix(detectedContentType, "text/html") {
return true
}
trimmed := bytes.TrimSpace(content)
trimmed = bytes.TrimPrefix(trimmed, []byte{0xEF, 0xBB, 0xBF}) // UTF-8 BOM
if len(trimmed) == 0 {
return false
}
lower := strings.ToLower(string(trimmed))
return strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html")
}
func ResolveLatestReleaseURL(explicit string, source ReleaseSource) (string, error) { func ResolveLatestReleaseURL(explicit string, source ReleaseSource) (string, error) {
if releaseURL := strings.TrimSpace(explicit); releaseURL != "" { if releaseURL := strings.TrimSpace(explicit); releaseURL != "" {
return releaseURL, nil return releaseURL, nil

View file

@ -425,6 +425,35 @@ func TestDownloadReleaseAssetRejectsArtifactOverLimit(t *testing.T) {
} }
} }
func TestValidateDownloadedArtifactRejectsHTMLDocument(t *testing.T) {
path := filepath.Join(t.TempDir(), "downloaded")
content := "<!DOCTYPE html><html><head><title>Forbidden</title></head><body>Access denied</body></html>"
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
err := validateDownloadedArtifact(path, "graylog-mcp-linux-amd64")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "looks like an HTML page") {
t.Fatalf("error = %v", err)
}
}
func TestValidateDownloadedArtifactAcceptsShebangScript(t *testing.T) {
path := filepath.Join(t.TempDir(), "downloaded")
content := "#!/usr/bin/env sh\necho ok\n"
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
err := validateDownloadedArtifact(path, "graylog-mcp-linux-amd64")
if err != nil {
t.Fatalf("validateDownloadedArtifact: %v", err)
}
}
func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) { func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("self-replace is not supported on windows") t.Skip("self-replace is not supported on windows")
@ -516,6 +545,85 @@ func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
} }
} }
func TestRunStopsWhenArtifactLooksLikeHTML(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil {
t.Skipf("unsupported test platform: %v", err)
}
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.String() {
case "https://releases.example.com/latest":
release := Release{TagName: "v1.2.3"}
release.Assets.Links = []ReleaseLink{
{Name: assetName, URL: "https://releases.example.com/artifact"},
}
payload, marshalErr := json.Marshal(release)
if marshalErr != nil {
t.Fatalf("Marshal release: %v", marshalErr)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(payload)),
}, nil
case "https://releases.example.com/artifact":
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(
"<!DOCTYPE html><html><body>Access denied</body></html>",
)),
}, nil
default:
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("not found")),
}, nil
}
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
replaceCalled := false
err = Run(context.Background(), Options{
Client: client,
CurrentVersion: "v1.2.2",
ExecutablePath: target,
LatestReleaseURL: "https://releases.example.com/latest",
BinaryName: "graylog-mcp",
ReplaceExecutable: func(downloadPath, targetPath string) error {
replaceCalled = true
return os.WriteFile(targetPath, []byte("unexpected"), 0o755)
},
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "looks like an HTML page") {
t.Fatalf("error = %v", err)
}
if replaceCalled {
t.Fatal("replace hook should not have been called")
}
got, readErr := os.ReadFile(target)
if readErr != nil {
t.Fatalf("ReadFile target: %v", readErr)
}
if string(got) != "old-binary" {
t.Fatalf("target content = %q, want unchanged binary", string(got))
}
}
func TestRunUsesDriverWithoutExplicitLatestReleaseURL(t *testing.T) { func TestRunUsesDriverWithoutExplicitLatestReleaseURL(t *testing.T) {
assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH) assetName, err := AssetName("graylog-mcp", runtime.GOOS, runtime.GOARCH)
if err != nil { if err != nil {