xdebug-mcp/internal/tools/callers.go
thibaud-leclere 3943bfb8cc feat: add get_callers and get_callees tools
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:52:51 +02:00

111 lines
3 KiB
Go

package tools
import (
"context"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"forge.lclr.dev/AI/xdebug-mcp/internal/cache"
"forge.lclr.dev/AI/xdebug-mcp/internal/cachegrind"
)
// CallersTool returns the MCP tool definition for get_callers.
func CallersTool() mcp.Tool {
return mcp.NewTool("get_callers",
mcp.WithDescription("List functions that call a given function in an Xdebug profiling file, sorted by call cost descending."),
mcp.WithString("file_path",
mcp.Required(),
mcp.Description("Absolute or relative path to the cachegrind file"),
),
mcp.WithString("function_name",
mcp.Required(),
mcp.Description("Exact function name or substring to search for"),
),
mcp.WithNumber("top_n",
mcp.Description("Maximum number of callers to return (default: 10)"),
),
)
}
// CallersHandler returns the MCP handler for get_callers.
func CallersHandler(c *cache.Cache) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
filePath := req.GetString("file_path", "")
if filePath == "" {
return mcp.NewToolResultError("file_path is required"), nil
}
name := req.GetString("function_name", "")
if name == "" {
return mcp.NewToolResultError("function_name is required"), nil
}
topN := req.GetInt("top_n", 10)
if topN <= 0 {
topN = 10
}
p, err := loadProfile(filePath, c)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(Callers(p, name, topN)), nil
}
}
// Callers formats the callers of name in p. Exported for testing.
func Callers(p *cachegrind.Profile, name string, topN int) string {
fns, errMsg := findFunctions(p, name)
if errMsg != "" {
return errMsg
}
var sb strings.Builder
for _, fn := range fns {
fmt.Fprintf(&sb, "Callers of %q [%s]\n", fn.Name, fn.File)
if len(fn.CalledBy) == 0 {
fmt.Fprintf(&sb, " no incoming calls recorded\n\n")
continue
}
callers := sortedCalls(fn.CalledBy)
if topN < len(callers) {
callers = callers[:topN]
}
fmt.Fprintf(&sb, " called by %d function(s):\n\n", len(fn.CalledBy))
for i, call := range callers {
fmt.Fprintf(&sb, " %3d. %-60s calls=%d %s\n",
i+1, call.Caller.Name, call.Count, formatCosts(call.Costs, p.Events))
fmt.Fprintf(&sb, " %s\n", call.Caller.File)
}
fmt.Fprintln(&sb)
}
return sb.String()
}
// sortedCalls returns a copy of calls sorted by Costs[0] descending.
func sortedCalls(calls []*cachegrind.Call) []*cachegrind.Call {
out := make([]*cachegrind.Call, len(calls))
copy(out, calls)
sortCallSlice(out)
return out
}
func sortCallSlice(calls []*cachegrind.Call) {
for i := 1; i < len(calls); i++ {
for j := i; j > 0; j-- {
ci, cj := int64(0), int64(0)
if len(calls[j].Costs) > 0 {
ci = calls[j].Costs[0]
}
if len(calls[j-1].Costs) > 0 {
cj = calls[j-1].Costs[0]
}
if ci > cj {
calls[j], calls[j-1] = calls[j-1], calls[j]
} else {
break
}
}
}
}