diff --git a/libraries/apollo-normalized-cache/src/commonMain/kotlin/com/apollographql/apollo3/cache/normalized/internal/ApolloCacheInterceptor.kt b/libraries/apollo-normalized-cache/src/commonMain/kotlin/com/apollographql/apollo3/cache/normalized/internal/ApolloCacheInterceptor.kt index c6e0727863f..02774e0015c 100644 --- a/libraries/apollo-normalized-cache/src/commonMain/kotlin/com/apollographql/apollo3/cache/normalized/internal/ApolloCacheInterceptor.kt +++ b/libraries/apollo-normalized-cache/src/commonMain/kotlin/com/apollographql/apollo3/cache/normalized/internal/ApolloCacheInterceptor.kt @@ -28,20 +28,24 @@ import com.apollographql.apollo3.exception.apolloExceptionHandler import com.apollographql.apollo3.interceptor.ApolloInterceptor import com.apollographql.apollo3.interceptor.ApolloInterceptorChain import com.apollographql.apollo3.mpp.currentTimeMillis +import kotlinx.coroutines.Job import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.flatMapMerge import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch internal class ApolloCacheInterceptor( val store: ApolloStore, ) : ApolloInterceptor, ApolloStoreInterceptor { - private suspend fun maybeAsync(request: ApolloRequest, block: suspend () -> Unit) { + private suspend fun maybeAsync(request: ApolloRequest, block: suspend () -> Unit): Job { if (request.writeToCacheAsynchronously) { val scope = request.executionContext[ConcurrencyInfo]!!.coroutineScope - scope.launch { + return scope.launch { try { block() } catch (e: Throwable) { @@ -50,6 +54,7 @@ internal class ApolloCacheInterceptor( } } else { block() + return CompletedJob() } } @@ -61,18 +66,18 @@ internal class ApolloCacheInterceptor( response: ApolloResponse, customScalarAdapters: CustomScalarAdapters, extraKeys: Set = emptySet(), - ) { + ): Job { if (request.doNotStore) { - return + return CompletedJob() } if (response.data == null) { - return + return CompletedJob() } if (response.hasErrors() && !request.storePartialResponses) { - return + return CompletedJob() } - maybeAsync(request) { + return maybeAsync(request) { val cacheKeys = if (response.data != null) { var cacheHeaders = request.cacheHeaders + response.cacheHeaders if (request.storeReceiveDate) { @@ -118,9 +123,12 @@ internal class ApolloCacheInterceptor( chain: ApolloInterceptorChain, ): Flow> { val customScalarAdapters = request.customScalarAdapters + val cacheWrites = mutableListOf() return chain.proceed(request).onEach { maybeWriteToCache(request, it, customScalarAdapters) + }.onCompletion { + cacheWrites.joinAll() } } @@ -169,8 +177,9 @@ internal class ApolloCacheInterceptor( emptySet() } - maybeWriteToCache(request, response, customScalarAdapters, optimisticKeys!!) + val cacheWriteJob = maybeWriteToCache(request, response, customScalarAdapters, optimisticKeys!!) emit(response) + cacheWriteJob.join() } if (networkException != null) { @@ -253,8 +262,10 @@ internal class ApolloCacheInterceptor( customScalarAdapters: CustomScalarAdapters, ): Flow> { val startMillis = currentTimeMillis() + val cacheWrites = mutableListOf() + return chain.proceed(request).onEach { - maybeWriteToCache(request, it, customScalarAdapters) + cacheWrites += maybeWriteToCache(request, it, customScalarAdapters) }.map { networkResponse -> networkResponse.newBuilder() .cacheInfo( @@ -264,6 +275,8 @@ internal class ApolloCacheInterceptor( .networkException(networkResponse.exception) .build() ).build() + }.onCompletion { + cacheWrites.joinAll() } } @@ -271,5 +284,7 @@ internal class ApolloCacheInterceptor( private fun nowDateCacheHeaders(): CacheHeaders { return CacheHeaders.Builder().addHeader(ApolloCacheHeaders.DATE, (currentTimeMillis() / 1000).toString()).build() } + + private fun CompletedJob() = Job().apply { complete() } } }