From 5dbc073e5c38ea21857b3c4237dff7a8f96fd611 Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Fri, 10 Apr 2026 15:24:37 +0200 Subject: [PATCH] fix: implement MCP JSON-RPC protocol --- README.md | 4 +- internal/mcpserver/server.go | 305 ++++++++++++++++++++++++++++-- internal/mcpserver/server_test.go | 198 ++++++++++++++++--- 3 files changed, 457 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index f1b2f1c..b210bd8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # email-mcp -Serveur MCP local pour lire une boîte mail via IMAP. Le projet expose trois outils : +Serveur MCP local pour lire une boîte mail via IMAP. Le serveur parle le protocole MCP standard sur `stdio`, avec des messages **JSON-RPC 2.0** (`initialize`, `notifications/initialized`, `tools/list`, `tools/call`). Le projet expose trois outils : - **`list_mailboxes`** — lister les boîtes IMAP visibles - **`list_messages`** — lister les messages récents d'une boîte @@ -53,7 +53,7 @@ Si KDE Wallet n'est pas disponible, le setup échoue explicitement et n'écrit r ### Étape 2 : lancer le serveur MCP -Le serveur MCP s'exécute sur `stdin/stdout` : +Le serveur MCP s'exécute sur `stdin/stdout` avec le handshake MCP standard : ```sh ./email-mcp mcp diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index c086646..3ef8f22 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -1,6 +1,7 @@ package mcpserver import ( + "bytes" "context" "encoding/json" "errors" @@ -15,6 +16,26 @@ import ( var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`") +const ( + jsonRPCVersion = "2.0" + mcpServerName = "email-mcp" + mcpServerVersion = "dev" + mcpMethodInitialize = "initialize" + mcpMethodInitialized = "notifications/initialized" + mcpMethodPing = "ping" + mcpMethodToolsList = "tools/list" + mcpMethodToolsCall = "tools/call" + jsonRPCParseErrorCode = -32700 + jsonRPCInvalidRequestCode = -32600 + jsonRPCMethodNotFoundCode = -32601 + jsonRPCInvalidParamsCode = -32602 + jsonRPCInternalErrorCode = -32603 +) + +var supportedProtocolVersions = []string{ + "2025-03-26", +} + type MailService interface { ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) @@ -29,7 +50,7 @@ type Server struct { type Tool struct { Name string `json:"name"` Description string `json:"description"` - InputSchema map[string]any `json:"input_schema,omitempty"` + InputSchema map[string]any `json:"inputSchema,omitempty"` } type Runner struct { @@ -44,6 +65,52 @@ type toolRequest struct { Arguments json.RawMessage `json:"arguments,omitempty"` } +type rpcRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type rpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error *rpcErrorObject `json:"error,omitempty"` +} + +type rpcErrorObject struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +type initializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"clientInfo"` +} + +type toolsCallParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +type invalidParamsError struct { + err error +} + +func (e *invalidParamsError) Error() string { + return e.err.Error() +} + +func (e *invalidParamsError) Unwrap() error { + return e.err +} + type listMessagesArguments struct { Mailbox string `json:"mailbox"` Limit *int `json:"limit,omitempty"` @@ -162,20 +229,18 @@ func (r Runner) Run(ctx context.Context) error { defer stopCancelRead() encoder := json.NewEncoder(r.out) - if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil { - return err - } if r.in == nil { return nil } decoder := json.NewDecoder(r.in) + session := runnerSession{} for { if err := ctx.Err(); err != nil { return err } - var request toolRequest + var request rpcRequest if err := decoder.Decode(&request); err != nil { if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr @@ -183,20 +248,158 @@ func (r Runner) Run(ctx context.Context) error { if errors.Is(err, io.EOF) { return nil } - return err + if writeErr := writeRPCError(encoder, nil, jsonRPCParseErrorCode, err.Error(), nil); writeErr != nil { + return writeErr + } + return nil } - result, err := r.server.handleTool(ctx, cred, request.Tool, request.Arguments) - if err != nil { - if encodeErr := encoder.Encode(map[string]any{"error": err.Error()}); encodeErr != nil { - return encodeErr + if err := r.handleRPCRequest(ctx, encoder, &session, cred, request); err != nil { + if errors.Is(err, context.Canceled) { + return err + } + if writeErr := writeRPCError(encoder, request.ID, jsonRPCInternalErrorCode, err.Error(), nil); writeErr != nil { + return writeErr } continue } - if err := encoder.Encode(map[string]any{"result": result}); err != nil { - return err + } +} + +type runnerSession struct { + initialized bool + ready bool + protocolVersion string +} + +func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, cred secretstore.Credential, request rpcRequest) error { + if request.JSONRPC != jsonRPCVersion { + return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil) + } + if request.Method == "" { + return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "method is required", nil) + } + + isNotification := len(request.ID) == 0 + if !session.initialized { + switch request.Method { + case mcpMethodInitialize: + if isNotification { + return writeRPCError(encoder, nil, jsonRPCInvalidRequestCode, "initialize must be a request", nil) + } + return r.handleInitialize(encoder, session, request) + case mcpMethodPing: + if isNotification { + return nil + } + return writeRPCResult(encoder, request.ID, map[string]any{}) + default: + if isNotification { + return nil + } + return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "server not initialized", nil) } } + + switch request.Method { + case mcpMethodInitialized: + session.ready = true + return nil + case mcpMethodPing: + if isNotification { + return nil + } + return writeRPCResult(encoder, request.ID, map[string]any{}) + case mcpMethodToolsList: + if isNotification { + return nil + } + return writeRPCResult(encoder, request.ID, map[string]any{"tools": r.server.Tools()}) + case mcpMethodToolsCall: + if isNotification { + return nil + } + return r.handleToolsCall(ctx, encoder, cred, request) + default: + if isNotification { + return nil + } + return writeRPCError(encoder, request.ID, jsonRPCMethodNotFoundCode, fmt.Sprintf("method not found: %s", request.Method), nil) + } +} + +func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error { + var params initializeParams + if err := decodeArguments(request.Params, ¶ms); err != nil { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) + } + if params.ProtocolVersion == "" { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil) + } + + negotiatedVersion, ok := negotiateProtocolVersion(params.ProtocolVersion) + if !ok { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "Unsupported protocol version", map[string]any{ + "supported": supportedProtocolVersions, + "requested": params.ProtocolVersion, + }) + } + + session.initialized = true + session.protocolVersion = negotiatedVersion + + return writeRPCResult(encoder, request.ID, map[string]any{ + "protocolVersion": negotiatedVersion, + "capabilities": map[string]any{ + "tools": map[string]any{}, + }, + "serverInfo": map[string]any{ + "name": mcpServerName, + "version": mcpServerVersion, + }, + }) +} + +func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, cred secretstore.Credential, request rpcRequest) error { + var params toolsCallParams + if err := decodeArguments(request.Params, ¶ms); err != nil { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) + } + if strings.TrimSpace(params.Name) == "" { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil) + } + + result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments) + if err != nil { + var invalidErr *invalidParamsError + if errors.As(err, &invalidErr) { + return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, invalidErr.Error(), nil) + } + return writeRPCResult(encoder, request.ID, map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": err.Error(), + }, + }, + "isError": true, + }) + } + + payload, err := json.Marshal(result) + if err != nil { + return err + } + + return writeRPCResult(encoder, request.ID, map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": string(payload), + }, + }, + "isError": false, + }) } func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, name string, rawArgs json.RawMessage) (any, error) { @@ -206,25 +409,33 @@ func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, nam case "list_messages": var args listMessagesArguments if err := decodeArguments(rawArgs, &args); err != nil { - return nil, err + return nil, &invalidParamsError{err: err} + } + mailbox, err := validateMailbox(args.Mailbox) + if err != nil { + return nil, &invalidParamsError{err: err} } limit, err := normalizeListMessagesLimit(args.Limit) if err != nil { - return nil, err + return nil, &invalidParamsError{err: err} } - return s.listMessages(ctx, cred, args.Mailbox, limit) + return s.listMessages(ctx, cred, mailbox, limit) case "get_message": var args getMessageArguments if err := decodeArguments(rawArgs, &args); err != nil { - return nil, err + return nil, &invalidParamsError{err: err} + } + mailbox, err := validateMailbox(args.Mailbox) + if err != nil { + return nil, &invalidParamsError{err: err} } uid, err := validateMessageUID(args.UID) if err != nil { - return nil, err + return nil, &invalidParamsError{err: err} } - return s.getMessage(ctx, cred, args.Mailbox, uid) + return s.getMessage(ctx, cred, mailbox, uid) default: - return nil, fmt.Errorf("unknown tool: %s", name) + return nil, &invalidParamsError{err: fmt.Errorf("unknown tool: %s", name)} } } @@ -282,14 +493,34 @@ func decodeArguments(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") } - decoder := json.NewDecoder(strings.NewReader(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) decoder.DisallowUnknownFields() if err := decoder.Decode(dest); err != nil { - return fmt.Errorf("invalid tool arguments: %w", err) + return normalizeDecodeError(err) + } + if decoder.More() { + return fmt.Errorf("invalid JSON object") } return nil } +func normalizeDecodeError(err error) error { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + + switch { + case errors.As(err, &syntaxErr): + return fmt.Errorf("invalid JSON arguments") + case errors.As(err, &typeErr): + if typeErr.Field != "" { + return fmt.Errorf("%s has an invalid type", typeErr.Field) + } + return fmt.Errorf("arguments have an invalid type") + default: + return fmt.Errorf("invalid arguments: %w", err) + } +} + func validateMailbox(mailbox string) (string, error) { mailbox = strings.TrimSpace(mailbox) if mailbox == "" { @@ -337,3 +568,35 @@ func (r Runner) closeInputOnCancel(ctx context.Context) func() { close(done) } } + +func negotiateProtocolVersion(requested string) (string, bool) { + for _, supported := range supportedProtocolVersions { + if requested == supported { + return supported, true + } + } + if len(supportedProtocolVersions) == 0 { + return "", false + } + return supportedProtocolVersions[0], false +} + +func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error { + return encoder.Encode(rpcResponse{ + JSONRPC: jsonRPCVersion, + ID: id, + Result: result, + }) +} + +func writeRPCError(encoder *json.Encoder, id json.RawMessage, code int, message string, data any) error { + return encoder.Encode(rpcResponse{ + JSONRPC: jsonRPCVersion, + ID: id, + Error: &rpcErrorObject{ + Code: code, + Message: message, + Data: data, + }, + }) +} diff --git a/internal/mcpserver/server_test.go b/internal/mcpserver/server_test.go index ee9ee7a..429cfa4 100644 --- a/internal/mcpserver/server_test.go +++ b/internal/mcpserver/server_test.go @@ -172,7 +172,7 @@ func TestRunnerRunWritesToolManifestAndHandlesRequests(t *testing.T) { Password: "secret", }, } - input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":5}}\n") + input := bytes.NewBufferString("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test-client\",\"version\":\"1.0.0\"}}}\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/list\",\"params\":{}}\n{\"jsonrpc\":\"2.0\",\"id\":3,\"method\":\"tools/call\",\"params\":{\"name\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":5}}}\n") output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -200,29 +200,90 @@ func TestRunnerRunWritesToolManifestAndHandlesRequests(t *testing.T) { decoder := json.NewDecoder(output) - var manifest struct { - Tools []struct { - Name string `json:"name"` - } `json:"tools"` + var initializeResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"serverInfo"` + } `json:"result"` } - if err := decoder.Decode(&manifest); err != nil { - t.Fatalf("failed to decode manifest: %v", err) + if err := decoder.Decode(&initializeResponse); err != nil { + t.Fatalf("failed to decode initialize response: %v", err) } - if len(manifest.Tools) != 3 { - t.Fatalf("expected 3 tools, got %#v", manifest.Tools) + if initializeResponse.JSONRPC != "2.0" { + t.Fatalf("expected jsonrpc 2.0, got %#v", initializeResponse) } - if manifest.Tools[0].Name != "list_mailboxes" || manifest.Tools[1].Name != "list_messages" || manifest.Tools[2].Name != "get_message" { - t.Fatalf("unexpected tool manifest: %#v", manifest.Tools) + if initializeResponse.ID != 1 { + t.Fatalf("expected initialize response id 1, got %#v", initializeResponse) + } + if initializeResponse.Result.ProtocolVersion != "2025-03-26" { + t.Fatalf("expected negotiated protocol version, got %#v", initializeResponse.Result) + } + if _, ok := initializeResponse.Result.Capabilities["tools"]; !ok { + t.Fatalf("expected tools capability, got %#v", initializeResponse.Result.Capabilities) + } + if initializeResponse.Result.ServerInfo.Name == "" || initializeResponse.Result.ServerInfo.Version == "" { + t.Fatalf("expected server info, got %#v", initializeResponse.Result.ServerInfo) + } + + var listResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Tools []map[string]any `json:"tools"` + } `json:"result"` + } + if err := decoder.Decode(&listResponse); err != nil { + t.Fatalf("failed to decode tools/list response: %v", err) + } + if listResponse.JSONRPC != "2.0" || listResponse.ID != 2 { + t.Fatalf("unexpected tools/list response envelope: %#v", listResponse) + } + if len(listResponse.Result.Tools) != 3 { + t.Fatalf("expected 3 tools, got %#v", listResponse.Result.Tools) + } + if listResponse.Result.Tools[0]["name"] != "list_mailboxes" || listResponse.Result.Tools[1]["name"] != "list_messages" || listResponse.Result.Tools[2]["name"] != "get_message" { + t.Fatalf("unexpected tool manifest: %#v", listResponse.Result.Tools) + } + if _, ok := listResponse.Result.Tools[1]["inputSchema"]; !ok { + t.Fatalf("expected inputSchema field in tools/list response, got %#v", listResponse.Result.Tools[1]) } var response struct { - Result []imapclient.MessageSummary `json:"result"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` } if err := decoder.Decode(&response); err != nil { t.Fatalf("failed to decode response: %v", err) } - if len(response.Result) != 1 || response.Result[0].UID != 42 { - t.Fatalf("unexpected response: %#v", response.Result) + if response.JSONRPC != "2.0" || response.ID != 3 { + t.Fatalf("unexpected tools/call response envelope: %#v", response) + } + if response.Result.IsError { + t.Fatalf("expected successful tools/call result, got %#v", response.Result) + } + if len(response.Result.Content) != 1 || response.Result.Content[0].Type != "text" { + t.Fatalf("unexpected tools/call content: %#v", response.Result.Content) + } + + var messages []imapclient.MessageSummary + if err := json.Unmarshal([]byte(response.Result.Content[0].Text), &messages); err != nil { + t.Fatalf("failed to decode tools/call text payload: %v", err) + } + if len(messages) != 1 || messages[0].UID != 42 { + t.Fatalf("unexpected response: %#v", messages) } } @@ -355,7 +416,7 @@ func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) { Password: "secret", }, } - input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n") + input := bytes.NewBufferString("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test-client\",\"version\":\"1.0.0\"}}}\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}}\n{\"jsonrpc\":\"2.0\",\"id\":3,\"method\":\"tools/call\",\"params\":{\"name\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}}\n") output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -378,28 +439,51 @@ func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) { decoder := json.NewDecoder(output) if err := decoder.Decode(&struct { - Tools []Tool `json:"tools"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` }{}); err != nil { - t.Fatalf("failed to decode manifest: %v", err) + t.Fatalf("failed to decode initialize response: %v", err) } var firstResponse struct { - Error string `json:"error"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` } if err := decoder.Decode(&firstResponse); err != nil { t.Fatalf("failed to decode first error response: %v", err) } - if firstResponse.Error != "limit must be between 1 and 50" { + if firstResponse.JSONRPC != "2.0" || firstResponse.ID != 2 { + t.Fatalf("unexpected first error envelope: %#v", firstResponse) + } + if firstResponse.Error.Code != -32602 { + t.Fatalf("expected invalid params code, got %#v", firstResponse) + } + if firstResponse.Error.Message != "limit must be between 1 and 50" { t.Fatalf("unexpected first error: %#v", firstResponse) } var secondResponse struct { - Error string `json:"error"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` } if err := decoder.Decode(&secondResponse); err != nil { t.Fatalf("failed to decode second error response: %v", err) } - if secondResponse.Error != "uid is required" { + if secondResponse.JSONRPC != "2.0" || secondResponse.ID != 3 { + t.Fatalf("unexpected second error envelope: %#v", secondResponse) + } + if secondResponse.Error.Code != -32602 { + t.Fatalf("expected invalid params code, got %#v", secondResponse) + } + if secondResponse.Error.Message != "uid is required" { t.Fatalf("unexpected second error: %#v", secondResponse) } } @@ -412,7 +496,7 @@ func TestRunnerRunRejectsWhitespaceOnlyMailboxValues(t *testing.T) { Password: "secret", }, } - input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\" \"}}\n") + input := bytes.NewBufferString("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test-client\",\"version\":\"1.0.0\"}}}\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_messages\",\"arguments\":{\"mailbox\":\" \"}}}\n") output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -435,18 +519,30 @@ func TestRunnerRunRejectsWhitespaceOnlyMailboxValues(t *testing.T) { decoder := json.NewDecoder(output) if err := decoder.Decode(&struct { - Tools []Tool `json:"tools"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` }{}); err != nil { - t.Fatalf("failed to decode manifest: %v", err) + t.Fatalf("failed to decode initialize response: %v", err) } var response struct { - Error string `json:"error"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` } if err := decoder.Decode(&response); err != nil { t.Fatalf("failed to decode error response: %v", err) } - if response.Error != "mailbox is required" { + if response.JSONRPC != "2.0" || response.ID != 2 { + t.Fatalf("unexpected error envelope: %#v", response) + } + if response.Error.Code != -32602 { + t.Fatalf("expected invalid params code, got %#v", response) + } + if response.Error.Message != "mailbox is required" { t.Fatalf("unexpected error: %#v", response) } } @@ -459,7 +555,7 @@ func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) { Password: "secret", }, } - input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n") + input := bytes.NewBufferString("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test-client\",\"version\":\"1.0.0\"}}}\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/call\",\"params\":{\"name\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}}\n") output := &bytes.Buffer{} runner := NewRunner(New(store, serviceStub{ listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { @@ -483,6 +579,54 @@ func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) { } } +func TestRunnerRunRejectsRequestsBeforeInitialize(t *testing.T) { + store := &storeStub{ + credential: secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }, + } + input := bytes.NewBufferString("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/list\",\"params\":{}}\n") + output := &bytes.Buffer{} + runner := NewRunner(New(store, serviceStub{ + listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { + t.Fatal("ListMailboxes should not be called") + return nil, nil + }, + listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) { + t.Fatal("ListMessages should not be called") + return nil, nil + }, + getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) { + t.Fatal("GetMessage should not be called") + return imapclient.Message{}, nil + }, + }), input, output, &bytes.Buffer{}) + + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) + } + + var response struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.NewDecoder(output).Decode(&response); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if response.JSONRPC != "2.0" || response.ID != 1 { + t.Fatalf("unexpected response envelope: %#v", response) + } + if response.Error.Code == 0 || response.Error.Message == "" { + t.Fatalf("expected protocol error before initialization, got %#v", response) + } +} + func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) { store := &storeStub{ credential: secretstore.Credential{