fix: tolerate unknown fields in MCP protocol messages and defer credential loading
Claude Code sends extra fields (e.g. "title") in initialize params that caused the server to reject the request due to DisallowUnknownFields. Use lenient JSON decoding for protocol messages while keeping strict validation for tool arguments. Also defer KWallet credential loading from server startup to tool invocation time, and negotiate protocol versions per MCP spec instead of rejecting unknown ones. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
5f9e0e3a5a
commit
92fc30cb2d
3 changed files with 115 additions and 42 deletions
|
|
@ -69,17 +69,26 @@ func TestExecuteSetupWritesWalletGuidanceAndReturnsExitCodeOne(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecuteMCPWritesMissingCredentialGuidanceAndReturnsExitCodeOne(t *testing.T) {
|
||||
func TestExecuteMCPReturnsMissingCredentialErrorOnToolCall(t *testing.T) {
|
||||
store := &entrypointStoreStub{loadErr: kwallet.ErrCredentialNotFound}
|
||||
mail := entrypointMailServiceStub{}
|
||||
runner := mcpserver.NewRunner(mcpserver.New(store, mail), nil, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
input := bytes.NewBufferString(
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
|
||||
)
|
||||
output := &bytes.Buffer{}
|
||||
runner := mcpserver.NewRunner(mcpserver.New(store, mail), input, output, &bytes.Buffer{})
|
||||
app := NewAppWithDependencies(nil, store, runner, nil)
|
||||
|
||||
stderr := &bytes.Buffer{}
|
||||
if code := Execute(app, []string{"mcp"}, stderr); code != 1 {
|
||||
t.Fatalf("expected exit code 1, got %d", code)
|
||||
if code := Execute(app, []string{"mcp"}, stderr); code != 0 {
|
||||
t.Fatalf("expected exit code 0, got %d; stderr: %s", code, stderr.String())
|
||||
}
|
||||
if got := stderr.String(); got != "credentials not configured; run `email-mcp setup`\n" {
|
||||
t.Fatalf("unexpected stderr: %q", got)
|
||||
|
||||
// Verify the credential error appears in the tool call response
|
||||
got := output.String()
|
||||
if !bytes.Contains([]byte(got), []byte("credentials not configured")) {
|
||||
t.Fatalf("expected credential error in output, got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -221,11 +221,6 @@ func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (ima
|
|||
}
|
||||
|
||||
func (r Runner) Run(ctx context.Context) error {
|
||||
cred, err := r.server.loadCredential(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stopCancelRead := r.closeInputOnCancel(ctx)
|
||||
defer stopCancelRead()
|
||||
|
||||
|
|
@ -255,7 +250,7 @@ func (r Runner) Run(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := r.handleRPCRequest(ctx, encoder, &session, cred, request); err != nil {
|
||||
if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
|
|
@ -273,7 +268,7 @@ type runnerSession struct {
|
|||
protocolVersion string
|
||||
}
|
||||
|
||||
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, cred secretstore.Credential, request rpcRequest) error {
|
||||
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
||||
if request.JSONRPC != jsonRPCVersion {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil)
|
||||
}
|
||||
|
|
@ -320,7 +315,7 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses
|
|||
if isNotification {
|
||||
return nil
|
||||
}
|
||||
return r.handleToolsCall(ctx, encoder, cred, request)
|
||||
return r.handleToolsCall(ctx, encoder, request)
|
||||
default:
|
||||
if isNotification {
|
||||
return nil
|
||||
|
|
@ -331,20 +326,14 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses
|
|||
|
||||
func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
||||
var params initializeParams
|
||||
if err := decodeArguments(request.Params, ¶ms); err != nil {
|
||||
if err := decodeParams(request.Params, ¶ms); err != nil {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
||||
}
|
||||
if params.ProtocolVersion == "" {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil)
|
||||
}
|
||||
|
||||
negotiatedVersion, ok := negotiateProtocolVersion(params.ProtocolVersion)
|
||||
if !ok {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "Unsupported protocol version", map[string]any{
|
||||
"supported": supportedProtocolVersions,
|
||||
"requested": params.ProtocolVersion,
|
||||
})
|
||||
}
|
||||
negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion)
|
||||
|
||||
session.initialized = true
|
||||
session.protocolVersion = negotiatedVersion
|
||||
|
|
@ -361,15 +350,28 @@ func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession,
|
|||
})
|
||||
}
|
||||
|
||||
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, cred secretstore.Credential, request rpcRequest) error {
|
||||
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error {
|
||||
var params toolsCallParams
|
||||
if err := decodeArguments(request.Params, ¶ms); err != nil {
|
||||
if err := decodeParams(request.Params, ¶ms); err != nil {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
||||
}
|
||||
if strings.TrimSpace(params.Name) == "" {
|
||||
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil)
|
||||
}
|
||||
|
||||
cred, err := r.server.loadCredential(ctx)
|
||||
if err != nil {
|
||||
return writeRPCResult(encoder, request.ID, map[string]any{
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": err.Error(),
|
||||
},
|
||||
},
|
||||
"isError": true,
|
||||
})
|
||||
}
|
||||
|
||||
result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments)
|
||||
if err != nil {
|
||||
var invalidErr *invalidParamsError
|
||||
|
|
@ -490,6 +492,13 @@ func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, err
|
|||
return cred, nil
|
||||
}
|
||||
|
||||
func decodeParams(raw json.RawMessage, dest any) error {
|
||||
if len(raw) == 0 {
|
||||
raw = []byte("{}")
|
||||
}
|
||||
return json.Unmarshal(raw, dest)
|
||||
}
|
||||
|
||||
func decodeArguments(raw json.RawMessage, dest any) error {
|
||||
if len(raw) == 0 {
|
||||
raw = []byte("{}")
|
||||
|
|
@ -570,16 +579,13 @@ func (r Runner) closeInputOnCancel(ctx context.Context) func() {
|
|||
}
|
||||
}
|
||||
|
||||
func negotiateProtocolVersion(requested string) (string, bool) {
|
||||
func negotiateProtocolVersion(requested string) string {
|
||||
for _, supported := range supportedProtocolVersions {
|
||||
if requested == supported {
|
||||
return supported, true
|
||||
return supported
|
||||
}
|
||||
}
|
||||
if len(supportedProtocolVersions) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return supportedProtocolVersions[0], false
|
||||
return supportedProtocolVersions[0]
|
||||
}
|
||||
|
||||
func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error {
|
||||
|
|
|
|||
|
|
@ -345,6 +345,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
|
|||
store := &storeStub{
|
||||
loadErr: kwallet.ErrCredentialNotFound,
|
||||
}
|
||||
input := bytes.NewBufferString(
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
|
||||
)
|
||||
output := &bytes.Buffer{}
|
||||
runner := NewRunner(New(store, serviceStub{
|
||||
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||
|
|
@ -359,14 +364,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
|
|||
t.Fatal("GetMessage should not be called")
|
||||
return imapclient.Message{}, nil
|
||||
},
|
||||
}), bytes.NewBuffer(nil), output, &bytes.Buffer{})
|
||||
}), input, output, &bytes.Buffer{})
|
||||
|
||||
err := runner.Run(context.Background())
|
||||
if !errors.Is(err, ErrCredentialsNotConfigured) {
|
||||
t.Fatalf("expected missing credential error, got %v", err)
|
||||
if err := runner.Run(context.Background()); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if output.Len() != 0 {
|
||||
t.Fatalf("expected no output when credentials are missing, got %q", output.String())
|
||||
|
||||
decoder := json.NewDecoder(output)
|
||||
|
||||
// Skip initialize response
|
||||
var initResp json.RawMessage
|
||||
if err := decoder.Decode(&initResp); err != nil {
|
||||
t.Fatalf("failed to decode initialize response: %v", err)
|
||||
}
|
||||
|
||||
// Check tool call response contains credential error
|
||||
var toolResp struct {
|
||||
ID int `json:"id"`
|
||||
Result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
IsError bool `json:"isError"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := decoder.Decode(&toolResp); err != nil {
|
||||
t.Fatalf("failed to decode tool call response: %v", err)
|
||||
}
|
||||
if !toolResp.Result.IsError {
|
||||
t.Fatal("expected isError true for missing credentials")
|
||||
}
|
||||
if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() {
|
||||
t.Fatalf("expected credential error message, got %#v", toolResp.Result)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -374,6 +403,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate
|
|||
store := &storeStub{
|
||||
loadErr: ErrCredentialsNotConfigured,
|
||||
}
|
||||
input := bytes.NewBufferString(
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
|
||||
)
|
||||
output := &bytes.Buffer{}
|
||||
runner := NewRunner(New(store, serviceStub{
|
||||
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||
|
|
@ -388,14 +422,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate
|
|||
t.Fatal("GetMessage should not be called")
|
||||
return imapclient.Message{}, nil
|
||||
},
|
||||
}), bytes.NewBuffer(nil), output, &bytes.Buffer{})
|
||||
}), input, output, &bytes.Buffer{})
|
||||
|
||||
err := runner.Run(context.Background())
|
||||
if !errors.Is(err, ErrCredentialsNotConfigured) {
|
||||
t.Fatalf("expected missing credential error, got %v", err)
|
||||
if err := runner.Run(context.Background()); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if output.Len() != 0 {
|
||||
t.Fatalf("expected no output when credentials are missing, got %q", output.String())
|
||||
|
||||
decoder := json.NewDecoder(output)
|
||||
|
||||
// Skip initialize response
|
||||
var initResp json.RawMessage
|
||||
if err := decoder.Decode(&initResp); err != nil {
|
||||
t.Fatalf("failed to decode initialize response: %v", err)
|
||||
}
|
||||
|
||||
// Check tool call response contains credential error
|
||||
var toolResp struct {
|
||||
ID int `json:"id"`
|
||||
Result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
IsError bool `json:"isError"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := decoder.Decode(&toolResp); err != nil {
|
||||
t.Fatalf("failed to decode tool call response: %v", err)
|
||||
}
|
||||
if !toolResp.Result.IsError {
|
||||
t.Fatal("expected isError true for missing credentials")
|
||||
}
|
||||
if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() {
|
||||
t.Fatalf("expected credential error message, got %#v", toolResp.Result)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue