fix: tolerate unknown fields in MCP protocol messages and defer credential loading

Claude Code sends extra fields (e.g. "title") in initialize params that
caused the server to reject the request due to DisallowUnknownFields.
Use lenient JSON decoding for protocol messages while keeping strict
validation for tool arguments. Also defer KWallet credential loading
from server startup to tool invocation time, and negotiate protocol
versions per MCP spec instead of rejecting unknown ones.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
thibaud-leclere 2026-04-10 15:46:53 +02:00
parent 5f9e0e3a5a
commit 92fc30cb2d
3 changed files with 115 additions and 42 deletions

View file

@ -69,17 +69,26 @@ func TestExecuteSetupWritesWalletGuidanceAndReturnsExitCodeOne(t *testing.T) {
} }
} }
func TestExecuteMCPWritesMissingCredentialGuidanceAndReturnsExitCodeOne(t *testing.T) { func TestExecuteMCPReturnsMissingCredentialErrorOnToolCall(t *testing.T) {
store := &entrypointStoreStub{loadErr: kwallet.ErrCredentialNotFound} store := &entrypointStoreStub{loadErr: kwallet.ErrCredentialNotFound}
mail := entrypointMailServiceStub{} mail := entrypointMailServiceStub{}
runner := mcpserver.NewRunner(mcpserver.New(store, mail), nil, &bytes.Buffer{}, &bytes.Buffer{}) input := bytes.NewBufferString(
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
)
output := &bytes.Buffer{}
runner := mcpserver.NewRunner(mcpserver.New(store, mail), input, output, &bytes.Buffer{})
app := NewAppWithDependencies(nil, store, runner, nil) app := NewAppWithDependencies(nil, store, runner, nil)
stderr := &bytes.Buffer{} stderr := &bytes.Buffer{}
if code := Execute(app, []string{"mcp"}, stderr); code != 1 { if code := Execute(app, []string{"mcp"}, stderr); code != 0 {
t.Fatalf("expected exit code 1, got %d", code) t.Fatalf("expected exit code 0, got %d; stderr: %s", code, stderr.String())
} }
if got := stderr.String(); got != "credentials not configured; run `email-mcp setup`\n" {
t.Fatalf("unexpected stderr: %q", got) // Verify the credential error appears in the tool call response
got := output.String()
if !bytes.Contains([]byte(got), []byte("credentials not configured")) {
t.Fatalf("expected credential error in output, got %q", got)
} }
} }

View file

