001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.test.context.transaction;
018
019import java.util.Map;
020import javax.sql.DataSource;
021
022import org.apache.commons.logging.Log;
023import org.apache.commons.logging.LogFactory;
024
025import org.springframework.beans.BeansException;
026import org.springframework.beans.factory.BeanFactory;
027import org.springframework.beans.factory.BeanFactoryUtils;
028import org.springframework.beans.factory.ListableBeanFactory;
029import org.springframework.test.context.TestContext;
030import org.springframework.transaction.PlatformTransactionManager;
031import org.springframework.transaction.annotation.TransactionManagementConfigurer;
032import org.springframework.transaction.interceptor.DelegatingTransactionAttribute;
033import org.springframework.transaction.interceptor.TransactionAttribute;
034import org.springframework.util.Assert;
035import org.springframework.util.ClassUtils;
036import org.springframework.util.StringUtils;
037
038/**
039 * Utility methods for working with transactions and data access related beans
040 * within the <em>Spring TestContext Framework</em>.
041 *
042 * <p>Mainly for internal use within the framework.
043 *
044 * @author Sam Brannen
045 * @author Juergen Hoeller
046 * @since 4.1
047 */
048public abstract class TestContextTransactionUtils {
049
050        /**
051         * Default bean name for a {@link DataSource}: {@code "dataSource"}.
052         */
053        public static final String DEFAULT_DATA_SOURCE_NAME = "dataSource";
054
055        /**
056         * Default bean name for a {@link PlatformTransactionManager}:
057         * {@code "transactionManager"}.
058         */
059        public static final String DEFAULT_TRANSACTION_MANAGER_NAME = "transactionManager";
060
061
062        private static final Log logger = LogFactory.getLog(TestContextTransactionUtils.class);
063
064
065        /**
066         * Retrieve the {@link DataSource} to use for the supplied {@linkplain TestContext
067         * test context}.
068         * <p>The following algorithm is used to retrieve the {@code DataSource} from
069         * the {@link org.springframework.context.ApplicationContext ApplicationContext}
070         * of the supplied test context:
071         * <ol>
072         * <li>Look up the {@code DataSource} by type and name, if the supplied
073         * {@code name} is non-empty, throwing a {@link BeansException} if the named
074         * {@code DataSource} does not exist.
075         * <li>Attempt to look up the single {@code DataSource} by type.
076         * <li>Attempt to look up the <em>primary</em> {@code DataSource} by type.
077         * <li>Attempt to look up the {@code DataSource} by type and the
078         * {@linkplain #DEFAULT_DATA_SOURCE_NAME default data source name}.
079         * @param testContext the test context for which the {@code DataSource}
080         * should be retrieved; never {@code null}
081         * @param name the name of the {@code DataSource} to retrieve
082         * (may be {@code null} or <em>empty</em>)
083         * @return the {@code DataSource} to use, or {@code null} if not found
084         * @throws BeansException if an error occurs while retrieving an explicitly
085         * named {@code DataSource}
086         */
087        public static DataSource retrieveDataSource(TestContext testContext, String name) {
088                Assert.notNull(testContext, "TestContext must not be null");
089                BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();
090
091                try {
092                        // Look up by type and explicit name
093                        if (StringUtils.hasText(name)) {
094                                return bf.getBean(name, DataSource.class);
095                        }
096                }
097                catch (BeansException ex) {
098                        logger.error(String.format("Failed to retrieve DataSource named '%s' for test context %s",
099                                        name, testContext), ex);
100                        throw ex;
101                }
102
103                try {
104                        if (bf instanceof ListableBeanFactory) {
105                                ListableBeanFactory lbf = (ListableBeanFactory) bf;
106
107                                // Look up single bean by type
108                                Map<String, DataSource> dataSources =
109                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, DataSource.class);
110                                if (dataSources.size() == 1) {
111                                        return dataSources.values().iterator().next();
112                                }
113
114                                try {
115                                        // look up single bean by type, with support for 'primary' beans
116                                        return bf.getBean(DataSource.class);
117                                }
118                                catch (BeansException ex) {
119                                        logBeansException(testContext, ex, PlatformTransactionManager.class);
120                                }
121                        }
122
123                        // look up by type and default name
124                        return bf.getBean(DEFAULT_DATA_SOURCE_NAME, DataSource.class);
125                }
126                catch (BeansException ex) {
127                        logBeansException(testContext, ex, DataSource.class);
128                        return null;
129                }
130        }
131
132        /**
133         * Retrieve the {@linkplain PlatformTransactionManager transaction manager}
134         * to use for the supplied {@linkplain TestContext test context}.
135         * <p>The following algorithm is used to retrieve the transaction manager
136         * from the {@link org.springframework.context.ApplicationContext ApplicationContext}
137         * of the supplied test context:
138         * <ol>
139         * <li>Look up the transaction manager by type and explicit name, if the supplied
140         * {@code name} is non-empty, throwing a {@link BeansException} if the named
141         * transaction manager does not exist.
142         * <li>Attempt to look up the single transaction manager by type.
143         * <li>Attempt to look up the <em>primary</em> transaction manager by type.
144         * <li>Attempt to look up the transaction manager via a
145         * {@link TransactionManagementConfigurer}, if present.
146         * <li>Attempt to look up the transaction manager by type and the
147         * {@linkplain #DEFAULT_TRANSACTION_MANAGER_NAME default transaction manager
148         * name}.
149         * @param testContext the test context for which the transaction manager
150         * should be retrieved; never {@code null}
151         * @param name the name of the transaction manager to retrieve
152         * (may be {@code null} or <em>empty</em>)
153         * @return the transaction manager to use, or {@code null} if not found
154         * @throws BeansException if an error occurs while retrieving an explicitly
155         * named transaction manager
156         * @throws IllegalStateException if more than one TransactionManagementConfigurer
157         * exists in the ApplicationContext
158         */
159        public static PlatformTransactionManager retrieveTransactionManager(TestContext testContext, String name) {
160                Assert.notNull(testContext, "TestContext must not be null");
161                BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();
162
163                try {
164                        // Look up by type and explicit name
165                        if (StringUtils.hasText(name)) {
166                                return bf.getBean(name, PlatformTransactionManager.class);
167                        }
168                }
169                catch (BeansException ex) {
170                        logger.error(String.format("Failed to retrieve transaction manager named '%s' for test context %s",
171                                        name, testContext), ex);
172                        throw ex;
173                }
174
175                try {
176                        if (bf instanceof ListableBeanFactory) {
177                                ListableBeanFactory lbf = (ListableBeanFactory) bf;
178
179                                // Look up single bean by type
180                                Map<String, PlatformTransactionManager> txMgrs =
181                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, PlatformTransactionManager.class);
182                                if (txMgrs.size() == 1) {
183                                        return txMgrs.values().iterator().next();
184                                }
185
186                                try {
187                                        // Look up single bean by type, with support for 'primary' beans
188                                        return bf.getBean(PlatformTransactionManager.class);
189                                }
190                                catch (BeansException ex) {
191                                        logBeansException(testContext, ex, PlatformTransactionManager.class);
192                                }
193
194                                // Look up single TransactionManagementConfigurer
195                                Map<String, TransactionManagementConfigurer> configurers =
196                                                BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, TransactionManagementConfigurer.class);
197                                Assert.state(configurers.size() <= 1,
198                                                "Only one TransactionManagementConfigurer may exist in the ApplicationContext");
199                                if (configurers.size() == 1) {
200                                        return configurers.values().iterator().next().annotationDrivenTransactionManager();
201                                }
202                        }
203
204                        // look up by type and default name
205                        return bf.getBean(DEFAULT_TRANSACTION_MANAGER_NAME, PlatformTransactionManager.class);
206                }
207                catch (BeansException ex) {
208                        logBeansException(testContext, ex, PlatformTransactionManager.class);
209                        return null;
210                }
211        }
212
213        private static void logBeansException(TestContext testContext, BeansException ex, Class<?> beanType) {
214                if (logger.isDebugEnabled()) {
215                        logger.debug(String.format("Caught exception while retrieving %s for test context %s",
216                                beanType.getSimpleName(), testContext), ex);
217                }
218        }
219
220        /**
221         * Create a delegating {@link TransactionAttribute} for the supplied target
222         * {@link TransactionAttribute} and {@link TestContext}, using the names of
223         * the test class and test method to build the name of the transaction.
224         * @param testContext the {@code TestContext} upon which to base the name
225         * @param targetAttribute the {@code TransactionAttribute} to delegate to
226         * @return the delegating {@code TransactionAttribute}
227         */
228        public static TransactionAttribute createDelegatingTransactionAttribute(
229                        TestContext testContext, TransactionAttribute targetAttribute) {
230
231                Assert.notNull(testContext, "TestContext must not be null");
232                Assert.notNull(targetAttribute, "Target TransactionAttribute must not be null");
233                return new TestContextTransactionAttribute(targetAttribute, testContext);
234        }
235
236
237        @SuppressWarnings("serial")
238        private static class TestContextTransactionAttribute extends DelegatingTransactionAttribute {
239
240                private final String name;
241
242                public TestContextTransactionAttribute(TransactionAttribute targetAttribute, TestContext testContext) {
243                        super(targetAttribute);
244                        this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
245                }
246
247                @Override
248                public String getName() {
249                        return this.name;
250                }
251        }
252
253}