package mcpserver import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "strings" "email-mcp/internal/imapclient" "email-mcp/internal/secretstore" "email-mcp/internal/secretstore/kwallet" ) var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`") const ( jsonRPCVersion = "2.0" mcpServerName = "email-mcp" mcpServerVersion = "dev" mcpMethodInitialize = "initialize" mcpMethodInitialized = "notifications/initialized" mcpMethodPing = "ping" mcpMethodToolsList = "tools/list" mcpMethodToolsCall = "tools/call" jsonRPCParseErrorCode = -32700 jsonRPCInvalidRequestCode = -32600 jsonRPCMethodNotFoundCode = -32601 jsonRPCInvalidParamsCode = -32602 jsonRPCInternalErrorCode = -32603 ) var supportedProtocolVersions = []string{ "2025-03-26", "2024-11-05", } type MailService interface { ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) GetMessage(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) } type Server struct { store secretstore.Store mail MailService } type Tool struct { Name string `json:"name"` Description string `json:"description"` InputSchema map[string]any `json:"inputSchema,omitempty"` } type Runner struct { server Server in io.Reader out io.Writer errOut io.Writer } type toolRequest struct { Tool string `json:"tool"` Arguments json.RawMessage `json:"arguments,omitempty"` } type rpcRequest struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params,omitempty"` } type rpcResponse struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` Result any `json:"result,omitempty"` Error *rpcErrorObject `json:"error,omitempty"` } type rpcErrorObject struct { Code int `json:"code"` Message string `json:"message"` Data any `json:"data,omitempty"` } type initializeParams struct { ProtocolVersion string `json:"protocolVersion"` Capabilities map[string]any `json:"capabilities"` ClientInfo struct { Name string `json:"name"` Version string `json:"version"` } `json:"clientInfo"` } type toolsCallParams struct { Name string `json:"name"` Arguments json.RawMessage `json:"arguments,omitempty"` } type invalidParamsError struct { err error } func (e *invalidParamsError) Error() string { return e.err.Error() } func (e *invalidParamsError) Unwrap() error { return e.err } type listMessagesArguments struct { Mailbox string `json:"mailbox"` Limit *int `json:"limit,omitempty"` } type getMessageArguments struct { Mailbox string `json:"mailbox"` UID *uint32 `json:"uid"` } func New(store secretstore.Store, mail MailService) Server { return Server{ store: store, mail: mail, } } func NewRunner(server Server, in io.Reader, out io.Writer, errOut io.Writer) Runner { if out == nil { out = io.Discard } if errOut == nil { errOut = io.Discard } return Runner{ server: server, in: in, out: out, errOut: errOut, } } func (s Server) Tools() []Tool { return []Tool{ { Name: "list_mailboxes", Description: "List visible IMAP mailboxes for the configured account.", }, { Name: "list_messages", Description: "List recent messages from a mailbox using IMAP UIDs.", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ "mailbox": map[string]any{ "type": "string", "minLength": 1, "pattern": "\\S", }, "limit": map[string]any{ "type": "integer", "default": imapclient.DefaultListMessagesLimit, "minimum": 1, "maximum": imapclient.MaxListMessagesLimit, }, }, "required": []string{"mailbox"}, "additionalProperties": false, }, }, { Name: "get_message", Description: "Fetch a single message by mailbox and IMAP UID.", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ "mailbox": map[string]any{ "type": "string", "minLength": 1, "pattern": "\\S", }, "uid": map[string]any{ "type": "integer", "minimum": 1, }, }, "required": []string{"mailbox", "uid"}, "additionalProperties": false, }, }, } } func (s Server) ListMailboxes(ctx context.Context) ([]imapclient.Mailbox, error) { cred, err := s.loadCredential(ctx) if err != nil { return nil, err } return s.listMailboxes(ctx, cred) } func (s Server) ListMessages(ctx context.Context, mailbox string, limit int) ([]imapclient.MessageSummary, error) { cred, err := s.loadCredential(ctx) if err != nil { return nil, err } return s.listMessages(ctx, cred, mailbox, limit) } func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (imapclient.Message, error) { cred, err := s.loadCredential(ctx) if err != nil { return imapclient.Message{}, err } return s.getMessage(ctx, cred, mailbox, uid) } func (r Runner) Run(ctx context.Context) error { stopCancelRead := r.closeInputOnCancel(ctx) defer stopCancelRead() encoder := json.NewEncoder(r.out) if r.in == nil { return nil } decoder := json.NewDecoder(r.in) session := runnerSession{} for { if err := ctx.Err(); err != nil { return err } var request rpcRequest if err := decoder.Decode(&request); err != nil { if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr } if errors.Is(err, io.EOF) { return nil } if writeErr := writeRPCError(encoder, nil, jsonRPCParseErrorCode, err.Error(), nil); writeErr != nil { return writeErr } return nil } if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil { if errors.Is(err, context.Canceled) { return err } if writeErr := writeRPCError(encoder, request.ID, jsonRPCInternalErrorCode, err.Error(), nil); writeErr != nil { return writeErr } continue } } } type runnerSession struct { initialized bool ready bool protocolVersion string } func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error { if request.JSONRPC != jsonRPCVersion { return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil) } if request.Method == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "method is required", nil) } isNotification := len(request.ID) == 0 if !session.initialized { switch request.Method { case mcpMethodInitialize: if isNotification { return writeRPCError(encoder, nil, jsonRPCInvalidRequestCode, "initialize must be a request", nil) } return r.handleInitialize(encoder, session, request) case mcpMethodPing: if isNotification { return nil } return writeRPCResult(encoder, request.ID, map[string]any{}) default: if isNotification { return nil } return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "server not initialized", nil) } } switch request.Method { case mcpMethodInitialized: session.ready = true return nil case mcpMethodPing: if isNotification { return nil } return writeRPCResult(encoder, request.ID, map[string]any{}) case mcpMethodToolsList: if isNotification { return nil } return writeRPCResult(encoder, request.ID, map[string]any{"tools": r.server.Tools()}) case mcpMethodToolsCall: if isNotification { return nil } return r.handleToolsCall(ctx, encoder, request) default: if isNotification { return nil } return writeRPCError(encoder, request.ID, jsonRPCMethodNotFoundCode, fmt.Sprintf("method not found: %s", request.Method), nil) } } func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error { var params initializeParams if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if params.ProtocolVersion == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil) } negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion) session.initialized = true session.protocolVersion = negotiatedVersion return writeRPCResult(encoder, request.ID, map[string]any{ "protocolVersion": negotiatedVersion, "capabilities": map[string]any{ "tools": map[string]any{}, }, "serverInfo": map[string]any{ "name": mcpServerName, "version": mcpServerVersion, }, }) } func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error { var params toolsCallParams if err := decodeParams(request.Params, ¶ms); err != nil { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil) } if strings.TrimSpace(params.Name) == "" { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil) } cred, err := r.server.loadCredential(ctx) if err != nil { return writeRPCResult(encoder, request.ID, map[string]any{ "content": []map[string]any{ { "type": "text", "text": err.Error(), }, }, "isError": true, }) } result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments) if err != nil { var invalidErr *invalidParamsError if errors.As(err, &invalidErr) { return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, invalidErr.Error(), nil) } return writeRPCResult(encoder, request.ID, map[string]any{ "content": []map[string]any{ { "type": "text", "text": err.Error(), }, }, "isError": true, }) } payload, err := json.Marshal(result) if err != nil { return err } return writeRPCResult(encoder, request.ID, map[string]any{ "content": []map[string]any{ { "type": "text", "text": string(payload), }, }, "isError": false, }) } func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, name string, rawArgs json.RawMessage) (any, error) { switch name { case "list_mailboxes": return s.listMailboxes(ctx, cred) case "list_messages": var args listMessagesArguments if err := decodeArguments(rawArgs, &args); err != nil { return nil, &invalidParamsError{err: err} } mailbox, err := validateMailbox(args.Mailbox) if err != nil { return nil, &invalidParamsError{err: err} } limit, err := normalizeListMessagesLimit(args.Limit) if err != nil { return nil, &invalidParamsError{err: err} } return s.listMessages(ctx, cred, mailbox, limit) case "get_message": var args getMessageArguments if err := decodeArguments(rawArgs, &args); err != nil { return nil, &invalidParamsError{err: err} } mailbox, err := validateMailbox(args.Mailbox) if err != nil { return nil, &invalidParamsError{err: err} } uid, err := validateMessageUID(args.UID) if err != nil { return nil, &invalidParamsError{err: err} } return s.getMessage(ctx, cred, mailbox, uid) default: return nil, &invalidParamsError{err: fmt.Errorf("unknown tool: %s", name)} } } func (s Server) listMailboxes(ctx context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) { if s.mail == nil { return nil, fmt.Errorf("mail service is not configured") } return s.mail.ListMailboxes(ctx, cred) } func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) { if s.mail == nil { return nil, fmt.Errorf("mail service is not configured") } mailbox, err := validateMailbox(mailbox) if err != nil { return nil, err } return s.mail.ListMessages(ctx, cred, mailbox, limit) } func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) { if s.mail == nil { return imapclient.Message{}, fmt.Errorf("mail service is not configured") } mailbox, err := validateMailbox(mailbox) if err != nil { return imapclient.Message{}, err } if uid == 0 { return imapclient.Message{}, fmt.Errorf("uid must be greater than zero") } return s.mail.GetMessage(ctx, cred, mailbox, uid) } func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, error) { if s.store == nil { return secretstore.Credential{}, fmt.Errorf("secret store is not configured") } cred, err := s.store.Load(ctx, secretstore.DefaultAccountKey) if err != nil { if errors.Is(err, kwallet.ErrCredentialNotFound) || errors.Is(err, ErrCredentialsNotConfigured) { return secretstore.Credential{}, ErrCredentialsNotConfigured } return secretstore.Credential{}, err } if err := cred.Validate(); err != nil { return secretstore.Credential{}, fmt.Errorf("default credential is invalid: %w", err) } return cred, nil } func decodeParams(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") } return json.Unmarshal(raw, dest) } func decodeArguments(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") } decoder := json.NewDecoder(bytes.NewReader(raw)) decoder.DisallowUnknownFields() if err := decoder.Decode(dest); err != nil { return normalizeDecodeError(err) } if decoder.More() { return fmt.Errorf("invalid JSON object") } return nil } func normalizeDecodeError(err error) error { var syntaxErr *json.SyntaxError var typeErr *json.UnmarshalTypeError switch { case errors.As(err, &syntaxErr): return fmt.Errorf("invalid JSON arguments") case errors.As(err, &typeErr): if typeErr.Field != "" { return fmt.Errorf("%s has an invalid type", typeErr.Field) } return fmt.Errorf("arguments have an invalid type") default: return fmt.Errorf("invalid arguments: %w", err) } } func validateMailbox(mailbox string) (string, error) { mailbox = strings.TrimSpace(mailbox) if mailbox == "" { return "", fmt.Errorf("mailbox is required") } return mailbox, nil } func normalizeListMessagesLimit(limit *int) (int, error) { if limit == nil { return imapclient.DefaultListMessagesLimit, nil } if *limit < 1 || *limit > imapclient.MaxListMessagesLimit { return 0, fmt.Errorf("limit must be between 1 and %d", imapclient.MaxListMessagesLimit) } return *limit, nil } func validateMessageUID(uid *uint32) (uint32, error) { if uid == nil { return 0, fmt.Errorf("uid is required") } if *uid == 0 { return 0, fmt.Errorf("uid must be greater than zero") } return *uid, nil } func (r Runner) closeInputOnCancel(ctx context.Context) func() { closer, ok := r.in.(io.Closer) if !ok { return func() {} } done := make(chan struct{}) go func() { select { case <-ctx.Done(): _ = closer.Close() case <-done: } }() return func() { close(done) } } func negotiateProtocolVersion(requested string) string { for _, supported := range supportedProtocolVersions { if requested == supported { return supported } } return supportedProtocolVersions[0] } func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error { return encoder.Encode(rpcResponse{ JSONRPC: jsonRPCVersion, ID: id, Result: result, }) } func writeRPCError(encoder *json.Encoder, id json.RawMessage, code int, message string, data any) error { return encoder.Encode(rpcResponse{ JSONRPC: jsonRPCVersion, ID: id, Error: &rpcErrorObject{ Code: code, Message: message, Data: data, }, }) }