001/*
002 * Copyright 2006-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.springframework.batch.item.database.support;
017
018import java.util.ArrayList;
019import java.util.List;
020import javax.sql.DataSource;
021
022import org.springframework.batch.support.DatabaseType;
023import org.springframework.jdbc.support.incrementer.DB2MainframeSequenceMaxValueIncrementer;
024import org.springframework.jdbc.support.incrementer.DB2SequenceMaxValueIncrementer;
025import org.springframework.jdbc.support.incrementer.DataFieldMaxValueIncrementer;
026import org.springframework.jdbc.support.incrementer.DerbyMaxValueIncrementer;
027import org.springframework.jdbc.support.incrementer.H2SequenceMaxValueIncrementer;
028import org.springframework.jdbc.support.incrementer.HsqlMaxValueIncrementer;
029import org.springframework.jdbc.support.incrementer.MySQLMaxValueIncrementer;
030import org.springframework.jdbc.support.incrementer.OracleSequenceMaxValueIncrementer;
031import org.springframework.jdbc.support.incrementer.PostgreSQLSequenceMaxValueIncrementer;
032import org.springframework.jdbc.support.incrementer.SqlServerMaxValueIncrementer;
033import org.springframework.jdbc.support.incrementer.SybaseMaxValueIncrementer;
034
035import static org.springframework.batch.support.DatabaseType.DB2;
036import static org.springframework.batch.support.DatabaseType.DB2AS400;
037import static org.springframework.batch.support.DatabaseType.DB2ZOS;
038import static org.springframework.batch.support.DatabaseType.DERBY;
039import static org.springframework.batch.support.DatabaseType.H2;
040import static org.springframework.batch.support.DatabaseType.HSQL;
041import static org.springframework.batch.support.DatabaseType.MYSQL;
042import static org.springframework.batch.support.DatabaseType.ORACLE;
043import static org.springframework.batch.support.DatabaseType.POSTGRES;
044import static org.springframework.batch.support.DatabaseType.SQLITE;
045import static org.springframework.batch.support.DatabaseType.SQLSERVER;
046import static org.springframework.batch.support.DatabaseType.SYBASE;
047
048/**
049 * Default implementation of the {@link DataFieldMaxValueIncrementerFactory}
050 * interface. Valid database types are given by the {@link DatabaseType} enum.
051 *
052 * Note: For MySql databases, the
053 * {@link MySQLMaxValueIncrementer#setUseNewConnection(boolean)} will be set to true.
054 * 
055 * @author Lucas Ward
056 * @author Michael Minella
057 * @see DatabaseType
058 */
059public class DefaultDataFieldMaxValueIncrementerFactory implements DataFieldMaxValueIncrementerFactory {
060
061        private DataSource dataSource;
062
063        private String incrementerColumnName = "ID";
064
065        /**
066-        * Public setter for the column name (defaults to "ID") in the incrementer.
067-        * Only used by some platforms (Derby, HSQL, MySQL, SQL Server and Sybase),
068-        * and should be fine for use with Spring Batch meta data as long as the
069         * default batch schema hasn't been changed.
070         * 
071         * @param incrementerColumnName the primary key column name to set
072         */
073        public void setIncrementerColumnName(String incrementerColumnName) {
074                this.incrementerColumnName = incrementerColumnName;
075        }
076
077        public DefaultDataFieldMaxValueIncrementerFactory(DataSource dataSource) {
078                this.dataSource = dataSource;
079        }
080
081        @Override
082        public DataFieldMaxValueIncrementer getIncrementer(String incrementerType, String incrementerName) {
083                DatabaseType databaseType = DatabaseType.valueOf(incrementerType.toUpperCase());
084
085                if (databaseType == DB2 || databaseType == DB2AS400) {
086                        return new DB2SequenceMaxValueIncrementer(dataSource, incrementerName);
087                }
088                else if (databaseType == DB2ZOS) {
089                        return new DB2MainframeSequenceMaxValueIncrementer(dataSource, incrementerName);
090                }
091                else if (databaseType == DERBY) {
092                        return new DerbyMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
093                }
094                else if (databaseType == HSQL) {
095                        return new HsqlMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
096                }
097                else if (databaseType == H2) {
098                        return new H2SequenceMaxValueIncrementer(dataSource, incrementerName);
099                }
100                else if (databaseType == MYSQL) {
101                        MySQLMaxValueIncrementer mySQLMaxValueIncrementer = new MySQLMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
102                        mySQLMaxValueIncrementer.setUseNewConnection(true);
103                        return mySQLMaxValueIncrementer;
104                }
105                else if (databaseType == ORACLE) {
106                        return new OracleSequenceMaxValueIncrementer(dataSource, incrementerName);
107                }
108                else if (databaseType == POSTGRES) {
109                        return new PostgreSQLSequenceMaxValueIncrementer(dataSource, incrementerName);
110                }
111                else if (databaseType == SQLITE) {
112                        return new SqliteMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
113                }
114                else if (databaseType == SQLSERVER) {
115                        return new SqlServerMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
116                }
117                else if (databaseType == SYBASE) {
118                        return new SybaseMaxValueIncrementer(dataSource, incrementerName, incrementerColumnName);
119                }
120                throw new IllegalArgumentException("databaseType argument was not on the approved list");
121        }
122        
123    @Override
124        public boolean isSupportedIncrementerType(String incrementerType) {
125                for (DatabaseType type : DatabaseType.values()) {
126                        if (type.name().equals(incrementerType.toUpperCase())) {
127                                return true;
128                        }
129                }
130
131                return false;
132        }
133
134    @Override
135        public String[] getSupportedIncrementerTypes() {
136
137                List<String> types = new ArrayList<>();
138
139                for (DatabaseType type : DatabaseType.values()) {
140                        types.add(type.name());
141                }
142
143                return types.toArray(new String[types.size()]);
144        }
145}