001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.client.standard;
018
019import javax.websocket.ContainerProvider;
020import javax.websocket.Session;
021import javax.websocket.WebSocketContainer;
022
023import org.springframework.beans.BeansException;
024import org.springframework.beans.factory.BeanFactory;
025import org.springframework.beans.factory.BeanFactoryAware;
026import org.springframework.core.task.SimpleAsyncTaskExecutor;
027import org.springframework.core.task.TaskExecutor;
028import org.springframework.lang.Nullable;
029import org.springframework.util.Assert;
030import org.springframework.web.socket.client.ConnectionManagerSupport;
031import org.springframework.web.socket.handler.BeanCreatingHandlerProvider;
032
033/**
034 * A WebSocket connection manager that is given a URI, a
035 * {@link javax.websocket.ClientEndpoint}-annotated endpoint, connects to a
036 * WebSocket server through the {@link #start()} and {@link #stop()} methods.
037 * If {@link #setAutoStartup(boolean)} is set to {@code true} this will be
038 * done automatically when the Spring ApplicationContext is refreshed.
039 *
040 * @author Rossen Stoyanchev
041 * @since 4.0
042 */
043public class AnnotatedEndpointConnectionManager extends ConnectionManagerSupport implements BeanFactoryAware {
044
045        @Nullable
046        private final Object endpoint;
047
048        @Nullable
049        private final BeanCreatingHandlerProvider<Object> endpointProvider;
050
051        private WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer();
052
053        private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("AnnotatedEndpointConnectionManager-");
054
055        @Nullable
056        private volatile Session session;
057
058
059        public AnnotatedEndpointConnectionManager(Object endpoint, String uriTemplate, Object... uriVariables) {
060                super(uriTemplate, uriVariables);
061                this.endpoint = endpoint;
062                this.endpointProvider = null;
063        }
064
065        public AnnotatedEndpointConnectionManager(Class<?> endpointClass, String uriTemplate, Object... uriVariables) {
066                super(uriTemplate, uriVariables);
067                this.endpoint = null;
068                this.endpointProvider = new BeanCreatingHandlerProvider<>(endpointClass);
069        }
070
071
072        public void setWebSocketContainer(WebSocketContainer webSocketContainer) {
073                this.webSocketContainer = webSocketContainer;
074        }
075
076        public WebSocketContainer getWebSocketContainer() {
077                return this.webSocketContainer;
078        }
079
080        @Override
081        public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
082                if (this.endpointProvider != null) {
083                        this.endpointProvider.setBeanFactory(beanFactory);
084                }
085        }
086
087        /**
088         * Set a {@link TaskExecutor} to use to open the connection.
089         * By default {@link SimpleAsyncTaskExecutor} is used.
090         */
091        public void setTaskExecutor(TaskExecutor taskExecutor) {
092                Assert.notNull(taskExecutor, "TaskExecutor must not be null");
093                this.taskExecutor = taskExecutor;
094        }
095
096        /**
097         * Return the configured {@link TaskExecutor}.
098         */
099        public TaskExecutor getTaskExecutor() {
100                return this.taskExecutor;
101        }
102
103
104        @Override
105        protected void openConnection() {
106                this.taskExecutor.execute(() -> {
107                        try {
108                                if (logger.isInfoEnabled()) {
109                                        logger.info("Connecting to WebSocket at " + getUri());
110                                }
111                                Object endpointToUse = this.endpoint;
112                                if (endpointToUse == null) {
113                                        Assert.state(this.endpointProvider != null, "No endpoint set");
114                                        endpointToUse = this.endpointProvider.getHandler();
115                                }
116                                this.session = this.webSocketContainer.connectToServer(endpointToUse, getUri());
117                                logger.info("Successfully connected to WebSocket");
118                        }
119                        catch (Throwable ex) {
120                                logger.error("Failed to connect to WebSocket", ex);
121                        }
122                });
123        }
124
125        @Override
126        protected void closeConnection() throws Exception {
127                try {
128                        Session session = this.session;
129                        if (session != null && session.isOpen()) {
130                                session.close();
131                        }
132                }
133                finally {
134                        this.session = null;
135                }
136        }
137
138        @Override
139        protected boolean isConnected() {
140                Session session = this.session;
141                return (session != null && session.isOpen());
142        }
143
144}