@ -221,11 +221,6 @@ func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (ima
} }
func (r Runner) Run(ctx context.Context) error { func (r Runner) Run(ctx context.Context) error {
cred, err := r.server.loadCredential(ctx)
if err != nil {
return err
}
stopCancelRead := r.closeInputOnCancel(ctx) stopCancelRead := r.closeInputOnCancel(ctx)
defer stopCancelRead() defer stopCancelRead()
@ -255,7 +250,7 @@ func (r Runner) Run(ctx context.Context) error {
return nil return nil
} }
if err := r.handleRPCRequest(ctx, encoder, &session, cred, request); err != nil { if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return err return err
} }
@ -273,7 +268,7 @@ type runnerSession struct {
protocolVersion string protocolVersion string
} }
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, cred secretstore.Credential, request rpcRequest) error { func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
if request.JSONRPC != jsonRPCVersion { if request.JSONRPC != jsonRPCVersion {
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil) return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil)
} }
@ -320,7 +315,7 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses
if isNotification { if isNotification {
return nil return nil
} }
return r.handleToolsCall(ctx, encoder, cred, request) return r.handleToolsCall(ctx, encoder, request)
default: default:
if isNotification { if isNotification {
return nil return nil
@ -331,20 +326,14 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses
func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error { func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
var params initializeParams var params initializeParams
if err := decodeArguments(request.Params, &params); err != nil { if err := decodeParams(request.Params, &params); err != nil {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
} }
if params.ProtocolVersion == "" { if params.ProtocolVersion == "" {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil) return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil)
} }
negotiatedVersion, ok := negotiateProtocolVersion(params.ProtocolVersion) negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion)
if !ok {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "Unsupported protocol version", map[string]any{
"supported": supportedProtocolVersions,
"requested": params.ProtocolVersion,
})
}
session.initialized = true session.initialized = true
session.protocolVersion = negotiatedVersion session.protocolVersion = negotiatedVersion
@ -361,15 +350,28 @@ func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession,
}) })
} }
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, cred secretstore.Credential, request rpcRequest) error { func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error {
var params toolsCallParams var params toolsCallParams
if err := decodeArguments(request.Params, &params); err != nil { if err := decodeParams(request.Params, &params); err != nil {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
} }
if strings.TrimSpace(params.Name) == "" { if strings.TrimSpace(params.Name) == "" {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil) return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil)
} }
cred, err := r.server.loadCredential(ctx)
if err != nil {
return writeRPCResult(encoder, request.ID, map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": err.Error(),
},
},
"isError": true,
})
}
result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments) result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments)
if err != nil { if err != nil {
var invalidErr *invalidParamsError var invalidErr *invalidParamsError
@ -490,6 +492,13 @@ func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, err
return cred, nil return cred, nil
} }
func decodeParams(raw json.RawMessage, dest any) error {
if len(raw) == 0 {
raw = []byte("{}")
}
return json.Unmarshal(raw, dest)
}
func decodeArguments(raw json.RawMessage, dest any) error { func decodeArguments(raw json.RawMessage, dest any) error {
if len(raw) == 0 { if len(raw) == 0 {
raw = []byte("{}") raw = []byte("{}")
@ -570,16 +579,13 @@ func (r Runner) closeInputOnCancel(ctx context.Context) func() {
} }
} }
func negotiateProtocolVersion(requested string) (string, bool) { func negotiateProtocolVersion(requested string) string {
for _, supported := range supportedProtocolVersions { for _, supported := range supportedProtocolVersions {
if requested == supported { if requested == supported {
return supported, true return supported
} }
} }
if len(supportedProtocolVersions) == 0 { return supportedProtocolVersions[0]
return "", false
}
return supportedProtocolVersions[0], false
} }
func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error { func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error {

View file

@ -345,6 +345,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
store := &storeStub{ store := &storeStub{
loadErr: kwallet.ErrCredentialNotFound, loadErr: kwallet.ErrCredentialNotFound,
} }
input := bytes.NewBufferString(
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
)
output := &bytes.Buffer{} output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{ runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
@ -359,14 +364,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
t.Fatal("GetMessage should not be called") t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil return imapclient.Message{}, nil
}, },
}), bytes.NewBuffer(nil), output, &bytes.Buffer{}) }), input, output, &bytes.Buffer{})
err := runner.Run(context.Background()) if err := runner.Run(context.Background()); err != nil {
if !errors.Is(err, ErrCredentialsNotConfigured) { t.Fatalf("Run returned error: %v", err)
t.Fatalf("expected missing credential error, got %v", err)
} }
if output.Len() != 0 {
t.Fatalf("expected no output when credentials are missing, got %q", output.String()) decoder := json.NewDecoder(output)
// Skip initialize response
var initResp json.RawMessage
if err := decoder.Decode(&initResp); err != nil {
t.Fatalf("failed to decode initialize response: %v", err)
}
// Check tool call response contains credential error
var toolResp struct {
ID int `json:"id"`
Result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
IsError bool `json:"isError"`
} `json:"result"`
}
if err := decoder.Decode(&toolResp); err != nil {
t.Fatalf("failed to decode tool call response: %v", err)
}
if !toolResp.Result.IsError {
t.Fatal("expected isError true for missing credentials")
}
if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() {
t.Fatalf("expected credential error message, got %#v", toolResp.Result)
} }
} }
@ -374,6 +403,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate
store := &storeStub{ store := &storeStub{
loadErr: ErrCredentialsNotConfigured, loadErr: ErrCredentialsNotConfigured,
} }
input := bytes.NewBufferString(
"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" +
"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" +
"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n",
)
output := &bytes.Buffer{} output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{ runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
@ -388,14 +422,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate
t.Fatal("GetMessage should not be called") t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil return imapclient.Message{}, nil
}, },
}), bytes.NewBuffer(nil), output, &bytes.Buffer{}) }), input, output, &bytes.Buffer{})
err := runner.Run(context.Background()) if err := runner.Run(context.Background()); err != nil {
if !errors.Is(err, ErrCredentialsNotConfigured) { t.Fatalf("Run returned error: %v", err)
t.Fatalf("expected missing credential error, got %v", err)
} }
if output.Len() != 0 {
t.Fatalf("expected no output when credentials are missing, got %q", output.String()) decoder := json.NewDecoder(output)
// Skip initialize response
var initResp json.RawMessage
if err := decoder.Decode(&initResp); err != nil {
t.Fatalf("failed to decode initialize response: %v", err)
}
// Check tool call response contains credential error
var toolResp struct {
ID int `json:"id"`
Result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
IsError bool `json:"isError"`
} `json:"result"`
}
if err := decoder.Decode(&toolResp); err != nil {
t.Fatalf("failed to decode tool call response: %v", err)
}
if !toolResp.Result.IsError {
t.Fatal("expected isError true for missing credentials")
}
if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() {
t.Fatalf("expected credential error message, got %#v", toolResp.Result)
} }
} }