001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.adapter; 018 019import java.nio.charset.StandardCharsets; 020import java.util.Map; 021import java.util.concurrent.ConcurrentHashMap; 022import java.util.function.Function; 023 024import org.apache.commons.logging.Log; 025import org.apache.commons.logging.LogFactory; 026import org.reactivestreams.Publisher; 027import reactor.core.publisher.Flux; 028import reactor.core.publisher.Mono; 029 030import org.springframework.core.io.buffer.DataBuffer; 031import org.springframework.core.io.buffer.DataBufferFactory; 032import org.springframework.util.Assert; 033import org.springframework.web.reactive.socket.HandshakeInfo; 034import org.springframework.web.reactive.socket.WebSocketMessage; 035import org.springframework.web.reactive.socket.WebSocketSession; 036 037/** 038 * Convenient base class for {@link WebSocketSession} implementations that 039 * holds common fields and exposes accessors. Also implements the 040 * {@code WebSocketMessage} factory methods. 041 * 042 * @author Rossen Stoyanchev 043 * @since 5.0 044 * @param <T> the native delegate type 045 */ 046public abstract class AbstractWebSocketSession<T> implements WebSocketSession { 047 048 protected final Log logger = LogFactory.getLog(getClass()); 049 050 private final T delegate; 051 052 private final String id; 053 054 private final HandshakeInfo handshakeInfo; 055 056 private final DataBufferFactory bufferFactory; 057 058 private final Map<String, Object> attributes = new ConcurrentHashMap<>(); 059 060 private final String logPrefix; 061 062 063 /** 064 * Create a new WebSocket session. 065 */ 066 protected AbstractWebSocketSession(T delegate, String id, HandshakeInfo info, DataBufferFactory bufferFactory) { 067 Assert.notNull(delegate, "Native session is required."); 068 Assert.notNull(id, "Session id is required."); 069 Assert.notNull(info, "HandshakeInfo is required."); 070 Assert.notNull(bufferFactory, "DataBuffer factory is required."); 071 072 this.delegate = delegate; 073 this.id = id; 074 this.handshakeInfo = info; 075 this.bufferFactory = bufferFactory; 076 this.attributes.putAll(info.getAttributes()); 077 this.logPrefix = initLogPrefix(info, id); 078 079 if (logger.isDebugEnabled()) { 080 logger.debug(getLogPrefix() + "Session id \"" + getId() + "\" for " + getHandshakeInfo().getUri()); 081 } 082 } 083 084 private static String initLogPrefix(HandshakeInfo info, String id) { 085 return info.getLogPrefix() != null ? info.getLogPrefix() : "[" + id + "] "; 086 } 087 088 089 protected T getDelegate() { 090 return this.delegate; 091 } 092 093 @Override 094 public String getId() { 095 return this.id; 096 } 097 098 @Override 099 public HandshakeInfo getHandshakeInfo() { 100 return this.handshakeInfo; 101 } 102 103 @Override 104 public DataBufferFactory bufferFactory() { 105 return this.bufferFactory; 106 } 107 108 @Override 109 public Map<String, Object> getAttributes() { 110 return this.attributes; 111 } 112 113 protected String getLogPrefix() { 114 return this.logPrefix; 115 } 116 117 118 @Override 119 public abstract Flux<WebSocketMessage> receive(); 120 121 @Override 122 public abstract Mono<Void> send(Publisher<WebSocketMessage> messages); 123 124 125 // WebSocketMessage factory methods 126 127 @Override 128 public WebSocketMessage textMessage(String payload) { 129 byte[] bytes = payload.getBytes(StandardCharsets.UTF_8); 130 DataBuffer buffer = bufferFactory().wrap(bytes); 131 return new WebSocketMessage(WebSocketMessage.Type.TEXT, buffer); 132 } 133 134 @Override 135 public WebSocketMessage binaryMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) { 136 DataBuffer payload = payloadFactory.apply(bufferFactory()); 137 return new WebSocketMessage(WebSocketMessage.Type.BINARY, payload); 138 } 139 140 @Override 141 public WebSocketMessage pingMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) { 142 DataBuffer payload = payloadFactory.apply(bufferFactory()); 143 return new WebSocketMessage(WebSocketMessage.Type.PING, payload); 144 } 145 146 @Override 147 public WebSocketMessage pongMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) { 148 DataBuffer payload = payloadFactory.apply(bufferFactory()); 149 return new WebSocketMessage(WebSocketMessage.Type.PONG, payload); 150 } 151 152 153 @Override 154 public String toString() { 155 return getClass().getSimpleName() + "[id=" + getId() + ", uri=" + getHandshakeInfo().getUri() + "]"; 156 } 157 158}