diff --git a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift index 435f34f98..08cd1d0f6 100644 --- a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift +++ b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift @@ -26,6 +26,7 @@ final class SessionEngine { private var publishers = [AnyCancellable]() private let logger: ConsoleLogging private let sessionRequestsProvider: SessionRequestsProvider + private let invalidRequestsSanitiser: InvalidRequestsSanitiser init( networkingInteractor: NetworkInteracting, @@ -35,7 +36,8 @@ final class SessionEngine { kms: KeyManagementServiceProtocol, sessionStore: WCSessionStorage, logger: ConsoleLogging, - sessionRequestsProvider: SessionRequestsProvider + sessionRequestsProvider: SessionRequestsProvider, + invalidRequestsSanitiser: InvalidRequestsSanitiser ) { self.networkingInteractor = networkingInteractor self.historyService = historyService @@ -45,6 +47,7 @@ final class SessionEngine { self.sessionStore = sessionStore self.logger = logger self.sessionRequestsProvider = sessionRequestsProvider + self.invalidRequestsSanitiser = invalidRequestsSanitiser setupConnectionSubscriptions() setupRequestSubscriptions() @@ -52,8 +55,15 @@ final class SessionEngine { setupUpdateSubscriptions() setupExpirationSubscriptions() DispatchQueue.main.asyncAfter(deadline: .now() + 1) { [weak self] in - sessionRequestsProvider.emitRequestIfPending() + self?.sessionRequestsProvider.emitRequestIfPending() } + + removeInvalidSessionRequests() + } + + private func removeInvalidSessionRequests() { + let sessionTopics = Set(sessionStore.getAll().map(\.topic)) + invalidRequestsSanitiser.removeInvalidSessionRequests(validSessionTopics: sessionTopics) } func hasSession(for topic: String) -> Bool { diff --git a/Sources/WalletConnectSign/Services/HistoryService.swift b/Sources/WalletConnectSign/Services/HistoryService.swift index 6a3b8a01b..ddb44c1ab 100644 --- a/Sources/WalletConnectSign/Services/HistoryService.swift +++ b/Sources/WalletConnectSign/Services/HistoryService.swift @@ -1,6 +1,10 @@ import Foundation -final class HistoryService { +protocol HistoryServiceProtocol { + func getPendingRequests() -> [(request: Request, context: VerifyContext?)] +} + +final class HistoryService: HistoryServiceProtocol { private let history: RPCHistory private let verifyContextStore: CodableStore @@ -25,6 +29,8 @@ final class HistoryService { getPendingRequestsSortedByTimestamp() } + + func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] { let requests = history.getPending() .compactMap { mapRequestRecord($0) } @@ -80,3 +86,13 @@ private extension HistoryService { return (mappedRequest, record.id, record.timestamp) } } + +#if DEBUG +class MockHistoryService: HistoryServiceProtocol { + var pendingRequests: [(request: Request, context: VerifyContext?)] = [] + + func getPendingRequests() -> [(request: Request, context: VerifyContext?)] { + return pendingRequests + } +} +#endif diff --git a/Sources/WalletConnectSign/Services/InvalidRequestsSanitiser.swift b/Sources/WalletConnectSign/Services/InvalidRequestsSanitiser.swift new file mode 100644 index 000000000..31793769b --- /dev/null +++ b/Sources/WalletConnectSign/Services/InvalidRequestsSanitiser.swift @@ -0,0 +1,20 @@ + +import Foundation + +final class InvalidRequestsSanitiser { + private let historyService: HistoryServiceProtocol + private let history: RPCHistoryProtocol + + init(historyService: HistoryServiceProtocol, history: RPCHistoryProtocol) { + self.historyService = historyService + self.history = history + } + + func removeInvalidSessionRequests(validSessionTopics: Set) { + let pendingRequests = historyService.getPendingRequests() + let invalidTopics = Set(pendingRequests.map { $0.request.topic }).subtracting(validSessionTopics) + if !invalidTopics.isEmpty { + history.deleteAll(forTopics: Array(invalidTopics)) + } + } +} diff --git a/Sources/WalletConnectSign/Sign/SignClientFactory.swift b/Sources/WalletConnectSign/Sign/SignClientFactory.swift index bfbdcc6b9..33143fb43 100644 --- a/Sources/WalletConnectSign/Sign/SignClientFactory.swift +++ b/Sources/WalletConnectSign/Sign/SignClientFactory.swift @@ -61,7 +61,8 @@ public struct SignClientFactory { let historyService = HistoryService(history: rpcHistory, verifyContextStore: verifyContextStore) let verifyClient = VerifyClientFactory.create() let sessionRequestsProvider = SessionRequestsProvider(historyService: historyService) - let sessionEngine = SessionEngine(networkingInteractor: networkingClient, historyService: historyService, verifyContextStore: verifyContextStore, verifyClient: verifyClient, kms: kms, sessionStore: sessionStore, logger: logger, sessionRequestsProvider: sessionRequestsProvider) + let invalidRequestsSanitiser = InvalidRequestsSanitiser(historyService: historyService, history: rpcHistory) + let sessionEngine = SessionEngine(networkingInteractor: networkingClient, historyService: historyService, verifyContextStore: verifyContextStore, verifyClient: verifyClient, kms: kms, sessionStore: sessionStore, logger: logger, sessionRequestsProvider: sessionRequestsProvider, invalidRequestsSanitiser: invalidRequestsSanitiser) let nonControllerSessionStateMachine = NonControllerSessionStateMachine(networkingInteractor: networkingClient, kms: kms, sessionStore: sessionStore, logger: logger) let controllerSessionStateMachine = ControllerSessionStateMachine(networkingInteractor: networkingClient, kms: kms, sessionStore: sessionStore, logger: logger) let sessionExtendRequester = SessionExtendRequester(sessionStore: sessionStore, networkingInteractor: networkingClient) diff --git a/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift b/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift index ff6bda280..7060c89ae 100644 --- a/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift +++ b/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift @@ -1,6 +1,10 @@ import Foundation -public final class RPCHistory { +public protocol RPCHistoryProtocol { + func deleteAll(forTopics topics: [String]) +} + +public final class RPCHistory: RPCHistoryProtocol { public struct Record: Codable { public enum Origin: String, Codable { @@ -144,3 +148,13 @@ extension RPCHistory { } } } + +#if DEBUG +class MockRPCHistory: RPCHistoryProtocol { + var deletedTopics: [String] = [] + + func deleteAll(forTopics topics: [String]) { + deletedTopics.append(contentsOf: topics) + } +} +#endif diff --git a/Tests/WalletConnectSignTests/InvalidRequestsSanitiserTests.swift b/Tests/WalletConnectSignTests/InvalidRequestsSanitiserTests.swift new file mode 100644 index 000000000..dda89b064 --- /dev/null +++ b/Tests/WalletConnectSignTests/InvalidRequestsSanitiserTests.swift @@ -0,0 +1,69 @@ +import XCTest +@testable import WalletConnectSign +@testable import WalletConnectUtils + +class InvalidRequestsSanitiserTests: XCTestCase { + var sanitiser: InvalidRequestsSanitiser! + var mockHistoryService: MockHistoryService! + var mockRPCHistory: MockRPCHistory! + + override func setUp() { + super.setUp() + mockHistoryService = MockHistoryService() + mockRPCHistory = MockRPCHistory() + sanitiser = InvalidRequestsSanitiser(historyService: mockHistoryService, history: mockRPCHistory) + } + + override func tearDown() { + sanitiser = nil + mockHistoryService = nil + mockRPCHistory = nil + super.tearDown() + } + + func testRemoveInvalidSessionRequests_noPendingRequests() { + let validSessionTopics: Set = ["validTopic1", "validTopic2"] + + sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics) + + XCTAssertTrue(mockRPCHistory.deletedTopics.isEmpty) + } + + func testRemoveInvalidSessionRequests_allRequestsValid() { + let validSessionTopics: Set = ["validTopic1", "validTopic2"] + mockHistoryService.pendingRequests = [ + (request: try! Request(topic: "validTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil), + (request: try! Request(topic: "validTopic2", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil) + ] + + sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics) + + XCTAssertTrue(mockRPCHistory.deletedTopics.isEmpty) + } + + func testRemoveInvalidSessionRequests_someRequestsInvalid() { + let validSessionTopics: Set = ["validTopic1", "validTopic2"] + mockHistoryService.pendingRequests = [ + (request: try! Request(topic: "validTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil), + (request: try! Request(topic: "invalidTopic1", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil), + (request: try! Request(topic: "invalidTopic2", method: "method3", params: AnyCodable("params3"), chainId: Blockchain("eip155:1")!), context: nil) + ] + + sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics) + + XCTAssertEqual(mockRPCHistory.deletedTopics.sorted(), ["invalidTopic1", "invalidTopic2"]) + } + + func testRemoveInvalidSessionRequests_withEmptyValidSessionTopics() { + let validSessionTopics: Set = [] + + mockHistoryService.pendingRequests = [ + (request: try! Request(topic: "invalidTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil), + (request: try! Request(topic: "invalidTopic2", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil) + ] + + sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics) + + XCTAssertEqual(mockRPCHistory.deletedTopics.sorted(), ["invalidTopic1", "invalidTopic2"]) + } +} diff --git a/Tests/WalletConnectSignTests/SessionEngineTests.swift b/Tests/WalletConnectSignTests/SessionEngineTests.swift index e3a42232e..a6854c5bc 100644 --- a/Tests/WalletConnectSignTests/SessionEngineTests.swift +++ b/Tests/WalletConnectSignTests/SessionEngineTests.swift @@ -13,33 +13,29 @@ final class SessionEngineTests: XCTestCase { override func setUp() { networkingInteractor = NetworkingInteractorMock() sessionStorage = WCSessionStorageMock() + let defaults = RuntimeKeyValueStorage() + let rpcHistory = RPCHistory( + keyValueStore: .init( + defaults: defaults, + identifier: "" + ) + ) verifyContextStore = CodableStore(defaults: RuntimeKeyValueStorage(), identifier: "") + let historyService = HistoryService( + history: rpcHistory, + verifyContextStore: verifyContextStore + ) engine = SessionEngine( networkingInteractor: networkingInteractor, - historyService: HistoryService( - history: RPCHistory( - keyValueStore: .init( - defaults: RuntimeKeyValueStorage(), - identifier: "" - ) - ), - verifyContextStore: verifyContextStore - ), + historyService: historyService, verifyContextStore: verifyContextStore, verifyClient: VerifyClientMock(), kms: KeyManagementServiceMock(), sessionStore: sessionStorage, logger: ConsoleLoggerMock(), sessionRequestsProvider: SessionRequestsProvider( - historyService: HistoryService( - history: RPCHistory( - keyValueStore: .init( - defaults: RuntimeKeyValueStorage(), - identifier: "" - ) - ), - verifyContextStore: verifyContextStore - )) + historyService: historyService), + invalidRequestsSanitiser: InvalidRequestsSanitiser(historyService: historyService, history: rpcHistory) ) }