fix: durcir le scaffold runtime et la sécurité des updates

This commit is contained in:
thibaud-lclr 2026-04-15 12:13:41 +02:00
parent 3eeb2fe173
commit 0d266cd5cc
5 changed files with 135 additions and 21 deletions

25
.gitea/workflows/ci.yml Normal file
View file

@ -0,0 +1,25 @@
name: CI
"on":
push:
branches:
- main
pull_request:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run go test
run: go test ./...
- name: Run go vet
run: go vet ./...

View file

@ -70,19 +70,24 @@ func Generate(options Options) (Result, error) {
} }
files := []generatedFile{ files := []generatedFile{
{Path: ".gitignore", Content: renderTemplate(gitignoreTemplate, normalized), Mode: 0o644}, {Path: ".gitignore", Template: gitignoreTemplate, Mode: 0o644},
{Path: "go.mod", Content: renderTemplate(goModTemplate, normalized), Mode: 0o644}, {Path: "go.mod", Template: goModTemplate, Mode: 0o644},
{Path: "README.md", Content: renderTemplate(readmeTemplate, normalized), Mode: 0o644}, {Path: "README.md", Template: readmeTemplate, Mode: 0o644},
{Path: "install.sh", Content: renderTemplate(installTemplate, normalized), Mode: 0o755}, {Path: "install.sh", Template: installTemplate, Mode: 0o755},
{Path: "mcp.toml", Content: renderTemplate(manifestTemplate, normalized), Mode: 0o644}, {Path: "mcp.toml", Template: manifestTemplate, Mode: 0o644},
{Path: filepath.Join("cmd", normalized.BinaryName, "main.go"), Content: renderTemplate(mainTemplate, normalized), Mode: 0o644}, {Path: filepath.Join("cmd", normalized.BinaryName, "main.go"), Template: mainTemplate, Mode: 0o644},
{Path: filepath.Join("internal", "app", "app.go"), Content: renderTemplate(appTemplate, normalized), Mode: 0o644}, {Path: filepath.Join("internal", "app", "app.go"), Template: appTemplate, Mode: 0o644},
} }
written := make([]string, 0, len(files)) written := make([]string, 0, len(files))
for _, file := range files { for _, file := range files {
content, err := renderTemplate(file.Template, normalized)
if err != nil {
return Result{}, fmt.Errorf("render scaffold file %q: %w", file.Path, err)
}
fullPath := filepath.Join(normalized.TargetDir, file.Path) fullPath := filepath.Join(normalized.TargetDir, file.Path)
if err := writeFile(fullPath, file.Content, file.Mode, normalized.Overwrite); err != nil { if err := writeFile(fullPath, content, file.Mode, normalized.Overwrite); err != nil {
return Result{}, err return Result{}, err
} }
written = append(written, file.Path) written = append(written, file.Path)
@ -97,7 +102,7 @@ func Generate(options Options) (Result, error) {
type generatedFile struct { type generatedFile struct {
Path string Path string
Content string Template string
Mode os.FileMode Mode os.FileMode
} }
@ -126,15 +131,18 @@ func writeFile(path, content string, mode os.FileMode, overwrite bool) error {
return nil return nil
} }
func renderTemplate(src string, data normalizedOptions) string { func renderTemplate(src string, data normalizedOptions) (string, error) {
tpl := template.Must(template.New("scaffold").Parse(src)) tpl, err := template.New("scaffold").Parse(src)
if err != nil {
return "", fmt.Errorf("parse template: %w", err)
}
var builder strings.Builder var builder strings.Builder
if err := tpl.Execute(&builder, data); err != nil { if err := tpl.Execute(&builder, data); err != nil {
panic(err) return "", fmt.Errorf("execute template: %w", err)
} }
return builder.String() return builder.String(), nil
} }
func normalizeOptions(options Options) (normalizedOptions, error) { func normalizeOptions(options Options) (normalizedOptions, error) {
@ -859,6 +867,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strings" "strings"
"gitea.lclr.dev/AI/mcp-framework/bootstrap" "gitea.lclr.dev/AI/mcp-framework/bootstrap"
@ -895,10 +904,20 @@ func Run(ctx context.Context, args []string, version string) error {
} }
func NewRuntime(version string) (Runtime, error) { func NewRuntime(version string) (Runtime, error) {
manifestFile, _, err := manifest.LoadDefault(".") manifestStartDir := "."
if executablePath, err := os.Executable(); err == nil {
if dir := strings.TrimSpace(filepath.Dir(executablePath)); dir != "" {
manifestStartDir = dir
}
}
manifestFile, _, err := manifest.LoadDefault(manifestStartDir)
if err != nil { if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return Runtime{}, err return Runtime{}, err
} }
manifestFile = manifest.File{}
}
bootstrapInfo := manifestFile.BootstrapInfo() bootstrapInfo := manifestFile.BootstrapInfo()
scaffoldInfo := manifestFile.ScaffoldInfo() scaffoldInfo := manifestFile.ScaffoldInfo()
@ -1162,7 +1181,7 @@ repository = "{{.ReleaseRepository}}"
base_url = "{{.ReleaseBaseURL}}" base_url = "{{.ReleaseBaseURL}}"
asset_name_template = "{binary}-{os}-{arch}{ext}" asset_name_template = "{binary}-{os}-{arch}{ext}"
checksum_asset_name = "{asset}.sha256" checksum_asset_name = "{asset}.sha256"
checksum_required = false checksum_required = true
token_header = "Authorization" token_header = "Authorization"
token_prefix = "token" token_prefix = "token"
token_env_names = ["{{.ReleaseTokenEnv}}"] token_env_names = ["{{.ReleaseTokenEnv}}"]

View file

@ -70,6 +70,8 @@ func TestGenerateCreatesRecommendedSkeleton(t *testing.T) {
"update.Run", "update.Run",
"manifest.LoadDefault", "manifest.LoadDefault",
"bootstrap.Run", "bootstrap.Run",
"os.Executable()",
"errors.Is(err, os.ErrNotExist)",
} { } {
if !strings.Contains(string(appGo), snippet) { if !strings.Contains(string(appGo), snippet) {
t.Fatalf("app.go missing snippet %q", snippet) t.Fatalf("app.go missing snippet %q", snippet)
@ -86,6 +88,7 @@ func TestGenerateCreatesRecommendedSkeleton(t *testing.T) {
for _, snippet := range []string{ for _, snippet := range []string{
"binary_name = \"my-mcp\"", "binary_name = \"my-mcp\"",
"[update]", "[update]",
"checksum_required = true",
"[secret_store]", "[secret_store]",
"[environment]", "[environment]",
"[profiles]", "[profiles]",

View file

@ -20,6 +20,7 @@ import (
) )
const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}" const defaultAssetNameTemplate = "{binary}-{os}-{arch}{ext}"
const defaultMaxDownloadBytes int64 = 200 * 1024 * 1024
type Options struct { type Options struct {
Client *http.Client Client *http.Client
@ -32,6 +33,7 @@ type Options struct {
ReleaseSource ReleaseSource ReleaseSource ReleaseSource
GOOS string GOOS string
GOARCH string GOARCH string
MaxDownloadBytes int64
ValidateDownloaded ValidateDownloadedFunc ValidateDownloaded ValidateDownloadedFunc
ReplaceExecutable ReplaceExecutableFunc ReplaceExecutable ReplaceExecutableFunc
} }
@ -124,6 +126,9 @@ func Run(ctx context.Context, opts Options) error {
if strings.TrimSpace(opts.GOARCH) == "" { if strings.TrimSpace(opts.GOARCH) == "" {
opts.GOARCH = runtime.GOARCH opts.GOARCH = runtime.GOARCH
} }
if opts.MaxDownloadBytes <= 0 {
opts.MaxDownloadBytes = defaultMaxDownloadBytes
}
source := normalizeSource(opts.ReleaseSource) source := normalizeSource(opts.ReleaseSource)
auth := ResolveAuth(source.Token, source) auth := ResolveAuth(source.Token, source)
@ -162,7 +167,7 @@ func Run(ctx context.Context, opts Options) error {
return err return err
} }
downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source) downloadPath, err := DownloadReleaseAsset(ctx, opts.Client, assetURL, targetPath, auth, source, opts.MaxDownloadBytes)
if err != nil { if err != nil {
return err return err
} }
@ -394,7 +399,18 @@ func (r Release) AssetURL(assetName, releaseURL string) (string, error) {
) )
} }
func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, targetPath string, auth Auth, source ReleaseSource) (string, error) { func DownloadReleaseAsset(
ctx context.Context,
client *http.Client,
assetURL, targetPath string,
auth Auth,
source ReleaseSource,
maxDownloadBytes int64,
) (string, error) {
if maxDownloadBytes <= 0 {
maxDownloadBytes = defaultMaxDownloadBytes
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, assetURL, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("build artifact download request: %w", err) return "", fmt.Errorf("build artifact download request: %w", err)
@ -419,6 +435,13 @@ func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, ta
strings.TrimSpace(string(body)), strings.TrimSpace(string(body)),
) )
} }
if resp.ContentLength > 0 && resp.ContentLength > maxDownloadBytes {
return "", fmt.Errorf(
"download release artifact: content length %d exceeds limit %d bytes",
resp.ContentLength,
maxDownloadBytes,
)
}
existingInfo, err := os.Stat(targetPath) existingInfo, err := os.Stat(targetPath)
if err != nil { if err != nil {
@ -437,9 +460,14 @@ func DownloadReleaseAsset(ctx context.Context, client *http.Client, assetURL, ta
return "", copyErr return "", copyErr
} }
if _, err := io.Copy(tempFile, resp.Body); err != nil { limited := &io.LimitedReader{R: resp.Body, N: maxDownloadBytes + 1}
written, err := io.Copy(tempFile, limited)
if err != nil {
return cleanup(fmt.Errorf("write downloaded artifact: %w", err)) return cleanup(fmt.Errorf("write downloaded artifact: %w", err))
} }
if written > maxDownloadBytes {
return cleanup(fmt.Errorf("write downloaded artifact: size exceeds limit %d bytes", maxDownloadBytes))
}
if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil { if err := tempFile.Chmod(existingInfo.Mode().Perm()); err != nil {
return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err)) return cleanup(fmt.Errorf("set executable mode on downloaded artifact: %w", err))
} }

View file

@ -384,6 +384,45 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error)
return fn(req) return fn(req)
} }
func TestDownloadReleaseAssetRejectsArtifactOverLimit(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() != "https://releases.example.com/artifact" {
t.Fatalf("unexpected url: %s", r.URL.String())
}
body := "123456"
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
ContentLength: int64(len(body)),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
}),
}
tempDir := t.TempDir()
target := filepath.Join(tempDir, "graylog-mcp")
if err := os.WriteFile(target, []byte("old-binary"), 0o755); err != nil {
t.Fatalf("WriteFile target: %v", err)
}
_, err := DownloadReleaseAsset(
context.Background(),
client,
"https://releases.example.com/artifact",
target,
Auth{},
ReleaseSource{},
5,
)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "exceeds limit") {
t.Fatalf("error = %v", err)
}
}
func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) { func TestRunReplacesExecutableWithLatestArtifact(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("self-replace is not supported on windows") t.Skip("self-replace is not supported on windows")