001/*
002 * Copyright 2006-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.batch.item.database;
018
019import java.sql.ResultSet;
020import java.sql.SQLException;
021import java.util.ArrayList;
022import java.util.Collection;
023import java.util.LinkedHashMap;
024import java.util.List;
025import java.util.Map;
026import java.util.SortedMap;
027import java.util.TreeMap;
028import java.util.concurrent.CopyOnWriteArrayList;
029
030import javax.sql.DataSource;
031
032import org.springframework.batch.item.ExecutionContext;
033import org.springframework.batch.item.ItemStreamException;
034import org.springframework.beans.factory.InitializingBean;
035import org.springframework.jdbc.core.JdbcTemplate;
036import org.springframework.jdbc.core.RowMapper;
037import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
038import org.springframework.util.Assert;
039import org.springframework.util.ClassUtils;
040
041/**
042 * <p>
043 * {@link org.springframework.batch.item.ItemReader} for reading database
044 * records using JDBC in a paging fashion.
045 * </p>
046 * 
047 * <p>
048 * It executes the SQL built by the {@link PagingQueryProvider} to retrieve
049 * requested data. The query is executed using paged requests of a size
050 * specified in {@link #setPageSize(int)}. Additional pages are requested when
051 * needed as {@link #read()} method is called, returning an object corresponding
052 * to current position. On restart it uses the last sort key value to locate the
053 * first page to read (so it doesn't matter if the successfully processed items
054 * have been removed or modified). It is important to have a unique key constraint
055 * on the sort key to guarantee that no data is lost between executions.
056 * </p>
057 * 
058 * <p>
059 * The performance of the paging depends on the database specific features
060 * available to limit the number of returned rows. Setting a fairly large page
061 * size and using a commit interval that matches the page size should provide
062 * better performance.
063 * </p>
064 * 
065 * <p>
066 * The implementation is thread-safe in between calls to
067 * {@link #open(ExecutionContext)}, but remember to use
068 * <code>saveState=false</code> if used in a multi-threaded client (no restart
069 * available).
070 * </p>
071 * 
072 * @author Thomas Risberg
073 * @author Dave Syer
074 * @author Michael Minella
075 * @author Mahmoud Ben Hassine
076 * @since 2.0
077 */
078public class JdbcPagingItemReader<T> extends AbstractPagingItemReader<T> implements InitializingBean {
079        private static final String START_AFTER_VALUE = "start.after";
080
081        public static final int VALUE_NOT_SET = -1;
082
083        private DataSource dataSource;
084
085        private PagingQueryProvider queryProvider;
086
087        private Map<String, Object> parameterValues;
088
089        private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
090
091        private RowMapper<T> rowMapper;
092
093        private String firstPageSql;
094
095        private String remainingPagesSql;
096
097        private Map<String, Object> startAfterValues;
098        
099        private Map<String, Object> previousStartAfterValues;
100
101        private int fetchSize = VALUE_NOT_SET;
102
103        public JdbcPagingItemReader() {
104                setName(ClassUtils.getShortName(JdbcPagingItemReader.class));
105        }
106
107        public void setDataSource(DataSource dataSource) {
108                this.dataSource = dataSource;
109        }
110
111        /**
112         * Gives the JDBC driver a hint as to the number of rows that should be
113         * fetched from the database when more rows are needed for this
114         * <code>ResultSet</code> object. If the fetch size specified is zero, the
115         * JDBC driver ignores the value.
116         * 
117         * @param fetchSize the number of rows to fetch
118         * @see ResultSet#setFetchSize(int)
119         */
120        public void setFetchSize(int fetchSize) {
121                this.fetchSize = fetchSize;
122        }
123
124        /**
125         * A {@link PagingQueryProvider}. Supplies all the platform dependent query
126         * generation capabilities needed by the reader.
127         * 
128         * @param queryProvider the {@link PagingQueryProvider} to use
129         */
130        public void setQueryProvider(PagingQueryProvider queryProvider) {
131                this.queryProvider = queryProvider;
132        }
133
134        /**
135         * The row mapper implementation to be used by this reader. The row mapper
136         * is used to convert result set rows into objects, which are then returned
137         * by the reader.
138         * 
139         * @param rowMapper a
140         * {@link RowMapper}
141         * implementation
142         */
143        public void setRowMapper(RowMapper<T> rowMapper) {
144                this.rowMapper = rowMapper;
145        }
146
147        /**
148         * The parameter values to be used for the query execution. If you use named
149         * parameters then the key should be the name used in the query clause. If
150         * you use "?" placeholders then the key should be the relative index that
151         * the parameter appears in the query string built using the select, from
152         * and where clauses specified.
153         * 
154         * @param parameterValues the values keyed by the parameter named/index used
155         * in the query string.
156         */
157        public void setParameterValues(Map<String, Object> parameterValues) {
158                this.parameterValues = parameterValues;
159        }
160
161        /**
162         * Check mandatory properties.
163         * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
164         */
165        @Override
166        public void afterPropertiesSet() throws Exception {
167                super.afterPropertiesSet();
168                Assert.notNull(dataSource, "DataSource may not be null");
169                JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
170                if (fetchSize != VALUE_NOT_SET) {
171                        jdbcTemplate.setFetchSize(fetchSize);
172                }
173                jdbcTemplate.setMaxRows(getPageSize());
174                namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate);
175                Assert.notNull(queryProvider, "QueryProvider may not be null");
176                queryProvider.init(dataSource);
177                this.firstPageSql = queryProvider.generateFirstPageQuery(getPageSize());
178                this.remainingPagesSql = queryProvider.generateRemainingPagesQuery(getPageSize());
179        }
180
181        @Override
182        @SuppressWarnings("unchecked")
183        protected void doReadPage() {
184                if (results == null) {
185                        results = new CopyOnWriteArrayList<>();
186                }
187                else {
188                        results.clear();
189                }
190
191                PagingRowMapper rowCallback = new PagingRowMapper();
192
193                List<?> query;
194
195                if (getPage() == 0) {
196                        if (logger.isDebugEnabled()) {
197                                logger.debug("SQL used for reading first page: [" + firstPageSql + "]");
198                        }
199                        if (parameterValues != null && parameterValues.size() > 0) {
200                                if (this.queryProvider.isUsingNamedParameters()) {
201                                        query = namedParameterJdbcTemplate.query(firstPageSql,
202                                                        getParameterMap(parameterValues, null), rowCallback);
203                                }
204                                else {
205                                        query = getJdbcTemplate().query(firstPageSql,
206                                                        getParameterList(parameterValues, null).toArray(), rowCallback);
207                                }
208                        }
209                        else {
210                                query = getJdbcTemplate().query(firstPageSql, rowCallback);
211                        }
212
213                }
214                else {
215                        previousStartAfterValues = startAfterValues;
216                        if (logger.isDebugEnabled()) {
217                                logger.debug("SQL used for reading remaining pages: [" + remainingPagesSql + "]");
218                        }
219                        if (this.queryProvider.isUsingNamedParameters()) {
220                                query = namedParameterJdbcTemplate.query(remainingPagesSql,
221                                                getParameterMap(parameterValues, startAfterValues), rowCallback);
222                        }
223                        else {
224                                query = getJdbcTemplate().query(remainingPagesSql,
225                                                getParameterList(parameterValues, startAfterValues).toArray(), rowCallback);
226                        }
227                }
228
229                Collection<T> result = (Collection<T>) query;
230                results.addAll(result);
231        }
232
233        @Override
234        public void update(ExecutionContext executionContext) throws ItemStreamException {
235                super.update(executionContext);
236                if (isSaveState()) {
237                        if (isAtEndOfPage() && startAfterValues != null) {
238                                // restart on next page
239                                executionContext.put(getExecutionContextKey(START_AFTER_VALUE), startAfterValues);      
240                        } else if (previousStartAfterValues != null) {
241                                // restart on current page
242                                executionContext.put(getExecutionContextKey(START_AFTER_VALUE), previousStartAfterValues);
243                        }
244                }
245        }
246        
247        private boolean isAtEndOfPage() {
248                return getCurrentItemCount() % getPageSize() == 0;
249        }
250
251        @Override
252        @SuppressWarnings("unchecked")
253        public void open(ExecutionContext executionContext) {
254                if (isSaveState()) {
255                        startAfterValues = (Map<String, Object>) executionContext.get(getExecutionContextKey(START_AFTER_VALUE));
256
257                        if(startAfterValues == null) {
258                                startAfterValues = new LinkedHashMap<>();
259                        }
260                }
261
262                super.open(executionContext);
263        }
264
265        @Override
266        protected void doJumpToPage(int itemIndex) {
267                /*
268                 * Normally this would be false (the startAfterValue is enough
269                 * information to restart from.
270                 */
271                // TODO: this is dead code, startAfterValues is never null - see #open(ExecutionContext)
272                if (startAfterValues == null && getPage() > 0) {
273
274                        String jumpToItemSql = queryProvider.generateJumpToItemQuery(itemIndex, getPageSize());
275
276                        if (logger.isDebugEnabled()) {
277                                logger.debug("SQL used for jumping: [" + jumpToItemSql + "]");
278                        }
279                        
280                        if (this.queryProvider.isUsingNamedParameters()) {
281                                startAfterValues = namedParameterJdbcTemplate.queryForMap(jumpToItemSql, getParameterMap(parameterValues, null));
282                        }
283                        else {
284                                startAfterValues = getJdbcTemplate().queryForMap(jumpToItemSql, getParameterList(parameterValues, null).toArray());
285                        }
286                }
287        }
288
289        private Map<String, Object> getParameterMap(Map<String, Object> values, Map<String, Object> sortKeyValues) {
290                Map<String, Object> parameterMap = new LinkedHashMap<>();
291                if (values != null) {
292                        parameterMap.putAll(values);
293                }
294                if (sortKeyValues != null && !sortKeyValues.isEmpty()) {
295                        for (Map.Entry<String, Object> sortKey : sortKeyValues.entrySet()) {
296                                parameterMap.put("_" + sortKey.getKey(), sortKey.getValue());
297                        }
298                }
299                if (logger.isDebugEnabled()) {
300                        logger.debug("Using parameterMap:" + parameterMap);
301                }
302                return parameterMap;
303        }
304
305        private List<Object> getParameterList(Map<String, Object> values, Map<String, Object> sortKeyValue) {
306                SortedMap<String, Object> sm = new TreeMap<>();
307                if (values != null) {
308                        sm.putAll(values);
309                }
310                List<Object> parameterList = new ArrayList<>();
311                parameterList.addAll(sm.values());
312                if (sortKeyValue != null && sortKeyValue.size() > 0) {
313                        List<Map.Entry<String, Object>> keys = new ArrayList<>(sortKeyValue.entrySet());
314
315                        for(int i = 0; i < keys.size(); i++) {
316                                for(int j = 0; j < i; j++) {
317                                        parameterList.add(keys.get(j).getValue());
318                                }
319
320                                parameterList.add(keys.get(i).getValue());
321                        }
322                }
323
324                if (logger.isDebugEnabled()) {
325                        logger.debug("Using parameterList:" + parameterList);
326                }
327                return parameterList;
328        }
329
330        private class PagingRowMapper implements RowMapper<T> {
331                @Override
332                public T mapRow(ResultSet rs, int rowNum) throws SQLException {
333                        startAfterValues = new LinkedHashMap<>();
334                        for (Map.Entry<String, Order> sortKey : queryProvider.getSortKeys().entrySet()) {
335                                startAfterValues.put(sortKey.getKey(), rs.getObject(sortKey.getKey()));
336                        }
337
338                        return rowMapper.mapRow(rs, rowNum);
339                }
340        }
341
342        private JdbcTemplate getJdbcTemplate() {
343                return (JdbcTemplate) namedParameterJdbcTemplate.getJdbcOperations();
344        }
345}