001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.ByteArrayOutputStream;
020import java.io.IOException;
021import java.io.InputStream;
022import java.net.URI;
023
024import org.springframework.core.task.SimpleAsyncTaskExecutor;
025import org.springframework.core.task.TaskExecutor;
026import org.springframework.http.HttpHeaders;
027import org.springframework.http.HttpMethod;
028import org.springframework.http.HttpStatus;
029import org.springframework.http.ResponseEntity;
030import org.springframework.http.client.ClientHttpRequest;
031import org.springframework.http.client.ClientHttpResponse;
032import org.springframework.util.Assert;
033import org.springframework.util.StreamUtils;
034import org.springframework.util.concurrent.SettableListenableFuture;
035import org.springframework.web.client.HttpServerErrorException;
036import org.springframework.web.client.RequestCallback;
037import org.springframework.web.client.ResponseExtractor;
038import org.springframework.web.client.RestOperations;
039import org.springframework.web.client.RestTemplate;
040import org.springframework.web.client.UnknownHttpStatusCodeException;
041import org.springframework.web.socket.CloseStatus;
042import org.springframework.web.socket.TextMessage;
043import org.springframework.web.socket.WebSocketHandler;
044import org.springframework.web.socket.WebSocketSession;
045import org.springframework.web.socket.sockjs.frame.SockJsFrame;
046
047/**
048 * An {@code XhrTransport} implementation that uses a
049 * {@link org.springframework.web.client.RestTemplate RestTemplate}.
050 *
051 * @author Rossen Stoyanchev
052 * @since 4.1
053 */
054public class RestTemplateXhrTransport extends AbstractXhrTransport {
055
056        private final RestOperations restTemplate;
057
058        private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
059
060
061        public RestTemplateXhrTransport() {
062                this(new RestTemplate());
063        }
064
065        public RestTemplateXhrTransport(RestOperations restTemplate) {
066                Assert.notNull(restTemplate, "'restTemplate' is required");
067                this.restTemplate = restTemplate;
068        }
069
070
071        /**
072         * Return the configured {@code RestTemplate}.
073         */
074        public RestOperations getRestTemplate() {
075                return this.restTemplate;
076        }
077
078        /**
079         * Configure the {@code TaskExecutor} to use to execute XHR receive requests.
080         * <p>By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor
081         * SimpleAsyncTaskExecutor} is configured which creates a new thread every
082         * time the transports connects.
083         */
084        public void setTaskExecutor(TaskExecutor taskExecutor) {
085                Assert.notNull(taskExecutor, "TaskExecutor must not be null");
086                this.taskExecutor = taskExecutor;
087        }
088
089        /**
090         * Return the configured {@code TaskExecutor}.
091         */
092        public TaskExecutor getTaskExecutor() {
093                return this.taskExecutor;
094        }
095
096
097        @Override
098        protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler,
099                        final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session,
100                        final SettableListenableFuture<WebSocketSession> connectFuture) {
101
102                getTaskExecutor().execute(new Runnable() {
103                        @Override
104                        public void run() {
105                                HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
106                                XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders);
107                                XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders);
108                                XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session);
109                                while (true) {
110                                        if (session.isDisconnected()) {
111                                                session.afterTransportClosed(null);
112                                                break;
113                                        }
114                                        try {
115                                                if (logger.isTraceEnabled()) {
116                                                        logger.trace("Starting XHR receive request, url=" + receiveUrl);
117                                                }
118                                                getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor);
119                                                requestCallback = requestCallbackAfterHandshake;
120                                        }
121                                        catch (Throwable ex) {
122                                                if (!connectFuture.isDone()) {
123                                                        connectFuture.setException(ex);
124                                                }
125                                                else {
126                                                        session.handleTransportError(ex);
127                                                        session.afterTransportClosed(new CloseStatus(1006, ex.getMessage()));
128                                                }
129                                                break;
130                                        }
131                                }
132                        }
133                });
134        }
135
136        @Override
137        protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
138                RequestCallback requestCallback = new XhrRequestCallback(headers);
139                return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor);
140        }
141
142        @Override
143        public ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
144                RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload());
145                return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textResponseExtractor);
146        }
147
148
149        /**
150         * A simple ResponseExtractor that reads the body into a String.
151         */
152        private final static ResponseExtractor<ResponseEntity<String>> textResponseExtractor =
153                        new ResponseExtractor<ResponseEntity<String>>() {
154                                @Override
155                                public ResponseEntity<String> extractData(ClientHttpResponse response) throws IOException {
156                                        if (response.getBody() == null) {
157                                                return new ResponseEntity<String>(response.getHeaders(), response.getStatusCode());
158                                        }
159                                        else {
160                                                String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET);
161                                                return new ResponseEntity<String>(body, response.getHeaders(), response.getStatusCode());
162                                        }
163                                }
164                        };
165
166        /**
167         * A RequestCallback to add the headers and (optionally) String content.
168         */
169        private static class XhrRequestCallback implements RequestCallback {
170
171                private final HttpHeaders headers;
172
173                private final String body;
174
175                public XhrRequestCallback(HttpHeaders headers) {
176                        this(headers, null);
177                }
178
179                public XhrRequestCallback(HttpHeaders headers, String body) {
180                        this.headers = headers;
181                        this.body = body;
182                }
183
184                @Override
185                public void doWithRequest(ClientHttpRequest request) throws IOException {
186                        if (this.headers != null) {
187                                request.getHeaders().putAll(this.headers);
188                        }
189                        if (this.body != null) {
190                                StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody());
191                        }
192                }
193        }
194
195        /**
196         * Splits the body of an HTTP response into SockJS frames and delegates those
197         * to an {@link XhrClientSockJsSession}.
198         */
199        private class XhrReceiveExtractor implements ResponseExtractor<Object> {
200
201                private final XhrClientSockJsSession sockJsSession;
202
203                public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) {
204                        this.sockJsSession = sockJsSession;
205                }
206
207                @Override
208                public Object extractData(ClientHttpResponse response) throws IOException {
209                        try {
210                                if (!HttpStatus.OK.equals(response.getStatusCode())) {
211                                        throw new HttpServerErrorException(response.getStatusCode());
212                                }
213                        }
214                        catch (IllegalArgumentException ex) {
215                                throw new UnknownHttpStatusCodeException(
216                                                response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), null, null);
217                        }
218
219                        if (logger.isTraceEnabled()) {
220                                logger.trace("XHR receive headers: " + response.getHeaders());
221                        }
222                        InputStream is = response.getBody();
223                        ByteArrayOutputStream os = new ByteArrayOutputStream();
224
225                        while (true) {
226                                if (this.sockJsSession.isDisconnected()) {
227                                        if (logger.isDebugEnabled()) {
228                                                logger.debug("SockJS sockJsSession closed, closing response.");
229                                        }
230                                        response.close();
231                                        break;
232                                }
233                                int b = is.read();
234                                if (b == -1) {
235                                        if (os.size() > 0) {
236                                                handleFrame(os);
237                                        }
238                                        if (logger.isTraceEnabled()) {
239                                                logger.trace("XHR receive completed");
240                                        }
241                                        break;
242                                }
243                                if (b == '\n') {
244                                        handleFrame(os);
245                                }
246                                else {
247                                        os.write(b);
248                                }
249                        }
250                        return null;
251                }
252
253                private void handleFrame(ByteArrayOutputStream os) {
254                        byte[] bytes = os.toByteArray();
255                        os.reset();
256                        String content = new String(bytes, SockJsFrame.CHARSET);
257                        if (logger.isTraceEnabled()) {
258                                logger.trace("XHR receive content: " + content);
259                        }
260                        if (!PRELUDE.equals(content)) {
261                                this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET));
262                        }
263                }
264        }
265
266}