diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizer.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizer.java index d74740aa8b88..c29f3a787d83 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizer.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizer.java @@ -17,13 +17,8 @@ package org.springframework.test.context.bean.override; import java.util.Set; -import java.util.function.Consumer; -import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.config.ConstructorArgumentValues; -import org.springframework.beans.factory.config.RuntimeBeanReference; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; -import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.test.context.ContextCustomizer; import org.springframework.test.context.MergedContextConfiguration; @@ -57,43 +52,18 @@ class BeanOverrideContextCustomizer implements ContextCustomizer { @Override public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) { - if (!(context instanceof BeanDefinitionRegistry registry)) { - throw new IllegalStateException("Cannot process bean overrides with an ApplicationContext " + - "that doesn't implement BeanDefinitionRegistry: " + context.getClass()); - } - registerInfrastructure(registry); + ConfigurableListableBeanFactory beanFactory = context.getBeanFactory(); + BeanOverrideRegistrar beanOverrideRegistrar = new BeanOverrideRegistrar(); + beanOverrideRegistrar.setBeanFactory(beanFactory); + beanFactory.registerSingleton(REGISTRAR_BEAN_NAME, beanOverrideRegistrar); + beanFactory.registerSingleton(EARLY_INFRASTRUCTURE_BEAN_NAME, new WrapEarlyBeanPostProcessor(beanOverrideRegistrar)); + beanFactory.registerSingleton(INFRASTRUCTURE_BEAN_NAME, new BeanOverrideBeanFactoryPostProcessor(this.metadata, beanOverrideRegistrar)); } Set getMetadata() { return this.metadata; } - private void registerInfrastructure(BeanDefinitionRegistry registry) { - addInfrastructureBeanDefinition(registry, BeanOverrideRegistrar.class, REGISTRAR_BEAN_NAME, - constructorArgs -> {}); - - RuntimeBeanReference registrarReference = new RuntimeBeanReference(REGISTRAR_BEAN_NAME); - addInfrastructureBeanDefinition(registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME, - constructorArgs -> constructorArgs.addIndexedArgumentValue(0, registrarReference)); - addInfrastructureBeanDefinition(registry, BeanOverrideBeanFactoryPostProcessor.class, INFRASTRUCTURE_BEAN_NAME, - constructorArgs -> { - constructorArgs.addIndexedArgumentValue(0, this.metadata); - constructorArgs.addIndexedArgumentValue(1, registrarReference); - }); - } - - private void addInfrastructureBeanDefinition(BeanDefinitionRegistry registry, - Class clazz, String beanName, Consumer constructorArgumentsConsumer) { - - if (!registry.containsBeanDefinition(beanName)) { - RootBeanDefinition definition = new RootBeanDefinition(clazz); - definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); - ConstructorArgumentValues constructorArguments = definition.getConstructorArgumentValues(); - constructorArgumentsConsumer.accept(constructorArguments); - registry.registerBeanDefinition(beanName, definition); - } - } - @Override public boolean equals(Object other) { if (other == this) { diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/AotIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/AotIntegrationTests.java index 7c23d97ea3ac..ae87abb366c5 100644 --- a/spring-test/src/test/java/org/springframework/test/context/aot/AotIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/aot/AotIntegrationTests.java @@ -145,7 +145,6 @@ void endToEndTestsForEntireSpringTestModule() { runEndToEndTests(testClasses, false); } - @Disabled("Comment out to run @TestBean integration tests in AOT mode") @Test void endToEndTestsForTestBeanOverrideTestClasses() { List> testClasses = List.of(