001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.jdbc;
018
019import javax.annotation.PostConstruct;
020import javax.sql.DataSource;
021
022import org.springframework.core.io.ResourceLoader;
023import org.springframework.jdbc.datasource.init.DatabasePopulatorUtils;
024import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
025import org.springframework.jdbc.support.JdbcUtils;
026import org.springframework.jdbc.support.MetaDataAccessException;
027import org.springframework.util.Assert;
028
029/**
030 * Base class used for {@link DataSource} initialization.
031 *
032 * @author Vedran Pavic
033 * @author Stephane Nicoll
034 * @since 1.5.0
035 */
036public abstract class AbstractDataSourceInitializer {
037
038        private static final String PLATFORM_PLACEHOLDER = "@@platform@@";
039
040        private final DataSource dataSource;
041
042        private final ResourceLoader resourceLoader;
043
044        protected AbstractDataSourceInitializer(DataSource dataSource,
045                        ResourceLoader resourceLoader) {
046                Assert.notNull(dataSource, "DataSource must not be null");
047                Assert.notNull(resourceLoader, "ResourceLoader must not be null");
048                this.dataSource = dataSource;
049                this.resourceLoader = resourceLoader;
050        }
051
052        @PostConstruct
053        protected void initialize() {
054                if (!isEnabled()) {
055                        return;
056                }
057                ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
058                String schemaLocation = getSchemaLocation();
059                if (schemaLocation.contains(PLATFORM_PLACEHOLDER)) {
060                        String platform = getDatabaseName();
061                        schemaLocation = schemaLocation.replace(PLATFORM_PLACEHOLDER, platform);
062                }
063                populator.addScript(this.resourceLoader.getResource(schemaLocation));
064                populator.setContinueOnError(true);
065                customize(populator);
066                DatabasePopulatorUtils.execute(populator, this.dataSource);
067        }
068
069        private boolean isEnabled() {
070                if (getMode() == DataSourceInitializationMode.NEVER) {
071                        return false;
072                }
073                if (getMode() == DataSourceInitializationMode.EMBEDDED
074                                && !EmbeddedDatabaseConnection.isEmbedded(this.dataSource)) {
075                        return false;
076                }
077                return true;
078        }
079
080        /**
081         * Customize the {@link ResourceDatabasePopulator}.
082         * @param populator the configured database populator
083         */
084        protected void customize(ResourceDatabasePopulator populator) {
085        }
086
087        protected abstract DataSourceInitializationMode getMode();
088
089        protected abstract String getSchemaLocation();
090
091        protected String getDatabaseName() {
092                try {
093                        String productName = JdbcUtils.commonDatabaseName(JdbcUtils
094                                        .extractDatabaseMetaData(this.dataSource, "getDatabaseProductName")
095                                        .toString());
096                        DatabaseDriver databaseDriver = DatabaseDriver.fromProductName(productName);
097                        if (databaseDriver == DatabaseDriver.UNKNOWN) {
098                                throw new IllegalStateException("Unable to detect database type");
099                        }
100                        return databaseDriver.getId();
101                }
102                catch (MetaDataAccessException ex) {
103                        throw new IllegalStateException("Unable to detect database type", ex);
104                }
105        }
106
107}