mcp-framework/generate/generate_test.go

213 lines
6.2 KiB
Go

package generate
import (
"errors"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestGenerateCreatesManifestLoader(t *testing.T) {
projectDir := newProject(t, `
binary_name = "demo-mcp"
docs_url = "https://docs.example.com/demo"
[bootstrap]
description = "Demo MCP"
`)
result, err := Generate(Options{ProjectDir: projectDir})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if !slices.Equal(result.Files, []string{filepath.Join("mcpgen", "manifest.go")}) {
t.Fatalf("result files = %v", result.Files)
}
generatedPath := filepath.Join(projectDir, "mcpgen", "manifest.go")
content, err := os.ReadFile(generatedPath)
if err != nil {
t.Fatalf("ReadFile generated manifest: %v", err)
}
for _, snippet := range []string{
"// Code generated by mcp-framework generate. DO NOT EDIT.",
"package mcpgen",
"import fwmanifest \"gitea.lclr.dev/AI/mcp-framework/manifest\"",
"const embeddedManifest = ",
"func LoadManifest(startDir string) (fwmanifest.File, string, error) {",
"return fwmanifest.LoadDefaultOrEmbedded(startDir, embeddedManifest)",
`binary_name = \"demo-mcp\"`,
} {
if !strings.Contains(string(content), snippet) {
t.Fatalf("generated manifest.go missing snippet %q:\n%s", snippet, content)
}
}
}
func TestGenerateIsIdempotentAndCheckDetectsDrift(t *testing.T) {
projectDir := newProject(t, `binary_name = "demo-mcp"`)
if _, err := Generate(Options{ProjectDir: projectDir}); err != nil {
t.Fatalf("first Generate returned error: %v", err)
}
generatedPath := filepath.Join(projectDir, "mcpgen", "manifest.go")
first, err := os.ReadFile(generatedPath)
if err != nil {
t.Fatalf("ReadFile first generated file: %v", err)
}
if _, err := Generate(Options{ProjectDir: projectDir}); err != nil {
t.Fatalf("second Generate returned error: %v", err)
}
second, err := os.ReadFile(generatedPath)
if err != nil {
t.Fatalf("ReadFile second generated file: %v", err)
}
if string(second) != string(first) {
t.Fatalf("second generation changed content")
}
if _, err := Generate(Options{ProjectDir: projectDir, Check: true}); err != nil {
t.Fatalf("check after generation returned error: %v", err)
}
if err := os.WriteFile(generatedPath, append(second, []byte("// drift\n")...), 0o600); err != nil {
t.Fatalf("WriteFile drift: %v", err)
}
_, err = Generate(Options{ProjectDir: projectDir, Check: true})
if !errors.Is(err, ErrGeneratedFilesOutdated) {
t.Fatalf("check error = %v, want ErrGeneratedFilesOutdated", err)
}
}
func TestGenerateSupportsManifestAndPackageFlags(t *testing.T) {
projectDir := t.TempDir()
manifestPath := filepath.Join(projectDir, "config", "custom.toml")
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
t.Fatalf("MkdirAll manifest dir: %v", err)
}
if err := os.WriteFile(manifestPath, []byte(`binary_name = "demo-mcp"`), 0o600); err != nil {
t.Fatalf("WriteFile manifest: %v", err)
}
result, err := Generate(Options{
ProjectDir: projectDir,
ManifestPath: manifestPath,
PackageDir: "internal/generated",
PackageName: "generated",
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if !slices.Equal(result.Files, []string{filepath.Join("internal", "generated", "manifest.go")}) {
t.Fatalf("result files = %v", result.Files)
}
content, err := os.ReadFile(filepath.Join(projectDir, "internal", "generated", "manifest.go"))
if err != nil {
t.Fatalf("ReadFile generated manifest: %v", err)
}
if !strings.Contains(string(content), "package generated") {
t.Fatalf("generated file should use package name: %s", content)
}
}
func TestGenerateRejectsInvalidManifest(t *testing.T) {
projectDir := newProject(t, "[bootstrap\n")
_, err := Generate(Options{ProjectDir: projectDir})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "parse manifest") {
t.Fatalf("error = %v", err)
}
}
func TestGeneratedLoaderFallsBackToEmbeddedManifest(t *testing.T) {
projectDir := newProject(t, `
binary_name = "embedded-demo"
docs_url = "https://docs.example.com/embedded"
`)
writeModule(t, projectDir)
if _, err := Generate(Options{ProjectDir: projectDir}); err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if err := os.Remove(filepath.Join(projectDir, "mcp.toml")); err != nil {
t.Fatalf("Remove runtime manifest: %v", err)
}
cmd := exec.Command("go", "test", "./...")
cmd.Dir = projectDir
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("go test generated project: %v\n%s", err, output)
}
}
func newProject(t *testing.T, manifest string) string {
t.Helper()
projectDir := t.TempDir()
if err := os.WriteFile(filepath.Join(projectDir, "mcp.toml"), []byte(manifest), 0o600); err != nil {
t.Fatalf("WriteFile manifest: %v", err)
}
return projectDir
}
func writeModule(t *testing.T, projectDir string) {
t.Helper()
repoRoot, err := filepath.Abs("..")
if err != nil {
t.Fatalf("Abs repo root: %v", err)
}
goMod := "module example.com/generated-demo\n\ngo 1.25.0\n\nrequire (\n\tgithub.com/BurntSushi/toml v1.6.0\n\tgitea.lclr.dev/AI/mcp-framework v0.0.0\n)\n\nreplace gitea.lclr.dev/AI/mcp-framework => " + filepath.ToSlash(repoRoot) + "\n"
if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(goMod), 0o600); err != nil {
t.Fatalf("WriteFile go.mod: %v", err)
}
goSum, err := os.ReadFile(filepath.Join(repoRoot, "go.sum"))
if err != nil {
t.Fatalf("ReadFile go.sum: %v", err)
}
if err := os.WriteFile(filepath.Join(projectDir, "go.sum"), goSum, 0o600); err != nil {
t.Fatalf("WriteFile go.sum: %v", err)
}
testFile := `package main
import (
"testing"
"example.com/generated-demo/mcpgen"
fwmanifest "gitea.lclr.dev/AI/mcp-framework/manifest"
)
func TestGeneratedLoaderUsesEmbeddedManifest(t *testing.T) {
file, source, err := mcpgen.LoadManifest(".")
if err != nil {
t.Fatalf("LoadManifest returned error: %v", err)
}
if source != fwmanifest.EmbeddedSource {
t.Fatalf("source = %q, want %q", source, fwmanifest.EmbeddedSource)
}
if file.BinaryName != "embedded-demo" {
t.Fatalf("binary name = %q", file.BinaryName)
}
}
`
if err := os.WriteFile(filepath.Join(projectDir, "main_test.go"), []byte(testFile), 0o600); err != nil {
t.Fatalf("WriteFile main_test.go: %v", err)
}
}