email-mcp/internal/mcpserver/server_test.go
2026-04-10 12:10:42 +02:00

475 lines
15 KiB
Go

package mcpserver
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"testing"
"time"
"email-mcp/internal/imapclient"
"email-mcp/internal/secretstore"
"email-mcp/internal/secretstore/kwallet"
)
type storeStub struct {
credential secretstore.Credential
loadErr error
loadCalls int
loadKey string
}
func (s *storeStub) Save(context.Context, string, secretstore.Credential) error {
return nil
}
func (s *storeStub) Load(_ context.Context, key string) (secretstore.Credential, error) {
s.loadCalls++
s.loadKey = key
if s.loadErr != nil {
return secretstore.Credential{}, s.loadErr
}
return s.credential, nil
}
type serviceStub struct {
listMailboxes func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error)
listMessages func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error)
getMessage func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error)
}
func (s serviceStub) ListMailboxes(ctx context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) {
return s.listMailboxes(ctx, cred)
}
func (s serviceStub) ListMessages(ctx context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
return s.listMessages(ctx, cred, mailbox, limit)
}
func (s serviceStub) GetMessage(ctx context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) {
return s.getMessage(ctx, cred, mailbox, uid)
}
func TestServerListMailboxesLoadsCredentialAndDelegates(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
server := New(store, serviceStub{
listMailboxes: func(_ context.Context, cred secretstore.Credential) ([]imapclient.Mailbox, error) {
if cred.Host != "imap.example.com" || cred.Username != "alice" || cred.Password != "secret" {
t.Fatalf("unexpected credential: %#v", cred)
}
return []imapclient.Mailbox{{Name: "INBOX"}}, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
})
result, err := server.ListMailboxes(context.Background())
if err != nil {
t.Fatalf("ListMailboxes returned error: %v", err)
}
if store.loadCalls != 1 {
t.Fatalf("expected credential to be loaded once, got %d", store.loadCalls)
}
if store.loadKey != secretstore.DefaultAccountKey {
t.Fatalf("expected load key %q, got %q", secretstore.DefaultAccountKey, store.loadKey)
}
if len(result) != 1 || result[0].Name != "INBOX" {
t.Fatalf("unexpected result: %#v", result)
}
}
func TestServerListMessagesLoadsCredentialAndDelegates(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
server := New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != 5 {
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
}
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
})
result, err := server.ListMessages(context.Background(), "INBOX", 5)
if err != nil {
t.Fatalf("ListMessages returned error: %v", err)
}
if len(result) != 1 || result[0].UID != 42 {
t.Fatalf("unexpected result: %#v", result)
}
}
func TestServerGetMessageUsesUIDContract(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
server := New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(_ context.Context, cred secretstore.Credential, mailbox string, uid uint32) (imapclient.Message, error) {
if cred.Host != "imap.example.com" || mailbox != "INBOX" || uid != 42 {
t.Fatalf("unexpected call: cred=%#v mailbox=%q uid=%d", cred, mailbox, uid)
}
return imapclient.Message{
UID: 42,
Mailbox: "INBOX",
Body: "body",
}, nil
},
})
message, err := server.GetMessage(context.Background(), "INBOX", 42)
if err != nil {
t.Fatalf("GetMessage returned error: %v", err)
}
if message.UID != 42 || message.Mailbox != "INBOX" {
t.Fatalf("unexpected message: %#v", message)
}
}
func TestRunnerRunWritesToolManifestAndHandlesRequests(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":5}}\n")
output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != 5 {
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
}
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, output, &bytes.Buffer{})
if err := runner.Run(context.Background()); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if store.loadCalls != 1 {
t.Fatalf("expected credential preload once, got %d", store.loadCalls)
}
decoder := json.NewDecoder(output)
var manifest struct {
Tools []struct {
Name string `json:"name"`
} `json:"tools"`
}
if err := decoder.Decode(&manifest); err != nil {
t.Fatalf("failed to decode manifest: %v", err)
}
if len(manifest.Tools) != 3 {
t.Fatalf("expected 3 tools, got %#v", manifest.Tools)
}
if manifest.Tools[0].Name != "list_mailboxes" || manifest.Tools[1].Name != "list_messages" || manifest.Tools[2].Name != "get_message" {
t.Fatalf("unexpected tool manifest: %#v", manifest.Tools)
}
var response struct {
Result []imapclient.MessageSummary `json:"result"`
}
if err := decoder.Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(response.Result) != 1 || response.Result[0].UID != 42 {
t.Fatalf("unexpected response: %#v", response.Result)
}
}
func TestRunnerRunReturnsFriendlyMissingCredentialError(t *testing.T) {
store := &storeStub{
loadErr: kwallet.ErrCredentialNotFound,
}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), bytes.NewBuffer(nil), &bytes.Buffer{}, &bytes.Buffer{})
err := runner.Run(context.Background())
if !errors.Is(err, ErrCredentialsNotConfigured) {
t.Fatalf("expected missing credential error, got %v", err)
}
}
func TestServerToolsAdvertiseValidatedArgumentContracts(t *testing.T) {
tools := New(&storeStub{}, serviceStub{}).Tools()
if len(tools) != 3 {
t.Fatalf("expected 3 tools, got %d", len(tools))
}
listMessages := tools[1]
if listMessages.Name != "list_messages" {
t.Fatalf("unexpected tool ordering: %#v", tools)
}
if got := listMessages.InputSchema["type"]; got != "object" {
t.Fatalf("expected object schema, got %#v", got)
}
listProps, ok := listMessages.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatalf("expected properties map, got %#v", listMessages.InputSchema["properties"])
}
limitSchema, ok := listProps["limit"].(map[string]any)
if !ok {
t.Fatalf("expected limit schema, got %#v", listProps["limit"])
}
if got := limitSchema["default"]; got != float64(defaultListMessagesLimit) && got != defaultListMessagesLimit {
t.Fatalf("expected limit default %d, got %#v", defaultListMessagesLimit, got)
}
if got := limitSchema["minimum"]; got != float64(1) && got != 1 {
t.Fatalf("expected limit minimum 1, got %#v", got)
}
if got := limitSchema["maximum"]; got != float64(maxListMessagesLimit) && got != maxListMessagesLimit {
t.Fatalf("expected limit maximum %d, got %#v", maxListMessagesLimit, got)
}
getMessage := tools[2]
if getMessage.Name != "get_message" {
t.Fatalf("unexpected tool ordering: %#v", tools)
}
getProps, ok := getMessage.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatalf("expected get_message properties map, got %#v", getMessage.InputSchema["properties"])
}
uidSchema, ok := getProps["uid"].(map[string]any)
if !ok {
t.Fatalf("expected uid schema, got %#v", getProps["uid"])
}
if got := uidSchema["minimum"]; got != float64(1) && got != 1 {
t.Fatalf("expected uid minimum 1, got %#v", got)
}
}
func TestRunnerRunReturnsValidationErrorsForInvalidRequests(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\",\"limit\":0}}\n{\"tool\":\"get_message\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, output, &bytes.Buffer{})
if err := runner.Run(context.Background()); err != nil {
t.Fatalf("Run returned error: %v", err)
}
decoder := json.NewDecoder(output)
if err := decoder.Decode(&struct {
Tools []Tool `json:"tools"`
}{}); err != nil {
t.Fatalf("failed to decode manifest: %v", err)
}
var firstResponse struct {
Error string `json:"error"`
}
if err := decoder.Decode(&firstResponse); err != nil {
t.Fatalf("failed to decode first error response: %v", err)
}
if firstResponse.Error != "limit must be between 1 and 50" {
t.Fatalf("unexpected first error: %#v", firstResponse)
}
var secondResponse struct {
Error string `json:"error"`
}
if err := decoder.Decode(&secondResponse); err != nil {
t.Fatalf("failed to decode second error response: %v", err)
}
if secondResponse.Error != "uid is required" {
t.Fatalf("unexpected second error: %#v", secondResponse)
}
}
func TestRunnerRunAppliesDefaultLimitWhenOmitted(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := bytes.NewBufferString("{\"tool\":\"list_messages\",\"arguments\":{\"mailbox\":\"INBOX\"}}\n")
output := &bytes.Buffer{}
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(_ context.Context, cred secretstore.Credential, mailbox string, limit int) ([]imapclient.MessageSummary, error) {
if cred.Host != "imap.example.com" || mailbox != "INBOX" || limit != defaultListMessagesLimit {
t.Fatalf("unexpected call: cred=%#v mailbox=%q limit=%d", cred, mailbox, limit)
}
return []imapclient.MessageSummary{{UID: 42, Subject: "hello", From: "alice@example.com"}}, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, output, &bytes.Buffer{})
if err := runner.Run(context.Background()); err != nil {
t.Fatalf("Run returned error: %v", err)
}
}
func TestRunnerRunStopsWhenContextCanceledWhileWaitingForInput(t *testing.T) {
store := &storeStub{
credential: secretstore.Credential{
Host: "imap.example.com",
Username: "alice",
Password: "secret",
},
}
input := newBlockingReadCloser()
runner := NewRunner(New(store, serviceStub{
listMailboxes: func(context.Context, secretstore.Credential) ([]imapclient.Mailbox, error) {
t.Fatal("ListMailboxes should not be called")
return nil, nil
},
listMessages: func(context.Context, secretstore.Credential, string, int) ([]imapclient.MessageSummary, error) {
t.Fatal("ListMessages should not be called")
return nil, nil
},
getMessage: func(context.Context, secretstore.Credential, string, uint32) (imapclient.Message, error) {
t.Fatal("GetMessage should not be called")
return imapclient.Message{}, nil
},
}), input, &bytes.Buffer{}, &bytes.Buffer{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
done <- runner.Run(ctx)
}()
select {
case <-input.started:
case <-time.After(200 * time.Millisecond):
t.Fatal("runner never started reading input")
}
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context cancellation, got %v", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("runner did not stop after context cancellation")
}
if !input.closed {
t.Fatal("expected runner to close input reader on cancellation")
}
}
type blockingReadCloser struct {
started chan struct{}
closed bool
done chan struct{}
}
func newBlockingReadCloser() *blockingReadCloser {
return &blockingReadCloser{
started: make(chan struct{}),
done: make(chan struct{}),
}
}
func (r *blockingReadCloser) Read(_ []byte) (int, error) {
select {
case <-r.started:
default:
close(r.started)
}
<-r.done
return 0, io.EOF
}
func (r *blockingReadCloser) Close() error {
if !r.closed {
r.closed = true
close(r.done)
}
return nil
}