001package org.springframework.batch.integration.partition; 002 003import java.util.ArrayList; 004import java.util.Collection; 005import java.util.Iterator; 006import java.util.List; 007import java.util.Set; 008import java.util.concurrent.Callable; 009import java.util.concurrent.Future; 010import java.util.concurrent.TimeUnit; 011 012import javax.sql.DataSource; 013 014import org.apache.commons.logging.Log; 015import org.apache.commons.logging.LogFactory; 016 017import org.springframework.batch.core.Step; 018import org.springframework.batch.core.StepExecution; 019import org.springframework.batch.core.explore.JobExplorer; 020import org.springframework.batch.core.explore.support.JobExplorerFactoryBean; 021import org.springframework.batch.core.partition.PartitionHandler; 022import org.springframework.batch.core.partition.StepExecutionSplitter; 023import org.springframework.batch.core.repository.JobRepository; 024import org.springframework.batch.poller.DirectPoller; 025import org.springframework.batch.poller.Poller; 026import org.springframework.beans.factory.InitializingBean; 027import org.springframework.integration.MessageTimeoutException; 028import org.springframework.integration.annotation.Aggregator; 029import org.springframework.integration.annotation.MessageEndpoint; 030import org.springframework.integration.annotation.Payloads; 031import org.springframework.integration.channel.QueueChannel; 032import org.springframework.integration.core.MessagingTemplate; 033import org.springframework.integration.support.MessageBuilder; 034import org.springframework.messaging.Message; 035import org.springframework.messaging.MessageChannel; 036import org.springframework.messaging.PollableChannel; 037import org.springframework.util.Assert; 038import org.springframework.util.CollectionUtils; 039 040/** 041 * A {@link PartitionHandler} that uses {@link MessageChannel} instances to send instructions to remote workers and 042 * receive their responses. The {@link MessageChannel} provides a nice abstraction so that the location of the workers 043 * and the transport used to communicate with them can be changed at run time. The communication with the remote workers 044 * does not need to be transactional or have guaranteed delivery, so a local thread pool based implementation works as 045 * well as a remote web service or JMS implementation. If a remote worker fails, the job will fail and can be restarted 046 * to pick up missing messages and processing. The remote workers need access to the Spring Batch {@link JobRepository} 047 * so that the shared state across those restarts can be managed centrally. 048 * 049 * While a {@link org.springframework.messaging.MessageChannel} is used for sending the requests to the workers, the 050 * worker's responses can be obtained in one of two ways: 051 * <ul> 052 * <li>A reply channel - Slaves will respond with messages that will be aggregated via this component.</li> 053 * <li>Polling the job repository - Since the state of each slave is maintained independently within the job 054 * repository, we can poll the store to determine the state without the need of the slaves to formally respond.</li> 055 * </ul> 056 * 057 * Note: The reply channel for this is instance based. Sharing this component across 058 * multiple step instances may result in the crossing of messages. It's recommended that 059 * this component be step or job scoped. 060 * 061 * @author Dave Syer 062 * @author Will Schipp 063 * @author Michael Minella 064 * @author Mahmoud Ben Hassine 065 * 066 */ 067@MessageEndpoint 068public class MessageChannelPartitionHandler implements PartitionHandler, InitializingBean { 069 070 private static Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class); 071 072 private int gridSize = 1; 073 074 private MessagingTemplate messagingGateway; 075 076 private String stepName; 077 078 private long pollInterval = 10000; 079 080 private JobExplorer jobExplorer; 081 082 private boolean pollRepositoryForResults = false; 083 084 private long timeout = -1; 085 086 private DataSource dataSource; 087 088 /** 089 * pollable channel for the replies 090 */ 091 private PollableChannel replyChannel; 092 093 @Override 094 public void afterPropertiesSet() throws Exception { 095 Assert.notNull(stepName, "A step name must be provided for the remote workers."); 096 Assert.state(messagingGateway != null, "The MessagingOperations must be set"); 097 098 pollRepositoryForResults = !(dataSource == null && jobExplorer == null); 099 100 if(pollRepositoryForResults) { 101 logger.debug("MessageChannelPartitionHandler is configured to poll the job repository for slave results"); 102 } 103 104 if(dataSource != null && jobExplorer == null) { 105 JobExplorerFactoryBean jobExplorerFactoryBean = new JobExplorerFactoryBean(); 106 jobExplorerFactoryBean.setDataSource(dataSource); 107 jobExplorerFactoryBean.afterPropertiesSet(); 108 jobExplorer = jobExplorerFactoryBean.getObject(); 109 } 110 111 if (!pollRepositoryForResults && replyChannel == null) { 112 replyChannel = new QueueChannel(); 113 }//end if 114 115 } 116 117 /** 118 * When using job repository polling, the time limit to wait. 119 * 120 * @param timeout milliseconds to wait, defaults to -1 (no timeout). 121 */ 122 public void setTimeout(long timeout) { 123 this.timeout = timeout; 124 } 125 126 /** 127 * {@link org.springframework.batch.core.explore.JobExplorer} to use to query the job repository. Either this or 128 * a {@link javax.sql.DataSource} is required when using job repository polling. 129 * 130 * @param jobExplorer {@link org.springframework.batch.core.explore.JobExplorer} to use for lookups 131 */ 132 public void setJobExplorer(JobExplorer jobExplorer) { 133 this.jobExplorer = jobExplorer; 134 } 135 136 /** 137 * How often to poll the job repository for the status of the slaves. 138 * 139 * @param pollInterval milliseconds between polls, defaults to 10000 (10 seconds). 140 */ 141 public void setPollInterval(long pollInterval) { 142 this.pollInterval = pollInterval; 143 } 144 145 /** 146 * {@link javax.sql.DataSource} pointing to the job repository 147 * 148 * @param dataSource {@link javax.sql.DataSource} that points to the job repository's store 149 */ 150 public void setDataSource(DataSource dataSource) { 151 this.dataSource = dataSource; 152 } 153 154 /** 155 * A pre-configured gateway for sending and receiving messages to the remote workers. Using this property allows a 156 * large degree of control over the timeouts and other properties of the send. It should have channels set up 157 * internally: <ul> <li>request channel capable of accepting {@link StepExecutionRequest} payloads</li> <li>reply 158 * channel that returns a list of {@link StepExecution} results</li> </ul> The timeout for the reply should be set 159 * sufficiently long that the remote steps have time to complete. 160 * 161 * @param messagingGateway the {@link org.springframework.integration.core.MessagingTemplate} to set 162 */ 163 public void setMessagingOperations(MessagingTemplate messagingGateway) { 164 this.messagingGateway = messagingGateway; 165 } 166 167 /** 168 * Passed to the {@link StepExecutionSplitter} in the {@link #handle(StepExecutionSplitter, StepExecution)} method, 169 * instructing it how many {@link StepExecution} instances are required, ideally. The {@link StepExecutionSplitter} 170 * is allowed to ignore the grid size in the case of a restart, since the input data partitions must be preserved. 171 * 172 * @param gridSize the number of step executions that will be created 173 */ 174 public void setGridSize(int gridSize) { 175 this.gridSize = gridSize; 176 } 177 178 /** 179 * The name of the {@link Step} that will be used to execute the partitioned {@link StepExecution}. This is a 180 * regular Spring Batch step, with all the business logic required to complete an execution based on the input 181 * parameters in its {@link StepExecution} context. The name will be translated into a {@link Step} instance by the 182 * remote worker. 183 * 184 * @param stepName the name of the {@link Step} instance to execute business logic 185 */ 186 public void setStepName(String stepName) { 187 this.stepName = stepName; 188 } 189 190 /** 191 * @param messages the messages to be aggregated 192 * @return the list as it was passed in 193 */ 194 @Aggregator(sendPartialResultsOnExpiry = "true") 195 public List<?> aggregate(@Payloads List<?> messages) { 196 return messages; 197 } 198 199 public void setReplyChannel(PollableChannel replyChannel) { 200 this.replyChannel = replyChannel; 201 } 202 203 /** 204 * Sends {@link StepExecutionRequest} objects to the request channel of the {@link MessagingTemplate}, and then 205 * receives the result back as a list of {@link StepExecution} on a reply channel. Use the {@link #aggregate(List)} 206 * method as an aggregator of the individual remote replies. The receive timeout needs to be set realistically in 207 * the {@link MessagingTemplate} <b>and</b> the aggregator, so that there is a good chance of all work being done. 208 * 209 * @see PartitionHandler#handle(StepExecutionSplitter, StepExecution) 210 */ 211 public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplitter, 212 final StepExecution masterStepExecution) throws Exception { 213 214 final Set<StepExecution> split = stepExecutionSplitter.split(masterStepExecution, gridSize); 215 216 if(CollectionUtils.isEmpty(split)) { 217 return split; 218 } 219 220 int count = 0; 221 222 for (StepExecution stepExecution : split) { 223 Message<StepExecutionRequest> request = createMessage(count++, split.size(), new StepExecutionRequest( 224 stepName, stepExecution.getJobExecutionId(), stepExecution.getId()), replyChannel); 225 if (logger.isDebugEnabled()) { 226 logger.debug("Sending request: " + request); 227 } 228 messagingGateway.send(request); 229 } 230 231 if(!pollRepositoryForResults) { 232 return receiveReplies(replyChannel); 233 } 234 else { 235 return pollReplies(masterStepExecution, split); 236 } 237 } 238 239 private Collection<StepExecution> pollReplies(final StepExecution masterStepExecution, final Set<StepExecution> split) throws Exception { 240 final Collection<StepExecution> result = new ArrayList<StepExecution>(split.size()); 241 242 Callable<Collection<StepExecution>> callback = new Callable<Collection<StepExecution>>() { 243 @Override 244 public Collection<StepExecution> call() throws Exception { 245 246 for(Iterator<StepExecution> stepExecutionIterator = split.iterator(); stepExecutionIterator.hasNext(); ) { 247 StepExecution curStepExecution = stepExecutionIterator.next(); 248 249 if(!result.contains(curStepExecution)) { 250 StepExecution partitionStepExecution = 251 jobExplorer.getStepExecution(masterStepExecution.getJobExecutionId(), curStepExecution.getId()); 252 253 if(!partitionStepExecution.getStatus().isRunning()) { 254 result.add(partitionStepExecution); 255 } 256 } 257 } 258 259 if(logger.isDebugEnabled()) { 260 logger.debug(String.format("Currently waiting on %s partitions to finish", split.size())); 261 } 262 263 if(result.size() == split.size()) { 264 return result; 265 } 266 else { 267 return null; 268 } 269 } 270 }; 271 272 Poller<Collection<StepExecution>> poller = new DirectPoller<Collection<StepExecution>>(pollInterval); 273 Future<Collection<StepExecution>> resultsFuture = poller.poll(callback); 274 275 if(timeout >= 0) { 276 return resultsFuture.get(timeout, TimeUnit.MILLISECONDS); 277 } 278 else { 279 return resultsFuture.get(); 280 } 281 } 282 283 private Collection<StepExecution> receiveReplies(PollableChannel currentReplyChannel) { 284 @SuppressWarnings("unchecked") 285 Message<Collection<StepExecution>> message = (Message<Collection<StepExecution>>) messagingGateway.receive(currentReplyChannel); 286 287 if(message == null) { 288 throw new MessageTimeoutException("Timeout occurred before all partitions returned"); 289 } else if (logger.isDebugEnabled()) { 290 logger.debug("Received replies: " + message); 291 } 292 293 return message.getPayload(); 294 } 295 296 private Message<StepExecutionRequest> createMessage(int sequenceNumber, int sequenceSize, 297 StepExecutionRequest stepExecutionRequest, PollableChannel replyChannel) { 298 return MessageBuilder.withPayload(stepExecutionRequest).setSequenceNumber(sequenceNumber) 299 .setSequenceSize(sequenceSize) 300 .setCorrelationId(stepExecutionRequest.getJobExecutionId() + ":" + stepExecutionRequest.getStepName()) 301 .setReplyChannel(replyChannel) 302 .build(); 303 } 304}