001package org.springframework.batch.integration.partition;
002
003import java.util.ArrayList;
004import java.util.Collection;
005import java.util.Iterator;
006import java.util.List;
007import java.util.Set;
008import java.util.concurrent.Callable;
009import java.util.concurrent.Future;
010import java.util.concurrent.TimeUnit;
011
012import javax.sql.DataSource;
013
014import org.apache.commons.logging.Log;
015import org.apache.commons.logging.LogFactory;
016
017import org.springframework.batch.core.Step;
018import org.springframework.batch.core.StepExecution;
019import org.springframework.batch.core.explore.JobExplorer;
020import org.springframework.batch.core.explore.support.JobExplorerFactoryBean;
021import org.springframework.batch.core.partition.PartitionHandler;
022import org.springframework.batch.core.partition.StepExecutionSplitter;
023import org.springframework.batch.core.repository.JobRepository;
024import org.springframework.batch.poller.DirectPoller;
025import org.springframework.batch.poller.Poller;
026import org.springframework.beans.factory.InitializingBean;
027import org.springframework.integration.MessageTimeoutException;
028import org.springframework.integration.annotation.Aggregator;
029import org.springframework.integration.annotation.MessageEndpoint;
030import org.springframework.integration.annotation.Payloads;
031import org.springframework.integration.channel.QueueChannel;
032import org.springframework.integration.core.MessagingTemplate;
033import org.springframework.integration.support.MessageBuilder;
034import org.springframework.messaging.Message;
035import org.springframework.messaging.MessageChannel;
036import org.springframework.messaging.PollableChannel;
037import org.springframework.util.Assert;
038import org.springframework.util.CollectionUtils;
039
040/**
041 * A {@link PartitionHandler} that uses {@link MessageChannel} instances to send instructions to remote workers and
042 * receive their responses. The {@link MessageChannel} provides a nice abstraction so that the location of the workers
043 * and the transport used to communicate with them can be changed at run time. The communication with the remote workers
044 * does not need to be transactional or have guaranteed delivery, so a local thread pool based implementation works as
045 * well as a remote web service or JMS implementation. If a remote worker fails, the job will fail and can be restarted
046 * to pick up missing messages and processing. The remote workers need access to the Spring Batch {@link JobRepository}
047 * so that the shared state across those restarts can be managed centrally.
048 *
049 * While a {@link org.springframework.messaging.MessageChannel} is used for sending the requests to the workers, the
050 * worker's responses can be obtained in one of two ways:
051 * <ul>
052 *     <li>A reply channel - Slaves will respond with messages that will be aggregated via this component.</li>
053 *     <li>Polling the job repository - Since the state of each slave is maintained independently within the job
054 *     repository, we can poll the store to determine the state without the need of the slaves to formally respond.</li>
055 * </ul>
056 *
057 * Note: The reply channel for this is instance based.  Sharing this component across
058 * multiple step instances may result in the crossing of messages.  It's recommended that
059 * this component be step or job scoped.
060 *
061 * @author Dave Syer
062 * @author Will Schipp
063 * @author Michael Minella
064 * @author Mahmoud Ben Hassine
065 *
066 */
067@MessageEndpoint
068public class MessageChannelPartitionHandler implements PartitionHandler, InitializingBean {
069
070        private static Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class);
071
072        private int gridSize = 1;
073
074        private MessagingTemplate messagingGateway;
075
076        private String stepName;
077
078        private long pollInterval = 10000;
079
080        private JobExplorer jobExplorer;
081
082        private boolean pollRepositoryForResults = false;
083
084        private long timeout = -1;
085
086        private DataSource dataSource;
087
088        /**
089         * pollable channel for the replies
090         */
091        private PollableChannel replyChannel;
092
093        @Override
094        public void afterPropertiesSet() throws Exception {
095                Assert.notNull(stepName, "A step name must be provided for the remote workers.");
096                Assert.state(messagingGateway != null, "The MessagingOperations must be set");
097
098                pollRepositoryForResults = !(dataSource == null && jobExplorer == null);
099
100                if(pollRepositoryForResults) {
101                        logger.debug("MessageChannelPartitionHandler is configured to poll the job repository for slave results");
102                }
103
104                if(dataSource != null && jobExplorer == null) {
105                        JobExplorerFactoryBean jobExplorerFactoryBean = new JobExplorerFactoryBean();
106                        jobExplorerFactoryBean.setDataSource(dataSource);
107                        jobExplorerFactoryBean.afterPropertiesSet();
108                        jobExplorer = jobExplorerFactoryBean.getObject();
109                }
110
111                if (!pollRepositoryForResults && replyChannel == null) {
112                        replyChannel = new QueueChannel();
113                }//end if
114
115        }
116
117        /**
118         * When using job repository polling, the time limit to wait.
119         *
120         * @param timeout milliseconds to wait, defaults to -1 (no timeout).
121         */
122        public void setTimeout(long timeout) {
123                this.timeout = timeout;
124        }
125
126        /**
127         * {@link org.springframework.batch.core.explore.JobExplorer} to use to query the job repository.  Either this or
128         * a {@link javax.sql.DataSource} is required when using job repository polling.
129         *
130         * @param jobExplorer {@link org.springframework.batch.core.explore.JobExplorer} to use for lookups
131         */
132        public void setJobExplorer(JobExplorer jobExplorer) {
133                this.jobExplorer = jobExplorer;
134        }
135
136        /**
137         * How often to poll the job repository for the status of the slaves.
138         *
139         * @param pollInterval milliseconds between polls, defaults to 10000 (10 seconds).
140         */
141        public void setPollInterval(long pollInterval) {
142                this.pollInterval = pollInterval;
143        }
144
145        /**
146         * {@link javax.sql.DataSource} pointing to the job repository
147         *
148         * @param dataSource {@link javax.sql.DataSource} that points to the job repository's store
149         */
150        public void setDataSource(DataSource dataSource) {
151                this.dataSource = dataSource;
152        }
153
154        /**
155         * A pre-configured gateway for sending and receiving messages to the remote workers. Using this property allows a
156         * large degree of control over the timeouts and other properties of the send. It should have channels set up
157         * internally: <ul> <li>request channel capable of accepting {@link StepExecutionRequest} payloads</li> <li>reply
158         * channel that returns a list of {@link StepExecution} results</li> </ul> The timeout for the reply should be set
159         * sufficiently long that the remote steps have time to complete.
160         *
161         * @param messagingGateway the {@link org.springframework.integration.core.MessagingTemplate} to set
162         */
163        public void setMessagingOperations(MessagingTemplate messagingGateway) {
164                this.messagingGateway = messagingGateway;
165        }
166
167        /**
168         * Passed to the {@link StepExecutionSplitter} in the {@link #handle(StepExecutionSplitter, StepExecution)} method,
169         * instructing it how many {@link StepExecution} instances are required, ideally. The {@link StepExecutionSplitter}
170         * is allowed to ignore the grid size in the case of a restart, since the input data partitions must be preserved.
171         *
172         * @param gridSize the number of step executions that will be created
173         */
174        public void setGridSize(int gridSize) {
175                this.gridSize = gridSize;
176        }
177
178        /**
179         * The name of the {@link Step} that will be used to execute the partitioned {@link StepExecution}. This is a
180         * regular Spring Batch step, with all the business logic required to complete an execution based on the input
181         * parameters in its {@link StepExecution} context. The name will be translated into a {@link Step} instance by the
182         * remote worker.
183         *
184         * @param stepName the name of the {@link Step} instance to execute business logic
185         */
186        public void setStepName(String stepName) {
187                this.stepName = stepName;
188        }
189
190        /**
191         * @param messages the messages to be aggregated
192         * @return the list as it was passed in
193         */
194        @Aggregator(sendPartialResultsOnExpiry = "true")
195        public List<?> aggregate(@Payloads List<?> messages) {
196                return messages;
197        }
198
199        public void setReplyChannel(PollableChannel replyChannel) {
200                this.replyChannel = replyChannel;
201        }
202
203        /**
204         * Sends {@link StepExecutionRequest} objects to the request channel of the {@link MessagingTemplate}, and then
205         * receives the result back as a list of {@link StepExecution} on a reply channel. Use the {@link #aggregate(List)}
206         * method as an aggregator of the individual remote replies. The receive timeout needs to be set realistically in
207         * the {@link MessagingTemplate} <b>and</b> the aggregator, so that there is a good chance of all work being done.
208         *
209         * @see PartitionHandler#handle(StepExecutionSplitter, StepExecution)
210         */
211        public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplitter,
212                        final StepExecution masterStepExecution) throws Exception {
213
214                final Set<StepExecution> split = stepExecutionSplitter.split(masterStepExecution, gridSize);
215
216                if(CollectionUtils.isEmpty(split)) {
217                        return split;
218                }
219
220                int count = 0;
221
222                for (StepExecution stepExecution : split) {
223                        Message<StepExecutionRequest> request = createMessage(count++, split.size(), new StepExecutionRequest(
224                                        stepName, stepExecution.getJobExecutionId(), stepExecution.getId()), replyChannel);
225                        if (logger.isDebugEnabled()) {
226                                logger.debug("Sending request: " + request);
227                        }
228                        messagingGateway.send(request);
229                }
230
231                if(!pollRepositoryForResults) {
232                        return receiveReplies(replyChannel);
233                }
234                else {
235                        return pollReplies(masterStepExecution, split);
236                }
237        }
238
239        private Collection<StepExecution> pollReplies(final StepExecution masterStepExecution, final Set<StepExecution> split) throws Exception {
240                final Collection<StepExecution> result = new ArrayList<StepExecution>(split.size());
241
242                Callable<Collection<StepExecution>> callback = new Callable<Collection<StepExecution>>() {
243                        @Override
244                        public Collection<StepExecution> call() throws Exception {
245
246                                for(Iterator<StepExecution> stepExecutionIterator = split.iterator(); stepExecutionIterator.hasNext(); ) {
247                                        StepExecution curStepExecution = stepExecutionIterator.next();
248
249                                        if(!result.contains(curStepExecution)) {
250                                                StepExecution partitionStepExecution =
251                                                                jobExplorer.getStepExecution(masterStepExecution.getJobExecutionId(), curStepExecution.getId());
252
253                                                if(!partitionStepExecution.getStatus().isRunning()) {
254                                                        result.add(partitionStepExecution);
255                                                }
256                                        }
257                                }
258
259                                if(logger.isDebugEnabled()) {
260                                        logger.debug(String.format("Currently waiting on %s partitions to finish", split.size()));
261                                }
262
263                                if(result.size() == split.size()) {
264                                        return result;
265                                }
266                                else {
267                                        return null;
268                                }
269                        }
270                };
271
272                Poller<Collection<StepExecution>> poller = new DirectPoller<Collection<StepExecution>>(pollInterval);
273                Future<Collection<StepExecution>> resultsFuture = poller.poll(callback);
274
275                if(timeout >= 0) {
276                        return resultsFuture.get(timeout, TimeUnit.MILLISECONDS);
277                }
278                else {
279                        return resultsFuture.get();
280                }
281        }
282
283        private Collection<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
284                @SuppressWarnings("unchecked")
285                Message<Collection<StepExecution>> message = (Message<Collection<StepExecution>>) messagingGateway.receive(currentReplyChannel);
286
287                if(message == null) {
288                        throw new MessageTimeoutException("Timeout occurred before all partitions returned");
289                } else if (logger.isDebugEnabled()) {
290                        logger.debug("Received replies: " + message);
291                }
292
293                return message.getPayload();
294        }
295
296        private Message<StepExecutionRequest> createMessage(int sequenceNumber, int sequenceSize,
297                        StepExecutionRequest stepExecutionRequest, PollableChannel replyChannel) {
298                return MessageBuilder.withPayload(stepExecutionRequest).setSequenceNumber(sequenceNumber)
299                                .setSequenceSize(sequenceSize)
300                                .setCorrelationId(stepExecutionRequest.getJobExecutionId() + ":" + stepExecutionRequest.getStepName())
301                                .setReplyChannel(replyChannel)
302                                .build();
303        }
304}