email-mcp/internal/mcpserver/server.go
thibaud-leclere 92fc30cb2d fix: tolerate unknown fields in MCP protocol messages and defer credential loading
Claude Code sends extra fields (e.g. "title") in initialize params that
caused the server to reject the request due to DisallowUnknownFields.
Use lenient JSON decoding for protocol messages while keeping strict
validation for tool arguments. Also defer KWallet credential loading
from server startup to tool invocation time, and negotiate protocol
versions per MCP spec instead of rejecting unknown ones.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 15:46:53 +02:00

609 lines
16 KiB
Go

package mcpserver
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"email-mcp/internal/imapclient"
"email-mcp/internal/secretstore"
"email-mcp/internal/secretstore/kwallet"
)
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
const (
jsonRPCVersion = "2.0"
mcpServerName = "email-mcp"
mcpServerVersion = "dev"
mcpMethodInitialize = "initialize"
mcpMethodInitialized = "notifications/initialized"
mcpMethodPing = "ping"
mcpMethodToolsList = "tools/list"
mcpMethodToolsCall = "tools/call"
jsonRPCParseErrorCode = -32700
jsonRPCInvalidRequestCode = -32600
jsonRPCMethodNotFoundCode = -32601
jsonRPCInvalidParamsCode = -32602
jsonRPCInternalErrorCode = -32603
)
var supportedProtocolVersions = []string{
"2025-03-26",
"2024-11-05",
}
type MailService interface {
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
GetMessage(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error)
}
type Server struct {
store secretstore.Store
mail MailService
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]any `json:"inputSchema,omitempty"`
}
type Runner struct {
server Server
in io.Reader
out io.Writer
errOut io.Writer
}
type toolRequest struct {
Tool string `json:"tool"`
Arguments json.RawMessage `json:"arguments,omitempty"`
}
type rpcRequest struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type rpcResponse struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result any `json:"result,omitempty"`
Error *rpcErrorObject `json:"error,omitempty"`
}
type rpcErrorObject struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]any `json:"capabilities"`
ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
} `json:"clientInfo"`
}
type toolsCallParams struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments,omitempty"`
}
type invalidParamsError struct {
err error
}
func (e *invalidParamsError) Error() string {
return e.err.Error()
}
func (e *invalidParamsError) Unwrap() error {
return e.err
}
type listMessagesArguments struct {
Mailbox string `json:"mailbox"`
Limit *int `json:"limit,omitempty"`
}
type getMessageArguments struct {
Mailbox string `json:"mailbox"`
UID *uint32 `json:"uid"`
}
func New(store secretstore.Store, mail MailService) Server {
return Server{
store: store,
mail: mail,
}
}
func NewRunner(server Server, in io.Reader, out io.Writer, errOut io.Writer) Runner {
if out == nil {
out = io.Discard
}
if errOut == nil {
errOut = io.Discard
}
return Runner{
server: server,
in: in,
out: out,
errOut: errOut,
}
}
func (s Server) Tools() []Tool {
return []Tool{
{
Name: "list_mailboxes",
Description: "List visible IMAP mailboxes for the configured account.",
},
{
Name: "list_messages",
Description: "List recent messages from a mailbox using IMAP UIDs.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"mailbox": map[string]any{
"type": "string",
"minLength": 1,
"pattern": "\\S",
},
"limit": map[string]any{
"type": "integer",
"default": imapclient.DefaultListMessagesLimit,
"minimum": 1,
"maximum": imapclient.MaxListMessagesLimit,
},
},
"required": []string{"mailbox"},
"additionalProperties": false,
},
},
{
Name: "get_message",
Description: "Fetch a single message by mailbox and IMAP UID.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"mailbox": map[string]any{
"type": "string",
"minLength": 1,
"pattern": "\\S",
},
"uid": map[string]any{
"type": "integer",
"minimum": 1,
},
},
"required": []string{"mailbox", "uid"},
"additionalProperties": false,
},
},
}
}
func (s Server) ListMailboxes(ctx context.Context) ([]imapclient.Mailbox, error) {
cred, err := s.loadCredential(ctx)
if err != nil {
return nil, err
}
return s.listMailboxes(ctx, cred)
}
func (s Server) ListMessages(ctx context.Context, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
cred, err := s.loadCredential(ctx)
if err != nil {
return nil, err
}
return s.listMessages(ctx, cred, mailbox, limit)
}
func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (imapclient.Message, error) {
cred, err := s.loadCredential(ctx)
if err != nil {
return imapclient.Message{}, err
}
return s.getMessage(ctx, cred, mailbox, uid)
}
func (r Runner) Run(ctx context.Context) error {
stopCancelRead := r.closeInputOnCancel(ctx)
defer stopCancelRead()
encoder := json.NewEncoder(r.out)
if r.in == nil {
return nil
}
decoder := json.NewDecoder(r.in)
session := runnerSession{}
for {
if err := ctx.Err(); err != nil {
return err
}
var request rpcRequest
if err := decoder.Decode(&request); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
if errors.Is(err, io.EOF) {
return nil
}
if writeErr := writeRPCError(encoder, nil, jsonRPCParseErrorCode, err.Error(), nil); writeErr != nil {
return writeErr
}
return nil
}
if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil {
if errors.Is(err, context.Canceled) {
return err
}
if writeErr := writeRPCError(encoder, request.ID, jsonRPCInternalErrorCode, err.Error(), nil); writeErr != nil {
return writeErr
}
continue
}
}
}
type runnerSession struct {
initialized bool
ready bool
protocolVersion string
}
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
if request.JSONRPC != jsonRPCVersion {
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil)
}
if request.Method == "" {
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "method is required", nil)
}
isNotification := len(request.ID) == 0
if !session.initialized {
switch request.Method {
case mcpMethodInitialize:
if isNotification {
return writeRPCError(encoder, nil, jsonRPCInvalidRequestCode, "initialize must be a request", nil)
}
return r.handleInitialize(encoder, session, request)
case mcpMethodPing:
if isNotification {
return nil
}
return writeRPCResult(encoder, request.ID, map[string]any{})
default:
if isNotification {
return nil
}
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "server not initialized", nil)
}
}
switch request.Method {
case mcpMethodInitialized:
session.ready = true
return nil
case mcpMethodPing:
if isNotification {
return nil
}
return writeRPCResult(encoder, request.ID, map[string]any{})
case mcpMethodToolsList:
if isNotification {
return nil
}
return writeRPCResult(encoder, request.ID, map[string]any{"tools": r.server.Tools()})
case mcpMethodToolsCall:
if isNotification {
return nil
}
return r.handleToolsCall(ctx, encoder, request)
default:
if isNotification {
return nil
}
return writeRPCError(encoder, request.ID, jsonRPCMethodNotFoundCode, fmt.Sprintf("method not found: %s", request.Method), nil)
}
}
func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
var params initializeParams
if err := decodeParams(request.Params, &params); err != nil {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
}
if params.ProtocolVersion == "" {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil)
}
negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion)
session.initialized = true
session.protocolVersion = negotiatedVersion
return writeRPCResult(encoder, request.ID, map[string]any{
"protocolVersion": negotiatedVersion,
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": mcpServerName,
"version": mcpServerVersion,
},
})
}
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error {
var params toolsCallParams
if err := decodeParams(request.Params, &params); err != nil {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
}
if strings.TrimSpace(params.Name) == "" {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil)
}
cred, err := r.server.loadCredential(ctx)
if err != nil {
return writeRPCResult(encoder, request.ID, map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": err.Error(),
},
},
"isError": true,
})
}
result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments)
if err != nil {
var invalidErr *invalidParamsError
if errors.As(err, &invalidErr) {
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, invalidErr.Error(), nil)
}
return writeRPCResult(encoder, request.ID, map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": err.Error(),
},
},
"isError": true,
})
}
payload, err := json.Marshal(result)
if err != nil {
return err
}
return writeRPCResult(encoder, request.ID, map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": string(payload),
},
},
"isError": false,
})
}
func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, name string, rawArgs json.RawMessage) (any, error) {
switch name {
case "list_mailboxes":
return s.listMailboxes(ctx, cred)
case "list_messages":
var args listMessagesArguments
if err := decodeArguments(rawArgs, &args); err != nil {
return nil, &invalidParamsError{err: err}
}
mailbox, err := validateMailbox(args.Mailbox)
if err != nil {
return nil, &invalidParamsError{err: err}
}
limit, err := normalizeListMessagesLimit(args.Limit)
if err != nil {
return nil, &invalidParamsError{err: err}
}
return s.listMessages(ctx, cred, mailbox, limit)
case "get_message":
var args getMessageArguments
if err := decodeArguments(rawArgs, &args); err != nil {
return nil, &invalidParamsError{err: err}
}
mailbox, err := validateMailbox(args.Mailbox)
if err != nil {
return nil, &invalidParamsError{err: err}
}
uid, err := validateMessageUID(args.UID)
if err != nil {
return nil, &invalidParamsError{err: err}
}
return s.getMessage(ctx, cred, mailbox, uid)
default:
return nil, &invalidParamsError{err: fmt.Errorf("unknown tool: %s", name)}
}
}
func (s Server) listMailboxes(ctx context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) {
if s.mail == nil {
return nil, fmt.Errorf("mail service is not configured")
}
return s.mail.ListMailboxes(ctx, cred)
}
func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
if s.mail == nil {
return nil, fmt.Errorf("mail service is not configured")
}
mailbox, err := validateMailbox(mailbox)
if err != nil {
return nil, err
}
return s.mail.ListMessages(ctx, cred, mailbox, limit)
}
func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) {
if s.mail == nil {
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
}
mailbox, err := validateMailbox(mailbox)
if err != nil {
return imapclient.Message{}, err
}
if uid == 0 {
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
}
return s.mail.GetMessage(ctx, cred, mailbox, uid)
}
func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, error) {
if s.store == nil {
return secretstore.Credential{}, fmt.Errorf("secret store is not configured")
}
cred, err := s.store.Load(ctx, secretstore.DefaultAccountKey)
if err != nil {
if errors.Is(err, kwallet.ErrCredentialNotFound) || errors.Is(err, ErrCredentialsNotConfigured) {
return secretstore.Credential{}, ErrCredentialsNotConfigured
}
return secretstore.Credential{}, err
}
if err := cred.Validate(); err != nil {
return secretstore.Credential{}, fmt.Errorf("default credential is invalid: %w", err)
}
return cred, nil
}
func decodeParams(raw json.RawMessage, dest any) error {
if len(raw) == 0 {
raw = []byte("{}")
}
return json.Unmarshal(raw, dest)
}
func decodeArguments(raw json.RawMessage, dest any) error {
if len(raw) == 0 {
raw = []byte("{}")
}
decoder := json.NewDecoder(bytes.NewReader(raw))
decoder.DisallowUnknownFields()
if err := decoder.Decode(dest); err != nil {
return normalizeDecodeError(err)
}
if decoder.More() {
return fmt.Errorf("invalid JSON object")
}
return nil
}
func normalizeDecodeError(err error) error {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
switch {
case errors.As(err, &syntaxErr):
return fmt.Errorf("invalid JSON arguments")
case errors.As(err, &typeErr):
if typeErr.Field != "" {
return fmt.Errorf("%s has an invalid type", typeErr.Field)
}
return fmt.Errorf("arguments have an invalid type")
default:
return fmt.Errorf("invalid arguments: %w", err)
}
}
func validateMailbox(mailbox string) (string, error) {
mailbox = strings.TrimSpace(mailbox)
if mailbox == "" {
return "", fmt.Errorf("mailbox is required")
}
return mailbox, nil
}
func normalizeListMessagesLimit(limit *int) (int, error) {
if limit == nil {
return imapclient.DefaultListMessagesLimit, nil
}
if *limit < 1 || *limit > imapclient.MaxListMessagesLimit {
return 0, fmt.Errorf("limit must be between 1 and %d", imapclient.MaxListMessagesLimit)
}
return *limit, nil
}
func validateMessageUID(uid *uint32) (uint32, error) {
if uid == nil {
return 0, fmt.Errorf("uid is required")
}
if *uid == 0 {
return 0, fmt.Errorf("uid must be greater than zero")
}
return *uid, nil
}
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
closer, ok := r.in.(io.Closer)
if !ok {
return func() {}
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = closer.Close()
case <-done:
}
}()
return func() {
close(done)
}
}
func negotiateProtocolVersion(requested string) string {
for _, supported := range supportedProtocolVersions {
if requested == supported {
return supported
}
}
return supportedProtocolVersions[0]
}
func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error {
return encoder.Encode(rpcResponse{
JSONRPC: jsonRPCVersion,
ID: id,
Result: result,
})
}
func writeRPCError(encoder *json.Encoder, id json.RawMessage, code int, message string, data any) error {
return encoder.Encode(rpcResponse{
JSONRPC: jsonRPCVersion,
ID: id,
Error: &rpcErrorObject{
Code: code,
Message: message,
Data: data,
},
})
}