diff --git a/internal/cli/integration_test.go b/internal/cli/integration_test.go index 8e341c8..5824782 100644 --- a/internal/cli/integration_test.go +++ b/internal/cli/integration_test.go @@ -69,17 +69,26 @@ func TestExecuteSetupWritesWalletGuidanceAndReturnsExitCodeOne(t *testing.T) { } } -func TestExecuteMCPWritesMissingCredentialGuidanceAndReturnsExitCodeOne(t *testing.T) { +func TestExecuteMCPReturnsMissingCredentialErrorOnToolCall(t *testing.T) { store := &entrypointStoreStub{loadErr: kwallet.ErrCredentialNotFound} mail := entrypointMailServiceStub{} - runner := mcpserver.NewRunner(mcpserver.New(store, mail), nil, &bytes.Buffer{}, &bytes.Buffer{}) + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) + output := &bytes.Buffer{} + runner := mcpserver.NewRunner(mcpserver.New(store, mail), input, output, &bytes.Buffer{}) app := NewAppWithDependencies(nil, store, runner, nil) stderr := &bytes.Buffer{} - if code := Execute(app, []string{"mcp"}, stderr); code != 1 { - t.Fatalf("expected exit code 1, got %d", code) + if code := Execute(app, []string{"mcp"}, stderr); code != 0 { + t.Fatalf("expected exit code 0, got %d; stderr: %s", code, stderr.String()) } - if got := stderr.String(); got != "credentials not configured; run `email-mcp setup`\n" { - t.Fatalf("unexpected stderr: %q", got) + + // Verify the credential error appears in the tool call response + got := output.String() + if !bytes.Contains([]byte(got), []byte("credentials not configured")) { + t.Fatalf("expected credential error in output, got %q", got) } } diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 8223f59..9d49c84 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -221,11 +221,6 @@ func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (ima } func (r Runner) Run(ctx context.Context) error { - cred, err := r.server.loadCredential(ctx) - if err != nil { - return err - } - stopCancelRead := r.closeInputOnCancel(ctx) defer stopCancelRead() @@ -255,7 +250,7 @@ func (r Runner) Run(ctx context.Context) error { return nil } - if err := r.handleRPCRequest(ctx, encoder, &session, cred, request); err != nil { + if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil { if errors.Is(err, context.Canceled) { return err } @@ -273,7 +268,7 @@ type runnerSession struct { protocolVersion string } -func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, cred secretstore.Credential, request rpcRequest) error { +func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error { if request.JSONRPC != jsonRPCVersion { return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil) } @@ -320,7 +315,7 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses if isNotification { return nil } - return r.handleToolsCall(ctx, encoder, cred, request) + return r.handleToolsCall(ctx, encoder, request) default: if isNotification { return nil @@ -331,20 +326,14 @@ func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, ses func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error { var params initializeParams - if err := decodeArguments(request.Params, ¶ms); err != nil { + if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if params.ProtocolVersion == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil) } - negotiatedVersion, ok := negotiateProtocolVersion(params.ProtocolVersion) - if !ok { - return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "Unsupported protocol version", map[string]any{ - "supported": supportedProtocolVersions, - "requested": params.ProtocolVersion, - }) - } + negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion) session.initialized = true session.protocolVersion = negotiatedVersion @@ -361,15 +350,28 @@ func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, }) } -func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, cred secretstore.Credential, request rpcRequest) error { +func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error { var params toolsCallParams - if err := decodeArguments(request.Params, ¶ms); err != nil { + if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if strings.TrimSpace(params.Name) == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil) } + cred, err := r.server.loadCredential(ctx) + if err != nil { + return writeRPCResult(encoder, request.ID, map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": err.Error(), + }, + }, + "isError": true, + }) + } + result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments) if err != nil { var invalidErr *invalidParamsError @@ -490,6 +492,13 @@ func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, err return cred, nil } +func decodeParams(raw json.RawMessage, dest any) error { + if len(raw) == 0 { + raw = []byte("{}") + } + return json.Unmarshal(raw, dest) +} + func decodeArguments(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") @@ -570,16 +579,13 @@ func (r Runner) closeInputOnCancel(ctx context.Context) func() { } } -func negotiateProtocolVersion(requested string) (string, bool) { +func negotiateProtocolVersion(requested string) string { for _, supported := range supportedProtocolVersions { if requested == supported { - return supported, true + return supported } } - if len(supportedProtocolVersions) == 0 { - return "", false - } - return supportedProtocolVersions[0], false + return supportedProtocolVersions[0] } func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error { diff --git a/internal/mcpserver/server_test.go b/internal/mcpserver/server_test.go index 523ed8f..1039920 100644 --- a/internal/mcpserver/server_test.go +++ b/internal/mcpserver/server_test.go @@ -345,6 +345,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) { store := &storeStub{ loadErr: kwallet.ErrCredentialNotFound, } + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -359,14 +364,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) { t.Fatal("GetMessage should not be called") return imapclient.Message{}, nil }, - }), bytes.NewBuffer(nil), output, &bytes.Buffer{}) + }), input, output, &bytes.Buffer{}) - err := runner.Run(context.Background()) - if !errors.Is(err, ErrCredentialsNotConfigured) { - t.Fatalf("expected missing credential error, got %v", err) + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) } - if output.Len() != 0 { - t.Fatalf("expected no output when credentials are missing, got %q", output.String()) + + decoder := json.NewDecoder(output) + + // Skip initialize response + var initResp json.RawMessage + if err := decoder.Decode(&initResp); err != nil { + t.Fatalf("failed to decode initialize response: %v", err) + } + + // Check tool call response contains credential error + var toolResp struct { + ID int `json:"id"` + Result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + if err := decoder.Decode(&toolResp); err != nil { + t.Fatalf("failed to decode tool call response: %v", err) + } + if !toolResp.Result.IsError { + t.Fatal("expected isError true for missing credentials") + } + if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() { + t.Fatalf("expected credential error message, got %#v", toolResp.Result) } } @@ -374,6 +403,11 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate store := &storeStub{ loadErr: ErrCredentialsNotConfigured, } + input := bytes.NewBufferString( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test\",\"version\":\"1.0.0\"}}}\n" + + "{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n" + + "{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_mailboxes\"}}\n", + ) output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -388,14 +422,38 @@ func TestRunnerRunReturnsFriendlyMissingCredentialErrorWhenStoreAlreadyTranslate t.Fatal("GetMessage should not be called") return imapclient.Message{}, nil }, - }), bytes.NewBuffer(nil), output, &bytes.Buffer{}) + }), input, output, &bytes.Buffer{}) - err := runner.Run(context.Background()) - if !errors.Is(err, ErrCredentialsNotConfigured) { - t.Fatalf("expected missing credential error, got %v", err) + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) } - if output.Len() != 0 { - t.Fatalf("expected no output when credentials are missing, got %q", output.String()) + + decoder := json.NewDecoder(output) + + // Skip initialize response + var initResp json.RawMessage + if err := decoder.Decode(&initResp); err != nil { + t.Fatalf("failed to decode initialize response: %v", err) + } + + // Check tool call response contains credential error + var toolResp struct { + ID int `json:"id"` + Result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + if err := decoder.Decode(&toolResp); err != nil { + t.Fatalf("failed to decode tool call response: %v", err) + } + if !toolResp.Result.IsError { + t.Fatal("expected isError true for missing credentials") + } + if len(toolResp.Result.Content) == 0 || toolResp.Result.Content[0].Text != ErrCredentialsNotConfigured.Error() { + t.Fatalf("expected credential error message, got %#v", toolResp.Result) } }