From 3943bfb8ccf213d3b0e6ce6ee8c7982702d25473 Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Tue, 12 May 2026 09:52:51 +0200 Subject: [PATCH] feat: add get_callers and get_callees tools Co-Authored-By: Claude Sonnet 4.6 --- internal/tools/callees.go | 84 ++++++++++++++++++++++++++ internal/tools/callers.go | 111 +++++++++++++++++++++++++++++++++++ internal/tools/tools_test.go | 39 ++++++++++++ 3 files changed, 234 insertions(+) create mode 100644 internal/tools/callees.go create mode 100644 internal/tools/callers.go diff --git a/internal/tools/callees.go b/internal/tools/callees.go new file mode 100644 index 0000000..8dfc23b --- /dev/null +++ b/internal/tools/callees.go @@ -0,0 +1,84 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "forge.lclr.dev/AI/xdebug-mcp/internal/cache" + "forge.lclr.dev/AI/xdebug-mcp/internal/cachegrind" +) + +// CalleesTool returns the MCP tool definition for get_callees. +func CalleesTool() mcp.Tool { + return mcp.NewTool("get_callees", + mcp.WithDescription("List functions called by a given function in an Xdebug profiling file, sorted by call cost descending."), + mcp.WithString("file_path", + mcp.Required(), + mcp.Description("Absolute or relative path to the cachegrind file"), + ), + mcp.WithString("function_name", + mcp.Required(), + mcp.Description("Exact function name or substring to search for"), + ), + mcp.WithNumber("top_n", + mcp.Description("Maximum number of callees to return (default: 10)"), + ), + ) +} + +// CalleesHandler returns the MCP handler for get_callees. +func CalleesHandler(c *cache.Cache) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + filePath := req.GetString("file_path", "") + if filePath == "" { + return mcp.NewToolResultError("file_path is required"), nil + } + name := req.GetString("function_name", "") + if name == "" { + return mcp.NewToolResultError("function_name is required"), nil + } + topN := req.GetInt("top_n", 10) + if topN <= 0 { + topN = 10 + } + p, err := loadProfile(filePath, c) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + return mcp.NewToolResultText(Callees(p, name, topN)), nil + } +} + +// Callees formats the callees of name in p. Exported for testing. +func Callees(p *cachegrind.Profile, name string, topN int) string { + fns, errMsg := findFunctions(p, name) + if errMsg != "" { + return errMsg + } + + var sb strings.Builder + for _, fn := range fns { + fmt.Fprintf(&sb, "Callees of %q [%s]\n", fn.Name, fn.File) + if len(fn.Calls) == 0 { + fmt.Fprintf(&sb, " no outgoing calls recorded\n\n") + continue + } + + callees := sortedCalls(fn.Calls) + if topN < len(callees) { + callees = callees[:topN] + } + fmt.Fprintf(&sb, " calls %d function(s):\n\n", len(fn.Calls)) + for i, call := range callees { + fmt.Fprintf(&sb, " %3d. %-60s calls=%d %s\n", + i+1, call.Callee.Name, call.Count, formatCosts(call.Costs, p.Events)) + fmt.Fprintf(&sb, " %s\n", call.Callee.File) + } + fmt.Fprintln(&sb) + } + return sb.String() +} diff --git a/internal/tools/callers.go b/internal/tools/callers.go new file mode 100644 index 0000000..23d0e9d --- /dev/null +++ b/internal/tools/callers.go @@ -0,0 +1,111 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "forge.lclr.dev/AI/xdebug-mcp/internal/cache" + "forge.lclr.dev/AI/xdebug-mcp/internal/cachegrind" +) + +// CallersTool returns the MCP tool definition for get_callers. +func CallersTool() mcp.Tool { + return mcp.NewTool("get_callers", + mcp.WithDescription("List functions that call a given function in an Xdebug profiling file, sorted by call cost descending."), + mcp.WithString("file_path", + mcp.Required(), + mcp.Description("Absolute or relative path to the cachegrind file"), + ), + mcp.WithString("function_name", + mcp.Required(), + mcp.Description("Exact function name or substring to search for"), + ), + mcp.WithNumber("top_n", + mcp.Description("Maximum number of callers to return (default: 10)"), + ), + ) +} + +// CallersHandler returns the MCP handler for get_callers. +func CallersHandler(c *cache.Cache) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + filePath := req.GetString("file_path", "") + if filePath == "" { + return mcp.NewToolResultError("file_path is required"), nil + } + name := req.GetString("function_name", "") + if name == "" { + return mcp.NewToolResultError("function_name is required"), nil + } + topN := req.GetInt("top_n", 10) + if topN <= 0 { + topN = 10 + } + p, err := loadProfile(filePath, c) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + return mcp.NewToolResultText(Callers(p, name, topN)), nil + } +} + +// Callers formats the callers of name in p. Exported for testing. +func Callers(p *cachegrind.Profile, name string, topN int) string { + fns, errMsg := findFunctions(p, name) + if errMsg != "" { + return errMsg + } + + var sb strings.Builder + for _, fn := range fns { + fmt.Fprintf(&sb, "Callers of %q [%s]\n", fn.Name, fn.File) + if len(fn.CalledBy) == 0 { + fmt.Fprintf(&sb, " no incoming calls recorded\n\n") + continue + } + + callers := sortedCalls(fn.CalledBy) + if topN < len(callers) { + callers = callers[:topN] + } + fmt.Fprintf(&sb, " called by %d function(s):\n\n", len(fn.CalledBy)) + for i, call := range callers { + fmt.Fprintf(&sb, " %3d. %-60s calls=%d %s\n", + i+1, call.Caller.Name, call.Count, formatCosts(call.Costs, p.Events)) + fmt.Fprintf(&sb, " %s\n", call.Caller.File) + } + fmt.Fprintln(&sb) + } + return sb.String() +} + +// sortedCalls returns a copy of calls sorted by Costs[0] descending. +func sortedCalls(calls []*cachegrind.Call) []*cachegrind.Call { + out := make([]*cachegrind.Call, len(calls)) + copy(out, calls) + sortCallSlice(out) + return out +} + +func sortCallSlice(calls []*cachegrind.Call) { + for i := 1; i < len(calls); i++ { + for j := i; j > 0; j-- { + ci, cj := int64(0), int64(0) + if len(calls[j].Costs) > 0 { + ci = calls[j].Costs[0] + } + if len(calls[j-1].Costs) > 0 { + cj = calls[j-1].Costs[0] + } + if ci > cj { + calls[j], calls[j-1] = calls[j-1], calls[j] + } else { + break + } + } + } +} diff --git a/internal/tools/tools_test.go b/internal/tools/tools_test.go index 9765097..fe80f57 100644 --- a/internal/tools/tools_test.go +++ b/internal/tools/tools_test.go @@ -59,3 +59,42 @@ func TestAnalyze_TopNLimited(t *testing.T) { // require is used to avoid unused import error if tests change var _ = require.New + +func TestCallers_Found(t *testing.T) { + p := makeTestProfile() + result := tools.Callers(p, "query", 10) + + assert.Contains(t, result, "main") + assert.Contains(t, result, "calls=1") + require.NotContains(t, result, "not found") +} + +func TestCallers_ContainsFallback(t *testing.T) { + p := makeTestProfile() + // "uer" is a substring of "query" + result := tools.Callers(p, "uer", 10) + + assert.Contains(t, result, "main") +} + +func TestCallers_NotFound(t *testing.T) { + p := makeTestProfile() + result := tools.Callers(p, "nonexistent_xyz", 10) + + assert.Contains(t, result, "not found") +} + +func TestCallees_Found(t *testing.T) { + p := makeTestProfile() + result := tools.Callees(p, "query", 10) + + assert.Contains(t, result, "connect") + assert.Contains(t, result, "calls=2") +} + +func TestCallees_NotFound(t *testing.T) { + p := makeTestProfile() + result := tools.Callees(p, "connect", 10) // connect has no callees + + assert.Contains(t, result, "no outgoing calls") +}