614 lines
16 KiB
Go
614 lines
16 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"email-mcp/internal/imapclient"
|
|
"email-mcp/internal/secretstore"
|
|
"email-mcp/internal/secretstore/kwallet"
|
|
)
|
|
|
|
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
|
|
|
|
const (
|
|
jsonRPCVersion = "2.0"
|
|
mcpServerName = "email-mcp"
|
|
mcpServerVersion = "dev"
|
|
mcpMethodInitialize = "initialize"
|
|
mcpMethodInitialized = "notifications/initialized"
|
|
mcpMethodPing = "ping"
|
|
mcpMethodToolsList = "tools/list"
|
|
mcpMethodToolsCall = "tools/call"
|
|
jsonRPCParseErrorCode = -32700
|
|
jsonRPCInvalidRequestCode = -32600
|
|
jsonRPCMethodNotFoundCode = -32601
|
|
jsonRPCInvalidParamsCode = -32602
|
|
jsonRPCInternalErrorCode = -32603
|
|
)
|
|
|
|
var supportedProtocolVersions = []string{
|
|
"2025-03-26",
|
|
"2024-11-05",
|
|
}
|
|
|
|
type MailService interface {
|
|
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
|
|
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
|
|
GetMessage(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error)
|
|
}
|
|
|
|
type Server struct {
|
|
store secretstore.Store
|
|
mail MailService
|
|
}
|
|
|
|
type Tool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
InputSchema map[string]any `json:"inputSchema,omitempty"`
|
|
}
|
|
|
|
type Runner struct {
|
|
server Server
|
|
in io.Reader
|
|
out io.Writer
|
|
errOut io.Writer
|
|
}
|
|
|
|
type toolRequest struct {
|
|
Tool string `json:"tool"`
|
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type rpcRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id,omitempty"`
|
|
Method string `json:"method"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
}
|
|
|
|
type rpcResponse struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id,omitempty"`
|
|
Result any `json:"result,omitempty"`
|
|
Error *rpcErrorObject `json:"error,omitempty"`
|
|
}
|
|
|
|
type rpcErrorObject struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data any `json:"data,omitempty"`
|
|
}
|
|
|
|
type initializeParams struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
ClientInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
} `json:"clientInfo"`
|
|
}
|
|
|
|
type toolsCallParams struct {
|
|
Name string `json:"name"`
|
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type invalidParamsError struct {
|
|
err error
|
|
}
|
|
|
|
func (e *invalidParamsError) Error() string {
|
|
return e.err.Error()
|
|
}
|
|
|
|
func (e *invalidParamsError) Unwrap() error {
|
|
return e.err
|
|
}
|
|
|
|
type listMessagesArguments struct {
|
|
Mailbox string `json:"mailbox"`
|
|
Limit *int `json:"limit,omitempty"`
|
|
}
|
|
|
|
type getMessageArguments struct {
|
|
Mailbox string `json:"mailbox"`
|
|
UID *uint32 `json:"uid"`
|
|
}
|
|
|
|
func New(store secretstore.Store, mail MailService) Server {
|
|
return Server{
|
|
store: store,
|
|
mail: mail,
|
|
}
|
|
}
|
|
|
|
func NewRunner(server Server, in io.Reader, out io.Writer, errOut io.Writer) Runner {
|
|
if out == nil {
|
|
out = io.Discard
|
|
}
|
|
if errOut == nil {
|
|
errOut = io.Discard
|
|
}
|
|
|
|
return Runner{
|
|
server: server,
|
|
in: in,
|
|
out: out,
|
|
errOut: errOut,
|
|
}
|
|
}
|
|
|
|
func (s Server) Tools() []Tool {
|
|
return []Tool{
|
|
{
|
|
Name: "list_mailboxes",
|
|
Description: "List visible IMAP mailboxes for the configured account.",
|
|
InputSchema: map[string]any{
|
|
"type": "object",
|
|
},
|
|
},
|
|
{
|
|
Name: "list_messages",
|
|
Description: "List recent messages from a mailbox using IMAP UIDs.",
|
|
InputSchema: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"mailbox": map[string]any{
|
|
"type": "string",
|
|
"minLength": 1,
|
|
"pattern": "\\S",
|
|
},
|
|
"limit": map[string]any{
|
|
"type": "integer",
|
|
"default": imapclient.DefaultListMessagesLimit,
|
|
"minimum": 1,
|
|
"maximum": imapclient.MaxListMessagesLimit,
|
|
},
|
|
},
|
|
"required": []string{"mailbox"},
|
|
"additionalProperties": false,
|
|
},
|
|
},
|
|
{
|
|
Name: "get_message",
|
|
Description: "Fetch a single message by mailbox and IMAP UID.",
|
|
InputSchema: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"mailbox": map[string]any{
|
|
"type": "string",
|
|
"minLength": 1,
|
|
"pattern": "\\S",
|
|
},
|
|
"uid": map[string]any{
|
|
"type": "integer",
|
|
"minimum": 1,
|
|
},
|
|
},
|
|
"required": []string{"mailbox", "uid"},
|
|
"additionalProperties": false,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s Server) ListMailboxes(ctx context.Context) ([]imapclient.Mailbox, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.listMailboxes(ctx, cred)
|
|
}
|
|
|
|
func (s Server) ListMessages(ctx context.Context, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.listMessages(ctx, cred, mailbox, limit)
|
|
}
|
|
|
|
func (s Server) GetMessage(ctx context.Context, mailbox string, uid uint32) (imapclient.Message, error) {
|
|
cred, err := s.loadCredential(ctx)
|
|
if err != nil {
|
|
return imapclient.Message{}, err
|
|
}
|
|
return s.getMessage(ctx, cred, mailbox, uid)
|
|
}
|
|
|
|
func (r Runner) Run(ctx context.Context) error {
|
|
stopCancelRead := r.closeInputOnCancel(ctx)
|
|
defer stopCancelRead()
|
|
|
|
encoder := json.NewEncoder(r.out)
|
|
if r.in == nil {
|
|
return nil
|
|
}
|
|
|
|
decoder := json.NewDecoder(r.in)
|
|
session := runnerSession{}
|
|
for {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
var request rpcRequest
|
|
if err := decoder.Decode(&request); err != nil {
|
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
return ctxErr
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
if writeErr := writeRPCError(encoder, nil, jsonRPCParseErrorCode, err.Error(), nil); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := r.handleRPCRequest(ctx, encoder, &session, request); err != nil {
|
|
if errors.Is(err, context.Canceled) {
|
|
return err
|
|
}
|
|
if writeErr := writeRPCError(encoder, request.ID, jsonRPCInternalErrorCode, err.Error(), nil); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
type runnerSession struct {
|
|
initialized bool
|
|
ready bool
|
|
protocolVersion string
|
|
}
|
|
|
|
func (r Runner) handleRPCRequest(ctx context.Context, encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
|
if request.JSONRPC != jsonRPCVersion {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "jsonrpc must be 2.0", nil)
|
|
}
|
|
if request.Method == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "method is required", nil)
|
|
}
|
|
|
|
isNotification := len(request.ID) == 0
|
|
if !session.initialized {
|
|
switch request.Method {
|
|
case mcpMethodInitialize:
|
|
if isNotification {
|
|
return writeRPCError(encoder, nil, jsonRPCInvalidRequestCode, "initialize must be a request", nil)
|
|
}
|
|
return r.handleInitialize(encoder, session, request)
|
|
case mcpMethodPing:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{})
|
|
default:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidRequestCode, "server not initialized", nil)
|
|
}
|
|
}
|
|
|
|
switch request.Method {
|
|
case mcpMethodInitialized:
|
|
session.ready = true
|
|
return nil
|
|
case mcpMethodPing:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{})
|
|
case mcpMethodToolsList:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{"tools": r.server.Tools()})
|
|
case mcpMethodToolsCall:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return r.handleToolsCall(ctx, encoder, request)
|
|
default:
|
|
if isNotification {
|
|
return nil
|
|
}
|
|
return writeRPCError(encoder, request.ID, jsonRPCMethodNotFoundCode, fmt.Sprintf("method not found: %s", request.Method), nil)
|
|
}
|
|
}
|
|
|
|
func (r Runner) handleInitialize(encoder *json.Encoder, session *runnerSession, request rpcRequest) error {
|
|
var params initializeParams
|
|
if err := decodeParams(request.Params, ¶ms); err != nil {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
|
}
|
|
if params.ProtocolVersion == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "protocolVersion is required", nil)
|
|
}
|
|
|
|
negotiatedVersion := negotiateProtocolVersion(params.ProtocolVersion)
|
|
|
|
session.initialized = true
|
|
session.protocolVersion = negotiatedVersion
|
|
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"protocolVersion": negotiatedVersion,
|
|
"capabilities": map[string]any{
|
|
"tools": map[string]any{
|
|
"listChanged": true,
|
|
},
|
|
},
|
|
"serverInfo": map[string]any{
|
|
"name": mcpServerName,
|
|
"version": mcpServerVersion,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (r Runner) handleToolsCall(ctx context.Context, encoder *json.Encoder, request rpcRequest) error {
|
|
var params toolsCallParams
|
|
if err := decodeParams(request.Params, ¶ms); err != nil {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, err.Error(), nil)
|
|
}
|
|
if strings.TrimSpace(params.Name) == "" {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, "name is required", nil)
|
|
}
|
|
|
|
cred, err := r.server.loadCredential(ctx)
|
|
if err != nil {
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": err.Error(),
|
|
},
|
|
},
|
|
"isError": true,
|
|
})
|
|
}
|
|
|
|
result, err := r.server.handleTool(ctx, cred, params.Name, params.Arguments)
|
|
if err != nil {
|
|
var invalidErr *invalidParamsError
|
|
if errors.As(err, &invalidErr) {
|
|
return writeRPCError(encoder, request.ID, jsonRPCInvalidParamsCode, invalidErr.Error(), nil)
|
|
}
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": err.Error(),
|
|
},
|
|
},
|
|
"isError": true,
|
|
})
|
|
}
|
|
|
|
payload, err := json.Marshal(result)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return writeRPCResult(encoder, request.ID, map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": string(payload),
|
|
},
|
|
},
|
|
"isError": false,
|
|
})
|
|
}
|
|
|
|
func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, name string, rawArgs json.RawMessage) (any, error) {
|
|
switch name {
|
|
case "list_mailboxes":
|
|
return s.listMailboxes(ctx, cred)
|
|
case "list_messages":
|
|
var args listMessagesArguments
|
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
mailbox, err := validateMailbox(args.Mailbox)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
limit, err := normalizeListMessagesLimit(args.Limit)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
return s.listMessages(ctx, cred, mailbox, limit)
|
|
case "get_message":
|
|
var args getMessageArguments
|
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
mailbox, err := validateMailbox(args.Mailbox)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
uid, err := validateMessageUID(args.UID)
|
|
if err != nil {
|
|
return nil, &invalidParamsError{err: err}
|
|
}
|
|
return s.getMessage(ctx, cred, mailbox, uid)
|
|
default:
|
|
return nil, &invalidParamsError{err: fmt.Errorf("unknown tool: %s", name)}
|
|
}
|
|
}
|
|
|
|
func (s Server) listMailboxes(ctx context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) {
|
|
if s.mail == nil {
|
|
return nil, fmt.Errorf("mail service is not configured")
|
|
}
|
|
return s.mail.ListMailboxes(ctx, cred)
|
|
}
|
|
|
|
func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
|
if s.mail == nil {
|
|
return nil, fmt.Errorf("mail service is not configured")
|
|
}
|
|
mailbox, err := validateMailbox(mailbox)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.mail.ListMessages(ctx, cred, mailbox, limit)
|
|
}
|
|
|
|
func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) {
|
|
if s.mail == nil {
|
|
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
|
|
}
|
|
mailbox, err := validateMailbox(mailbox)
|
|
if err != nil {
|
|
return imapclient.Message{}, err
|
|
}
|
|
if uid == 0 {
|
|
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
|
|
}
|
|
return s.mail.GetMessage(ctx, cred, mailbox, uid)
|
|
}
|
|
|
|
func (s Server) loadCredential(ctx context.Context) (secretstore.Credential, error) {
|
|
if s.store == nil {
|
|
return secretstore.Credential{}, fmt.Errorf("secret store is not configured")
|
|
}
|
|
|
|
cred, err := s.store.Load(ctx, secretstore.DefaultAccountKey)
|
|
if err != nil {
|
|
if errors.Is(err, kwallet.ErrCredentialNotFound) || errors.Is(err, ErrCredentialsNotConfigured) {
|
|
return secretstore.Credential{}, ErrCredentialsNotConfigured
|
|
}
|
|
return secretstore.Credential{}, err
|
|
}
|
|
if err := cred.Validate(); err != nil {
|
|
return secretstore.Credential{}, fmt.Errorf("default credential is invalid: %w", err)
|
|
}
|
|
return cred, nil
|
|
}
|
|
|
|
func decodeParams(raw json.RawMessage, dest any) error {
|
|
if len(raw) == 0 {
|
|
raw = []byte("{}")
|
|
}
|
|
return json.Unmarshal(raw, dest)
|
|
}
|
|
|
|
func decodeArguments(raw json.RawMessage, dest any) error {
|
|
if len(raw) == 0 {
|
|
raw = []byte("{}")
|
|
}
|
|
decoder := json.NewDecoder(bytes.NewReader(raw))
|
|
decoder.DisallowUnknownFields()
|
|
if err := decoder.Decode(dest); err != nil {
|
|
return normalizeDecodeError(err)
|
|
}
|
|
if decoder.More() {
|
|
return fmt.Errorf("invalid JSON object")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeDecodeError(err error) error {
|
|
var syntaxErr *json.SyntaxError
|
|
var typeErr *json.UnmarshalTypeError
|
|
|
|
switch {
|
|
case errors.As(err, &syntaxErr):
|
|
return fmt.Errorf("invalid JSON arguments")
|
|
case errors.As(err, &typeErr):
|
|
if typeErr.Field != "" {
|
|
return fmt.Errorf("%s has an invalid type", typeErr.Field)
|
|
}
|
|
return fmt.Errorf("arguments have an invalid type")
|
|
default:
|
|
return fmt.Errorf("invalid arguments: %w", err)
|
|
}
|
|
}
|
|
|
|
func validateMailbox(mailbox string) (string, error) {
|
|
mailbox = strings.TrimSpace(mailbox)
|
|
if mailbox == "" {
|
|
return "", fmt.Errorf("mailbox is required")
|
|
}
|
|
return mailbox, nil
|
|
}
|
|
|
|
func normalizeListMessagesLimit(limit *int) (int, error) {
|
|
if limit == nil {
|
|
return imapclient.DefaultListMessagesLimit, nil
|
|
}
|
|
if *limit < 1 || *limit > imapclient.MaxListMessagesLimit {
|
|
return 0, fmt.Errorf("limit must be between 1 and %d", imapclient.MaxListMessagesLimit)
|
|
}
|
|
return *limit, nil
|
|
}
|
|
|
|
func validateMessageUID(uid *uint32) (uint32, error) {
|
|
if uid == nil {
|
|
return 0, fmt.Errorf("uid is required")
|
|
}
|
|
if *uid == 0 {
|
|
return 0, fmt.Errorf("uid must be greater than zero")
|
|
}
|
|
return *uid, nil
|
|
}
|
|
|
|
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
|
|
closer, ok := r.in.(io.Closer)
|
|
if !ok {
|
|
return func() {}
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = closer.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
return func() {
|
|
close(done)
|
|
}
|
|
}
|
|
|
|
func negotiateProtocolVersion(requested string) string {
|
|
for _, supported := range supportedProtocolVersions {
|
|
if requested == supported {
|
|
return supported
|
|
}
|
|
}
|
|
return supportedProtocolVersions[0]
|
|
}
|
|
|
|
func writeRPCResult(encoder *json.Encoder, id json.RawMessage, result any) error {
|
|
return encoder.Encode(rpcResponse{
|
|
JSONRPC: jsonRPCVersion,
|
|
ID: id,
|
|
Result: result,
|
|
})
|
|
}
|
|
|
|
func writeRPCError(encoder *json.Encoder, id json.RawMessage, code int, message string, data any) error {
|
|
return encoder.Encode(rpcResponse{
|
|
JSONRPC: jsonRPCVersion,
|
|
ID: id,
|
|
Error: &rpcErrorObject{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
},
|
|
})
|
|
}
|