From 0f622ab9d94df89fe9d1f3fe5b9bf3628cbea624 Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Fri, 10 Apr 2026 12:10:42 +0200 Subject: [PATCH] fix: tighten mcp runner input handling --- internal/mcpserver/server.go | 121 +++++++++++++--- internal/mcpserver/server_test.go | 225 ++++++++++++++++++++++++++++++ 2 files changed, 328 insertions(+), 18 deletions(-) diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 4e76803..19618fb 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -15,6 +15,11 @@ import ( var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`") +const ( + defaultListMessagesLimit = 20 + maxListMessagesLimit = 50 +) + type MailService interface { ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) @@ -46,12 +51,12 @@ type toolRequest struct { type listMessagesArguments struct { Mailbox string `json:"mailbox"` - Limit int `json:"limit,omitempty"` + Limit *int `json:"limit,omitempty"` } type getMessageArguments struct { - Mailbox string `json:"mailbox"` - UID uint32 `json:"uid"` + Mailbox string `json:"mailbox"` + UID *uint32 `json:"uid"` } func New(store secretstore.Store, mail MailService) Server { @@ -89,10 +94,19 @@ func (s Server) Tools() []Tool { InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ - "mailbox": map[string]any{"type": "string"}, - "limit": map[string]any{"type": "integer"}, + "mailbox": map[string]any{ + "type": "string", + "minLength": 1, + }, + "limit": map[string]any{ + "type": "integer", + "default": defaultListMessagesLimit, + "minimum": 1, + "maximum": maxListMessagesLimit, + }, }, - "required": []string{"mailbox"}, + "required": []string{"mailbox"}, + "additionalProperties": false, }, }, { @@ -101,10 +115,17 @@ func (s Server) Tools() []Tool { InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ - "mailbox": map[string]any{"type": "string"}, - "uid": map[string]any{"type": "integer"}, + "mailbox": map[string]any{ + "type": "string", + "minLength": 1, + }, + "uid": map[string]any{ + "type": "integer", + "minimum": 1, + }, }, - "required": []string{"mailbox", "uid"}, + "required": []string{"mailbox", "uid"}, + "additionalProperties": false, }, }, } @@ -140,6 +161,9 @@ func (r Runner) Run(ctx context.Context) error { return err } + stopCancelRead := r.closeInputOnCancel(ctx) + defer stopCancelRead() + encoder := json.NewEncoder(r.out) if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil { return err @@ -156,6 +180,9 @@ func (r Runner) Run(ctx context.Context) error { var request toolRequest if err := decoder.Decode(&request); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if errors.Is(err, io.EOF) { return nil } @@ -184,13 +211,21 @@ func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, nam if err := decodeArguments(rawArgs, &args); err != nil { return nil, err } - return s.listMessages(ctx, cred, args.Mailbox, args.Limit) + limit, err := normalizeListMessagesLimit(args.Limit) + if err != nil { + return nil, err + } + return s.listMessages(ctx, cred, args.Mailbox, limit) case "get_message": var args getMessageArguments if err := decodeArguments(rawArgs, &args); err != nil { return nil, err } - return s.getMessage(ctx, cred, args.Mailbox, args.UID) + uid, err := validateMessageUID(args.UID) + if err != nil { + return nil, err + } + return s.getMessage(ctx, cred, args.Mailbox, uid) default: return nil, fmt.Errorf("unknown tool: %s", name) } @@ -207,9 +242,9 @@ func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, m if s.mail == nil { return nil, fmt.Errorf("mail service is not configured") } - mailbox = strings.TrimSpace(mailbox) - if mailbox == "" { - return nil, fmt.Errorf("mailbox is required") + mailbox, err := validateMailbox(mailbox) + if err != nil { + return nil, err } return s.mail.ListMessages(ctx, cred, mailbox, limit) } @@ -218,9 +253,9 @@ func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mai if s.mail == nil { return imapclient.Message{}, fmt.Errorf("mail service is not configured") } - mailbox = strings.TrimSpace(mailbox) - if mailbox == "" { - return imapclient.Message{}, fmt.Errorf("mailbox is required") + mailbox, err := validateMailbox(mailbox) + if err != nil { + return imapclient.Message{}, err } if uid == 0 { return imapclient.Message{}, fmt.Errorf("uid must be greater than zero") @@ -250,8 +285,58 @@ func decodeArguments(raw json.RawMessage, dest any) error { if len(raw) == 0 { raw = []byte("{}") } - if err := json.Unmarshal(raw, dest); err != nil { + decoder := json.NewDecoder(strings.NewReader(string(raw))) + decoder.DisallowUnknownFields() + if err := decoder.Decode(dest); err != nil { return fmt.Errorf("invalid tool arguments: %w", err) } return nil } + +func validateMailbox(mailbox string) (string, error) { + mailbox = strings.TrimSpace(mailbox) + if mailbox == "" { + return "", fmt.Errorf("mailbox is required") + } + return mailbox, nil +} + +func normalizeListMessagesLimit(limit *int) (int, error) { + if limit == nil { + return defaultListMessagesLimit, nil + } + if *limit < 1 || *limit > maxListMessagesLimit { + return 0, fmt.Errorf("limit must be between 1 and %d", maxListMessagesLimit) + } + return *limit, nil +} + +func validateMessageUID(uid *uint32) (uint32, error) { + if uid == nil { + return 0, fmt.Errorf("uid is required") + } + if *uid == 0 { + return 0, fmt.Errorf("uid must be greater than zero") + } + return *uid, nil +} + +func (r Runner) closeInputOnCancel(ctx context.Context) func() { + closer, ok := r.in.(io.Closer) + if !ok { + return func() {} + } + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = closer.Close() + case <-done: + } + }() + + return func() { + close(done) + } +} diff --git a/internal/mcpserver/server_test.go b/internal/mcpserver/server_test.go index 37f3be4..8768044 100644 --- a/internal/mcpserver/server_test.go +++ b/internal/mcpserver/server_test.go @@ -5,7 +5,9 @@ import ( "context" "encoding/json" "errors" + "io" "testing" + "time" "email-mcp/internal/imapclient" "email-mcp/internal/secretstore" @@ -248,3 +250,226 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) { t.Fatalf("expected missing credential error, got %v", err) } } + +func TestServerToolsAdvertiseValidatedArgumentContracts(t *testing.T) { + tools := New(&storeStub{}, serviceStub{}).Tools() + if len(tools) != 3 { + t.Fatalf("expected 3 tools, got %d", len(tools)) + } + + listMessages := tools[1] + if listMessages.Name != "list_messages" { + t.Fatalf("unexpected tool ordering: %#v", tools) + } + if got := listMessages.InputSchema["type"]; got != "object" { + t.Fatalf("expected object schema, got %#v", got) + } + + listProps, ok := listMessages.InputSchema["properties"].(map[string]any) + if !ok { + t.Fatalf("expected properties map, got %#v", listMessages.InputSchema["properties"]) + } + limitSchema, ok := listProps["limit"].(map[string]any) + if !ok { + t.Fatalf("expected limit schema, got %#v", listProps["limit"]) + } + if got := limitSchema["default"]; got != float64(defaultListMessagesLimit) && got != defaultListMessagesLimit { + t.Fatalf("expected limit default %d, got %#v", defaultListMessagesLimit, got) + } + if got := limitSchema["minimum"]; got != float64(1) && got != 1 { + t.Fatalf("expected limit minimum 1, got %#v", got) + } + if got := limitSchema["maximum"]; got != float64(maxListMessagesLimit) && got != maxListMessagesLimit { + t.Fatalf("expected limit maximum %d, got %#v", maxListMessagesLimit, got) + } + + getMessage := tools[2] + if getMessage.Name != "get_message" { + t.Fatalf("unexpected tool ordering: %#v", tools) + } + getProps, ok := getMessage.InputSchema["properties"].(map[string]any) + if !ok { + t.Fatalf("expected get_message properties map, got %#v", getMessage.InputSchema["properties"]) + } + uidSchema, ok := getProps["uid"].(map[string]any) + if !ok { + t.Fatalf("expected uid schema, got %#v", getProps["uid"]) + } + if got := uidSchema["minimum"]; got != float64(1) && got != 1 { + t.Fatalf("expected uid minimum 1, got %#v", got) + } +} + +func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) { + store := &storeStub{ + credential: secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }, + } + input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n") + output := &bytes.Buffer{} + runner := NewRunner(New(store, serviceStub{ + listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { + t.Fatal("ListMailboxes should not be called") + return nil, nil + }, + listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) { + t.Fatal("ListMessages should not be called") + return nil, nil + }, + getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) { + t.Fatal("GetMessage should not be called") + return imapclient.Message{}, nil + }, + }), input, output, &bytes.Buffer{}) + + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) + } + + decoder := json.NewDecoder(output) + if err := decoder.Decode(&struct { + Tools []Tool `json:"tools"` + }{}); err != nil { + t.Fatalf("failed to decode manifest: %v", err) + } + + var firstResponse struct { + Error string `json:"error"` + } + if err := decoder.Decode(&firstResponse); err != nil { + t.Fatalf("failed to decode first error response: %v", err) + } + if firstResponse.Error != "limit must be between 1 and 50" { + t.Fatalf("unexpected first error: %#v", firstResponse) + } + + var secondResponse struct { + Error string `json:"error"` + } + if err := decoder.Decode(&secondResponse); err != nil { + t.Fatalf("failed to decode second error response: %v", err) + } + if secondResponse.Error != "uid is required" { + t.Fatalf("unexpected second error: %#v", secondResponse) + } +} + +func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) { + store := &storeStub{ + credential: secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }, + } + input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n") + output := &bytes.Buffer{} + runner := NewRunner(New(store, serviceStub{ + listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { + t.Fatal("ListMailboxes should not be called") + return nil, nil + }, + listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) { + if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != defaultListMessagesLimit { + t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit) + } + return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil + }, + getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) { + t.Fatal("GetMessage should not be called") + return imapclient.Message{}, nil + }, + }), input, output, &bytes.Buffer{}) + + if err := runner.Run(context.Background()); err != nil { + t.Fatalf("Run returned error: %v", err) + } +} + +func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) { + store := &storeStub{ + credential: secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }, + } + input := newBlockingReadCloser() + runner := NewRunner(New(store, serviceStub{ + listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) { + t.Fatal("ListMailboxes should not be called") + return nil, nil + }, + listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) { + t.Fatal("ListMessages should not be called") + return nil, nil + }, + getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) { + t.Fatal("GetMessage should not be called") + return imapclient.Message{}, nil + }, + }), input, &bytes.Buffer{}, &bytes.Buffer{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- runner.Run(ctx) + }() + + select { + case <-input.started: + case <-time.After(200 * time.Millisecond): + t.Fatal("runner never started reading input") + } + + cancel() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("runner did not stop after context cancellation") + } + + if !input.closed { + t.Fatal("expected runner to close input reader on cancellation") + } +} + +type blockingReadCloser struct { + started chan struct{} + closed bool + done chan struct{} +} + +func newBlockingReadCloser() *blockingReadCloser { + return &blockingReadCloser{ + started: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (r *blockingReadCloser) Read(_ []byte) (int, error) { + select { + case <-r.started: + default: + close(r.started) + } + <-r.done + return 0, io.EOF +} + +func (r *blockingReadCloser) Close() error { + if !r.closed { + r.closed = true + close(r.done) + } + return nil +}