001/*
002 * Copyright 2002-2016 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.test.context.jdbc;
018
019import java.lang.reflect.Method;
020import java.util.List;
021import java.util.Set;
022import javax.sql.DataSource;
023
024import org.apache.commons.logging.Log;
025import org.apache.commons.logging.LogFactory;
026
027import org.springframework.context.ApplicationContext;
028import org.springframework.core.annotation.AnnotatedElementUtils;
029import org.springframework.core.io.ByteArrayResource;
030import org.springframework.core.io.ClassPathResource;
031import org.springframework.core.io.Resource;
032import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
033import org.springframework.test.context.TestContext;
034import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
035import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
036import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
037import org.springframework.test.context.support.AbstractTestExecutionListener;
038import org.springframework.test.context.transaction.TestContextTransactionUtils;
039import org.springframework.test.context.util.TestContextResourceUtils;
040import org.springframework.transaction.PlatformTransactionManager;
041import org.springframework.transaction.TransactionDefinition;
042import org.springframework.transaction.TransactionStatus;
043import org.springframework.transaction.interceptor.DefaultTransactionAttribute;
044import org.springframework.transaction.interceptor.TransactionAttribute;
045import org.springframework.transaction.support.TransactionCallbackWithoutResult;
046import org.springframework.transaction.support.TransactionTemplate;
047import org.springframework.util.ClassUtils;
048import org.springframework.util.ObjectUtils;
049import org.springframework.util.ReflectionUtils;
050import org.springframework.util.ResourceUtils;
051import org.springframework.util.StringUtils;
052
053/**
054 * {@code TestExecutionListener} that provides support for executing SQL
055 * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
056 * configured via the {@link Sql @Sql} annotation.
057 *
058 * <p>Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before}
059 * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding
060 * {@linkplain java.lang.reflect.Method test method}, depending on the configured
061 * value of the {@link Sql#executionPhase executionPhase} flag.
062 *
063 * <p>Scripts and inlined statements will be executed without a transaction,
064 * within an existing Spring-managed transaction, or within an isolated transaction,
065 * depending on the configured value of {@link SqlConfig#transactionMode} and the
066 * presence of a transaction manager.
067 *
068 * <h3>Script Resources</h3>
069 * <p>For details on default script detection and how script resource locations
070 * are interpreted, see {@link Sql#scripts}.
071 *
072 * <h3>Required Spring Beans</h3>
073 * <p>A {@link PlatformTransactionManager} <em>and</em> a {@link DataSource},
074 * just a {@link PlatformTransactionManager}, or just a {@link DataSource}
075 * must be defined as beans in the Spring {@link ApplicationContext} for the
076 * corresponding test. Consult the javadocs for {@link SqlConfig#transactionMode},
077 * {@link SqlConfig#transactionManager}, {@link SqlConfig#dataSource},
078 * {@link TestContextTransactionUtils#retrieveDataSource}, and
079 * {@link TestContextTransactionUtils#retrieveTransactionManager} for details
080 * on permissible configuration constellations and on the algorithms used to
081 * locate these beans.
082 *
083 * @author Sam Brannen
084 * @since 4.1
085 * @see Sql
086 * @see SqlConfig
087 * @see SqlGroup
088 * @see org.springframework.test.context.transaction.TestContextTransactionUtils
089 * @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
090 * @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
091 * @see org.springframework.jdbc.datasource.init.ScriptUtils
092 */
093public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
094
095        private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);
096
097
098        /**
099         * Returns {@code 5000}.
100         */
101        @Override
102        public final int getOrder() {
103                return 5000;
104        }
105
106        /**
107         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
108         * {@link TestContext} <em>before</em> the current test method.
109         */
110        @Override
111        public void beforeTestMethod(TestContext testContext) throws Exception {
112                executeSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_METHOD);
113        }
114
115        /**
116         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
117         * {@link TestContext} <em>after</em> the current test method.
118         */
119        @Override
120        public void afterTestMethod(TestContext testContext) throws Exception {
121                executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
122        }
123
124        /**
125         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
126         * {@link TestContext} and {@link ExecutionPhase}.
127         */
128        private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
129                boolean classLevel = false;
130
131                Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
132                                testContext.getTestMethod(), Sql.class, SqlGroup.class);
133                if (sqlAnnotations.isEmpty()) {
134                        sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
135                                        testContext.getTestClass(), Sql.class, SqlGroup.class);
136                        if (!sqlAnnotations.isEmpty()) {
137                                classLevel = true;
138                        }
139                }
140
141                for (Sql sql : sqlAnnotations) {
142                        executeSqlScripts(sql, executionPhase, testContext, classLevel);
143                }
144        }
145
146        /**
147         * Execute the SQL scripts configured via the supplied {@link Sql @Sql}
148         * annotation for the given {@link ExecutionPhase} and {@link TestContext}.
149         * <p>Special care must be taken in order to properly support the configured
150         * {@link SqlConfig#transactionMode}.
151         * @param sql the {@code @Sql} annotation to parse
152         * @param executionPhase the current execution phase
153         * @param testContext the current {@code TestContext}
154         * @param classLevel {@code true} if {@link Sql @Sql} was declared at the class level
155         */
156        private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestContext testContext, boolean classLevel)
157                        throws Exception {
158
159                if (executionPhase != sql.executionPhase()) {
160                        return;
161                }
162
163                MergedSqlConfig mergedSqlConfig = new MergedSqlConfig(sql.config(), testContext.getTestClass());
164                if (logger.isDebugEnabled()) {
165                        logger.debug(String.format("Processing %s for execution phase [%s] and test context %s.",
166                                        mergedSqlConfig, executionPhase, testContext));
167                }
168
169                final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
170                populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
171                populator.setSeparator(mergedSqlConfig.getSeparator());
172                populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
173                populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
174                populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
175                populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
176                populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
177
178                String[] scripts = getScripts(sql, testContext, classLevel);
179                scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
180                List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
181                                testContext.getApplicationContext(), scripts);
182                for (String stmt : sql.statements()) {
183                        if (StringUtils.hasText(stmt)) {
184                                stmt = stmt.trim();
185                                scriptResources.add(new ByteArrayResource(stmt.getBytes(), "from inlined SQL statement: " + stmt));
186                        }
187                }
188                populator.setScripts(scriptResources.toArray(new Resource[scriptResources.size()]));
189                if (logger.isDebugEnabled()) {
190                        logger.debug("Executing SQL scripts: " + ObjectUtils.nullSafeToString(scriptResources));
191                }
192
193                String dsName = mergedSqlConfig.getDataSource();
194                String tmName = mergedSqlConfig.getTransactionManager();
195                DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, dsName);
196                PlatformTransactionManager txMgr = TestContextTransactionUtils.retrieveTransactionManager(testContext, tmName);
197                boolean newTxRequired = (mergedSqlConfig.getTransactionMode() == TransactionMode.ISOLATED);
198
199                if (txMgr == null) {
200                        if (newTxRequired) {
201                                throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: " +
202                                                "cannot execute SQL scripts using Transaction Mode [%s] without a PlatformTransactionManager.",
203                                                testContext, TransactionMode.ISOLATED));
204                        }
205                        if (dataSource == null) {
206                                throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: " +
207                                                "supply at least a DataSource or PlatformTransactionManager.", testContext));
208                        }
209                        // Execute scripts directly against the DataSource
210                        populator.execute(dataSource);
211                }
212                else {
213                        DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager(txMgr);
214                        // Ensure user configured an appropriate DataSource/TransactionManager pair.
215                        if (dataSource != null && dataSourceFromTxMgr != null && !dataSource.equals(dataSourceFromTxMgr)) {
216                                throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: " +
217                                                "the configured DataSource [%s] (named '%s') is not the one associated with " +
218                                                "transaction manager [%s] (named '%s').", testContext, dataSource.getClass().getName(),
219                                                dsName, txMgr.getClass().getName(), tmName));
220                        }
221                        if (dataSource == null) {
222                                dataSource = dataSourceFromTxMgr;
223                                if (dataSource == null) {
224                                        throw new IllegalStateException(String.format("Failed to execute SQL scripts for " +
225                                                        "test context %s: could not obtain DataSource from transaction manager [%s] (named '%s').",
226                                                        testContext, txMgr.getClass().getName(), tmName));
227                                }
228                        }
229                        final DataSource finalDataSource = dataSource;
230                        int propagation = (newTxRequired ? TransactionDefinition.PROPAGATION_REQUIRES_NEW :
231                                        TransactionDefinition.PROPAGATION_REQUIRED);
232                        TransactionAttribute txAttr = TestContextTransactionUtils.createDelegatingTransactionAttribute(
233                                        testContext, new DefaultTransactionAttribute(propagation));
234                        new TransactionTemplate(txMgr, txAttr).execute(new TransactionCallbackWithoutResult() {
235                                @Override
236                                public void doInTransactionWithoutResult(TransactionStatus status) {
237                                        populator.execute(finalDataSource);
238                                }
239                        });
240                }
241        }
242
243        private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
244                try {
245                        Method getDataSourceMethod = transactionManager.getClass().getMethod("getDataSource");
246                        Object obj = ReflectionUtils.invokeMethod(getDataSourceMethod, transactionManager);
247                        if (obj instanceof DataSource) {
248                                return (DataSource) obj;
249                        }
250                }
251                catch (Exception ex) {
252                        // ignore
253                }
254                return null;
255        }
256
257        private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
258                String[] scripts = sql.scripts();
259                if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
260                        scripts = new String[] {detectDefaultScript(testContext, classLevel)};
261                }
262                return scripts;
263        }
264
265        /**
266         * Detect a default SQL script by implementing the algorithm defined in
267         * {@link Sql#scripts}.
268         */
269        private String detectDefaultScript(TestContext testContext, boolean classLevel) {
270                Class<?> clazz = testContext.getTestClass();
271                Method method = testContext.getTestMethod();
272                String elementType = (classLevel ? "class" : "method");
273                String elementName = (classLevel ? clazz.getName() : method.toString());
274
275                String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
276                if (!classLevel) {
277                        resourcePath += "." + method.getName();
278                }
279                resourcePath += ".sql";
280
                String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + resourcePath;
282                ClassPathResource classPathResource = new ClassPathResource(resourcePath);
283
284                if (classPathResource.exists()) {
285                        if (logger.isInfoEnabled()) {
286                                logger.info(String.format("Detected default SQL script \"%s\" for test %s [%s]",
287                                                prefixedResourcePath, elementType, elementName));
288                        }
289                        return prefixedResourcePath;
290                }
291                else {
292                        String msg = String.format("Could not detect default SQL script for test %s [%s]: " +
293                                        "%s does not exist. Either declare statements or scripts via @Sql or make the " +
294                                        "default SQL script available.", elementType, elementName, classPathResource);
295                        logger.error(msg);
296                        throw new IllegalStateException(msg);
297                }
298        }
299
300}