fix: tighten mcp runner input handling
This commit is contained in:
parent
679abbe328
commit
0f622ab9d9
2 changed files with 328 additions and 18 deletions
|
|
@ -15,6 +15,11 @@ import (
|
||||||
|
|
||||||
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
|
var ErrCredentialsNotConfigured = errors.New("credentials not configured; run `email-mcp setup`")
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultListMessagesLimit = 20
|
||||||
|
maxListMessagesLimit = 50
|
||||||
|
)
|
||||||
|
|
||||||
type MailService interface {
|
type MailService interface {
|
||||||
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
|
ListMailboxes(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
|
||||||
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
|
ListMessages(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
|
||||||
|
|
@ -46,12 +51,12 @@ type toolRequest struct {
|
||||||
|
|
||||||
type listMessagesArguments struct {
|
type listMessagesArguments struct {
|
||||||
Mailbox string `json:"mailbox"`
|
Mailbox string `json:"mailbox"`
|
||||||
Limit int `json:"limit,omitempty"`
|
Limit *int `json:"limit,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type getMessageArguments struct {
|
type getMessageArguments struct {
|
||||||
Mailbox string `json:"mailbox"`
|
Mailbox string `json:"mailbox"`
|
||||||
UID uint32 `json:"uid"`
|
UID *uint32 `json:"uid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(store secretstore.Store, mail MailService) Server {
|
func New(store secretstore.Store, mail MailService) Server {
|
||||||
|
|
@ -89,10 +94,19 @@ func (s Server) Tools() []Tool {
|
||||||
InputSchema: map[string]any{
|
InputSchema: map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]any{
|
"properties": map[string]any{
|
||||||
"mailbox": map[string]any{"type": "string"},
|
"mailbox": map[string]any{
|
||||||
"limit": map[string]any{"type": "integer"},
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
},
|
||||||
|
"limit": map[string]any{
|
||||||
|
"type": "integer",
|
||||||
|
"default": defaultListMessagesLimit,
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": maxListMessagesLimit,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"mailbox"},
|
"required": []string{"mailbox"},
|
||||||
|
"additionalProperties": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -101,10 +115,17 @@ func (s Server) Tools() []Tool {
|
||||||
InputSchema: map[string]any{
|
InputSchema: map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]any{
|
"properties": map[string]any{
|
||||||
"mailbox": map[string]any{"type": "string"},
|
"mailbox": map[string]any{
|
||||||
"uid": map[string]any{"type": "integer"},
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
},
|
||||||
|
"uid": map[string]any{
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"mailbox", "uid"},
|
"required": []string{"mailbox", "uid"},
|
||||||
|
"additionalProperties": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -140,6 +161,9 @@ func (r Runner) Run(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stopCancelRead := r.closeInputOnCancel(ctx)
|
||||||
|
defer stopCancelRead()
|
||||||
|
|
||||||
encoder := json.NewEncoder(r.out)
|
encoder := json.NewEncoder(r.out)
|
||||||
if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil {
|
if err := encoder.Encode(map[string]any{"tools": r.server.Tools()}); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -156,6 +180,9 @@ func (r Runner) Run(ctx context.Context) error {
|
||||||
|
|
||||||
var request toolRequest
|
var request toolRequest
|
||||||
if err := decoder.Decode(&request); err != nil {
|
if err := decoder.Decode(&request); err != nil {
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return ctxErr
|
||||||
|
}
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -184,13 +211,21 @@ func (s Server) handleTool(ctx context.Context, cred secretstore.Credential, nam
|
||||||
if err := decodeArguments(rawArgs, &args); err != nil {
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return s.listMessages(ctx, cred, args.Mailbox, args.Limit)
|
limit, err := normalizeListMessagesLimit(args.Limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.listMessages(ctx, cred, args.Mailbox, limit)
|
||||||
case "get_message":
|
case "get_message":
|
||||||
var args getMessageArguments
|
var args getMessageArguments
|
||||||
if err := decodeArguments(rawArgs, &args); err != nil {
|
if err := decodeArguments(rawArgs, &args); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return s.getMessage(ctx, cred, args.Mailbox, args.UID)
|
uid, err := validateMessageUID(args.UID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.getMessage(ctx, cred, args.Mailbox, uid)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown tool: %s", name)
|
return nil, fmt.Errorf("unknown tool: %s", name)
|
||||||
}
|
}
|
||||||
|
|
@ -207,9 +242,9 @@ func (s Server) listMessages(ctx context.Context, cred secretstore.Credential, m
|
||||||
if s.mail == nil {
|
if s.mail == nil {
|
||||||
return nil, fmt.Errorf("mail service is not configured")
|
return nil, fmt.Errorf("mail service is not configured")
|
||||||
}
|
}
|
||||||
mailbox = strings.TrimSpace(mailbox)
|
mailbox, err := validateMailbox(mailbox)
|
||||||
if mailbox == "" {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("mailbox is required")
|
return nil, err
|
||||||
}
|
}
|
||||||
return s.mail.ListMessages(ctx, cred, mailbox, limit)
|
return s.mail.ListMessages(ctx, cred, mailbox, limit)
|
||||||
}
|
}
|
||||||
|
|
@ -218,9 +253,9 @@ func (s Server) getMessage(ctx context.Context, cred secretstore.Credential, mai
|
||||||
if s.mail == nil {
|
if s.mail == nil {
|
||||||
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
|
return imapclient.Message{}, fmt.Errorf("mail service is not configured")
|
||||||
}
|
}
|
||||||
mailbox = strings.TrimSpace(mailbox)
|
mailbox, err := validateMailbox(mailbox)
|
||||||
if mailbox == "" {
|
if err != nil {
|
||||||
return imapclient.Message{}, fmt.Errorf("mailbox is required")
|
return imapclient.Message{}, err
|
||||||
}
|
}
|
||||||
if uid == 0 {
|
if uid == 0 {
|
||||||
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
|
return imapclient.Message{}, fmt.Errorf("uid must be greater than zero")
|
||||||
|
|
@ -250,8 +285,58 @@ func decodeArguments(raw json.RawMessage, dest any) error {
|
||||||
if len(raw) == 0 {
|
if len(raw) == 0 {
|
||||||
raw = []byte("{}")
|
raw = []byte("{}")
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(raw, dest); err != nil {
|
decoder := json.NewDecoder(strings.NewReader(string(raw)))
|
||||||
|
decoder.DisallowUnknownFields()
|
||||||
|
if err := decoder.Decode(dest); err != nil {
|
||||||
return fmt.Errorf("invalid tool arguments: %w", err)
|
return fmt.Errorf("invalid tool arguments: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateMailbox(mailbox string) (string, error) {
|
||||||
|
mailbox = strings.TrimSpace(mailbox)
|
||||||
|
if mailbox == "" {
|
||||||
|
return "", fmt.Errorf("mailbox is required")
|
||||||
|
}
|
||||||
|
return mailbox, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeListMessagesLimit(limit *int) (int, error) {
|
||||||
|
if limit == nil {
|
||||||
|
return defaultListMessagesLimit, nil
|
||||||
|
}
|
||||||
|
if *limit < 1 || *limit > maxListMessagesLimit {
|
||||||
|
return 0, fmt.Errorf("limit must be between 1 and %d", maxListMessagesLimit)
|
||||||
|
}
|
||||||
|
return *limit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateMessageUID(uid *uint32) (uint32, error) {
|
||||||
|
if uid == nil {
|
||||||
|
return 0, fmt.Errorf("uid is required")
|
||||||
|
}
|
||||||
|
if *uid == 0 {
|
||||||
|
return 0, fmt.Errorf("uid must be greater than zero")
|
||||||
|
}
|
||||||
|
return *uid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Runner) closeInputOnCancel(ctx context.Context) func() {
|
||||||
|
closer, ok := r.in.(io.Closer)
|
||||||
|
if !ok {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = closer.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"email-mcp/internal/imapclient"
|
"email-mcp/internal/imapclient"
|
||||||
"email-mcp/internal/secretstore"
|
"email-mcp/internal/secretstore"
|
||||||
|
|
@ -248,3 +250,226 @@ func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
|
||||||
t.Fatalf("expected missing credential error, got %v", err)
|
t.Fatalf("expected missing credential error, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerToolsAdvertiseValidatedArgumentContracts(t *testing.T) {
|
||||||
|
tools := New(&storeStub{}, serviceStub{}).Tools()
|
||||||
|
if len(tools) != 3 {
|
||||||
|
t.Fatalf("expected 3 tools, got %d", len(tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
listMessages := tools[1]
|
||||||
|
if listMessages.Name != "list_messages" {
|
||||||
|
t.Fatalf("unexpected tool ordering: %#v", tools)
|
||||||
|
}
|
||||||
|
if got := listMessages.InputSchema["type"]; got != "object" {
|
||||||
|
t.Fatalf("expected object schema, got %#v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
listProps, ok := listMessages.InputSchema["properties"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected properties map, got %#v", listMessages.InputSchema["properties"])
|
||||||
|
}
|
||||||
|
limitSchema, ok := listProps["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected limit schema, got %#v", listProps["limit"])
|
||||||
|
}
|
||||||
|
if got := limitSchema["default"]; got != float64(defaultListMessagesLimit) && got != defaultListMessagesLimit {
|
||||||
|
t.Fatalf("expected limit default %d, got %#v", defaultListMessagesLimit, got)
|
||||||
|
}
|
||||||
|
if got := limitSchema["minimum"]; got != float64(1) && got != 1 {
|
||||||
|
t.Fatalf("expected limit minimum 1, got %#v", got)
|
||||||
|
}
|
||||||
|
if got := limitSchema["maximum"]; got != float64(maxListMessagesLimit) && got != maxListMessagesLimit {
|
||||||
|
t.Fatalf("expected limit maximum %d, got %#v", maxListMessagesLimit, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
getMessage := tools[2]
|
||||||
|
if getMessage.Name != "get_message" {
|
||||||
|
t.Fatalf("unexpected tool ordering: %#v", tools)
|
||||||
|
}
|
||||||
|
getProps, ok := getMessage.InputSchema["properties"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected get_message properties map, got %#v", getMessage.InputSchema["properties"])
|
||||||
|
}
|
||||||
|
uidSchema, ok := getProps["uid"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected uid schema, got %#v", getProps["uid"])
|
||||||
|
}
|
||||||
|
if got := uidSchema["minimum"]; got != float64(1) && got != 1 {
|
||||||
|
t.Fatalf("expected uid minimum 1, got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) {
|
||||||
|
store := &storeStub{
|
||||||
|
credential: secretstore.Credential{
|
||||||
|
Host: "imap.example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
|
||||||
|
output := &bytes.Buffer{}
|
||||||
|
runner := NewRunner(New(store, serviceStub{
|
||||||
|
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||||
|
t.Fatal("ListMailboxes should not be called")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
|
||||||
|
t.Fatal("ListMessages should not be called")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||||
|
t.Fatal("GetMessage should not be called")
|
||||||
|
return imapclient.Message{}, nil
|
||||||
|
},
|
||||||
|
}), input, output, &bytes.Buffer{})
|
||||||
|
|
||||||
|
if err := runner.Run(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(output)
|
||||||
|
if err := decoder.Decode(&struct {
|
||||||
|
Tools []Tool `json:"tools"`
|
||||||
|
}{}); err != nil {
|
||||||
|
t.Fatalf("failed to decode manifest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var firstResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := decoder.Decode(&firstResponse); err != nil {
|
||||||
|
t.Fatalf("failed to decode first error response: %v", err)
|
||||||
|
}
|
||||||
|
if firstResponse.Error != "limit must be between 1 and 50" {
|
||||||
|
t.Fatalf("unexpected first error: %#v", firstResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
var secondResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := decoder.Decode(&secondResponse); err != nil {
|
||||||
|
t.Fatalf("failed to decode second error response: %v", err)
|
||||||
|
}
|
||||||
|
if secondResponse.Error != "uid is required" {
|
||||||
|
t.Fatalf("unexpected second error: %#v", secondResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) {
|
||||||
|
store := &storeStub{
|
||||||
|
credential: secretstore.Credential{
|
||||||
|
Host: "imap.example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
|
||||||
|
output := &bytes.Buffer{}
|
||||||
|
runner := NewRunner(New(store, serviceStub{
|
||||||
|
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||||
|
t.Fatal("ListMailboxes should not be called")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
|
||||||
|
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != defaultListMessagesLimit {
|
||||||
|
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
|
||||||
|
}
|
||||||
|
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
|
||||||
|
},
|
||||||
|
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||||
|
t.Fatal("GetMessage should not be called")
|
||||||
|
return imapclient.Message{}, nil
|
||||||
|
},
|
||||||
|
}), input, output, &bytes.Buffer{})
|
||||||
|
|
||||||
|
if err := runner.Run(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) {
|
||||||
|
store := &storeStub{
|
||||||
|
credential: secretstore.Credential{
|
||||||
|
Host: "imap.example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Password: "secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input := newBlockingReadCloser()
|
||||||
|
runner := NewRunner(New(store, serviceStub{
|
||||||
|
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
|
||||||
|
t.Fatal("ListMailboxes should not be called")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
|
||||||
|
t.Fatal("ListMessages should not be called")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
|
||||||
|
t.Fatal("GetMessage should not be called")
|
||||||
|
return imapclient.Message{}, nil
|
||||||
|
},
|
||||||
|
}), input, &bytes.Buffer{}, &bytes.Buffer{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- runner.Run(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-input.started:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("runner never started reading input")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("runner did not stop after context cancellation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !input.closed {
|
||||||
|
t.Fatal("expected runner to close input reader on cancellation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockingReadCloser struct {
|
||||||
|
started chan struct{}
|
||||||
|
closed bool
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingReadCloser() *blockingReadCloser {
|
||||||
|
return &blockingReadCloser{
|
||||||
|
started: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *blockingReadCloser) Read(_ []byte) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-r.started:
|
||||||
|
default:
|
||||||
|
close(r.started)
|
||||||
|
}
|
||||||
|
<-r.done
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *blockingReadCloser) Close() error {
|
||||||
|
if !r.closed {
|
||||||
|
r.closed = true
|
||||||
|
close(r.done)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue