001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.test.context.transaction;
018
019import java.util.Map;
020
021import javax.sql.DataSource;
022
023import org.apache.commons.logging.Log;
024import org.apache.commons.logging.LogFactory;
025
026import org.springframework.beans.BeansException;
027import org.springframework.beans.factory.BeanFactory;
028import org.springframework.beans.factory.BeanFactoryUtils;
029import org.springframework.beans.factory.ListableBeanFactory;
030import org.springframework.lang.Nullable;
031import org.springframework.test.context.TestContext;
032import org.springframework.transaction.PlatformTransactionManager;
033import org.springframework.transaction.TransactionManager;
034import org.springframework.transaction.annotation.TransactionManagementConfigurer;
035import org.springframework.transaction.interceptor.DelegatingTransactionAttribute;
036import org.springframework.transaction.interceptor.TransactionAttribute;
037import org.springframework.util.Assert;
038import org.springframework.util.ClassUtils;
039import org.springframework.util.StringUtils;
040
041/**
042 * Utility methods for working with transactions and data access related beans
043 * within the <em>Spring TestContext Framework</em>.
044 *
045 * <p>Mainly for internal use within the framework.
046 *
047 * @author Sam Brannen
048 * @author Juergen Hoeller
049 * @since 4.1
050 */
051public abstract class TestContextTransactionUtils {
052
053        /**
054         * Default bean name for a {@link DataSource}: {@code "dataSource"}.
055         */
056        public static final String DEFAULT_DATA_SOURCE_NAME = "dataSource";
057
058        /**
059         * Default bean name for a {@link PlatformTransactionManager}:
060         * {@code "transactionManager"}.
061         */
062        public static final String DEFAULT_TRANSACTION_MANAGER_NAME = "transactionManager";
063
064
065        private static final Log logger = LogFactory.getLog(TestContextTransactionUtils.class);
066
067
068        /**
069         * Retrieve the {@link DataSource} to use for the supplied {@linkplain TestContext
070         * test context}.
071         * <p>The following algorithm is used to retrieve the {@code DataSource} from
072         * the {@link org.springframework.context.ApplicationContext ApplicationContext}
073         * of the supplied test context:
074         * <ol>
075         * <li>Look up the {@code DataSource} by type and name, if the supplied
076         * {@code name} is non-empty, throwing a {@link BeansException} if the named
077         * {@code DataSource} does not exist.
078         * <li>Attempt to look up the single {@code DataSource} by type.
079         * <li>Attempt to look up the <em>primary</em> {@code DataSource} by type.
080         * <li>Attempt to look up the {@code DataSource} by type and the
081         * {@linkplain #DEFAULT_DATA_SOURCE_NAME default data source name}.
082         * </ol>
083         * @param testContext the test context for which the {@code DataSource}
084         * should be retrieved; never {@code null}
085         * @param name the name of the {@code DataSource} to retrieve
086         * (may be {@code null} or <em>empty</em>)
087         * @return the {@code DataSource} to use, or {@code null} if not found
088         * @throws BeansException if an error occurs while retrieving an explicitly
089         * named {@code DataSource}
090         */
091        @Nullable
092        public static DataSource retrieveDataSource(TestContext testContext, @Nullable String name) {
093                Assert.notNull(testContext, "TestContext must not be null");
094                BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();
095
096                try {
097                        // Look up by type and explicit name
098                        if (StringUtils.hasText(name)) {
099                                return bf.getBean(name, DataSource.class);
100                        }
101                }
102                catch (BeansException ex) {
103                        logger.error(String.format("Failed to retrieve DataSource named '%s' for test context %s",
104                                        name, testContext), ex);
105                        throw ex;
106                }
107
108                try {
109                        if (bf instanceof ListableBeanFactory) {
110                                ListableBeanFactory lbf = (ListableBeanFactory) bf;
111
112                                // Look up single bean by type
113                                Map<String, DataSource> dataSources =
114                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, DataSource.class);
115                                if (dataSources.size() == 1) {
116                                        return dataSources.values().iterator().next();
117                                }
118
119                                try {
120                                        // look up single bean by type, with support for 'primary' beans
121                                        return bf.getBean(DataSource.class);
122                                }
123                                catch (BeansException ex) {
124                                        logBeansException(testContext, ex, PlatformTransactionManager.class);
125                                }
126                        }
127
128                        // look up by type and default name
129                        return bf.getBean(DEFAULT_DATA_SOURCE_NAME, DataSource.class);
130                }
131                catch (BeansException ex) {
132                        logBeansException(testContext, ex, DataSource.class);
133                        return null;
134                }
135        }
136
137        /**
138         * Retrieve the {@linkplain PlatformTransactionManager transaction manager}
139         * to use for the supplied {@linkplain TestContext test context}.
140         * <p>The following algorithm is used to retrieve the transaction manager
141         * from the {@link org.springframework.context.ApplicationContext ApplicationContext}
142         * of the supplied test context:
143         * <ol>
144         * <li>Look up the transaction manager by type and explicit name, if the supplied
145         * {@code name} is non-empty, throwing a {@link BeansException} if the named
146         * transaction manager does not exist.
147         * <li>Attempt to look up the single transaction manager by type.
148         * <li>Attempt to look up the <em>primary</em> transaction manager by type.
149         * <li>Attempt to look up the transaction manager via a
150         * {@link TransactionManagementConfigurer}, if present.
151         * <li>Attempt to look up the transaction manager by type and the
152         * {@linkplain #DEFAULT_TRANSACTION_MANAGER_NAME default transaction manager
153         * name}.
154         * </ol>
155         * @param testContext the test context for which the transaction manager
156         * should be retrieved; never {@code null}
157         * @param name the name of the transaction manager to retrieve
158         * (may be {@code null} or <em>empty</em>)
159         * @return the transaction manager to use, or {@code null} if not found
160         * @throws BeansException if an error occurs while retrieving an explicitly
161         * named transaction manager
162         * @throws IllegalStateException if more than one TransactionManagementConfigurer
163         * exists in the ApplicationContext
164         */
165        @Nullable
166        public static PlatformTransactionManager retrieveTransactionManager(TestContext testContext, @Nullable String name) {
167                Assert.notNull(testContext, "TestContext must not be null");
168                BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();
169
170                try {
171                        // Look up by type and explicit name
172                        if (StringUtils.hasText(name)) {
173                                return bf.getBean(name, PlatformTransactionManager.class);
174                        }
175                }
176                catch (BeansException ex) {
177                        logger.error(String.format("Failed to retrieve transaction manager named '%s' for test context %s",
178                                        name, testContext), ex);
179                        throw ex;
180                }
181
182                try {
183                        if (bf instanceof ListableBeanFactory) {
184                                ListableBeanFactory lbf = (ListableBeanFactory) bf;
185
186                                // Look up single bean by type
187                                Map<String, PlatformTransactionManager> txMgrs =
188                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, PlatformTransactionManager.class);
189                                if (txMgrs.size() == 1) {
190                                        return txMgrs.values().iterator().next();
191                                }
192
193                                try {
194                                        // Look up single bean by type, with support for 'primary' beans
195                                        return bf.getBean(PlatformTransactionManager.class);
196                                }
197                                catch (BeansException ex) {
198                                        logBeansException(testContext, ex, PlatformTransactionManager.class);
199                                }
200
201                                // Look up single TransactionManagementConfigurer
202                                Map<String, TransactionManagementConfigurer> configurers =
203                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, TransactionManagementConfigurer.class);
204                                Assert.state(configurers.size() <= 1,
205                                                "Only one TransactionManagementConfigurer may exist in the ApplicationContext");
206                                if (configurers.size() == 1) {
207                                        TransactionManager tm = configurers.values().iterator().next().annotationDrivenTransactionManager();
208                                        Assert.state(tm instanceof PlatformTransactionManager, () ->
209                                                "Transaction manager specified via TransactionManagementConfigurer " +
210                                                "is not a PlatformTransactionManager: " + tm);
211                                        return (PlatformTransactionManager) tm;
212                                }
213                        }
214
215                        // look up by type and default name
216                        return bf.getBean(DEFAULT_TRANSACTION_MANAGER_NAME, PlatformTransactionManager.class);
217                }
218                catch (BeansException ex) {
219                        logBeansException(testContext, ex, PlatformTransactionManager.class);
220                        return null;
221                }
222        }
223
224        private static void logBeansException(TestContext testContext, BeansException ex, Class<?> beanType) {
225                if (logger.isDebugEnabled()) {
226                        logger.debug(String.format("Caught exception while retrieving %s for test context %s",
227                                beanType.getSimpleName(), testContext), ex);
228                }
229        }
230
231        /**
232         * Create a delegating {@link TransactionAttribute} for the supplied target
233         * {@link TransactionAttribute} and {@link TestContext}, using the names of
234         * the test class and test method to build the name of the transaction.
235         * @param testContext the {@code TestContext} upon which to base the name
236         * @param targetAttribute the {@code TransactionAttribute} to delegate to
237         * @return the delegating {@code TransactionAttribute}
238         */
239        public static TransactionAttribute createDelegatingTransactionAttribute(
240                        TestContext testContext, TransactionAttribute targetAttribute) {
241
242                Assert.notNull(testContext, "TestContext must not be null");
243                Assert.notNull(targetAttribute, "Target TransactionAttribute must not be null");
244                return new TestContextTransactionAttribute(targetAttribute, testContext);
245        }
246
247
248        @SuppressWarnings("serial")
249        private static class TestContextTransactionAttribute extends DelegatingTransactionAttribute {
250
251                private final String name;
252
253                public TestContextTransactionAttribute(TransactionAttribute targetAttribute, TestContext testContext) {
254                        super(targetAttribute);
255                        this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
256                }
257
258                @Override
259                @Nullable
260                public String getName() {
261                        return this.name;
262                }
263        }
264
265}