001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.client; 018 019import java.io.ByteArrayOutputStream; 020import java.io.IOException; 021import java.io.InputStream; 022import java.net.URI; 023 024import org.springframework.core.task.SimpleAsyncTaskExecutor; 025import org.springframework.core.task.TaskExecutor; 026import org.springframework.http.HttpHeaders; 027import org.springframework.http.HttpMethod; 028import org.springframework.http.HttpStatus; 029import org.springframework.http.ResponseEntity; 030import org.springframework.http.StreamingHttpOutputMessage; 031import org.springframework.http.client.ClientHttpRequest; 032import org.springframework.http.client.ClientHttpResponse; 033import org.springframework.lang.Nullable; 034import org.springframework.util.Assert; 035import org.springframework.util.StreamUtils; 036import org.springframework.util.concurrent.SettableListenableFuture; 037import org.springframework.web.client.HttpServerErrorException; 038import org.springframework.web.client.RequestCallback; 039import org.springframework.web.client.ResponseExtractor; 040import org.springframework.web.client.RestOperations; 041import org.springframework.web.client.RestTemplate; 042import org.springframework.web.client.UnknownHttpStatusCodeException; 043import org.springframework.web.socket.CloseStatus; 044import org.springframework.web.socket.TextMessage; 045import org.springframework.web.socket.WebSocketHandler; 046import org.springframework.web.socket.WebSocketSession; 047import org.springframework.web.socket.sockjs.frame.SockJsFrame; 048 049/** 050 * An {@code XhrTransport} implementation that uses a 051 * {@link org.springframework.web.client.RestTemplate RestTemplate}. 052 * 053 * @author Rossen Stoyanchev 054 * @since 4.1 055 */ 056public class RestTemplateXhrTransport extends AbstractXhrTransport { 057 058 private final RestOperations restTemplate; 059 060 private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); 061 062 063 public RestTemplateXhrTransport() { 064 this(new RestTemplate()); 065 } 066 067 public RestTemplateXhrTransport(RestOperations restTemplate) { 068 Assert.notNull(restTemplate, "'restTemplate' is required"); 069 this.restTemplate = restTemplate; 070 } 071 072 073 /** 074 * Return the configured {@code RestTemplate}. 075 */ 076 public RestOperations getRestTemplate() { 077 return this.restTemplate; 078 } 079 080 /** 081 * Configure the {@code TaskExecutor} to use to execute XHR receive requests. 082 * <p>By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor 083 * SimpleAsyncTaskExecutor} is configured which creates a new thread every 084 * time the transports connects. 085 */ 086 public void setTaskExecutor(TaskExecutor taskExecutor) { 087 Assert.notNull(taskExecutor, "TaskExecutor must not be null"); 088 this.taskExecutor = taskExecutor; 089 } 090 091 /** 092 * Return the configured {@code TaskExecutor}. 093 */ 094 public TaskExecutor getTaskExecutor() { 095 return this.taskExecutor; 096 } 097 098 099 @Override 100 protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler, 101 final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, 102 final SettableListenableFuture<WebSocketSession> connectFuture) { 103 104 getTaskExecutor().execute(() -> { 105 HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); 106 XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); 107 XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders); 108 XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); 109 while (true) { 110 if (session.isDisconnected()) { 111 session.afterTransportClosed(null); 112 break; 113 } 114 try { 115 if (logger.isTraceEnabled()) { 116 logger.trace("Starting XHR receive request, url=" + receiveUrl); 117 } 118 getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor); 119 requestCallback = requestCallbackAfterHandshake; 120 } 121 catch (Exception ex) { 122 if (!connectFuture.isDone()) { 123 connectFuture.setException(ex); 124 } 125 else { 126 session.handleTransportError(ex); 127 session.afterTransportClosed(new CloseStatus(1006, ex.getMessage())); 128 } 129 break; 130 } 131 } 132 }); 133 } 134 135 @Override 136 protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { 137 RequestCallback requestCallback = new XhrRequestCallback(headers); 138 return nonNull(this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor)); 139 } 140 141 @Override 142 public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { 143 RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload()); 144 return nonNull(this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textResponseExtractor)); 145 } 146 147 private static <T> T nonNull(@Nullable T result) { 148 Assert.state(result != null, "No result"); 149 return result; 150 } 151 152 153 /** 154 * A simple ResponseExtractor that reads the body into a String. 155 */ 156 private static final ResponseExtractor<ResponseEntity<String>> textResponseExtractor = 157 response -> { 158 String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); 159 return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).body(body); 160 }; 161 162 163 /** 164 * A RequestCallback to add the headers and (optionally) String content. 165 */ 166 private static class XhrRequestCallback implements RequestCallback { 167 168 private final HttpHeaders headers; 169 170 @Nullable 171 private final String body; 172 173 public XhrRequestCallback(HttpHeaders headers) { 174 this(headers, null); 175 } 176 177 public XhrRequestCallback(HttpHeaders headers, @Nullable String body) { 178 this.headers = headers; 179 this.body = body; 180 } 181 182 @Override 183 public void doWithRequest(ClientHttpRequest request) throws IOException { 184 request.getHeaders().putAll(this.headers); 185 if (this.body != null) { 186 if (request instanceof StreamingHttpOutputMessage) { 187 ((StreamingHttpOutputMessage) request).setBody(outputStream -> 188 StreamUtils.copy(this.body, SockJsFrame.CHARSET, outputStream)); 189 } 190 else { 191 StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody()); 192 } 193 } 194 } 195 } 196 197 /** 198 * Splits the body of an HTTP response into SockJS frames and delegates those 199 * to an {@link XhrClientSockJsSession}. 200 */ 201 private class XhrReceiveExtractor implements ResponseExtractor<Object> { 202 203 private final XhrClientSockJsSession sockJsSession; 204 205 public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) { 206 this.sockJsSession = sockJsSession; 207 } 208 209 @Override 210 public Object extractData(ClientHttpResponse response) throws IOException { 211 HttpStatus httpStatus = HttpStatus.resolve(response.getRawStatusCode()); 212 if (httpStatus == null) { 213 throw new UnknownHttpStatusCodeException( 214 response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), null, null); 215 } 216 if (httpStatus != HttpStatus.OK) { 217 throw new HttpServerErrorException( 218 httpStatus, response.getStatusText(), response.getHeaders(), null, null); 219 } 220 221 if (logger.isTraceEnabled()) { 222 logger.trace("XHR receive headers: " + response.getHeaders()); 223 } 224 InputStream is = response.getBody(); 225 ByteArrayOutputStream os = new ByteArrayOutputStream(); 226 227 while (true) { 228 if (this.sockJsSession.isDisconnected()) { 229 if (logger.isDebugEnabled()) { 230 logger.debug("SockJS sockJsSession closed, closing response."); 231 } 232 response.close(); 233 break; 234 } 235 int b = is.read(); 236 if (b == -1) { 237 if (os.size() > 0) { 238 handleFrame(os); 239 } 240 if (logger.isTraceEnabled()) { 241 logger.trace("XHR receive completed"); 242 } 243 break; 244 } 245 if (b == '\n') { 246 handleFrame(os); 247 } 248 else { 249 os.write(b); 250 } 251 } 252 return null; 253 } 254 255 private void handleFrame(ByteArrayOutputStream os) throws IOException { 256 String content = os.toString(SockJsFrame.CHARSET.name()); 257 os.reset(); 258 if (logger.isTraceEnabled()) { 259 logger.trace("XHR receive content: " + content); 260 } 261 if (!PRELUDE.equals(content)) { 262 this.sockJsSession.handleFrame(content); 263 } 264 } 265 } 266 267}