diff --git a/core/src/com/google/inject/internal/Errors.java b/core/src/com/google/inject/internal/Errors.java index 2c584a8c2e..6046a526fe 100644 --- a/core/src/com/google/inject/internal/Errors.java +++ b/core/src/com/google/inject/internal/Errors.java @@ -499,6 +499,10 @@ public Errors errorCheckingDuplicateBinding(Key key, Object source, Throwable t); } + public Errors errorCacheAlreadyLoading(Object key) { + return addMessage(ErrorId.OTHER, "%s was already loading.", key); + } + public Errors errorNotifyingTypeListener( TypeListenerBinding listener, TypeLiteral type, Throwable cause) { return errorInUserCode( diff --git a/core/src/com/google/inject/internal/FailableCache.java b/core/src/com/google/inject/internal/FailableCache.java index def5d980dd..b664d9204c 100644 --- a/core/src/com/google/inject/internal/FailableCache.java +++ b/core/src/com/google/inject/internal/FailableCache.java @@ -22,6 +22,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; /** * Lazily creates (and caches) values for keys. If creating the value fails (with errors), an @@ -31,6 +33,8 @@ */ public abstract class FailableCache { + private final Set loading = ConcurrentHashMap.newKeySet(); + private final LoadingCache delegate = CacheBuilder.newBuilder() .build( @@ -39,11 +43,13 @@ public abstract class FailableCache { public Object load(K key) { Errors errors = new Errors(); V result = null; + loading.add(key); try { result = FailableCache.this.create(key, errors); } catch (ErrorsException e) { errors.merge(e.getErrors()); } + loading.remove(key); return errors.hasErrors() ? errors : result; } }); @@ -51,6 +57,10 @@ public Object load(K key) { protected abstract V create(K key, Errors errors) throws ErrorsException; public V get(K key, Errors errors) throws ErrorsException { + if (loading.contains(key)) { + errors.errorCacheAlreadyLoading(key); + throw errors.toException(); + } Object resultOrError = delegate.getUnchecked(key); if (resultOrError instanceof Errors) { errors.merge((Errors) resultOrError); diff --git a/core/test/com/google/inject/RecursiveLoadTest.java b/core/test/com/google/inject/RecursiveLoadTest.java new file mode 100644 index 0000000000..1adf2cd1b5 --- /dev/null +++ b/core/test/com/google/inject/RecursiveLoadTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2006 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.inject; + +import com.google.inject.spi.Message; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import junit.framework.TestCase; +import org.junit.Assert; + +public class RecursiveLoadTest extends TestCase { + + public void testRecursiveLoad() { + Injector injector = + Guice.createInjector( + new AbstractModule() { + @Override + protected void configure() {} + }); + assertBothFailures(injector, A.class); + assertBothFailures(injector, B.class); + assertNoImplementationFailure(injector, C.class); + assertRecursiveFailure(injector, D.class); + assertNoImplementationFailure(injector, E.class); + } + + private static void assertFailure( + Injector injector, Class clazz, Consumer> checks) { + try { + injector.getBinding(clazz); + fail("Shouldn't have been able to get binding of: " + clazz); + } catch (ConfigurationException expected) { + List errorMessages = new ArrayList<>(expected.getErrorMessages()); + checks.accept(errorMessages); + } + } + + private static void assertBothFailures(Injector injector, Class clazz) { + assertFailure( + injector, + clazz, + errorMessages -> { + Assert.assertEquals(2, errorMessages.size()); + + Message msg1 = errorMessages.get(0); + Asserts.assertContains( + msg1.getMessage(), + "com.google.inject.RecursiveLoadTest$B.() was already loading."); + + Message msg2 = errorMessages.get(1); + Asserts.assertContains( + msg2.getMessage(), + "No implementation for com.google.inject.RecursiveLoadTest$Unresolved was bound."); + }); + } + + private static void assertRecursiveFailure(Injector injector, Class clazz) { + assertFailure( + injector, + clazz, + errorMessages -> { + Assert.assertEquals(1, errorMessages.size()); + + Message msg = errorMessages.get(0); + Asserts.assertContains( + msg.getMessage(), + "com.google.inject.RecursiveLoadTest$B.() was already loading."); + }); + } + + private static void assertNoImplementationFailure(Injector injector, Class clazz) { + assertFailure( + injector, + clazz, + errorMessages -> { + Assert.assertEquals(1, errorMessages.size()); + + Message msg = errorMessages.get(0); + Asserts.assertContains( + msg.getMessage(), + "No implementation for com.google.inject.RecursiveLoadTest$Unresolved was bound."); + }); + } + + static class A { + @Inject B b; + } + + static class B { + @Inject C c; + @Inject D d; + } + + static class C { + @Inject E e; + @Inject B b; + } + + static class D { + @Inject B b; + } + + static class E { + @Inject Unresolved unresolved; + } + + interface Unresolved {} +}