From 92fc30cb2dce20a19ab56eace7e4ce28f9c260b4 Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Fri, 10 Apr 2026 15:46:53 +0200 Subject: [PATCH] fix: tolerate unknown fields in MCP protocol messages and defer credential loading Claude Code sends extra fields (e.g. "title") in initialize params that caused the server to reject the request due to DisallowUnknownFields. Use lenient JSON decoding for protocol messages while keeping strict validation for tool arguments. Also defer KWallet credential loading from server startup to tool invocation time, and negotiate protocol versions per MCP spec instead of rejecting unknown ones. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/cli/integration_test.go | 21 +++++--- internal/mcpserver/server.go | 54 +++++++++++--------- internal/mcpserver/server_test.go | 82 ++++++++++++++++++++++++++----- 3 files changed, 115 insertions(+), 42 deletions(-) diff --git a/internal/cli/integration_test.go b/internal/cli/integration_test.go index 8e341c8..5824782 100644 --- a/internal/cli/integration_test.go +++ b/internal/cli/integration_test.go @@ -69,17 +69,26 @@ func TestExecuteSetupWritesWalletGuidanceAndReturnsExitCodeOne(t *testing.T) { } } -func TestExecuteMCPWritesMissingCredentialGuidanceAndReturnsExitCodeOne(t *testing.T) { +func TestExecuteMCPReturnsMissingCredentialErrorOnToolCall(t *testing.T) { store := &entrypointStoreStub{loadErr: kwallet.ErrCredentialNotFound} mail := entrypointMailServiceStub{} - runner := mcpserver.NewRunner(mcpserver.New(store, mail), nil, &bytes.Buffer{}, &bytes.Buffer{}) + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) + output := &bytes.Buffer{} + runner := mcpserver.NewRunner(mcpserver.New(store, mail), input, output, &bytes.Buffer{}) app := NewAppWithDependencies(nil, store, runner, nil) stderr := &bytes.Buffer{} - if code := Execute(app, []string{"mcp"}, stderr); code != 1 { - t.Fatalf("expected exit code 1, got %d", code) + if code := Execute(app, []string{"mcp"}, stderr); code != 0 { + t.Fatalf("expected exit code 0, got %d; stderr: %s", code, stderr.String()) } - if got := stderr.String(); got != "credentials not configured; run `email-mcp setup`\n" { - t.Fatalf("unexpected stderr: %q", got) + + // Verify the credential error appears in the tool call response + got := output.String() + if !bytes.Contains([]byte(got), []byte("credentials not configured")) { + t.Fatalf("expected credential error in output, got %q", got) } } diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 8223f59..9d49c84 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -221,11 +221,6 @@ func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (ima } func (r Runner) Run(ctx context.Context) error { - cred, err := r.server.loadCredential(ctx) - if err != nil { - return err - } - stopCancelRead := r.closeInputOnCancel(ctx) defer stopCancelRead() @@ -255,7 +250,7 @@ func (r Runner) Run(ctx context.Context) error { return nil } - if err := r.handleRPCRequest(ctx, encoder, &session, cred, request); err != nil { + if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil { if errors.Is(err, context.Canceled) { return err } @@ -273,7 +268,7 @@ type runnerSession struct { protocolVersion string } -func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, cred secretstore.Credential, request rpcRequest) error { +func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error { if request.JSONRPC != jsonRPCVersion { return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil) } @@ -320,7 +315,7 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses if isNotification { return nil } - return r.handleToolsCall(ctx, encoder, cred, request) + return r.handleToolsCall(ctx, encoder, request) default: if isNotification { return nil @@ -331,20 +326,14 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error { var params initializeParams - if err := decodeArguments(request.Params, ¶ms); err != nil { + if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if params.ProtocolVersion == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil) } - negotiatedVersion, ok := negotiateProtocolVersion(params.ProtocolVersion) - if !ok { - return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "Unsupported protocol version", map[string]any{ - "supported": supportedProtocolVersions, - "requested": params.ProtocolVersion, - }) - } + negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion) session.initialized = true session.protocolVersion = negotiatedVersion @@ -361,15 +350,28 @@ func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, }) } -func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, cred secretstore.Credential, request rpcRequest) error { +func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error { var params toolsCallParams - if err := decodeArguments(request.Params, ¶ms); err != nil { + if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if strings.TrimSpace(params.Name) == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil) } + cred, err := r.server.loadCredential(ctx) + if err != nil { + return writeRPCResult(encoder, request.ID, map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": err.Error(), + }, + }, + "isError": true, + }) + } + result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments) if err != nil { var invalidErr *invalidParamsError @@ -490,6 +492,13 @@ func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, err return cred, nil } +func decodeParams(raw json.RawMessage, dest any) error { + if len(raw) == 0 { + raw = []byte("{}") + } + return json.Unmarshal(raw, dest) +} + func decodeArguments(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") @@ -570,16 +579,13 @@ func (r Runner) closeInputOnCancel(ctx context.Context) func() { } } -func negotiateProtocolVersion(requested string) (string, bool) { +func negotiateProtocolVersion(requested string) string { for _, supported := range supportedProtocolVersions { if requested == supported { - return supported, true + return supported } } - if len(supportedProtocolVersions) == 0 { - return "", false - } - return supportedProtocolVersions[0], false + return supportedProtocolVersions[0] } func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error { diff --git a/internal/mcpserver/server_test.go b/internal/mcpserver/server_test.go index 523ed8f..1039920 100644 --- a/internal/mcpserver/server_test.go +++ b/internal/mcpserver/server_test.go @@ -345,6 +345,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) { store := &storeStub{ loadErr: kwallet.ErrCredentialNotFound, } + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -359,14 +364,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) { t.Fatal("GetMessage should not be called") return imapclient.Message{}, nil }, - }), bytes.NewBuffer(nil), output, &bytes.Buffer{}) + }), input, output, &bytes.Buffer{}) - err := runner.Run(context.Background()) - if !errors.Is(err, ErrCredentialsNotConfigured) { - t.Fatalf("expected missing credential error, got %v", err) + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) } - if output.Len() != 0 { - t.Fatalf("expected no output when credentials are missing, got %q", output.String()) + + decoder := json.NewDecoder(output) + + // Skip initialize response + var initResp json.RawMessage + if err := decoder.Decode(&initResp); err != nil { + t.Fatalf("failed to decode initialize response: %v", err) + } + + // Check tool call response contains credential error + var toolResp struct { + ID int `json:"id"` + Result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + if err := decoder.Decode(&toolResp); err != nil { + t.Fatalf("failed to decode tool call response: %v", err) + } + if !toolResp.Result.IsError { + t.Fatal("expected isError true for missing credentials") + } + if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() { + t.Fatalf("expected credential error message, got %#v", toolResp.Result) } } @@ -374,6 +403,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate store := &storeStub{ loadErr: ErrCredentialsNotConfigured, } + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -388,14 +422,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate t.Fatal("GetMessage should not be called") return imapclient.Message{}, nil }, - }), bytes.NewBuffer(nil), output, &bytes.Buffer{}) + }), input, output, &bytes.Buffer{}) - err := runner.Run(context.Background()) - if !errors.Is(err, ErrCredentialsNotConfigured) { - t.Fatalf("expected missing credential error, got %v", err) + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) } - if output.Len() != 0 { - t.Fatalf("expected no output when credentials are missing, got %q", output.String()) + + decoder := json.NewDecoder(output) + + // Skip initialize response + var initResp json.RawMessage + if err := decoder.Decode(&initResp); err != nil { + t.Fatalf("failed to decode initialize response: %v", err) + } + + // Check tool call response contains credential error + var toolResp struct { + ID int `json:"id"` + Result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + if err := decoder.Decode(&toolResp); err != nil { + t.Fatalf("failed to decode tool call response: %v", err) + } + if !toolResp.Result.IsError { + t.Fatal("expected isError true for missing credentials") + } + if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() { + t.Fatalf("expected credential error message, got %#v", toolResp.Result) } }