diff --git a/internal/secretstore/kwallet/client.go b/internal/secretstore/kwallet/client.go index 55b7d54..f1228cb 100644 --- a/internal/secretstore/kwallet/client.go +++ b/internal/secretstore/kwallet/client.go @@ -184,7 +184,7 @@ func (s *walletSession) writeEntry(ctx context.Context, key string, value []byte code, err := s.callInt32(ctx, object, "writeEntry", s.handle, kwalletFolderName, key, value, kwalletAppID) if err != nil { - return wrapUnavailable("kwallet write failed", err) + return s.wrapUnavailable("kwallet write failed", err) } if code != 0 { return fmt.Errorf("kwallet write failed with code %d", code) @@ -204,7 +204,7 @@ func (s *walletSession) readEntry(ctx context.Context, key string) ([]byte, erro hasEntry, err := s.callBool(ctx, object, "hasEntry", s.handle, kwalletFolderName, key, kwalletAppID) if err != nil { - return nil, wrapUnavailable("kwallet entry lookup failed", err) + return nil, s.wrapUnavailable("kwallet entry lookup failed", err) } if !hasEntry { return nil, fmt.Errorf("%w: key %q", ErrCredentialNotFound, key) @@ -212,7 +212,7 @@ func (s *walletSession) readEntry(ctx context.Context, key string) ([]byte, erro value, err := s.callBytes(ctx, object, "readEntry", s.handle, kwalletFolderName, key, kwalletAppID) if err != nil { - return nil, wrapUnavailable("kwallet read failed", err) + return nil, s.wrapUnavailable("kwallet read failed", err) } return value, nil } @@ -305,6 +305,12 @@ func (s *walletSession) callBytes(ctx context.Context, object dbusObject, name s return value, nil } +func (s *walletSession) reset() { + s.object = nil + s.handle = 0 + s.opened = false +} + func kwalletMethod(name string) string { return kwalletInterface + "." + name } @@ -322,8 +328,12 @@ func (e *typedError) Unwrap() []error { return e.errs } -func wrapUnavailable(message string, err error) error { - if err == nil || errors.Is(err, ErrKWalletUnavailable) { +func (s *walletSession) wrapUnavailable(message string, err error) error { + if err == nil { + return err + } + s.reset() + if errors.Is(err, ErrKWalletUnavailable) { return err } return &typedError{ diff --git a/internal/secretstore/kwallet/client_test.go b/internal/secretstore/kwallet/client_test.go index e2e846e..f3ce8dc 100644 --- a/internal/secretstore/kwallet/client_test.go +++ b/internal/secretstore/kwallet/client_test.go @@ -45,6 +45,22 @@ func (c *stubConnection) Object(dest string, path dbus.ObjectPath) dbusObject { return c.objects[dest+"|"+string(path)] } +type rotatingConnection struct { + objects []dbusObject + index int +} + +func (c *rotatingConnection) Object(string, dbus.ObjectPath) dbusObject { + if len(c.objects) == 0 { + return nil + } + object := c.objects[c.index] + if c.index < len(c.objects)-1 { + c.index++ + } + return object +} + func TestClientIsAvailableReturnsErrorWhenServiceIsMissing(t *testing.T) { client := newClientImpl(newWalletSession(func() (dbusConnection, error) { return &stubConnection{ @@ -294,3 +310,107 @@ func TestClientReadEntryMapsTransportFailuresToUnavailable(t *testing.T) { t.Fatalf("expected wrapped transport error, got %v", err) } } + +func TestClientWriteEntryReopensAfterMappedTransportFailure(t *testing.T) { + firstObject := &stubObject{ + responses: map[string][]stubCall{ + kwalletMethod("isEnabled"): {{body: []any{true}}}, + kwalletMethod("networkWallet"): {{body: []any{"kdewallet"}}}, + kwalletMethod("open"): {{body: []any{int32(42)}}}, + kwalletMethod("hasFolder"): {{body: []any{true}}}, + kwalletMethod("writeEntry"): {{err: errors.New("transport closed")}}, + }, + } + secondObject := &stubObject{ + responses: map[string][]stubCall{ + kwalletMethod("isEnabled"): {{body: []any{true}}}, + kwalletMethod("networkWallet"): {{body: []any{"kdewallet"}}}, + kwalletMethod("open"): {{body: []any{int32(43)}}}, + kwalletMethod("hasFolder"): {{body: []any{true}}}, + kwalletMethod("writeEntry"): {{body: []any{int32(0)}}}, + }, + } + client := newClientImpl(newWalletSession(func() (dbusConnection, error) { + return &rotatingConnection{objects: []dbusObject{firstObject, secondObject}}, nil + })) + + err := client.WriteEntry(context.Background(), "default", []byte("payload")) + if !errors.Is(err, ErrKWalletUnavailable) { + t.Fatalf("expected unavailable error, got %v", err) + } + + if err := client.WriteEntry(context.Background(), "default", []byte("payload")); err != nil { + t.Fatalf("expected retry to succeed, got %v", err) + } + + firstOpenCalls := 0 + for _, call := range firstObject.calls { + if call.method == kwalletMethod("open") { + firstOpenCalls++ + } + } + secondOpenCalls := 0 + for _, call := range secondObject.calls { + if call.method == kwalletMethod("open") { + secondOpenCalls++ + } + } + if firstOpenCalls != 1 || secondOpenCalls != 1 { + t.Fatalf("expected reopen on retry, got first=%d second=%d", firstOpenCalls, secondOpenCalls) + } +} + +func TestClientReadEntryReopensAfterMappedTransportFailure(t *testing.T) { + firstObject := &stubObject{ + responses: map[string][]stubCall{ + kwalletMethod("isEnabled"): {{body: []any{true}}}, + kwalletMethod("networkWallet"): {{body: []any{"kdewallet"}}}, + kwalletMethod("open"): {{body: []any{int32(42)}}}, + kwalletMethod("hasFolder"): {{body: []any{true}}}, + kwalletMethod("hasEntry"): {{body: []any{true}}}, + kwalletMethod("readEntry"): {{err: errors.New("transport closed")}}, + }, + } + secondObject := &stubObject{ + responses: map[string][]stubCall{ + kwalletMethod("isEnabled"): {{body: []any{true}}}, + kwalletMethod("networkWallet"): {{body: []any{"kdewallet"}}}, + kwalletMethod("open"): {{body: []any{int32(43)}}}, + kwalletMethod("hasFolder"): {{body: []any{true}}}, + kwalletMethod("hasEntry"): {{body: []any{true}}}, + kwalletMethod("readEntry"): {{body: []any{[]byte("payload")}}}, + }, + } + client := newClientImpl(newWalletSession(func() (dbusConnection, error) { + return &rotatingConnection{objects: []dbusObject{firstObject, secondObject}}, nil + })) + + _, err := client.ReadEntry(context.Background(), "default") + if !errors.Is(err, ErrKWalletUnavailable) { + t.Fatalf("expected unavailable error, got %v", err) + } + + value, err := client.ReadEntry(context.Background(), "default") + if err != nil { + t.Fatalf("expected retry to succeed, got %v", err) + } + if string(value) != "payload" { + t.Fatalf("unexpected retry payload: %q", value) + } + + firstOpenCalls := 0 + for _, call := range firstObject.calls { + if call.method == kwalletMethod("open") { + firstOpenCalls++ + } + } + secondOpenCalls := 0 + for _, call := range secondObject.calls { + if call.method == kwalletMethod("open") { + secondOpenCalls++ + } + } + if firstOpenCalls != 1 || secondOpenCalls != 1 { + t.Fatalf("expected reopen on retry, got first=%d second=%d", firstOpenCalls, secondOpenCalls) + } +}