diff --git a/Client/SafeBrowsing/GSafeBrowsing.swift b/Client/SafeBrowsing/GSafeBrowsing.swift index 10db9770bc7..f9de3365602 100644 --- a/Client/SafeBrowsing/GSafeBrowsing.swift +++ b/Client/SafeBrowsing/GSafeBrowsing.swift @@ -104,17 +104,23 @@ class SafeBrowsingClient { do { let body = FetchRequest(client: clientInfo, listUpdateRequests: lists) let request = try encode(.post, endpoint: .fetch, body: body) - executeRequest(request, type: FetchResponse.self) { response, error in + executeRequest(request, type: FetchResponse.self) { [weak self] response, error in + guard let self = self else { return } + if let error = error { return completion(error) } if let response = response { + var didError = false self.database.update(response, completion: { if let error = $0 { log.error("Safe-Browsing: Error Updating Database: \(error)") + didError = true } }) + + return completion(didError ? SafeBrowsingError("Safe-Browsing: Error Updating Database") : nil) } completion(nil) diff --git a/Client/SafeBrowsing/GSafeBrowsingDatabase.swift b/Client/SafeBrowsing/GSafeBrowsingDatabase.swift index 994993e0070..d66e3faf060 100644 --- a/Client/SafeBrowsing/GSafeBrowsingDatabase.swift +++ b/Client/SafeBrowsing/GSafeBrowsingDatabase.swift @@ -82,7 +82,10 @@ class SafeBrowsingDatabase { }() private func save() { - let context = mainContext + saveContext(mainContext) + } + + private func saveContext(_ context: NSManagedObjectContext) { if context.hasChanges { do { try context.save() @@ -95,10 +98,14 @@ class SafeBrowsingDatabase { func getState(_ type: ThreatType) -> String { dbLock.lock(); defer { dbLock.unlock() } - let request: NSFetchRequest = Threat.fetchRequest() - request.fetchLimit = 1 - request.predicate = NSPredicate(format: "threatType == %@", type.rawValue) - return (try? mainContext.fetch(request))?.first?.state ?? "" + var state = "" + backgroundContext.performAndWait { + let request: NSFetchRequest = Threat.fetchRequest() + request.fetchLimit = 1 + request.predicate = NSPredicate(format: "threatType == %@", type.rawValue) + state = (try? backgroundContext.fetch(request))?.first?.state ?? "" + } + return state } func find(_ hash: String) -> [String] { @@ -110,14 +117,16 @@ class SafeBrowsingDatabase { return [] //ERROR Hash must be a full hash.. } - let request: NSFetchRequest = ThreatHash.fetchRequest() - let threatHashes = (try? mainContext.fetch(request)) ?? [] - - for threat in threatHashes { - let n = threat.matches(data) - if n > 0 { - let hashPrefix = data.subdata(in: 0.. = ThreatHash.fetchRequest() + let threatHashes = (try? backgroundContext.fetch(request)) ?? [] + + for threat in threatHashes { + let n = threat.matches(data) + if n > 0 { + let hashPrefix = data.subdata(in: 0.. = Threat.fetchRequest() - let count = (try? mainContext.count(for: request)) ?? 0 + var count = 0 + backgroundContext.performAndWait { + let request: NSFetchRequest = Threat.fetchRequest() + count = (try? backgroundContext.count(for: request)) ?? 0 + } + if count == 0 { return completion(SafeBrowsingError("Partial Update received for non-existent table")) } @@ -149,31 +162,48 @@ class SafeBrowsingDatabase { return completion(SafeBrowsingError("Indices to be removed included in a Full Update")) } - let request: NSFetchRequest = Threat.fetchRequest() - request.predicate = NSPredicate(format: "threatType == %@", response.threatType.rawValue) - (try? mainContext.fetch(request))?.forEach({ - self.mainContext.delete($0) - }) - - if mainContext.hasChanges { - self.save() + backgroundContext.performAndWait { + let request = { () -> NSBatchDeleteRequest in + let request: NSFetchRequest = Threat.fetchRequest() + request.predicate = NSPredicate(format: "threatType == %@", response.threatType.rawValue) + + let deleteRequest = NSBatchDeleteRequest(fetchRequest: request) + deleteRequest.resultType = .resultTypeObjectIDs + return deleteRequest + }() + + if let result = (try? backgroundContext.execute(request)) as? NSBatchDeleteResult { + if let deletedObjects = result.result as? [NSManagedObjectID] { + NSManagedObjectContext.mergeChanges( + fromRemoteContextSave: [NSDeletedObjectsKey: deletedObjects], + into: [mainContext, backgroundContext] + ) + } + } else { + let request: NSFetchRequest = Threat.fetchRequest() + request.predicate = NSPredicate(format: "threatType == %@", response.threatType.rawValue) + + (try? backgroundContext.fetch(request))?.forEach({ + backgroundContext.delete($0) + }) + } + + saveContext(backgroundContext) + backgroundContext.reset() } default: return completion(SafeBrowsingError("Unknown Response Type")) } - self.backgroundContext.performAndWait { [weak self] in - guard let self = self else { return } - let context = self.backgroundContext - + backgroundContext.performAndWait { let request: NSFetchRequest = Threat.fetchRequest() request.predicate = NSPredicate(format: "threatType == %@", response.threatType.rawValue) request.fetchLimit = 1 - var threat: Threat! = (try? context.fetch(request))?.first + var threat: Threat! = (try? backgroundContext.fetch(request))?.first if threat == nil { - threat = Threat(context: context) + threat = Threat(context: backgroundContext) } threat.threatType = response.threatType.rawValue @@ -217,12 +247,12 @@ class SafeBrowsingDatabase { threat.hashes?.forEach({ if let hash = $0 as? ThreatHash { - context.delete(hash) + backgroundContext.delete(hash) } }) threat.hashes = NSSet(array: hashes.map({ - let hash = ThreatHash(context: context) + let hash = ThreatHash(context: backgroundContext) hash.hashData = $0 return hash })) @@ -232,19 +262,12 @@ class SafeBrowsingDatabase { } if !validate(response.checksum.sha256, hashes) { - context.rollback() + backgroundContext.rollback() return completion(SafeBrowsingError("Threat List Checksum Mismatch")) } - if context.hasChanges { - do { - try context.save() - } catch { - completion(error) - } - - context.reset() - } + saveContext(backgroundContext) + backgroundContext.reset() } }) }