001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.test.context.jdbc; 018 019import java.lang.reflect.AnnotatedElement; 020import java.lang.reflect.Method; 021import java.util.List; 022import java.util.Set; 023 024import javax.sql.DataSource; 025 026import org.apache.commons.logging.Log; 027import org.apache.commons.logging.LogFactory; 028 029import org.springframework.context.ApplicationContext; 030import org.springframework.core.annotation.AnnotatedElementUtils; 031import org.springframework.core.io.ByteArrayResource; 032import org.springframework.core.io.ClassPathResource; 033import org.springframework.core.io.Resource; 034import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; 035import org.springframework.lang.NonNull; 036import org.springframework.lang.Nullable; 037import org.springframework.test.context.TestContext; 038import org.springframework.test.context.jdbc.Sql.ExecutionPhase; 039import org.springframework.test.context.jdbc.SqlConfig.ErrorMode; 040import org.springframework.test.context.jdbc.SqlConfig.TransactionMode; 041import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode; 042import org.springframework.test.context.support.AbstractTestExecutionListener; 043import org.springframework.test.context.transaction.TestContextTransactionUtils; 044import org.springframework.test.context.util.TestContextResourceUtils; 045import org.springframework.transaction.PlatformTransactionManager; 046import org.springframework.transaction.TransactionDefinition; 047import org.springframework.transaction.interceptor.DefaultTransactionAttribute; 048import org.springframework.transaction.interceptor.TransactionAttribute; 049import org.springframework.transaction.support.TransactionTemplate; 050import org.springframework.util.Assert; 051import org.springframework.util.ClassUtils; 052import org.springframework.util.ObjectUtils; 053import org.springframework.util.ReflectionUtils; 054import org.springframework.util.ResourceUtils; 055import org.springframework.util.StringUtils; 056 057/** 058 * {@code TestExecutionListener} that provides support for executing SQL 059 * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements} 060 * configured via the {@link Sql @Sql} annotation. 061 * 062 * <p>Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before} 063 * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding 064 * {@linkplain java.lang.reflect.Method test method}, depending on the configured 065 * value of the {@link Sql#executionPhase executionPhase} flag. 066 * 067 * <p>Scripts and inlined statements will be executed without a transaction, 068 * within an existing Spring-managed transaction, or within an isolated transaction, 069 * depending on the configured value of {@link SqlConfig#transactionMode} and the 070 * presence of a transaction manager. 071 * 072 * <h3>Script Resources</h3> 073 * <p>For details on default script detection and how script resource locations 074 * are interpreted, see {@link Sql#scripts}. 075 * 076 * <h3>Required Spring Beans</h3> 077 * <p>A {@link PlatformTransactionManager} <em>and</em> a {@link DataSource}, 078 * just a {@link PlatformTransactionManager}, or just a {@link DataSource} 079 * must be defined as beans in the Spring {@link ApplicationContext} for the 080 * corresponding test. Consult the javadocs for {@link SqlConfig#transactionMode}, 081 * {@link SqlConfig#transactionManager}, {@link SqlConfig#dataSource}, 082 * {@link TestContextTransactionUtils#retrieveDataSource}, and 083 * {@link TestContextTransactionUtils#retrieveTransactionManager} for details 084 * on permissible configuration constellations and on the algorithms used to 085 * locate these beans. 086 * 087 * @author Sam Brannen 088 * @author Dmitry Semukhin 089 * @since 4.1 090 * @see Sql 091 * @see SqlConfig 092 * @see SqlGroup 093 * @see org.springframework.test.context.transaction.TestContextTransactionUtils 094 * @see org.springframework.test.context.transaction.TransactionalTestExecutionListener 095 * @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator 096 * @see org.springframework.jdbc.datasource.init.ScriptUtils 097 */ 098public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener { 099 100 private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class); 101 102 103 /** 104 * Returns {@code 5000}. 105 */ 106 @Override 107 public final int getOrder() { 108 return 5000; 109 } 110 111 /** 112 * Execute SQL scripts configured via {@link Sql @Sql} for the supplied 113 * {@link TestContext} <em>before</em> the current test method. 114 */ 115 @Override 116 public void beforeTestMethod(TestContext testContext) { 117 executeSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_METHOD); 118 } 119 120 /** 121 * Execute SQL scripts configured via {@link Sql @Sql} for the supplied 122 * {@link TestContext} <em>after</em> the current test method. 123 */ 124 @Override 125 public void afterTestMethod(TestContext testContext) { 126 executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD); 127 } 128 129 /** 130 * Execute SQL scripts configured via {@link Sql @Sql} for the supplied 131 * {@link TestContext} and {@link ExecutionPhase}. 132 */ 133 private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) { 134 Method testMethod = testContext.getTestMethod(); 135 Class<?> testClass = testContext.getTestClass(); 136 137 if (mergeSqlAnnotations(testContext)) { 138 executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true); 139 executeSqlScripts(getSqlAnnotationsFor(testMethod), testContext, executionPhase, false); 140 } 141 else { 142 Set<Sql> methodLevelSqlAnnotations = getSqlAnnotationsFor(testMethod); 143 if (!methodLevelSqlAnnotations.isEmpty()) { 144 executeSqlScripts(methodLevelSqlAnnotations, testContext, executionPhase, false); 145 } 146 else { 147 executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true); 148 } 149 } 150 } 151 152 /** 153 * Determine if method-level {@code @Sql} annotations should be merged with 154 * class-level {@code @Sql} annotations. 155 */ 156 private boolean mergeSqlAnnotations(TestContext testContext) { 157 SqlMergeMode sqlMergeMode = getSqlMergeModeFor(testContext.getTestMethod()); 158 if (sqlMergeMode == null) { 159 sqlMergeMode = getSqlMergeModeFor(testContext.getTestClass()); 160 } 161 return (sqlMergeMode != null && sqlMergeMode.value() == MergeMode.MERGE); 162 } 163 164 /** 165 * Get the {@code @SqlMergeMode} annotation declared on the supplied {@code element}. 166 */ 167 @Nullable 168 private SqlMergeMode getSqlMergeModeFor(AnnotatedElement element) { 169 return AnnotatedElementUtils.findMergedAnnotation(element, SqlMergeMode.class); 170 } 171 172 /** 173 * Get the {@code @Sql} annotations declared on the supplied {@code element}. 174 */ 175 private Set<Sql> getSqlAnnotationsFor(AnnotatedElement element) { 176 return AnnotatedElementUtils.getMergedRepeatableAnnotations(element, Sql.class, SqlGroup.class); 177 } 178 179 /** 180 * Execute SQL scripts for the supplied {@link Sql @Sql} annotations. 181 */ 182 private void executeSqlScripts( 183 Set<Sql> sqlAnnotations, TestContext testContext, ExecutionPhase executionPhase, boolean classLevel) { 184 185 sqlAnnotations.forEach(sql -> executeSqlScripts(sql, executionPhase, testContext, classLevel)); 186 } 187 188 /** 189 * Execute the SQL scripts configured via the supplied {@link Sql @Sql} 190 * annotation for the given {@link ExecutionPhase} and {@link TestContext}. 191 * <p>Special care must be taken in order to properly support the configured 192 * {@link SqlConfig#transactionMode}. 193 * @param sql the {@code @Sql} annotation to parse 194 * @param executionPhase the current execution phase 195 * @param testContext the current {@code TestContext} 196 * @param classLevel {@code true} if {@link Sql @Sql} was declared at the class level 197 */ 198 private void executeSqlScripts( 199 Sql sql, ExecutionPhase executionPhase, TestContext testContext, boolean classLevel) { 200 201 if (executionPhase != sql.executionPhase()) { 202 return; 203 } 204 205 MergedSqlConfig mergedSqlConfig = new MergedSqlConfig(sql.config(), testContext.getTestClass()); 206 if (logger.isDebugEnabled()) { 207 logger.debug(String.format("Processing %s for execution phase [%s] and test context %s.", 208 mergedSqlConfig, executionPhase, testContext)); 209 } 210 211 String[] scripts = getScripts(sql, testContext, classLevel); 212 scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts); 213 List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList( 214 testContext.getApplicationContext(), scripts); 215 for (String stmt : sql.statements()) { 216 if (StringUtils.hasText(stmt)) { 217 stmt = stmt.trim(); 218 scriptResources.add(new ByteArrayResource(stmt.getBytes(), "from inlined SQL statement: " + stmt)); 219 } 220 } 221 222 ResourceDatabasePopulator populator = createDatabasePopulator(mergedSqlConfig); 223 populator.setScripts(scriptResources.toArray(new Resource[0])); 224 if (logger.isDebugEnabled()) { 225 logger.debug("Executing SQL scripts: " + ObjectUtils.nullSafeToString(scriptResources)); 226 } 227 228 String dsName = mergedSqlConfig.getDataSource(); 229 String tmName = mergedSqlConfig.getTransactionManager(); 230 DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, dsName); 231 PlatformTransactionManager txMgr = TestContextTransactionUtils.retrieveTransactionManager(testContext, tmName); 232 boolean newTxRequired = (mergedSqlConfig.getTransactionMode() == TransactionMode.ISOLATED); 233 234 if (txMgr == null) { 235 Assert.state(!newTxRequired, () -> String.format("Failed to execute SQL scripts for test context %s: " + 236 "cannot execute SQL scripts using Transaction Mode " + 237 "[%s] without a PlatformTransactionManager.", testContext, TransactionMode.ISOLATED)); 238 Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for test context %s: " + 239 "supply at least a DataSource or PlatformTransactionManager.", testContext)); 240 // Execute scripts directly against the DataSource 241 populator.execute(dataSource); 242 } 243 else { 244 DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager(txMgr); 245 // Ensure user configured an appropriate DataSource/TransactionManager pair. 246 if (dataSource != null && dataSourceFromTxMgr != null && !dataSource.equals(dataSourceFromTxMgr)) { 247 throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: " + 248 "the configured DataSource [%s] (named '%s') is not the one associated with " + 249 "transaction manager [%s] (named '%s').", testContext, dataSource.getClass().getName(), 250 dsName, txMgr.getClass().getName(), tmName)); 251 } 252 if (dataSource == null) { 253 dataSource = dataSourceFromTxMgr; 254 Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for " + 255 "test context %s: could not obtain DataSource from transaction manager [%s] (named '%s').", 256 testContext, txMgr.getClass().getName(), tmName)); 257 } 258 final DataSource finalDataSource = dataSource; 259 int propagation = (newTxRequired ? TransactionDefinition.PROPAGATION_REQUIRES_NEW : 260 TransactionDefinition.PROPAGATION_REQUIRED); 261 TransactionAttribute txAttr = TestContextTransactionUtils.createDelegatingTransactionAttribute( 262 testContext, new DefaultTransactionAttribute(propagation)); 263 new TransactionTemplate(txMgr, txAttr).executeWithoutResult(s -> populator.execute(finalDataSource)); 264 } 265 } 266 267 @NonNull 268 private ResourceDatabasePopulator createDatabasePopulator(MergedSqlConfig mergedSqlConfig) { 269 ResourceDatabasePopulator populator = new ResourceDatabasePopulator(); 270 populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding()); 271 populator.setSeparator(mergedSqlConfig.getSeparator()); 272 populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes()); 273 populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter()); 274 populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter()); 275 populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR); 276 populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS); 277 return populator; 278 } 279 280 @Nullable 281 private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) { 282 try { Method getDataSourceMethod = transactionManager.getClass().getMethod("getDataSource"); 284 Object obj = ReflectionUtils.invokeMethod(getDataSourceMethod, transactionManager); 285 if (obj instanceof DataSource) { 286 return (DataSource) obj; 287 } 288 } 289 catch (Exception ex) { 290 // ignore 291 } 292 return null; 293 } 294 295 private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) { 296 String[] scripts = sql.scripts(); 297 if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) { 298 scripts = new String[] {detectDefaultScript(testContext, classLevel)}; 299 } 300 return scripts; 301 } 302 303 /** 304 * Detect a default SQL script by implementing the algorithm defined in 305 * {@link Sql#scripts}. 306 */ 307 private String detectDefaultScript(TestContext testContext, boolean classLevel) { 308 Class<?> clazz = testContext.getTestClass(); 309 Method method = testContext.getTestMethod(); 310 String elementType = (classLevel ? "class" : "method"); 311 String elementName = (classLevel ? clazz.getName() : method.toString()); 312 313 String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName()); 314 if (!classLevel) { 315 resourcePath += "." + method.getName(); 316 } 317 resourcePath += ".sql"; 318 319 String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + resourcePath; 320 ClassPathResource classPathResource = new ClassPathResource(resourcePath); 321 322 if (classPathResource.exists()) { 323 if (logger.isInfoEnabled()) { 324 logger.info(String.format("Detected default SQL script \"%s\" for test %s [%s]", 325 prefixedResourcePath, elementType, elementName)); 326 } 327 return prefixedResourcePath; 328 } 329 else { 330 String msg = String.format("Could not detect default SQL script for test %s [%s]: " + 331 "%s does not exist. Either declare statements or scripts via @Sql or make the " + 332 "default SQL script available.", elementType, elementName, classPathResource); 333 logger.error(msg); 334 throw new IllegalStateException(msg); 335 } 336 } 337 338}