From 3f75357c890678062d89c0511b0a83951481ac4f Mon Sep 17 00:00:00 2001 From: thibaud-leclere Date: Fri, 10 Apr 2026 10:31:24 +0200 Subject: [PATCH] fix: harden kwallet store boundary --- internal/secretstore/kwallet/store.go | 18 ++- internal/secretstore/kwallet/store_test.go | 123 ++++++++++++++++++++- 2 files changed, 132 insertions(+), 9 deletions(-) diff --git a/internal/secretstore/kwallet/store.go b/internal/secretstore/kwallet/store.go index 141dc87..3d8f974 100644 --- a/internal/secretstore/kwallet/store.go +++ b/internal/secretstore/kwallet/store.go @@ -2,6 +2,7 @@ package kwallet import ( "context" + "errors" "email-mcp/internal/secretstore" ) @@ -19,11 +20,19 @@ type Store struct { var _ secretstore.Store = (*Store)(nil) +var errNilClient = errors.New("kwallet client is nil") + func NewStore(client Client) *Store { return &Store{client: client} } func (s *Store) Save(ctx context.Context, key string, cred secretstore.Credential) error { + if err := cred.Validate(); err != nil { + return err + } + if s.client == nil { + return errNilClient + } if err := s.client.IsAvailable(ctx); err != nil { return err } @@ -31,15 +40,14 @@ func (s *Store) Save(ctx context.Context, key string, cred secretstore.Credentia return err } - data, err := secretstore.MarshalCredential(cred) - if err != nil { - return err - } - + data, _ := secretstore.MarshalCredential(cred) return s.client.WriteEntry(ctx, key, data) } func (s *Store) Load(ctx context.Context, key string) (secretstore.Credential, error) { + if s.client == nil { + return secretstore.Credential{}, errNilClient + } if err := s.client.IsAvailable(ctx); err != nil { return secretstore.Credential{}, err } diff --git a/internal/secretstore/kwallet/store_test.go b/internal/secretstore/kwallet/store_test.go index 5493388..bdbedd4 100644 --- a/internal/secretstore/kwallet/store_test.go +++ b/internal/secretstore/kwallet/store_test.go @@ -15,13 +15,17 @@ type walletClientStub struct { readErr error readValue []byte - openCalled bool - writeKey string - writeValue []byte - readKey string + isAvailableCalled bool + openCalled bool + writeCalled bool + readCalled bool + writeKey string + writeValue []byte + readKey string } func (c *walletClientStub) IsAvailable(context.Context) error { + c.isAvailableCalled = true return c.availableErr } @@ -31,16 +35,40 @@ func (c *walletClientStub) Open(context.Context) error { } func (c *walletClientStub) WriteEntry(_ context.Context, key string, value []byte) error { + c.writeCalled = true c.writeKey = key c.writeValue = value return c.writeErr } func (c *walletClientStub) ReadEntry(_ context.Context, key string) ([]byte, error) { + c.readCalled = true c.readKey = key return c.readValue, c.readErr } +func TestStoreSaveReturnsErrorWhenClientIsNil(t *testing.T) { + store := NewStore(nil) + + err := store.Save(context.Background(), secretstore.DefaultAccountKey, secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }) + if err == nil { + t.Fatal("expected error when client is nil") + } +} + +func TestStoreLoadReturnsErrorWhenClientIsNil(t *testing.T) { + store := NewStore(nil) + + _, err := store.Load(context.Background(), secretstore.DefaultAccountKey) + if err == nil { + t.Fatal("expected error when client is nil") + } +} + func TestStoreSaveWritesSerializedCredential(t *testing.T) { client := &walletClientStub{} store := NewStore(client) @@ -84,6 +112,58 @@ func TestStoreSaveReturnsAvailabilityErrorWithoutOpeningWallet(t *testing.T) { } } +func TestStoreSaveValidatesBeforeCheckingWalletAvailability(t *testing.T) { + client := &walletClientStub{} + store := NewStore(client) + + err := store.Save(context.Background(), secretstore.DefaultAccountKey, secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + }) + if err == nil { + t.Fatal("expected validation error") + } + if client.isAvailableCalled { + t.Fatal("did not expect availability check for invalid credential") + } + if client.openCalled || client.writeCalled { + t.Fatal("did not expect wallet calls for invalid credential") + } +} + +func TestStoreSaveReturnsOpenError(t *testing.T) { + wantErr := errors.New("open failed") + client := &walletClientStub{openErr: wantErr} + store := NewStore(client) + + err := store.Save(context.Background(), secretstore.DefaultAccountKey, secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }) + if !errors.Is(err, wantErr) { + t.Fatalf("expected open error, got %v", err) + } + if client.writeCalled { + t.Fatal("did not expect write when open fails") + } +} + +func TestStoreSaveReturnsWriteError(t *testing.T) { + wantErr := errors.New("write failed") + client := &walletClientStub{writeErr: wantErr} + store := NewStore(client) + + err := store.Save(context.Background(), secretstore.DefaultAccountKey, secretstore.Credential{ + Host: "imap.example.com", + Username: "alice", + Password: "secret", + }) + if !errors.Is(err, wantErr) { + t.Fatalf("expected write error, got %v", err) + } +} + func TestStoreLoadReadsAndDecodesCredential(t *testing.T) { payload, err := secretstore.MarshalCredential(secretstore.Credential{ Host: "imap.example.com", @@ -112,3 +192,38 @@ func TestStoreLoadReadsAndDecodesCredential(t *testing.T) { t.Fatalf("unexpected credential: %#v", cred) } } + +func TestStoreLoadReturnsOpenError(t *testing.T) { + wantErr := errors.New("open failed") + client := &walletClientStub{openErr: wantErr} + store := NewStore(client) + + _, err := store.Load(context.Background(), secretstore.DefaultAccountKey) + if !errors.Is(err, wantErr) { + t.Fatalf("expected open error, got %v", err) + } + if client.readCalled { + t.Fatal("did not expect read when open fails") + } +} + +func TestStoreLoadReturnsReadError(t *testing.T) { + wantErr := errors.New("read failed") + client := &walletClientStub{readErr: wantErr} + store := NewStore(client) + + _, err := store.Load(context.Background(), secretstore.DefaultAccountKey) + if !errors.Is(err, wantErr) { + t.Fatalf("expected read error, got %v", err) + } +} + +func TestStoreLoadReturnsDecodeError(t *testing.T) { + client := &walletClientStub{readValue: []byte("not-json")} + store := NewStore(client) + + _, err := store.Load(context.Background(), secretstore.DefaultAccountKey) + if err == nil { + t.Fatal("expected decode error") + } +}