001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.client;
018
019import java.net.URI;
020import java.util.List;
021import java.util.Map;
022
023import javax.websocket.ClientEndpointConfig;
024import javax.websocket.ClientEndpointConfig.Configurator;
025import javax.websocket.ContainerProvider;
026import javax.websocket.Endpoint;
027import javax.websocket.HandshakeResponse;
028import javax.websocket.Session;
029import javax.websocket.WebSocketContainer;
030
031import org.apache.commons.logging.Log;
032import org.apache.commons.logging.LogFactory;
033import reactor.core.publisher.Mono;
034import reactor.core.publisher.MonoProcessor;
035import reactor.core.scheduler.Schedulers;
036
037import org.springframework.core.io.buffer.DataBufferFactory;
038import org.springframework.core.io.buffer.DefaultDataBufferFactory;
039import org.springframework.http.HttpHeaders;
040import org.springframework.web.reactive.socket.HandshakeInfo;
041import org.springframework.web.reactive.socket.WebSocketHandler;
042import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
043import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession;
044
045/**
046 * {@link WebSocketClient} implementation for use with the Java WebSocket API.
047 *
048 * @author Violeta Georgieva
049 * @author Rossen Stoyanchev
050 * @since 5.0
051 * @see <a href="https://www.jcp.org/en/jsr/detail?id=356">https://www.jcp.org/en/jsr/detail?id=356</a>
052 */
053public class StandardWebSocketClient implements WebSocketClient {
054
055        private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class);
056
057
058        private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
059
060        private final WebSocketContainer webSocketContainer;
061
062
063        /**
064         * Default constructor that calls
065         * {@code ContainerProvider.getWebSocketContainer()} to obtain a (new)
066         * {@link WebSocketContainer} instance.
067         */
068        public StandardWebSocketClient() {
069                this(ContainerProvider.getWebSocketContainer());
070        }
071
072        /**
073         * Constructor accepting an existing {@link WebSocketContainer} instance.
074         * @param webSocketContainer a web socket container
075         */
076        public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
077                this.webSocketContainer = webSocketContainer;
078        }
079
080
081        /**
082         * Return the configured {@link WebSocketContainer} to use.
083         */
084        public WebSocketContainer getWebSocketContainer() {
085                return this.webSocketContainer;
086        }
087
088
089        @Override
090        public Mono<Void> execute(URI url, WebSocketHandler handler) {
091                return execute(url, new HttpHeaders(), handler);
092        }
093
094        @Override
095        public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
096                return executeInternal(url, headers, handler);
097        }
098
099        private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
100                MonoProcessor<Void> completionMono = MonoProcessor.create();
101                return Mono.fromCallable(
102                                () -> {
103                                        if (logger.isDebugEnabled()) {
104                                                logger.debug("Connecting to " + url);
105                                        }
106                                        List<String> protocols = handler.getSubProtocols();
107                                        DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders);
108                                        Endpoint endpoint = createEndpoint(url, handler, completionMono, configurator);
109                                        ClientEndpointConfig config = createEndpointConfig(configurator, protocols);
110                                        return this.webSocketContainer.connectToServer(endpoint, config, url);
111                                })
112                                .subscribeOn(Schedulers.boundedElastic()) // connectToServer is blocking
113                                .then(completionMono);
114        }
115
116        private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler,
117                        MonoProcessor<Void> completion, DefaultConfigurator configurator) {
118
119                return new StandardWebSocketHandlerAdapter(handler, session ->
120                                createWebSocketSession(session, createHandshakeInfo(url, configurator), completion));
121        }
122
123        private HandshakeInfo createHandshakeInfo(URI url, DefaultConfigurator configurator) {
124                HttpHeaders responseHeaders = configurator.getResponseHeaders();
125                String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol");
126                return new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
127        }
128
129        protected StandardWebSocketSession createWebSocketSession(Session session, HandshakeInfo info,
130                        MonoProcessor<Void> completion) {
131
132                return new StandardWebSocketSession(session, info, this.bufferFactory, completion);
133        }
134
135        private ClientEndpointConfig createEndpointConfig(Configurator configurator, List<String> subProtocols) {
136                return ClientEndpointConfig.Builder.create()
137                                .configurator(configurator)
138                                .preferredSubprotocols(subProtocols)
139                                .build();
140        }
141
142        protected DataBufferFactory bufferFactory() {
143                return this.bufferFactory;
144        }
145
146
147        private static final class DefaultConfigurator extends Configurator {
148
149                private final HttpHeaders requestHeaders;
150
151                private final HttpHeaders responseHeaders = new HttpHeaders();
152
153                public DefaultConfigurator(HttpHeaders requestHeaders) {
154                        this.requestHeaders = requestHeaders;
155                }
156
157                public HttpHeaders getResponseHeaders() {
158                        return this.responseHeaders;
159                }
160
161                @Override
162                public void beforeRequest(Map<String, List<String>> requestHeaders) {
163                        requestHeaders.putAll(this.requestHeaders);
164                }
165
166                @Override
167                public void afterResponse(HandshakeResponse response) {
168                        response.getHeaders().forEach(this.responseHeaders::put);
169                }
170        }
171
172}