feat: add get_callers and get_callees tools
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8fafc4a69e
commit
3943bfb8cc
3 changed files with 234 additions and 0 deletions
84
internal/tools/callees.go
Normal file
84
internal/tools/callees.go
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"forge.lclr.dev/AI/xdebug-mcp/internal/cache"
|
||||
"forge.lclr.dev/AI/xdebug-mcp/internal/cachegrind"
|
||||
)
|
||||
|
||||
// CalleesTool returns the MCP tool definition for get_callees.
|
||||
func CalleesTool() mcp.Tool {
|
||||
return mcp.NewTool("get_callees",
|
||||
mcp.WithDescription("List functions called by a given function in an Xdebug profiling file, sorted by call cost descending."),
|
||||
mcp.WithString("file_path",
|
||||
mcp.Required(),
|
||||
mcp.Description("Absolute or relative path to the cachegrind file"),
|
||||
),
|
||||
mcp.WithString("function_name",
|
||||
mcp.Required(),
|
||||
mcp.Description("Exact function name or substring to search for"),
|
||||
),
|
||||
mcp.WithNumber("top_n",
|
||||
mcp.Description("Maximum number of callees to return (default: 10)"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// CalleesHandler returns the MCP handler for get_callees.
|
||||
func CalleesHandler(c *cache.Cache) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
filePath := req.GetString("file_path", "")
|
||||
if filePath == "" {
|
||||
return mcp.NewToolResultError("file_path is required"), nil
|
||||
}
|
||||
name := req.GetString("function_name", "")
|
||||
if name == "" {
|
||||
return mcp.NewToolResultError("function_name is required"), nil
|
||||
}
|
||||
topN := req.GetInt("top_n", 10)
|
||||
if topN <= 0 {
|
||||
topN = 10
|
||||
}
|
||||
p, err := loadProfile(filePath, c)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
return mcp.NewToolResultText(Callees(p, name, topN)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Callees formats the callees of name in p. Exported for testing.
|
||||
func Callees(p *cachegrind.Profile, name string, topN int) string {
|
||||
fns, errMsg := findFunctions(p, name)
|
||||
if errMsg != "" {
|
||||
return errMsg
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, fn := range fns {
|
||||
fmt.Fprintf(&sb, "Callees of %q [%s]\n", fn.Name, fn.File)
|
||||
if len(fn.Calls) == 0 {
|
||||
fmt.Fprintf(&sb, " no outgoing calls recorded\n\n")
|
||||
continue
|
||||
}
|
||||
|
||||
callees := sortedCalls(fn.Calls)
|
||||
if topN < len(callees) {
|
||||
callees = callees[:topN]
|
||||
}
|
||||
fmt.Fprintf(&sb, " calls %d function(s):\n\n", len(fn.Calls))
|
||||
for i, call := range callees {
|
||||
fmt.Fprintf(&sb, " %3d. %-60s calls=%d %s\n",
|
||||
i+1, call.Callee.Name, call.Count, formatCosts(call.Costs, p.Events))
|
||||
fmt.Fprintf(&sb, " %s\n", call.Callee.File)
|
||||
}
|
||||
fmt.Fprintln(&sb)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
111
internal/tools/callers.go
Normal file
111
internal/tools/callers.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"forge.lclr.dev/AI/xdebug-mcp/internal/cache"
|
||||
"forge.lclr.dev/AI/xdebug-mcp/internal/cachegrind"
|
||||
)
|
||||
|
||||
// CallersTool returns the MCP tool definition for get_callers.
|
||||
func CallersTool() mcp.Tool {
|
||||
return mcp.NewTool("get_callers",
|
||||
mcp.WithDescription("List functions that call a given function in an Xdebug profiling file, sorted by call cost descending."),
|
||||
mcp.WithString("file_path",
|
||||
mcp.Required(),
|
||||
mcp.Description("Absolute or relative path to the cachegrind file"),
|
||||
),
|
||||
mcp.WithString("function_name",
|
||||
mcp.Required(),
|
||||
mcp.Description("Exact function name or substring to search for"),
|
||||
),
|
||||
mcp.WithNumber("top_n",
|
||||
mcp.Description("Maximum number of callers to return (default: 10)"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// CallersHandler returns the MCP handler for get_callers.
|
||||
func CallersHandler(c *cache.Cache) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
filePath := req.GetString("file_path", "")
|
||||
if filePath == "" {
|
||||
return mcp.NewToolResultError("file_path is required"), nil
|
||||
}
|
||||
name := req.GetString("function_name", "")
|
||||
if name == "" {
|
||||
return mcp.NewToolResultError("function_name is required"), nil
|
||||
}
|
||||
topN := req.GetInt("top_n", 10)
|
||||
if topN <= 0 {
|
||||
topN = 10
|
||||
}
|
||||
p, err := loadProfile(filePath, c)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
return mcp.NewToolResultText(Callers(p, name, topN)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Callers formats the callers of name in p. Exported for testing.
|
||||
func Callers(p *cachegrind.Profile, name string, topN int) string {
|
||||
fns, errMsg := findFunctions(p, name)
|
||||
if errMsg != "" {
|
||||
return errMsg
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, fn := range fns {
|
||||
fmt.Fprintf(&sb, "Callers of %q [%s]\n", fn.Name, fn.File)
|
||||
if len(fn.CalledBy) == 0 {
|
||||
fmt.Fprintf(&sb, " no incoming calls recorded\n\n")
|
||||
continue
|
||||
}
|
||||
|
||||
callers := sortedCalls(fn.CalledBy)
|
||||
if topN < len(callers) {
|
||||
callers = callers[:topN]
|
||||
}
|
||||
fmt.Fprintf(&sb, " called by %d function(s):\n\n", len(fn.CalledBy))
|
||||
for i, call := range callers {
|
||||
fmt.Fprintf(&sb, " %3d. %-60s calls=%d %s\n",
|
||||
i+1, call.Caller.Name, call.Count, formatCosts(call.Costs, p.Events))
|
||||
fmt.Fprintf(&sb, " %s\n", call.Caller.File)
|
||||
}
|
||||
fmt.Fprintln(&sb)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// sortedCalls returns a copy of calls sorted by Costs[0] descending.
|
||||
func sortedCalls(calls []*cachegrind.Call) []*cachegrind.Call {
|
||||
out := make([]*cachegrind.Call, len(calls))
|
||||
copy(out, calls)
|
||||
sortCallSlice(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func sortCallSlice(calls []*cachegrind.Call) {
|
||||
for i := 1; i < len(calls); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
ci, cj := int64(0), int64(0)
|
||||
if len(calls[j].Costs) > 0 {
|
||||
ci = calls[j].Costs[0]
|
||||
}
|
||||
if len(calls[j-1].Costs) > 0 {
|
||||
cj = calls[j-1].Costs[0]
|
||||
}
|
||||
if ci > cj {
|
||||
calls[j], calls[j-1] = calls[j-1], calls[j]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -59,3 +59,42 @@ func TestAnalyze_TopNLimited(t *testing.T) {
|
|||
|
||||
// require is used to avoid unused import error if tests change
|
||||
var _ = require.New
|
||||
|
||||
func TestCallers_Found(t *testing.T) {
|
||||
p := makeTestProfile()
|
||||
result := tools.Callers(p, "query", 10)
|
||||
|
||||
assert.Contains(t, result, "main")
|
||||
assert.Contains(t, result, "calls=1")
|
||||
require.NotContains(t, result, "not found")
|
||||
}
|
||||
|
||||
func TestCallers_ContainsFallback(t *testing.T) {
|
||||
p := makeTestProfile()
|
||||
// "uer" is a substring of "query"
|
||||
result := tools.Callers(p, "uer", 10)
|
||||
|
||||
assert.Contains(t, result, "main")
|
||||
}
|
||||
|
||||
func TestCallers_NotFound(t *testing.T) {
|
||||
p := makeTestProfile()
|
||||
result := tools.Callers(p, "nonexistent_xyz", 10)
|
||||
|
||||
assert.Contains(t, result, "not found")
|
||||
}
|
||||
|
||||
func TestCallees_Found(t *testing.T) {
|
||||
p := makeTestProfile()
|
||||
result := tools.Callees(p, "query", 10)
|
||||
|
||||
assert.Contains(t, result, "connect")
|
||||
assert.Contains(t, result, "calls=2")
|
||||
}
|
||||
|
||||
func TestCallees_NotFound(t *testing.T) {
|
||||
p := makeTestProfile()
|
||||
result := tools.Callees(p, "connect", 10) // connect has no callees
|
||||
|
||||
assert.Contains(t, result, "no outgoing calls")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue