001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.ByteArrayOutputStream;
020import java.io.IOException;
021import java.io.InputStream;
022import java.net.URI;
023
024import org.springframework.core.task.SimpleAsyncTaskExecutor;
025import org.springframework.core.task.TaskExecutor;
026import org.springframework.http.HttpHeaders;
027import org.springframework.http.HttpMethod;
028import org.springframework.http.HttpStatus;
029import org.springframework.http.ResponseEntity;
030import org.springframework.http.StreamingHttpOutputMessage;
031import org.springframework.http.client.ClientHttpRequest;
032import org.springframework.http.client.ClientHttpResponse;
033import org.springframework.lang.Nullable;
034import org.springframework.util.Assert;
035import org.springframework.util.StreamUtils;
036import org.springframework.util.concurrent.SettableListenableFuture;
037import org.springframework.web.client.HttpServerErrorException;
038import org.springframework.web.client.RequestCallback;
039import org.springframework.web.client.ResponseExtractor;
040import org.springframework.web.client.RestOperations;
041import org.springframework.web.client.RestTemplate;
042import org.springframework.web.client.UnknownHttpStatusCodeException;
043import org.springframework.web.socket.CloseStatus;
044import org.springframework.web.socket.TextMessage;
045import org.springframework.web.socket.WebSocketHandler;
046import org.springframework.web.socket.WebSocketSession;
047import org.springframework.web.socket.sockjs.frame.SockJsFrame;
048
049/**
050 * An {@code XhrTransport} implementation that uses a
051 * {@link org.springframework.web.client.RestTemplate RestTemplate}.
052 *
053 * @author Rossen Stoyanchev
054 * @since 4.1
055 */
056public class RestTemplateXhrTransport extends AbstractXhrTransport {
057
058        private final RestOperations restTemplate;
059
060        private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
061
062
063        public RestTemplateXhrTransport() {
064                this(new RestTemplate());
065        }
066
067        public RestTemplateXhrTransport(RestOperations restTemplate) {
068                Assert.notNull(restTemplate, "'restTemplate' is required");
069                this.restTemplate = restTemplate;
070        }
071
072
073        /**
074         * Return the configured {@code RestTemplate}.
075         */
076        public RestOperations getRestTemplate() {
077                return this.restTemplate;
078        }
079
080        /**
081         * Configure the {@code TaskExecutor} to use to execute XHR receive requests.
082         * <p>By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor
083         * SimpleAsyncTaskExecutor} is configured which creates a new thread every
084         * time the transports connects.
085         */
086        public void setTaskExecutor(TaskExecutor taskExecutor) {
087                Assert.notNull(taskExecutor, "TaskExecutor must not be null");
088                this.taskExecutor = taskExecutor;
089        }
090
091        /**
092         * Return the configured {@code TaskExecutor}.
093         */
094        public TaskExecutor getTaskExecutor() {
095                return this.taskExecutor;
096        }
097
098
099        @Override
100        protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler,
101                        final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session,
102                        final SettableListenableFuture<WebSocketSession> connectFuture) {
103
104                getTaskExecutor().execute(() -> {
105                        HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
106                        XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders);
107                        XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders);
108                        XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session);
109                        while (true) {
110                                if (session.isDisconnected()) {
111                                        session.afterTransportClosed(null);
112                                        break;
113                                }
114                                try {
115                                        if (logger.isTraceEnabled()) {
116                                                logger.trace("Starting XHR receive request, url=" + receiveUrl);
117                                        }
118                                        getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor);
119                                        requestCallback = requestCallbackAfterHandshake;
120                                }
121                                catch (Exception ex) {
122                                        if (!connectFuture.isDone()) {
123                                                connectFuture.setException(ex);
124                                        }
125                                        else {
126                                                session.handleTransportError(ex);
127                                                session.afterTransportClosed(new CloseStatus(1006, ex.getMessage()));
128                                        }
129                                        break;
130                                }
131                        }
132                });
133        }
134
135        @Override
136        protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
137                RequestCallback requestCallback = new XhrRequestCallback(headers);
138                return nonNull(this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor));
139        }
140
141        @Override
142        public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
143                RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload());
144                return nonNull(this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textResponseExtractor));
145        }
146
147        private static <T> T nonNull(@Nullable T result) {
148                Assert.state(result != null, "No result");
149                return result;
150        }
151
152
153        /**
154         * A simple ResponseExtractor that reads the body into a String.
155         */
156        private static final ResponseExtractor<ResponseEntity<String>> textResponseExtractor =
157                        response -> {
158                                String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET);
159                                return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).body(body);
160                        };
161
162
163        /**
164         * A RequestCallback to add the headers and (optionally) String content.
165         */
166        private static class XhrRequestCallback implements RequestCallback {
167
168                private final HttpHeaders headers;
169
170                @Nullable
171                private final String body;
172
173                public XhrRequestCallback(HttpHeaders headers) {
174                        this(headers, null);
175                }
176
177                public XhrRequestCallback(HttpHeaders headers, @Nullable String body) {
178                        this.headers = headers;
179                        this.body = body;
180                }
181
182                @Override
183                public void doWithRequest(ClientHttpRequest request) throws IOException {
184                        request.getHeaders().putAll(this.headers);
185                        if (this.body != null) {
186                                if (request instanceof StreamingHttpOutputMessage) {
187                                        ((StreamingHttpOutputMessage) request).setBody(outputStream ->
188                                                        StreamUtils.copy(this.body, SockJsFrame.CHARSET, outputStream));
189                                }
190                                else {
191                                        StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody());
192                                }
193                        }
194                }
195        }
196
197        /**
198         * Splits the body of an HTTP response into SockJS frames and delegates those
199         * to an {@link XhrClientSockJsSession}.
200         */
201        private class XhrReceiveExtractor implements ResponseExtractor<Object> {
202
203                private final XhrClientSockJsSession sockJsSession;
204
205                public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) {
206                        this.sockJsSession = sockJsSession;
207                }
208
209                @Override
210                public Object extractData(ClientHttpResponse response) throws IOException {
211                        HttpStatus httpStatus = HttpStatus.resolve(response.getRawStatusCode());
212                        if (httpStatus == null) {
213                                throw new UnknownHttpStatusCodeException(
214                                                response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), null, null);
215                        }
216                        if (httpStatus != HttpStatus.OK) {
217                                throw new HttpServerErrorException(
218                                                httpStatus, response.getStatusText(), response.getHeaders(), null, null);
219                        }
220
221                        if (logger.isTraceEnabled()) {
222                                logger.trace("XHR receive headers: " + response.getHeaders());
223                        }
224                        InputStream is = response.getBody();
225                        ByteArrayOutputStream os = new ByteArrayOutputStream();
226
227                        while (true) {
228                                if (this.sockJsSession.isDisconnected()) {
229                                        if (logger.isDebugEnabled()) {
230                                                logger.debug("SockJS sockJsSession closed, closing response.");
231                                        }
232                                        response.close();
233                                        break;
234                                }
235                                int b = is.read();
236                                if (b == -1) {
237                                        if (os.size() > 0) {
238                                                handleFrame(os);
239                                        }
240                                        if (logger.isTraceEnabled()) {
241                                                logger.trace("XHR receive completed");
242                                        }
243                                        break;
244                                }
245                                if (b == '\n') {
246                                        handleFrame(os);
247                                }
248                                else {
249                                        os.write(b);
250                                }
251                        }
252                        return null;
253                }
254
255                private void handleFrame(ByteArrayOutputStream os) throws IOException {
256                        String content = os.toString(SockJsFrame.CHARSET.name());
257                        os.reset();
258                        if (logger.isTraceEnabled()) {
259                                logger.trace("XHR receive content: " + content);
260                        }
261                        if (!PRELUDE.equals(content)) {
262                                this.sockJsSession.handleFrame(content);
263                        }
264                }
265        }
266
267}