diff --git a/mvrx/src/main/kotlin/com/airbnb/mvrx/MvRxExtensions.kt b/mvrx/src/main/kotlin/com/airbnb/mvrx/MvRxExtensions.kt index dc43a4863..327cf86fc 100644 --- a/mvrx/src/main/kotlin/com/airbnb/mvrx/MvRxExtensions.kt +++ b/mvrx/src/main/kotlin/com/airbnb/mvrx/MvRxExtensions.kt @@ -1,10 +1,10 @@ package com.airbnb.mvrx -import androidx.lifecycle.ViewModelProviders import androidx.annotation.RestrictTo import androidx.annotation.RestrictTo.Scope.LIBRARY import androidx.fragment.app.Fragment import androidx.fragment.app.FragmentActivity +import androidx.lifecycle.ViewModelProviders import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KClass import kotlin.reflect.KProperty @@ -29,6 +29,60 @@ inline fun , reified S : MvRxState> T.fragm .apply { subscribe(this@fragmentViewModel, subscriber = { postInvalidate() }) } } +/** + * Gets or creates a ViewModel scoped to a parent fragment. This delegate will walk up the parentFragment hierarchy + * until it finds a Fragment that can provide the correct ViewModel. If no parent fragments can provide the ViewModel, + * a new one will be created in top-most parent Fragment. + */ +inline fun , reified S : MvRxState> T.parentFragmentViewModel( + viewModelClass: KClass = VM::class, + crossinline keyFactory: () -> String = { viewModelClass.java.name } +): Lazy where T : Fragment, T : MvRxView = lifecycleAwareLazy(this) { + requireNotNull(parentFragment) { "There is no parent fragment for ${this::class.java.simpleName}!" } + val notFoundMessage by lazy { "There is no ViewModel of type ${VM::class.java.simpleName} for this Fragment!" } + val factory = MvRxFactory { error(notFoundMessage) } + var fragment: Fragment? = parentFragment + val key = keyFactory() + while (fragment != null) { + try { + return@lifecycleAwareLazy ViewModelProviders.of(fragment, factory).get(key, viewModelClass.java) + .apply { subscribe(this@parentFragmentViewModel, subscriber = { postInvalidate() }) } + } catch (e: java.lang.IllegalStateException) { + if (e.message == notFoundMessage) { + fragment = fragment.parentFragment + } else { + throw e + } + } + } + + // ViewModel was not found. Create a new one in the top-most parent. + var topParentFragment = parentFragment + while (topParentFragment?.parentFragment != null) { + topParentFragment = topParentFragment.parentFragment + } + val viewModelContext = FragmentViewModelContext(this.requireActivity(), _fragmentArgsProvider(), topParentFragment!!) + return@lifecycleAwareLazy MvRxViewModelProvider.get(viewModelClass.java, S::class.java, viewModelContext, keyFactory()) + .apply { subscribe(this@parentFragmentViewModel, subscriber = { postInvalidate() }) } +} + +/** + * Gets or creates a ViewModel scoped to a target fragment. Throws [IllegalStateException] if there is no target fragment. + */ +inline fun , reified S : MvRxState> T.targetFragmentViewModel( + viewModelClass: KClass = VM::class, + crossinline keyFactory: () -> String = { viewModelClass.java.name } +): Lazy where T : Fragment, T : MvRxView = lifecycleAwareLazy(this) { + val targetFragment = requireNotNull(targetFragment) { "There is no target fragment for ${this::class.java.simpleName}!" } + MvRxViewModelProvider.get( + viewModelClass.java, + S::class.java, + FragmentViewModelContext(this.requireActivity(), targetFragment._fragmentArgsProvider(), targetFragment), + keyFactory() + ) + .apply { subscribe(this@targetFragmentViewModel, subscriber = { postInvalidate() }) } +} + /** * [activityViewModel] except it will throw [IllegalStateException] if the ViewModel doesn't already exist. * Use this for screens in the middle of a flow that cannot reasonably be an entrypoint to the flow. diff --git a/mvrx/src/test/kotlin/com/airbnb/mvrx/ViewSubscriberTest.kt b/mvrx/src/test/kotlin/com/airbnb/mvrx/FragmentSubscriberTest.kt similarity index 76% rename from mvrx/src/test/kotlin/com/airbnb/mvrx/ViewSubscriberTest.kt rename to mvrx/src/test/kotlin/com/airbnb/mvrx/FragmentSubscriberTest.kt index 7d1823f65..7e0f1e880 100644 --- a/mvrx/src/test/kotlin/com/airbnb/mvrx/ViewSubscriberTest.kt +++ b/mvrx/src/test/kotlin/com/airbnb/mvrx/FragmentSubscriberTest.kt @@ -5,6 +5,7 @@ import android.os.Bundle import android.view.LayoutInflater import android.view.View import android.view.ViewGroup +import android.widget.FrameLayout import androidx.fragment.app.Fragment import org.junit.Assert.assertEquals import org.junit.Test @@ -463,4 +464,122 @@ class FragmentSubscriberTest : BaseTest() { fun duplicateUniqueOnlySubscribeThrowIllegalStateException() { createFragment(containerId = CONTAINER_ID) } + + class ParentFragment : BaseMvRxFragment() { + + val viewModel: ViewSubscriberViewModel by fragmentViewModel() + + override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?) = FrameLayout(requireContext()) + + override fun invalidate() { + } + } + + class ChildFragmentWithParentViewModel : BaseMvRxFragment() { + + val viewModel: ViewSubscriberViewModel by parentFragmentViewModel() + + override fun invalidate() { + } + } + + @Test + fun testParentFragment() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val childFragment = ChildFragmentWithParentViewModel() + parentFragment.childFragmentManager.beginTransaction().add(childFragment, "child").commit() + assertEquals(parentFragment.viewModel, childFragment.viewModel) + } + + class ParentFragmentWithoutViewModel : BaseMvRxFragment() { + + override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?) = FrameLayout(requireContext()) + + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + childFragmentManager.beginTransaction() + .add(ChildFragmentWithParentViewModel(), "child1") + .commit() + childFragmentManager.beginTransaction() + .add(ChildFragmentWithParentViewModel(), "child2") + .commit() + } + + override fun invalidate() { + } + } + + @Test + fun testChildFragmentsCanShareViewModelWithoutParent() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val childFragment1 = parentFragment.childFragmentManager.findFragmentByTag("child1") as ChildFragmentWithParentViewModel + val childFragment2 = parentFragment.childFragmentManager.findFragmentByTag("child2") as ChildFragmentWithParentViewModel + assertEquals(childFragment1.viewModel, childFragment2.viewModel) + } + + class EmptyMvRxFragment : BaseMvRxFragment() { + override fun invalidate() { + } + } + + @Test + fun testCreatesViewModelInTopMostFragment() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val middleFragment = Fragment() + parentFragment.childFragmentManager.beginTransaction().add(middleFragment, "middle").commitNow() + val childFragment1 = ChildFragmentWithParentViewModel() + middleFragment.childFragmentManager.beginTransaction().add(childFragment1, "child1").commitNow() + + val childFragment2 = ChildFragmentWithParentViewModel() + parentFragment.childFragmentManager.beginTransaction().add(childFragment2, "child2").commitNow() + + assertEquals(childFragment1.viewModel, childFragment2.viewModel) + } + + class FragmentWithTarget : BaseMvRxFragment() { + val viewModel: ViewSubscriberViewModel by targetFragmentViewModel() + + var invalidateCount = 0 + + override fun invalidate() { + invalidateCount++ + } + } + + @Test + fun testTargetFragment() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val targetFragment = EmptyMvRxFragment() + parentFragment.childFragmentManager.beginTransaction().add(targetFragment, "target").commitNow() + val fragmentWithTarget = FragmentWithTarget() + fragmentWithTarget.setTargetFragment(targetFragment, 123) + parentFragment.childFragmentManager.beginTransaction().add(fragmentWithTarget, "fragment-with-target").commitNow() + // Make sure subscribe works. + assertEquals(1, fragmentWithTarget.invalidateCount) + fragmentWithTarget.viewModel.setFoo(1) + assertEquals(2, fragmentWithTarget.invalidateCount) + } + + @Test + fun testTargetFragmentsShareViewModel() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val targetFragment = EmptyMvRxFragment() + parentFragment.childFragmentManager.beginTransaction().add(targetFragment, "target").commitNow() + val fragmentWithTarget1 = FragmentWithTarget() + fragmentWithTarget1.setTargetFragment(targetFragment, 123) + parentFragment.childFragmentManager.beginTransaction().add(fragmentWithTarget1, "fragment-with-target1").commitNow() + val fragmentWithTarget2 = FragmentWithTarget() + fragmentWithTarget2.setTargetFragment(targetFragment, 123) + parentFragment.childFragmentManager.beginTransaction().add(fragmentWithTarget2, "fragment-with-target2").commitNow() + assertEquals(fragmentWithTarget1.viewModel, fragmentWithTarget2.viewModel) + } + + /** + * This would be [IllegalStateException] except it fails during the Fragment transaction so it's a RuntimeException. + */ + @Test(expected = RuntimeException::class) + fun testTargetFragmentWithoutTargetCrashes() { + val (_, parentFragment) = createFragment(containerId = CONTAINER_ID) + val fragmentWithTarget = FragmentWithTarget() + parentFragment.childFragmentManager.beginTransaction().add(fragmentWithTarget, "fragment-with-target").commitNow() + } } diff --git a/sample/src/main/java/com/airbnb/mvrx/sample/MainFragment.kt b/sample/src/main/java/com/airbnb/mvrx/sample/MainFragment.kt index 044f68cd5..e20aed837 100644 --- a/sample/src/main/java/com/airbnb/mvrx/sample/MainFragment.kt +++ b/sample/src/main/java/com/airbnb/mvrx/sample/MainFragment.kt @@ -28,6 +28,13 @@ class MainFragment : BaseFragment() { clickListener { _ -> navigateTo(R.id.action_main_to_helloWorldEpoxyFragment) } } + basicRow { + id("parent_fragments") + title("Parent/Child ViewModel") + subtitle(demonstrates("parentFragmentViewModel")) + clickListener { _ -> navigateTo(R.id.action_main_to_parentFragment) } + } + basicRow { id("random_dad_joke") title("Random Dad Joke") diff --git a/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ChildFragment.kt b/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ChildFragment.kt new file mode 100644 index 000000000..30e3b2314 --- /dev/null +++ b/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ChildFragment.kt @@ -0,0 +1,30 @@ +package com.airbnb.mvrx.sample.features.parentfragment + +import android.os.Bundle +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import com.airbnb.mvrx.BaseMvRxFragment +import com.airbnb.mvrx.parentFragmentViewModel +import com.airbnb.mvrx.sample.R +import com.airbnb.mvrx.withState +import kotlinx.android.synthetic.main.fragment_parent.textView + +class ChildFragment : BaseMvRxFragment() { + + private val viewModel: CounterViewModel by parentFragmentViewModel() + + override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?): View? { + return inflater.inflate(R.layout.fragment_child, container, false) + } + + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + textView.setOnClickListener { + viewModel.incrementCount() + } + } + + override fun invalidate() = withState(viewModel) { state -> + textView.text = "ChildFragment: Count: ${state.count}" + } +} \ No newline at end of file diff --git a/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ParentFragment.kt b/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ParentFragment.kt new file mode 100644 index 000000000..3ca9472d8 --- /dev/null +++ b/sample/src/main/java/com/airbnb/mvrx/sample/features/parentfragment/ParentFragment.kt @@ -0,0 +1,44 @@ +package com.airbnb.mvrx.sample.features.parentfragment + +import android.os.Bundle +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import androidx.navigation.fragment.findNavController +import androidx.navigation.ui.setupWithNavController +import com.airbnb.mvrx.BaseMvRxFragment +import com.airbnb.mvrx.MvRxState +import com.airbnb.mvrx.fragmentViewModel +import com.airbnb.mvrx.sample.R +import com.airbnb.mvrx.sample.core.MvRxViewModel +import com.airbnb.mvrx.withState +import kotlinx.android.synthetic.main.fragment_parent.textView +import kotlinx.android.synthetic.main.fragment_parent.toolbar + +data class CounterState(val count: Int = 0) : MvRxState +class CounterViewModel(state: CounterState) : MvRxViewModel(state) { + fun incrementCount() = setState { copy(count = count + 1) } +} + +class ParentFragment : BaseMvRxFragment() { + + private val viewModel: CounterViewModel by fragmentViewModel() + + override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?): View? { + return inflater.inflate(R.layout.fragment_parent, container, false) + } + + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + toolbar.setupWithNavController(findNavController()) + textView.setOnClickListener { + viewModel.incrementCount() + } + childFragmentManager.beginTransaction() + .replace(R.id.childContainer, ChildFragment()) + .commit() + } + + override fun invalidate() = withState(viewModel) { state -> + textView.text = "ParentFragment: Count: ${state.count}" + } +} \ No newline at end of file diff --git a/sample/src/main/res/layout/fragment_child.xml b/sample/src/main/res/layout/fragment_child.xml new file mode 100644 index 000000000..5254c0eb8 --- /dev/null +++ b/sample/src/main/res/layout/fragment_child.xml @@ -0,0 +1,12 @@ + + + + + + \ No newline at end of file diff --git a/sample/src/main/res/layout/fragment_parent.xml b/sample/src/main/res/layout/fragment_parent.xml new file mode 100644 index 000000000..475f50411 --- /dev/null +++ b/sample/src/main/res/layout/fragment_parent.xml @@ -0,0 +1,25 @@ + + + + + + + \ No newline at end of file diff --git a/sample/src/main/res/navigation/nav_graph.xml b/sample/src/main/res/navigation/nav_graph.xml index 14dcf7c7f..dca10da00 100644 --- a/sample/src/main/res/navigation/nav_graph.xml +++ b/sample/src/main/res/navigation/nav_graph.xml @@ -40,6 +40,9 @@ app:exitAnim="@anim/anim_exit" app:popEnterAnim="@anim/anim_pop_enter" app:popExitAnim="@anim/anim_pop_exit" /> + + \ No newline at end of file