Claude Code sends extra fields (e.g. "title") in initialize params that caused the server to reject the request due to DisallowUnknownFields. Use lenient JSON decoding for protocol messages while keeping strict validation for tool arguments. Also defer KWallet credential loading from server startup to tool invocation time, and negotiate protocol versions per MCP spec instead of rejecting unknown ones. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
609 lines
16 KiB
Go
609 lines
16 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"email-mcp/internal/imapclient"
|
|
"email-mcp/internal/secretstore"
|
|
"email-mcp/internal/secretstore/kwallet"
|
|
)
|
|
|
|
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
|
|
|
|
const (
|
|
jsonRPCVersion = "2.0"
|
|
mcpServerName = "email-mcp"
|
|
mcpServerVersion = "dev"
|
|
mcpMethodInitialize = "initialize"
|
|
mcpMethodInitialized = "notifications/initialized"
|
|
mcpMethodPing = "ping"
|
|
mcpMethodToolsList = "tools/list"
|
|
mcpMethodToolsCall = "tools/call"
|
|
jsonRPCParseErrorCode = -32700
|
|
jsonRPCInvalidRequestCode = -32600
|
|
jsonRPCMethodNotFoundCode = -32601
|
|
jsonRPCInvalidParamsCode = -32602
|
|
jsonRPCInternalErrorCode = -32603
|
|
)
|
|
|
|
var supportedProtocolVersions = []string{
|
|
"2025-03-26",
|
|
"2024-11-05",
|
|
}
|
|
|
|
type MailService interface {
|
|
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
|
|
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
|
|
GetMessage(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error)
|
|
}
|
|
|
|
type Server struct {
|
|
store secretstore.Store
|
|
mail MailService
|
|
}
|
|
|
|
type Tool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
InputSchema map[string]any `json:"inputSchema,omitempty"`
|
|
}
|
|
|
|
type Runner struct {
|
|
server Server
|
|
in io.Reader
|
|
out io.Writer
|
|
errOut io.Writer
|
|
}
|
|
|
|
type toolRequest struct {
|
|
Tool string `json:"tool"`
|
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type rpcRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id,omitempty"`
|
|
Method string `json:"method"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
}
|
|
|
|
type rpcResponse struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id,omitempty"`
|
|
Result any `json:"result,omitempty"`
|
|
Error *rpcErrorObject `json:"error,omitempty"`
|
|
}
|
|
|
|
type rpcErrorObject struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data any `json:"data,omitempty"`
|
|
}
|
|
|
|
type initializeParams struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
ClientInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
} `json:"clientInfo"`
|
|
}
|
|
|
|
type toolsCallParams struct {
|
|
Name string `json:"name"`
|
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type invalidParamsError struct {
|
|
err error
|
|
}
|
|
|
|
func (e *invalidParamsError) Error() string {
|
|
return e.err.Error()
|
|
}
|
|
|
|
func (e *invalidParamsError) Unwrap() error {
|
|
return e.err
|
|
}
|
|
|
|
type listMessagesArguments struct {
|
|
Mailbox string `json:"mailbox"`
|
|
Limit *int `json:"limit,omitempty"`
|
|
}
|
|
|
|
type getMessageArguments struct {
|
|
Mailbox string `json:"mailbox"`
|
|
UID *uint32 `json:"uid"`
|
|
}
|
|
|
|
func New(store secretstore.Store, mail MailService) Server {
|
|
return Server{
|
|
store: store,
|
|
mail: mail,
|
|
}
|
|
}
|
|
|
|
func NewRunner(server Server, in io.Reader, out io.Writer, errOut io.Writer) Runner {
|
|
if out == nil {
|
|
out = io.Discard
|
|
}
|
|
if errOut == nil {
|
|
errOut = io.Discard
|
|
}
|
|
|
|
return Runner{
|
|
server: server,
|
|
in: in,
|
|
out: out,
|
|
errOut: errOut,
|
|
}
|
|
}
|
|
|
|
func (s Server) Tools() []Tool {
|
|
return []Tool{
|
|
{
|
|
Name: "list_mailboxes",
|
|
Description: "List visible IMAP mailboxes for the configured account.",
|
|
},
|
|
{
|
|
Name: "list_messages",
|
|
Description: "List recent messages from a mailbox using IMAP UIDs.",
|
|
InputSchema: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"mailbox": map[string]any{
|
|
"type": "string",
|
|
"minLength": 1,
|
|
"pattern": "\\S",
|
|
},
|
|
"limit": map[string]any{
|
|
"type": "integer",
|
|
"default": imapclient.DefaultListMessagesLimit,
|
|
"minimum": 1,
|
|
"maximum": imapclient.MaxListMessagesLimit,
|
|
},
|
|
},
|
|
"required": []string{"mailbox"},
|
|
"additionalProperties": false,
|
|
},
|
|
},
|
|
{
|
|
Name: "get_message",
|
|
Description: "Fetch a single message by mailbox and IMAP UID.",
|
|
InputSchema: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"mailbox": map[string]any{
|
|
"type": "string",
|
|
"minLength": 1,
|
|
"pattern": "\\S",
|
|
},
|
|
"uid": map[string]any{
|
|
"type": "integer",
|
|
"minimum": 1,
|
|
},
|
|
},
|
|
"required": []string{"mailbox", "uid"},
|
|
"additionalProperties": false,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s Server) ListMailboxes(ctx context.Context) ([]imapclient.Mailbox, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.listMailboxes(ctx, cred)
|
|
}
|
|
|
|
func (s Server) ListMessages(ctx context.Context, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.listMessages(ctx, cred, mailbox, limit)
|
|
}
|
|
|
|
func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (imapclient.Message, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return imapclient.Message{}, err
|
|
}
|
|
return s.getMessage(ctx, cred, mailbox, uid)
|
|
}
|
|
|
|
func (r Runner) Run(ctx context.Context) error {
|
|
stopCancelRead := r.closeInputOnCancel(ctx)
|
|
defer stopCancelRead()
|
|
|
|
encoder := json.NewEncoder(r.out)
|
|
if r.in == nil {
|
|
return nil
|
|
}
|
|
|
|
decoder := json.NewDecoder(r.in)
|
|
session := runnerSession{}
|
|
for {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
var request rpcRequest
|
|
if err := decoder.Decode(&request); err != nil {
|
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
return ctxErr
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
if writeErr := writeRPCError(encoder, nil, jsonRPCParseErrorCode, err.Error(), nil); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil {
|
|
if errors.Is(err, context.Canceled) {
|
|
return err
|
|
}
|
|
if writeErr := writeRPCError(encoder, request.ID, jsonRPCInternalErrorCode, err.Error(), nil); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
type runnerSession struct {
|
|
initialized bool
|
|
ready bool
|
|
protocolVersion string
|
|
}
|
|
|
|
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
|
if request.JSONRPC != jsonRPCVersion {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil)
|
|
}
|
|
if request.Method == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "method is required", nil)
|
|
}
|
|
|
|
isNotification := len(request.ID) == 0
|
|
if !session.initialized {
|
|
switch request.Method {
|
|
case mcpMethodInitialize:
|
|
if isNotification {
|
|
return writeRPCError(encoder, nil, jsonRPCInvalidRequestCode, "initialize must be a request", nil)
|
|
}
|
|
return r.handleInitialize(encoder, session, request)
|
|
case mcpMethodPing:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{})
|
|
default:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "server not initialized", nil)
|
|
}
|
|
}
|
|
|
|
switch request.Method {
|
|
case mcpMethodInitialized:
|
|
session.ready = true
|
|
return nil
|
|
case mcpMethodPing:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{})
|
|
case mcpMethodToolsList:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{"tools": r.server.Tools()})
|
|
case mcpMethodToolsCall:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return r.handleToolsCall(ctx, encoder, request)
|
|
default:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCError(encoder, request.ID, jsonRPCMethodNotFoundCode, fmt.Sprintf("method not found: %s", request.Method), nil)
|
|
}
|
|
}
|
|
|
|
func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
|
var params initializeParams
|
|
if err := decodeParams(request.Params, ¶ms); err != nil {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
|
}
|
|
if params.ProtocolVersion == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil)
|
|
}
|
|
|
|
negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion)
|
|
|
|
session.initialized = true
|
|
session.protocolVersion = negotiatedVersion
|
|
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"protocolVersion": negotiatedVersion,
|
|
"capabilities": map[string]any{
|
|
"tools": map[string]any{},
|
|
},
|
|
"serverInfo": map[string]any{
|
|
"name": mcpServerName,
|
|
"version": mcpServerVersion,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error {
|
|
var params toolsCallParams
|
|
if err := decodeParams(request.Params, ¶ms); err != nil {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
|
}
|
|
if strings.TrimSpace(params.Name) == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil)
|
|
}
|
|
|
|
cred, err := r.server.loadCredential(ctx)
|
|
if err != nil {
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": err.Error(),
|
|
},
|
|
},
|
|
"isError": true,
|
|
})
|
|
}
|
|
|
|
result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments)
|
|
if err != nil {
|
|
var invalidErr *invalidParamsError
|
|
if errors.As(err, &invalidErr) {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, invalidErr.Error(), nil)
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": err.Error(),
|
|
},
|
|
},
|
|
"isError": true,
|
|
})
|
|
}
|
|
|
|
payload, err := json.Marshal(result)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": string(payload),
|
|
},
|
|
},
|
|
"isError": false,
|
|
})
|
|
}
|
|
|
|
func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, name string, rawArgs json.RawMessage) (any, error) {
|
|
switch name {
|
|
case "list_mailboxes":
|
|
return s.listMailboxes(ctx, cred)
|
|
case "list_messages":
|
|
var args listMessagesArguments
|
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
mailbox, err := validateMailbox(args.Mailbox)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
limit, err := normalizeListMessagesLimit(args.Limit)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
return s.listMessages(ctx, cred, mailbox, limit)
|
|
case "get_message":
|
|
var args getMessageArguments
|
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
mailbox, err := validateMailbox(args.Mailbox)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
uid, err := validateMessageUID(args.UID)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
return s.getMessage(ctx, cred, mailbox, uid)
|
|
default:
|
|
return nil, &invalidParamsError{err: fmt.Errorf("unknown tool: %s", name)}
|
|
}
|
|
}
|
|
|
|
func (s Server) listMailboxes(ctx context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) {
|
|
if s.mail == nil {
|
|
return nil, fmt.Errorf("mail service is not configured")
|
|
}
|
|
return s.mail.ListMailboxes(ctx, cred)
|
|
}
|
|
|
|
func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
|
if s.mail == nil {
|
|
return nil, fmt.Errorf("mail service is not configured")
|
|
}
|
|
mailbox, err := validateMailbox(mailbox)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.mail.ListMessages(ctx, cred, mailbox, limit)
|
|
}
|
|
|
|
func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) {
|
|
if s.mail == nil {
|
|
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
|
|
}
|
|
mailbox, err := validateMailbox(mailbox)
|
|
if err != nil {
|
|
return imapclient.Message{}, err
|
|
}
|
|
if uid == 0 {
|
|
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
|
|
}
|
|
return s.mail.GetMessage(ctx, cred, mailbox, uid)
|
|
}
|
|
|
|
func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, error) {
|
|
if s.store == nil {
|
|
return secretstore.Credential{}, fmt.Errorf("secret store is not configured")
|
|
}
|
|
|
|
cred, err := s.store.Load(ctx, secretstore.DefaultAccountKey)
|
|
if err != nil {
|
|
if errors.Is(err, kwallet.ErrCredentialNotFound) || errors.Is(err, ErrCredentialsNotConfigured) {
|
|
return secretstore.Credential{}, ErrCredentialsNotConfigured
|
|
}
|
|
return secretstore.Credential{}, err
|
|
}
|
|
if err := cred.Validate(); err != nil {
|
|
return secretstore.Credential{}, fmt.Errorf("default credential is invalid: %w", err)
|
|
}
|
|
return cred, nil
|
|
}
|
|
|
|
func decodeParams(raw json.RawMessage, dest any) error {
|
|
if len(raw) == 0 {
|
|
raw = []byte("{}")
|
|
}
|
|
return json.Unmarshal(raw, dest)
|
|
}
|
|
|
|
func decodeArguments(raw json.RawMessage, dest any) error {
|
|
if len(raw) == 0 {
|
|
raw = []byte("{}")
|
|
}
|
|
decoder := json.NewDecoder(bytes.NewReader(raw))
|
|
decoder.DisallowUnknownFields()
|
|
if err := decoder.Decode(dest); err != nil {
|
|
return normalizeDecodeError(err)
|
|
}
|
|
if decoder.More() {
|
|
return fmt.Errorf("invalid JSON object")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeDecodeError(err error) error {
|
|
var syntaxErr *json.SyntaxError
|
|
var typeErr *json.UnmarshalTypeError
|
|
|
|
switch {
|
|
case errors.As(err, &syntaxErr):
|
|
return fmt.Errorf("invalid JSON arguments")
|
|
case errors.As(err, &typeErr):
|
|
if typeErr.Field != "" {
|
|
return fmt.Errorf("%s has an invalid type", typeErr.Field)
|
|
}
|
|
return fmt.Errorf("arguments have an invalid type")
|
|
default:
|
|
return fmt.Errorf("invalid arguments: %w", err)
|
|
}
|
|
}
|
|
|
|
func validateMailbox(mailbox string) (string, error) {
|
|
mailbox = strings.TrimSpace(mailbox)
|
|
if mailbox == "" {
|
|
return "", fmt.Errorf("mailbox is required")
|
|
}
|
|
return mailbox, nil
|
|
}
|
|
|
|
func normalizeListMessagesLimit(limit *int) (int, error) {
|
|
if limit == nil {
|
|
return imapclient.DefaultListMessagesLimit, nil
|
|
}
|
|
if *limit < 1 || *limit > imapclient.MaxListMessagesLimit {
|
|
return 0, fmt.Errorf("limit must be between 1 and %d", imapclient.MaxListMessagesLimit)
|
|
}
|
|
return *limit, nil
|
|
}
|
|
|
|
func validateMessageUID(uid *uint32) (uint32, error) {
|
|
if uid == nil {
|
|
return 0, fmt.Errorf("uid is required")
|
|
}
|
|
if *uid == 0 {
|
|
return 0, fmt.Errorf("uid must be greater than zero")
|
|
}
|
|
return *uid, nil
|
|
}
|
|
|
|
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
|
|
closer, ok := r.in.(io.Closer)
|
|
if !ok {
|
|
return func() {}
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = closer.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
return func() {
|
|
close(done)
|
|
}
|
|
}
|
|
|
|
func negotiateProtocolVersion(requested string) string {
|
|
for _, supported := range supportedProtocolVersions {
|
|
if requested == supported {
|
|
return supported
|
|
}
|
|
}
|
|
return supportedProtocolVersions[0]
|
|
}
|
|
|
|
func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error {
|
|
return encoder.Encode(rpcResponse{
|
|
JSONRPC: jsonRPCVersion,
|
|
ID: id,
|
|
Result: result,
|
|
})
|
|
}
|
|
|
|
func writeRPCError(encoder *json.Encoder, id json.RawMessage, code int, message string, data any) error {
|
|
return encoder.Encode(rpcResponse{
|
|
JSONRPC: jsonRPCVersion,
|
|
ID: id,
|
|
Error: &rpcErrorObject{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
},
|
|
})
|
|
}
|