fix: tighten mcp runner input handling

This commit is contained in:
thibaud-leclere 2026-04-10 12:10:42 +02:00
parent 679abbe328
commit 0f622ab9d9
2 changed files with 328 additions and 18 deletions

View file

@ -15,6 +15,11 @@ import (
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
const (
defaultListMessagesLimit = 20
maxListMessagesLimit = 50
)
type MailService interface {
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
@ -46,12 +51,12 @@ type toolRequest struct {
type listMessagesArguments struct {
Mailbox string `json:"mailbox"`
Limit int `json:"limit,omitempty"`
Limit *int `json:"limit,omitempty"`
}
type getMessageArguments struct {
Mailbox string `json:"mailbox"`
UID uint32 `json:"uid"`
UID *uint32 `json:"uid"`
}
func New(store secretstore.Store, mail MailService) Server {
@ -89,10 +94,19 @@ func (s Server) Tools() []Tool {
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"mailbox": map[string]any{"type": "string"},
"limit": map[string]any{"type": "integer"},
"mailbox": map[string]any{
"type": "string",
"minLength": 1,
},
"limit": map[string]any{
"type": "integer",
"default": defaultListMessagesLimit,
"minimum": 1,
"maximum": maxListMessagesLimit,
},
},
"required": []string{"mailbox"},
"additionalProperties": false,
},
},
{
@ -101,10 +115,17 @@ func (s Server) Tools() []Tool {
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"mailbox": map[string]any{"type": "string"},
"uid": map[string]any{"type": "integer"},
"mailbox": map[string]any{
"type": "string",
"minLength": 1,
},
"uid": map[string]any{
"type": "integer",
"minimum": 1,
},
},
"required": []string{"mailbox", "uid"},
"additionalProperties": false,
},
},
}
@ -140,6 +161,9 @@ func (r Runner) Run(ctx context.Context) error {
return err
}
stopCancelRead := r.closeInputOnCancel(ctx)
defer stopCancelRead()
encoder := json.NewEncoder(r.out)
if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil {
return err
@ -156,6 +180,9 @@ func (r Runner) Run(ctx context.Context) error {
var request toolRequest
if err := decoder.Decode(&request); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
if errors.Is(err, io.EOF) {
return nil
}
@ -184,13 +211,21 @@ func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, nam
if err := decodeArguments(rawArgs, &args); err != nil {
return nil, err
}
return s.listMessages(ctx, cred, args.Mailbox, args.Limit)
limit, err := normalizeListMessagesLimit(args.Limit)
if err != nil {
return nil, err
}
return s.listMessages(ctx, cred, args.Mailbox, limit)
case "get_message":
var args getMessageArguments
if err := decodeArguments(rawArgs, &args); err != nil {
return nil, err
}
return s.getMessage(ctx, cred, args.Mailbox, args.UID)
uid, err := validateMessageUID(args.UID)
if err != nil {
return nil, err
}
return s.getMessage(ctx, cred, args.Mailbox, uid)
default:
return nil, fmt.Errorf("unknown tool: %s", name)
}
@ -207,9 +242,9 @@ func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, m
if s.mail == nil {
return nil, fmt.Errorf("mail service is not configured")
}
mailbox = strings.TrimSpace(mailbox)
if mailbox == "" {
return nil, fmt.Errorf("mailbox is required")
mailbox, err := validateMailbox(mailbox)
if err != nil {
return nil, err
}
return s.mail.ListMessages(ctx, cred, mailbox, limit)
}
@ -218,9 +253,9 @@ func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mai
if s.mail == nil {
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
}
mailbox = strings.TrimSpace(mailbox)
if mailbox == "" {
return imapclient.Message{}, fmt.Errorf("mailbox is required")
mailbox, err := validateMailbox(mailbox)
if err != nil {
return imapclient.Message{}, err
}
if uid == 0 {
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
@ -250,8 +285,58 @@ func decodeArguments(raw json.RawMessage, dest any) error {
if len(raw) == 0 {
raw = []byte("{}")
}
if err := json.Unmarshal(raw, dest); err != nil {
decoder := json.NewDecoder(strings.NewReader(string(raw)))
decoder.DisallowUnknownFields()
if err := decoder.Decode(dest); err != nil {
return fmt.Errorf("invalid tool arguments: %w", err)
}
return nil
}
func validateMailbox(mailbox string) (string, error) {
mailbox = strings.TrimSpace(mailbox)
if mailbox == "" {
return "", fmt.Errorf("mailbox is required")
}
return mailbox, nil
}
func normalizeListMessagesLimit(limit *int) (int, error) {
if limit == nil {
return defaultListMessagesLimit, nil
}
if *limit < 1 || *limit > maxListMessagesLimit {
return 0, fmt.Errorf("limit must be between 1 and %d", maxListMessagesLimit)
}
return *limit, nil
}
func validateMessageUID(uid *uint32) (uint32, error) {
if uid == nil {
return 0, fmt.Errorf("uid is required")
}
if *uid == 0 {
return 0, fmt.Errorf("uid must be greater than zero")
}
return *uid, nil
}
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
closer, ok := r.in.(io.Closer)
if !ok {
return func() {}
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = closer.Close()
case <-done:
}
}()
return func() {
close(done)
}
}

View file

@ -5,7 +5,9 @@ import (
"context"
"encoding/json"
"errors"
"io"
"testing"
"time"
"email-mcp/internal/imapclient"
"email-mcp/internal/secretstore"
@ -248,3 +250,226 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
t.Fatalf("expected missing credential error, got %v", err)
}
}
func TestServerToolsAdvertiseValidatedArgumentContracts(t *testing.T) {
tools := New(&storeStub{}, serviceStub{}).Tools()
if len(tools) != 3 {
t.Fatalf("expected 3 tools, got %d", len(tools))
}
listMessages := tools[1]
if listMessages.Name != "list_messages" {
t.Fatalf("unexpected tool ordering: %#v", tools)
}
if got := listMessages.InputSchema["type"]; got != "object" {
t.Fatalf("expected object schema, got %#v", got)
}
listProps, ok := listMessages.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatalf("expected properties map, got %#v", listMessages.InputSchema["properties"])
}
limitSchema, ok := listProps["limit"].(map[string]any)
if !ok {
t.Fatalf("expected limit schema, got %#v", listProps["limit"])
}
if got := limitSchema["default"]; got != float64(defaultListMessagesLimit) && got != defaultListMessagesLimit {
t.Fatalf("expected limit default %d, got %#v", defaultListMessagesLimit, got)
}
if got := limitSchema["minimum"]; got != float64(1) && got != 1 {
t.Fatalf("expected limit minimum 1, got %#v", got)
}
if got := limitSchema["maximum"]; got != float64(maxListMessagesLimit) && got != maxListMessagesLimit {
t.Fatalf("expected limit maximum %d, got %#v", maxListMessagesLimit, got)
}
getMessage := tools[2]
if getMessage.Name != "get_message" {
t.Fatalf("unexpected tool ordering: %#v", tools)
}
getProps, ok := getMessage.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatalf("expected get_message properties map, got %#v", getMessage.InputSchema["properties"])
}
uidSchema, ok := getProps["uid"].(map[string]any)
if !ok {
t.Fatalf("expected uid schema, got %#v", getProps["uid"])
}
if got := uidSchema["minimum"]; got != float64(1) && got != 1 {
t.Fatalf("expected uid minimum 1, got %#v", got)
}
}
func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, output, &bytes.Buffer{})
if err := runner.Run(context.Background()); err != nil {
t.Fatalf("Run returned error: %v", err)
}
decoder := json.NewDecoder(output)
if err := decoder.Decode(&struct {
Tools []Tool `json:"tools"`
}{}); err != nil {
t.Fatalf("failed to decode manifest: %v", err)
}
var firstResponse struct {
Error string `json:"error"`
}
if err := decoder.Decode(&firstResponse); err != nil {
t.Fatalf("failed to decode first error response: %v", err)
}
if firstResponse.Error != "limit must be between 1 and 50" {
t.Fatalf("unexpected first error: %#v", firstResponse)
}
var secondResponse struct {
Error string `json:"error"`
}
if err := decoder.Decode(&secondResponse); err != nil {
t.Fatalf("failed to decode second error response: %v", err)
}
if secondResponse.Error != "uid is required" {
t.Fatalf("unexpected second error: %#v", secondResponse)
}
}
func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != defaultListMessagesLimit {
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
}
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, output, &bytes.Buffer{})
if err := runner.Run(context.Background()); err != nil {
t.Fatalf("Run returned error: %v", err)
}
}
func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := newBlockingReadCloser()
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, &bytes.Buffer{}, &bytes.Buffer{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
done <- runner.Run(ctx)
}()
select {
case <-input.started:
case <-time.After(200 * time.Millisecond):
t.Fatal("runner never started reading input")
}
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context cancellation, got %v", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("runner did not stop after context cancellation")
}
if !input.closed {
t.Fatal("expected runner to close input reader on cancellation")
}
}
type blockingReadCloser struct {
started chan struct{}
closed bool
done chan struct{}
}
func newBlockingReadCloser() *blockingReadCloser {
return &blockingReadCloser{
started: make(chan struct{}),
done: make(chan struct{}),
}
}
func (r *blockingReadCloser) Read(_ []byte) (int, error) {
select {
case <-r.started:
default:
close(r.started)
}
<-r.done
return 0, io.EOF
}
func (r *blockingReadCloser) Close() error {
if !r.closed {
r.closed = true
close(r.done)
}
return nil
}