001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.client; 018 019import java.net.URI; 020 021import org.apache.commons.logging.Log; 022import org.apache.commons.logging.LogFactory; 023import reactor.core.publisher.Mono; 024import reactor.netty.http.client.HttpClient; 025import reactor.netty.http.websocket.WebsocketInbound; 026 027import org.springframework.core.io.buffer.NettyDataBufferFactory; 028import org.springframework.http.HttpHeaders; 029import org.springframework.util.Assert; 030import org.springframework.util.StringUtils; 031import org.springframework.web.reactive.socket.HandshakeInfo; 032import org.springframework.web.reactive.socket.WebSocketHandler; 033import org.springframework.web.reactive.socket.WebSocketSession; 034import org.springframework.web.reactive.socket.adapter.NettyWebSocketSessionSupport; 035import org.springframework.web.reactive.socket.adapter.ReactorNettyWebSocketSession; 036 037/** 038 * {@link WebSocketClient} implementation for use with Reactor Netty. 039 * 040 * @author Rossen Stoyanchev 041 * @since 5.0 042 */ 043public class ReactorNettyWebSocketClient implements WebSocketClient { 044 045 private static final Log logger = LogFactory.getLog(ReactorNettyWebSocketClient.class); 046 047 048 private final HttpClient httpClient; 049 050 private int maxFramePayloadLength = NettyWebSocketSessionSupport.DEFAULT_FRAME_MAX_SIZE; 051 052 private boolean handlePing; 053 054 055 /** 056 * Default constructor. 057 */ 058 public ReactorNettyWebSocketClient() { 059 this(HttpClient.create()); 060 } 061 062 /** 063 * Constructor that accepts an existing {@link HttpClient} builder. 064 * @since 5.1 065 */ 066 public ReactorNettyWebSocketClient(HttpClient httpClient) { 067 Assert.notNull(httpClient, "HttpClient is required"); 068 this.httpClient = httpClient; 069 } 070 071 072 /** 073 * Return the configured {@link HttpClient}. 074 */ 075 public HttpClient getHttpClient() { 076 return this.httpClient; 077 } 078 079 /** 080 * Configure the maximum allowable frame payload length. Setting this value 081 * to your application's requirement may reduce denial of service attacks 082 * using long data frames. 083 * <p>Corresponds to the argument with the same name in the constructor of 084 * {@link io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory 085 * WebSocketServerHandshakerFactory} in Netty. 086 * <p>By default set to 65536 (64K). 087 * @param maxFramePayloadLength the max length for frames. 088 * @since 5.2 089 */ 090 public void setMaxFramePayloadLength(int maxFramePayloadLength) { 091 this.maxFramePayloadLength = maxFramePayloadLength; 092 } 093 094 /** 095 * Return the configured {@link #setMaxFramePayloadLength(int) maxFramePayloadLength}. 096 * @since 5.2 097 */ 098 public int getMaxFramePayloadLength() { 099 return this.maxFramePayloadLength; 100 } 101 102 /** 103 * Configure whether to let ping frames through to be handled by the 104 * {@link WebSocketHandler} given to the execute method. By default, Reactor 105 * Netty automatically replies with pong frames in response to pings. This is 106 * useful in a proxy for allowing ping and pong frames through. 107 * <p>By default this is set to {@code false} in which case ping frames are 108 * handled automatically by Reactor Netty. If set to {@code true}, ping 109 * frames will be passed through to the {@link WebSocketHandler}. 110 * @param handlePing whether to let Ping frames through for handling 111 * @since 5.2.4 112 */ 113 public void setHandlePing(boolean handlePing) { 114 this.handlePing = handlePing; 115 } 116 117 /** 118 * Return the configured {@link #setHandlePing(boolean)}. 119 * @since 5.2.4 120 */ 121 public boolean getHandlePing() { 122 return this.handlePing; 123 } 124 125 @Override 126 public Mono<Void> execute(URI url, WebSocketHandler handler) { 127 return execute(url, new HttpHeaders(), handler); 128 } 129 130 @Override 131 @SuppressWarnings("deprecation") 132 public Mono<Void> execute(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { 133 String protocols = StringUtils.collectionToCommaDelimitedString(handler.getSubProtocols()); 134 return getHttpClient() 135 .headers(nettyHeaders -> setNettyHeaders(requestHeaders, nettyHeaders)) 136 .websocket(protocols, getMaxFramePayloadLength(), this.handlePing) 137 .uri(url.toString()) 138 .handle((inbound, outbound) -> { 139 HttpHeaders responseHeaders = toHttpHeaders(inbound); 140 String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol"); 141 HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol); 142 NettyDataBufferFactory factory = new NettyDataBufferFactory(outbound.alloc()); 143 WebSocketSession session = new ReactorNettyWebSocketSession( 144 inbound, outbound, info, factory, getMaxFramePayloadLength()); 145 if (logger.isDebugEnabled()) { 146 logger.debug("Started session '" + session.getId() + "' for " + url); 147 } 148 return handler.handle(session).checkpoint(url + " [ReactorNettyWebSocketClient]"); 149 }) 150 .doOnRequest(n -> { 151 if (logger.isDebugEnabled()) { 152 logger.debug("Connecting to " + url); 153 } 154 }) 155 .next(); 156 } 157 158 private void setNettyHeaders(HttpHeaders httpHeaders, io.netty.handler.codec.http.HttpHeaders nettyHeaders) { 159 httpHeaders.forEach(nettyHeaders::set); 160 } 161 162 private HttpHeaders toHttpHeaders(WebsocketInbound inbound) { 163 HttpHeaders headers = new HttpHeaders(); 164 io.netty.handler.codec.http.HttpHeaders nettyHeaders = inbound.headers(); 165 nettyHeaders.forEach(entry -> { 166 String name = entry.getKey(); 167 headers.put(name, nettyHeaders.getAll(name)); 168 }); 169 return headers; 170 } 171 172}