001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.client; 018 019import java.io.ByteArrayOutputStream; 020import java.io.IOException; 021import java.io.InputStream; 022import java.net.URI; 023 024import org.springframework.core.task.SimpleAsyncTaskExecutor; 025import org.springframework.core.task.TaskExecutor; 026import org.springframework.http.HttpHeaders; 027import org.springframework.http.HttpMethod; 028import org.springframework.http.HttpStatus; 029import org.springframework.http.ResponseEntity; 030import org.springframework.http.client.ClientHttpRequest; 031import org.springframework.http.client.ClientHttpResponse; 032import org.springframework.util.Assert; 033import org.springframework.util.StreamUtils; 034import org.springframework.util.concurrent.SettableListenableFuture; 035import org.springframework.web.client.HttpServerErrorException; 036import org.springframework.web.client.RequestCallback; 037import org.springframework.web.client.ResponseExtractor; 038import org.springframework.web.client.RestOperations; 039import org.springframework.web.client.RestTemplate; 040import org.springframework.web.client.UnknownHttpStatusCodeException; 041import org.springframework.web.socket.CloseStatus; 042import org.springframework.web.socket.TextMessage; 043import org.springframework.web.socket.WebSocketHandler; 044import org.springframework.web.socket.WebSocketSession; 045import org.springframework.web.socket.sockjs.frame.SockJsFrame; 046 047/** 048 * An {@code XhrTransport} implementation that uses a 049 * {@link org.springframework.web.client.RestTemplate RestTemplate}. 050 * 051 * @author Rossen Stoyanchev 052 * @since 4.1 053 */ 054public class RestTemplateXhrTransport extends AbstractXhrTransport { 055 056 private final RestOperations restTemplate; 057 058 private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); 059 060 061 public RestTemplateXhrTransport() { 062 this(new RestTemplate()); 063 } 064 065 public RestTemplateXhrTransport(RestOperations restTemplate) { 066 Assert.notNull(restTemplate, "'restTemplate' is required"); 067 this.restTemplate = restTemplate; 068 } 069 070 071 /** 072 * Return the configured {@code RestTemplate}. 073 */ 074 public RestOperations getRestTemplate() { 075 return this.restTemplate; 076 } 077 078 /** 079 * Configure the {@code TaskExecutor} to use to execute XHR receive requests. 080 * <p>By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor 081 * SimpleAsyncTaskExecutor} is configured which creates a new thread every 082 * time the transports connects. 083 */ 084 public void setTaskExecutor(TaskExecutor taskExecutor) { 085 Assert.notNull(taskExecutor, "TaskExecutor must not be null"); 086 this.taskExecutor = taskExecutor; 087 } 088 089 /** 090 * Return the configured {@code TaskExecutor}. 091 */ 092 public TaskExecutor getTaskExecutor() { 093 return this.taskExecutor; 094 } 095 096 097 @Override 098 protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler, 099 final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, 100 final SettableListenableFuture<WebSocketSession> connectFuture) { 101 102 getTaskExecutor().execute(new Runnable() { 103 @Override 104 public void run() { 105 HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); 106 XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); 107 XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders); 108 XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); 109 while (true) { 110 if (session.isDisconnected()) { 111 session.afterTransportClosed(null); 112 break; 113 } 114 try { 115 if (logger.isTraceEnabled()) { 116 logger.trace("Starting XHR receive request, url=" + receiveUrl); 117 } 118 getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor); 119 requestCallback = requestCallbackAfterHandshake; 120 } 121 catch (Throwable ex) { 122 if (!connectFuture.isDone()) { 123 connectFuture.setException(ex); 124 } 125 else { 126 session.handleTransportError(ex); 127 session.afterTransportClosed(new CloseStatus(1006, ex.getMessage())); 128 } 129 break; 130 } 131 } 132 } 133 }); 134 } 135 136 @Override 137 protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { 138 RequestCallback requestCallback = new XhrRequestCallback(headers); 139 return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor); 140 } 141 142 @Override 143 public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { 144 RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload()); 145 return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textResponseExtractor); 146 } 147 148 149 /** 150 * A simple ResponseExtractor that reads the body into a String. 151 */ 152 private final static ResponseExtractor<ResponseEntity<String>> textResponseExtractor = 153 new ResponseExtractor<ResponseEntity<String>>() { 154 @Override 155 public ResponseEntity<String> extractData(ClientHttpResponse response) throws IOException { 156 if (response.getBody() == null) { 157 return new ResponseEntity<String>(response.getHeaders(), response.getStatusCode()); 158 } 159 else { 160 String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); 161 return new ResponseEntity<String>(body, response.getHeaders(), response.getStatusCode()); 162 } 163 } 164 }; 165 166 /** 167 * A RequestCallback to add the headers and (optionally) String content. 168 */ 169 private static class XhrRequestCallback implements RequestCallback { 170 171 private final HttpHeaders headers; 172 173 private final String body; 174 175 public XhrRequestCallback(HttpHeaders headers) { 176 this(headers, null); 177 } 178 179 public XhrRequestCallback(HttpHeaders headers, String body) { 180 this.headers = headers; 181 this.body = body; 182 } 183 184 @Override 185 public void doWithRequest(ClientHttpRequest request) throws IOException { 186 if (this.headers != null) { 187 request.getHeaders().putAll(this.headers); 188 } 189 if (this.body != null) { 190 StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody()); 191 } 192 } 193 } 194 195 /** 196 * Splits the body of an HTTP response into SockJS frames and delegates those 197 * to an {@link XhrClientSockJsSession}. 198 */ 199 private class XhrReceiveExtractor implements ResponseExtractor<Object> { 200 201 private final XhrClientSockJsSession sockJsSession; 202 203 public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) { 204 this.sockJsSession = sockJsSession; 205 } 206 207 @Override 208 public Object extractData(ClientHttpResponse response) throws IOException { 209 try { 210 if (!HttpStatus.OK.equals(response.getStatusCode())) { 211 throw new HttpServerErrorException(response.getStatusCode()); 212 } 213 } 214 catch (IllegalArgumentException ex) { 215 throw new UnknownHttpStatusCodeException( 216 response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), null, null); 217 } 218 219 if (logger.isTraceEnabled()) { 220 logger.trace("XHR receive headers: " + response.getHeaders()); 221 } 222 InputStream is = response.getBody(); 223 ByteArrayOutputStream os = new ByteArrayOutputStream(); 224 225 while (true) { 226 if (this.sockJsSession.isDisconnected()) { 227 if (logger.isDebugEnabled()) { 228 logger.debug("SockJS sockJsSession closed, closing response."); 229 } 230 response.close(); 231 break; 232 } 233 int b = is.read(); 234 if (b == -1) { 235 if (os.size() > 0) { 236 handleFrame(os); 237 } 238 if (logger.isTraceEnabled()) { 239 logger.trace("XHR receive completed"); 240 } 241 break; 242 } 243 if (b == '\n') { 244 handleFrame(os); 245 } 246 else { 247 os.write(b); 248 } 249 } 250 return null; 251 } 252 253 private void handleFrame(ByteArrayOutputStream os) { 254 byte[] bytes = os.toByteArray(); 255 os.reset(); 256 String content = new String(bytes, SockJsFrame.CHARSET); 257 if (logger.isTraceEnabled()) { 258 logger.trace("XHR receive content: " + content); 259 } 260 if (!PRELUDE.equals(content)) { 261 this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); 262 } 263 } 264 } 265 266}