fix: tighten mcp runner input handling
This commit is contained in:
parent
679abbe328
commit
0f622ab9d9
2 changed files with 328 additions and 18 deletions
|
|
@ -15,6 +15,11 @@ import (
|
|||
|
||||
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
|
||||
|
||||
const (
|
||||
defaultListMessagesLimit = 20
|
||||
maxListMessagesLimit = 50
|
||||
)
|
||||
|
||||
type MailService interface {
|
||||
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
|
||||
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
|
||||
|
|
@ -46,12 +51,12 @@ type toolRequest struct {
|
|||
|
||||
type listMessagesArguments struct {
|
||||
Mailbox string `json:"mailbox"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Limit *int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
type getMessageArguments struct {
|
||||
Mailbox string `json:"mailbox"`
|
||||
UID uint32 `json:"uid"`
|
||||
UID *uint32 `json:"uid"`
|
||||
}
|
||||
|
||||
func New(store secretstore.Store, mail MailService) Server {
|
||||
|
|
@ -89,10 +94,19 @@ func (s Server) Tools() []Tool {
|
|||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"mailbox": map[string]any{"type": "string"},
|
||||
"limit": map[string]any{"type": "integer"},
|
||||
"mailbox": map[string]any{
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"default": defaultListMessagesLimit,
|
||||
"minimum": 1,
|
||||
"maximum": maxListMessagesLimit,
|
||||
},
|
||||
},
|
||||
"required": []string{"mailbox"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -101,10 +115,17 @@ func (s Server) Tools() []Tool {
|
|||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"mailbox": map[string]any{"type": "string"},
|
||||
"uid": map[string]any{"type": "integer"},
|
||||
"mailbox": map[string]any{
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"uid": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": []string{"mailbox", "uid"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -140,6 +161,9 @@ func (r Runner) Run(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
stopCancelRead := r.closeInputOnCancel(ctx)
|
||||
defer stopCancelRead()
|
||||
|
||||
encoder := json.NewEncoder(r.out)
|
||||
if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil {
|
||||
return err
|
||||
|
|
@ -156,6 +180,9 @@ func (r Runner) Run(ctx context.Context) error {
|
|||
|
||||
var request toolRequest
|
||||
if err := decoder.Decode(&request); err != nil {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return ctxErr
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -184,13 +211,21 @@ func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, nam
|
|||
if err := decodeArguments(rawArgs, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.listMessages(ctx, cred, args.Mailbox, args.Limit)
|
||||
limit, err := normalizeListMessagesLimit(args.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.listMessages(ctx, cred, args.Mailbox, limit)
|
||||
case "get_message":
|
||||
var args getMessageArguments
|
||||
if err := decodeArguments(rawArgs, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.getMessage(ctx, cred, args.Mailbox, args.UID)
|
||||
uid, err := validateMessageUID(args.UID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.getMessage(ctx, cred, args.Mailbox, uid)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown tool: %s", name)
|
||||
}
|
||||
|
|
@ -207,9 +242,9 @@ func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, m
|
|||
if s.mail == nil {
|
||||
return nil, fmt.Errorf("mail service is not configured")
|
||||
}
|
||||
mailbox = strings.TrimSpace(mailbox)
|
||||
if mailbox == "" {
|
||||
return nil, fmt.Errorf("mailbox is required")
|
||||
mailbox, err := validateMailbox(mailbox)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.mail.ListMessages(ctx, cred, mailbox, limit)
|
||||
}
|
||||
|
|
@ -218,9 +253,9 @@ func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mai
|
|||
if s.mail == nil {
|
||||
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
|
||||
}
|
||||
mailbox = strings.TrimSpace(mailbox)
|
||||
if mailbox == "" {
|
||||
return imapclient.Message{}, fmt.Errorf("mailbox is required")
|
||||
mailbox, err := validateMailbox(mailbox)
|
||||
if err != nil {
|
||||
return imapclient.Message{}, err
|
||||
}
|
||||
if uid == 0 {
|
||||
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
|
||||
|
|
@ -250,8 +285,58 @@ func decodeArguments(raw json.RawMessage, dest any) error {
|
|||
if len(raw) == 0 {
|
||||
raw = []byte("{}")
|
||||
}
|
||||
if err := json.Unmarshal(raw, dest); err != nil {
|
||||
decoder := json.NewDecoder(strings.NewReader(string(raw)))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(dest); err != nil {
|
||||
return fmt.Errorf("invalid tool arguments: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMailbox(mailbox string) (string, error) {
|
||||
mailbox = strings.TrimSpace(mailbox)
|
||||
if mailbox == "" {
|
||||
return "", fmt.Errorf("mailbox is required")
|
||||
}
|
||||
return mailbox, nil
|
||||
}
|
||||
|
||||
func normalizeListMessagesLimit(limit *int) (int, error) {
|
||||
if limit == nil {
|
||||
return defaultListMessagesLimit, nil
|
||||
}
|
||||
if *limit < 1 || *limit > maxListMessagesLimit {
|
||||
return 0, fmt.Errorf("limit must be between 1 and %d", maxListMessagesLimit)
|
||||
}
|
||||
return *limit, nil
|
||||
}
|
||||
|
||||
func validateMessageUID(uid *uint32) (uint32, error) {
|
||||
if uid == nil {
|
||||
return 0, fmt.Errorf("uid is required")
|
||||
}
|
||||
if *uid == 0 {
|
||||
return 0, fmt.Errorf("uid must be greater than zero")
|
||||
}
|
||||
return *uid, nil
|
||||
}
|
||||
|
||||
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
|
||||
closer, ok := r.in.(io.Closer)
|
||||
if !ok {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = closer.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"email-mcp/internal/imapclient"
|
||||
"email-mcp/internal/secretstore"
|
||||
|
|
@ -248,3 +250,226 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
|
|||
t.Fatalf("expected missing credential error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerToolsAdvertiseValidatedArgumentContracts(t *testing.T) {
|
||||
tools := New(&storeStub{}, serviceStub{}).Tools()
|
||||
if len(tools) != 3 {
|
||||
t.Fatalf("expected 3 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
listMessages := tools[1]
|
||||
if listMessages.Name != "list_messages" {
|
||||
t.Fatalf("unexpected tool ordering: %#v", tools)
|
||||
}
|
||||
if got := listMessages.InputSchema["type"]; got != "object" {
|
||||
t.Fatalf("expected object schema, got %#v", got)
|
||||
}
|
||||
|
||||
listProps, ok := listMessages.InputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected properties map, got %#v", listMessages.InputSchema["properties"])
|
||||
}
|
||||
limitSchema, ok := listProps["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected limit schema, got %#v", listProps["limit"])
|
||||
}
|
||||
if got := limitSchema["default"]; got != float64(defaultListMessagesLimit) && got != defaultListMessagesLimit {
|
||||
t.Fatalf("expected limit default %d, got %#v", defaultListMessagesLimit, got)
|
||||
}
|
||||
if got := limitSchema["minimum"]; got != float64(1) && got != 1 {
|
||||
t.Fatalf("expected limit minimum 1, got %#v", got)
|
||||
}
|
||||
if got := limitSchema["maximum"]; got != float64(maxListMessagesLimit) && got != maxListMessagesLimit {
|
||||
t.Fatalf("expected limit maximum %d, got %#v", maxListMessagesLimit, got)
|
||||
}
|
||||
|
||||
getMessage := tools[2]
|
||||
if getMessage.Name != "get_message" {
|
||||
t.Fatalf("unexpected tool ordering: %#v", tools)
|
||||
}
|
||||
getProps, ok := getMessage.InputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected get_message properties map, got %#v", getMessage.InputSchema["properties"])
|
||||
}
|
||||
uidSchema, ok := getProps["uid"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected uid schema, got %#v", getProps["uid"])
|
||||
}
|
||||
if got := uidSchema["minimum"]; got != float64(1) && got != 1 {
|
||||
t.Fatalf("expected uid minimum 1, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) {
|
||||
store := &storeStub{
|
||||
credential: secretstore.Credential{
|
||||
Host: "imap.example.com",
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
}
|
||||
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
|
||||
output := &bytes.Buffer{}
|
||||
runner := NewRunner(New(store, serviceStub{
|
||||
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||
t.Fatal("ListMailboxes should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
|
||||
t.Fatal("ListMessages should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||
t.Fatal("GetMessage should not be called")
|
||||
return imapclient.Message{}, nil
|
||||
},
|
||||
}), input, output, &bytes.Buffer{})
|
||||
|
||||
if err := runner.Run(context.Background()); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(output)
|
||||
if err := decoder.Decode(&struct {
|
||||
Tools []Tool `json:"tools"`
|
||||
}{}); err != nil {
|
||||
t.Fatalf("failed to decode manifest: %v", err)
|
||||
}
|
||||
|
||||
var firstResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := decoder.Decode(&firstResponse); err != nil {
|
||||
t.Fatalf("failed to decode first error response: %v", err)
|
||||
}
|
||||
if firstResponse.Error != "limit must be between 1 and 50" {
|
||||
t.Fatalf("unexpected first error: %#v", firstResponse)
|
||||
}
|
||||
|
||||
var secondResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := decoder.Decode(&secondResponse); err != nil {
|
||||
t.Fatalf("failed to decode second error response: %v", err)
|
||||
}
|
||||
if secondResponse.Error != "uid is required" {
|
||||
t.Fatalf("unexpected second error: %#v", secondResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) {
|
||||
store := &storeStub{
|
||||
credential: secretstore.Credential{
|
||||
Host: "imap.example.com",
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
}
|
||||
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
|
||||
output := &bytes.Buffer{}
|
||||
runner := NewRunner(New(store, serviceStub{
|
||||
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||
t.Fatal("ListMailboxes should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
||||
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != defaultListMessagesLimit {
|
||||
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
|
||||
}
|
||||
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
|
||||
},
|
||||
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||
t.Fatal("GetMessage should not be called")
|
||||
return imapclient.Message{}, nil
|
||||
},
|
||||
}), input, output, &bytes.Buffer{})
|
||||
|
||||
if err := runner.Run(context.Background()); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) {
|
||||
store := &storeStub{
|
||||
credential: secretstore.Credential{
|
||||
Host: "imap.example.com",
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
}
|
||||
input := newBlockingReadCloser()
|
||||
runner := NewRunner(New(store, serviceStub{
|
||||
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||
t.Fatal("ListMailboxes should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
|
||||
t.Fatal("ListMessages should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||
t.Fatal("GetMessage should not be called")
|
||||
return imapclient.Message{}, nil
|
||||
},
|
||||
}), input, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- runner.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-input.started:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("runner never started reading input")
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context cancellation, got %v", err)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("runner did not stop after context cancellation")
|
||||
}
|
||||
|
||||
if !input.closed {
|
||||
t.Fatal("expected runner to close input reader on cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
type blockingReadCloser struct {
|
||||
started chan struct{}
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newBlockingReadCloser() *blockingReadCloser {
|
||||
return &blockingReadCloser{
|
||||
started: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *blockingReadCloser) Read(_ []byte) (int, error) {
|
||||
select {
|
||||
case <-r.started:
|
||||
default:
|
||||
close(r.started)
|
||||
}
|
||||
<-r.done
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *blockingReadCloser) Close() error {
|
||||
if !r.closed {
|
||||
r.closed = true
|
||||
close(r.done)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue