001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.test.context.jdbc;
018
019import java.lang.reflect.AnnotatedElement;
020import java.lang.reflect.Method;
021import java.util.List;
022import java.util.Set;
023
024import javax.sql.DataSource;
025
026import org.apache.commons.logging.Log;
027import org.apache.commons.logging.LogFactory;
028
029import org.springframework.context.ApplicationContext;
030import org.springframework.core.annotation.AnnotatedElementUtils;
031import org.springframework.core.io.ByteArrayResource;
032import org.springframework.core.io.ClassPathResource;
033import org.springframework.core.io.Resource;
034import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
035import org.springframework.lang.NonNull;
036import org.springframework.lang.Nullable;
037import org.springframework.test.context.TestContext;
038import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
039import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
040import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
041import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode;
042import org.springframework.test.context.support.AbstractTestExecutionListener;
043import org.springframework.test.context.transaction.TestContextTransactionUtils;
044import org.springframework.test.context.util.TestContextResourceUtils;
045import org.springframework.transaction.PlatformTransactionManager;
046import org.springframework.transaction.TransactionDefinition;
047import org.springframework.transaction.interceptor.DefaultTransactionAttribute;
048import org.springframework.transaction.interceptor.TransactionAttribute;
049import org.springframework.transaction.support.TransactionTemplate;
050import org.springframework.util.Assert;
051import org.springframework.util.ClassUtils;
052import org.springframework.util.ObjectUtils;
053import org.springframework.util.ReflectionUtils;
054import org.springframework.util.ResourceUtils;
055import org.springframework.util.StringUtils;
056
057/**
058 * {@code TestExecutionListener} that provides support for executing SQL
059 * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
060 * configured via the {@link Sql @Sql} annotation.
061 *
062 * <p>Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before}
063 * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding
064 * {@linkplain java.lang.reflect.Method test method}, depending on the configured
065 * value of the {@link Sql#executionPhase executionPhase} flag.
066 *
067 * <p>Scripts and inlined statements will be executed without a transaction,
068 * within an existing Spring-managed transaction, or within an isolated transaction,
069 * depending on the configured value of {@link SqlConfig#transactionMode} and the
070 * presence of a transaction manager.
071 *
072 * <h3>Script Resources</h3>
073 * <p>For details on default script detection and how script resource locations
074 * are interpreted, see {@link Sql#scripts}.
075 *
076 * <h3>Required Spring Beans</h3>
077 * <p>A {@link PlatformTransactionManager} <em>and</em> a {@link DataSource},
078 * just a {@link PlatformTransactionManager}, or just a {@link DataSource}
079 * must be defined as beans in the Spring {@link ApplicationContext} for the
080 * corresponding test. Consult the javadocs for {@link SqlConfig#transactionMode},
081 * {@link SqlConfig#transactionManager}, {@link SqlConfig#dataSource},
082 * {@link TestContextTransactionUtils#retrieveDataSource}, and
083 * {@link TestContextTransactionUtils#retrieveTransactionManager} for details
084 * on permissible configuration constellations and on the algorithms used to
085 * locate these beans.
086 *
087 * @author Sam Brannen
088 * @author Dmitry Semukhin
089 * @since 4.1
090 * @see Sql
091 * @see SqlConfig
092 * @see SqlGroup
093 * @see org.springframework.test.context.transaction.TestContextTransactionUtils
094 * @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
095 * @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
096 * @see org.springframework.jdbc.datasource.init.ScriptUtils
097 */
098public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
099
100        private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);
101
102
103        /**
104         * Returns {@code 5000}.
105         */
106        @Override
107        public final int getOrder() {
108                return 5000;
109        }
110
111        /**
112         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
113         * {@link TestContext} <em>before</em> the current test method.
114         */
115        @Override
116        public void beforeTestMethod(TestContext testContext) {
117                executeSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_METHOD);
118        }
119
120        /**
121         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
122         * {@link TestContext} <em>after</em> the current test method.
123         */
124        @Override
125        public void afterTestMethod(TestContext testContext) {
126                executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
127        }
128
129        /**
130         * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
131         * {@link TestContext} and {@link ExecutionPhase}.
132         */
133        private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) {
134                Method testMethod = testContext.getTestMethod();
135                Class<?> testClass = testContext.getTestClass();
136
137                if (mergeSqlAnnotations(testContext)) {
138                        executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
139                        executeSqlScripts(getSqlAnnotationsFor(testMethod), testContext, executionPhase, false);
140                }
141                else {
142                        Set<Sql> methodLevelSqlAnnotations = getSqlAnnotationsFor(testMethod);
143                        if (!methodLevelSqlAnnotations.isEmpty()) {
144                                executeSqlScripts(methodLevelSqlAnnotations, testContext, executionPhase, false);
145                        }
146                        else {
147                                executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
148                        }
149                }
150        }
151
152        /**
153         * Determine if method-level {@code @Sql} annotations should be merged with
154         * class-level {@code @Sql} annotations.
155         */
156        private boolean mergeSqlAnnotations(TestContext testContext) {
157                SqlMergeMode sqlMergeMode = getSqlMergeModeFor(testContext.getTestMethod());
158                if (sqlMergeMode == null) {
159                        sqlMergeMode = getSqlMergeModeFor(testContext.getTestClass());
160                }
161                return (sqlMergeMode != null && sqlMergeMode.value() == MergeMode.MERGE);
162        }
163
164        /**
165         * Get the {@code @SqlMergeMode} annotation declared on the supplied {@code element}.
166         */
167        @Nullable
168        private SqlMergeMode getSqlMergeModeFor(AnnotatedElement element) {
169                return AnnotatedElementUtils.findMergedAnnotation(element, SqlMergeMode.class);
170        }
171
172        /**
173         * Get the {@code @Sql} annotations declared on the supplied {@code element}.
174         */
175        private Set<Sql> getSqlAnnotationsFor(AnnotatedElement element) {
176                return AnnotatedElementUtils.getMergedRepeatableAnnotations(element, Sql.class, SqlGroup.class);
177        }
178
179        /**
180         * Execute SQL scripts for the supplied {@link Sql @Sql} annotations.
181         */
182        private void executeSqlScripts(
183                        Set<Sql> sqlAnnotations, TestContext testContext, ExecutionPhase executionPhase, boolean classLevel) {
184
185                sqlAnnotations.forEach(sql -> executeSqlScripts(sql, executionPhase, testContext, classLevel));
186        }
187
188        /**
189         * Execute the SQL scripts configured via the supplied {@link Sql @Sql}
190         * annotation for the given {@link ExecutionPhase} and {@link TestContext}.
191         * <p>Special care must be taken in order to properly support the configured
192         * {@link SqlConfig#transactionMode}.
193         * @param sql the {@code @Sql} annotation to parse
194         * @param executionPhase the current execution phase
195         * @param testContext the current {@code TestContext}
196         * @param classLevel {@code true} if {@link Sql @Sql} was declared at the class level
197         */
198        private void executeSqlScripts(
199                        Sql sql, ExecutionPhase executionPhase, TestContext testContext, boolean classLevel) {
200
201                if (executionPhase != sql.executionPhase()) {
202                        return;
203                }
204
205                MergedSqlConfig mergedSqlConfig = new MergedSqlConfig(sql.config(), testContext.getTestClass());
206                if (logger.isDebugEnabled()) {
207                        logger.debug(String.format("Processing %s for execution phase [%s] and test context %s.",
208                                        mergedSqlConfig, executionPhase, testContext));
209                }
210
211                String[] scripts = getScripts(sql, testContext, classLevel);
212                scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
213                List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
214                                testContext.getApplicationContext(), scripts);
215                for (String stmt : sql.statements()) {
216                        if (StringUtils.hasText(stmt)) {
217                                stmt = stmt.trim();
218                                scriptResources.add(new ByteArrayResource(stmt.getBytes(), "from inlined SQL statement: " + stmt));
219                        }
220                }
221
222                ResourceDatabasePopulator populator = createDatabasePopulator(mergedSqlConfig);
223                populator.setScripts(scriptResources.toArray(new Resource[0]));
224                if (logger.isDebugEnabled()) {
225                        logger.debug("Executing SQL scripts: " + ObjectUtils.nullSafeToString(scriptResources));
226                }
227
228                String dsName = mergedSqlConfig.getDataSource();
229                String tmName = mergedSqlConfig.getTransactionManager();
230                DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, dsName);
231                PlatformTransactionManager txMgr = TestContextTransactionUtils.retrieveTransactionManager(testContext, tmName);
232                boolean newTxRequired = (mergedSqlConfig.getTransactionMode() == TransactionMode.ISOLATED);
233
234                if (txMgr == null) {
235                        Assert.state(!newTxRequired, () -> String.format("Failed to execute SQL scripts for test context %s: " +
236                                        "cannot execute SQL scripts using Transaction Mode " +
237                                        "[%s] without a PlatformTransactionManager.", testContext, TransactionMode.ISOLATED));
238                        Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for test context %s: " +
239                                        "supply at least a DataSource or PlatformTransactionManager.", testContext));
240                        // Execute scripts directly against the DataSource
241                        populator.execute(dataSource);
242                }
243                else {
244                        DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager(txMgr);
245                        // Ensure user configured an appropriate DataSource/TransactionManager pair.
246                        if (dataSource != null && dataSourceFromTxMgr != null && !dataSource.equals(dataSourceFromTxMgr)) {
247                                throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: " +
248                                                "the configured DataSource [%s] (named '%s') is not the one associated with " +
249                                                "transaction manager [%s] (named '%s').", testContext, dataSource.getClass().getName(),
250                                                dsName, txMgr.getClass().getName(), tmName));
251                        }
252                        if (dataSource == null) {
253                                dataSource = dataSourceFromTxMgr;
254                                Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for " +
255                                                "test context %s: could not obtain DataSource from transaction manager [%s] (named '%s').",
256                                                testContext, txMgr.getClass().getName(), tmName));
257                        }
258                        final DataSource finalDataSource = dataSource;
259                        int propagation = (newTxRequired ? TransactionDefinition.PROPAGATION_REQUIRES_NEW :
260                                        TransactionDefinition.PROPAGATION_REQUIRED);
261                        TransactionAttribute txAttr = TestContextTransactionUtils.createDelegatingTransactionAttribute(
262                                        testContext, new DefaultTransactionAttribute(propagation));
263                        new TransactionTemplate(txMgr, txAttr).executeWithoutResult(s -> populator.execute(finalDataSource));
264                }
265        }
266
267        @NonNull
268        private ResourceDatabasePopulator createDatabasePopulator(MergedSqlConfig mergedSqlConfig) {
269                ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
270                populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
271                populator.setSeparator(mergedSqlConfig.getSeparator());
272                populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes());
273                populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
274                populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
275                populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
276                populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
277                return populator;
278        }
279
280        @Nullable
281        private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
282                try {
                        Method getDataSourceMethod = transactionManager.getClass().getMethod("getDataSource");
284                        Object obj = ReflectionUtils.invokeMethod(getDataSourceMethod, transactionManager);
285                        if (obj instanceof DataSource) {
286                                return (DataSource) obj;
287                        }
288                }
289                catch (Exception ex) {
290                        // ignore
291                }
292                return null;
293        }
294
295        private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
296                String[] scripts = sql.scripts();
297                if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
298                        scripts = new String[] {detectDefaultScript(testContext, classLevel)};
299                }
300                return scripts;
301        }
302
303        /**
304         * Detect a default SQL script by implementing the algorithm defined in
305         * {@link Sql#scripts}.
306         */
307        private String detectDefaultScript(TestContext testContext, boolean classLevel) {
308                Class<?> clazz = testContext.getTestClass();
309                Method method = testContext.getTestMethod();
310                String elementType = (classLevel ? "class" : "method");
311                String elementName = (classLevel ? clazz.getName() : method.toString());
312
313                String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
314                if (!classLevel) {
315                        resourcePath += "." + method.getName();
316                }
317                resourcePath += ".sql";
318
319                String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + resourcePath;
320                ClassPathResource classPathResource = new ClassPathResource(resourcePath);
321
322                if (classPathResource.exists()) {
323                        if (logger.isInfoEnabled()) {
324                                logger.info(String.format("Detected default SQL script \"%s\" for test %s [%s]",
325                                                prefixedResourcePath, elementType, elementName));
326                        }
327                        return prefixedResourcePath;
328                }
329                else {
330                        String msg = String.format("Could not detect default SQL script for test %s [%s]: " +
331                                        "%s does not exist. Either declare statements or scripts via @Sql or make the " +
332                                        "default SQL script available.", elementType, elementName, classPathResource);
333                        logger.error(msg);
334                        throw new IllegalStateException(msg);
335                }
336        }
337
338}