001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.rsocket.annotation.support;
018
019import java.lang.reflect.AnnotatedElement;
020import java.util.ArrayList;
021import java.util.List;
022import java.util.Set;
023import java.util.function.Predicate;
024import java.util.stream.Collectors;
025
026import io.rsocket.ConnectionSetupPayload;
027import io.rsocket.RSocket;
028import io.rsocket.SocketAcceptor;
029import io.rsocket.frame.FrameType;
030import io.rsocket.metadata.WellKnownMimeType;
031import reactor.core.publisher.Mono;
032
033import org.springframework.beans.BeanUtils;
034import org.springframework.core.MethodParameter;
035import org.springframework.core.ReactiveAdapter;
036import org.springframework.core.ReactiveAdapterRegistry;
037import org.springframework.core.annotation.AnnotatedElementUtils;
038import org.springframework.core.codec.Decoder;
039import org.springframework.core.codec.Encoder;
040import org.springframework.lang.Nullable;
041import org.springframework.messaging.Message;
042import org.springframework.messaging.MessageDeliveryException;
043import org.springframework.messaging.handler.CompositeMessageCondition;
044import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
045import org.springframework.messaging.handler.HandlerMethod;
046import org.springframework.messaging.handler.MessageCondition;
047import org.springframework.messaging.handler.annotation.MessageMapping;
048import org.springframework.messaging.handler.annotation.reactive.MessageMappingMessageHandler;
049import org.springframework.messaging.handler.annotation.reactive.PayloadMethodArgumentResolver;
050import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
051import org.springframework.messaging.rsocket.MetadataExtractor;
052import org.springframework.messaging.rsocket.RSocketRequester;
053import org.springframework.messaging.rsocket.RSocketStrategies;
054import org.springframework.messaging.rsocket.annotation.ConnectMapping;
055import org.springframework.util.Assert;
056import org.springframework.util.MimeType;
057import org.springframework.util.MimeTypeUtils;
058import org.springframework.util.RouteMatcher;
059import org.springframework.util.StringUtils;
060
061/**
062 * Extension of {@link MessageMappingMessageHandler} for handling RSocket
063 * requests with {@link ConnectMapping @ConnectMapping} and
064 * {@link MessageMapping @MessageMapping} methods.
065 *
066 * <p>For server scenarios this class can be declared as a bean in Spring
067 * configuration and that would detect {@code @MessageMapping} methods in
068 * {@code @Controller} beans. What beans are checked can be changed through a
069 * {@link #setHandlerPredicate(Predicate) handlerPredicate}. Given an instance
070 * of this class, you can then use {@link #responder()} to obtain a
071 * {@link SocketAcceptor} adapter to register with the
072 * {@link io.rsocket.core.RSocketServer}.
073 *
074 * <p>For a client, possibly in the same process as a server, consider
075 * consider using the static factory method
076 * {@link #responder(RSocketStrategies, Object...)} to obtain a client
077 * responder to be registered via
078 * {@link org.springframework.messaging.rsocket.RSocketRequester.Builder#rsocketConnector
079 * RSocketRequester.Builder}.
080 *
081 * <p>For {@code @MessageMapping} methods, this class automatically determines
082 * the RSocket interaction type based on the input and output cardinality of the
083 * method. See the
084 * <a href="https://docs.spring.io/spring/docs/current/spring-framework-reference/web-reactive.html#rsocket-annot-responders">
085 * "Annotated Responders"</a> section of the Spring Framework reference for more details.
086 *
087 * @author Rossen Stoyanchev
088 * @since 5.2
089 */
090public class RSocketMessageHandler extends MessageMappingMessageHandler {
091
092        private final List<Encoder<?>> encoders = new ArrayList<>();
093
094        private RSocketStrategies strategies = RSocketStrategies.create();
095
096        @Nullable
097        private MimeType defaultDataMimeType;
098
099        private MimeType defaultMetadataMimeType = MimeTypeUtils.parseMimeType(
100                        WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
101
102
103        public RSocketMessageHandler() {
104                setRSocketStrategies(this.strategies);
105        }
106
107
108        /**
109         * Configure the encoders to use for encoding handler method return values.
110         * <p>When {@link #setRSocketStrategies(RSocketStrategies) rsocketStrategies}
111         * is set, this property is re-initialized with the encoders in it, and
112         * likewise when this property is set the {@code RSocketStrategies} are
113         * mutated to change the encoders in it.
114         * <p>By default this is set to the
115         * {@linkplain org.springframework.messaging.rsocket.RSocketStrategies.Builder#encoder(Encoder[]) defaults}
116         * from {@code RSocketStrategies}.
117         */
118        public void setEncoders(List<? extends Encoder<?>> encoders) {
119                this.encoders.clear();
120                this.encoders.addAll(encoders);
121                this.strategies = this.strategies.mutate()
122                                .encoders(list -> {
123                                        list.clear();
124                                        list.addAll(encoders);
125                                })
126                                .build();
127        }
128
129        /**
130         * Return the configured {@link #setEncoders(List) encoders}.
131         */
132        public List<? extends Encoder<?>> getEncoders() {
133                return this.encoders;
134        }
135
136        /**
137         * {@inheritDoc}
138         * <p>When {@link #setRSocketStrategies(RSocketStrategies) rsocketStrategies}
139         * is set, this property is re-initialized with the decoders in it, and
140         * likewise when this property is set the {@code RSocketStrategies} are
141         * mutated to change the decoders in them.
142         * <p>By default this is set to the
143         * {@linkplain org.springframework.messaging.rsocket.RSocketStrategies.Builder#decoder(Decoder[]) defaults}
144         * from {@code RSocketStrategies}.
145         */
146        @Override
147        public void setDecoders(List<? extends Decoder<?>> decoders) {
148                super.setDecoders(decoders);
149                this.strategies = this.strategies.mutate()
150                                .decoders(list -> {
151                                        list.clear();
152                                        list.addAll(decoders);
153                                })
154                                .build();
155        }
156
157        /**
158         * {@inheritDoc}
159         * <p>When {@link #setRSocketStrategies(RSocketStrategies) rsocketStrategies}
160         * is set, this property is re-initialized with the route matcher in it, and
161         * likewise when this property is set the {@code RSocketStrategies} are
162         * mutated to change the matcher in it.
163         * <p>By default this is set to the
164         * {@linkplain org.springframework.messaging.rsocket.RSocketStrategies.Builder#routeMatcher(RouteMatcher) defaults}
165         * from {@code RSocketStrategies}.
166         */
167        @Override
168        public void setRouteMatcher(@Nullable RouteMatcher routeMatcher) {
169                super.setRouteMatcher(routeMatcher);
170                this.strategies = this.strategies.mutate().routeMatcher(routeMatcher).build();
171        }
172
173        /**
174         * Configure the registry for adapting various reactive types.
175         * <p>When {@link #setRSocketStrategies(RSocketStrategies) rsocketStrategies}
176         * is set, this property is re-initialized with the registry in it, and
177         * likewise when this property is set the {@code RSocketStrategies} are
178         * mutated to change the registry in it.
179         * <p>By default this is set to the
180         * {@link org.springframework.messaging.rsocket.RSocketStrategies.Builder#reactiveAdapterStrategy(ReactiveAdapterRegistry) defaults}
181         * from {@code RSocketStrategies}.
182         */
183        @Override
184        public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) {
185                super.setReactiveAdapterRegistry(registry);
186                this.strategies = this.strategies.mutate().reactiveAdapterStrategy(registry).build();
187        }
188
189        /**
190         * Configure a {@link MetadataExtractor} to extract the route along with
191         * other metadata.
192         * <p>When {@link #setRSocketStrategies(RSocketStrategies) rsocketStrategies}
193         * is set, this property is re-initialized with the extractor in it, and
194         * likewise when this property is set the {@code RSocketStrategies} are
195         * mutated to change the extractor in it.
196         * <p>By default this is set to the
197         * {@link org.springframework.messaging.rsocket.RSocketStrategies.Builder#metadataExtractor(MetadataExtractor)} defaults}
198         * from {@code RSocketStrategies}.
199         * @param extractor the extractor to use
200         */
201        public void setMetadataExtractor(MetadataExtractor extractor) {
202                this.strategies = this.strategies.mutate().metadataExtractor(extractor).build();
203        }
204
205        /**
206         * Return the configured {@link #setMetadataExtractor MetadataExtractor}.
207         */
208        public MetadataExtractor getMetadataExtractor() {
209                return this.strategies.metadataExtractor();
210        }
211
212        /**
213         * Configure this handler through an {@link RSocketStrategies} instance which
214         * can be re-used to initialize a client-side {@link RSocketRequester}.
215         * <p>When this property is set, in turn it sets the following:
216         * <ul>
217         * <li>{@link #setDecoders(List)}
218         * <li>{@link #setEncoders(List)}
219         * <li>{@link #setRouteMatcher(RouteMatcher)}
220         * <li>{@link #setMetadataExtractor(MetadataExtractor)}
221         * <li>{@link #setReactiveAdapterRegistry(ReactiveAdapterRegistry)}
222         * </ul>
223         * <p>By default this is set to {@link RSocketStrategies#create()} which in
224         * turn sets default settings for all related properties.
225         */
226        public void setRSocketStrategies(RSocketStrategies rsocketStrategies) {
227                this.strategies = rsocketStrategies;
228                this.encoders.clear();
229                this.encoders.addAll(this.strategies.encoders());
230                super.setDecoders(this.strategies.decoders());
231                super.setRouteMatcher(this.strategies.routeMatcher());
232                super.setReactiveAdapterRegistry(this.strategies.reactiveAdapterRegistry());
233        }
234
235        /**
236         * Return the {@link #setRSocketStrategies configured} {@code RSocketStrategies}.
237         */
238        public RSocketStrategies getRSocketStrategies() {
239                return this.strategies;
240        }
241
242        /**
243         * Configure the default content type to use for data payloads if the
244         * {@code SETUP} frame did not specify one.
245         * <p>By default this is not set.
246         * @param mimeType the MimeType to use
247         */
248        public void setDefaultDataMimeType(@Nullable MimeType mimeType) {
249                this.defaultDataMimeType = mimeType;
250        }
251
252        /**
253         * Return the configured
254         * {@link #setDefaultDataMimeType defaultDataMimeType}, or {@code null}.
255         */
256        @Nullable
257        public MimeType getDefaultDataMimeType() {
258                return this.defaultDataMimeType;
259        }
260
261        /**
262         * Configure the default {@code MimeType} for payload data if the
263         * {@code SETUP} frame did not specify one.
264         * <p>By default this is set to {@code "message/x.rsocket.composite-metadata.v0"}
265         * @param mimeType the MimeType to use
266         */
267        public void setDefaultMetadataMimeType(MimeType mimeType) {
268                Assert.notNull(mimeType, "'metadataMimeType' is required");
269                this.defaultMetadataMimeType = mimeType;
270        }
271
272        /**
273         * Return the configured
274         * {@link #setDefaultMetadataMimeType defaultMetadataMimeType}.
275         */
276        public MimeType getDefaultMetadataMimeType() {
277                return this.defaultMetadataMimeType;
278        }
279
280
281        @Override
282        public void afterPropertiesSet() {
283
284                // Add argument resolver before parent initializes argument resolution
285                getArgumentResolverConfigurer().addCustomResolver(new RSocketRequesterMethodArgumentResolver());
286
287                super.afterPropertiesSet();
288
289                getHandlerMethods().forEach((composite, handler) -> {
290                        if (composite.getMessageConditions().contains(RSocketFrameTypeMessageCondition.CONNECT_CONDITION)) {
291                                MethodParameter returnType = handler.getReturnType();
292                                if (getCardinality(returnType) > 0) {
293                                        throw new IllegalStateException(
294                                                        "Invalid @ConnectMapping method. " +
295                                                                        "Return type must be void or a void async type: " + handler);
296                                }
297                        }
298                });
299        }
300
301        @Override
302        protected List<? extends HandlerMethodReturnValueHandler> initReturnValueHandlers() {
303                List<HandlerMethodReturnValueHandler> handlers = new ArrayList<>();
304                handlers.add(new RSocketPayloadReturnValueHandler(this.encoders, getReactiveAdapterRegistry()));
305                handlers.addAll(getReturnValueHandlerConfigurer().getCustomHandlers());
306                return handlers;
307        }
308
309
310        @Override
311        @Nullable
312        protected CompositeMessageCondition getCondition(AnnotatedElement element) {
313                MessageMapping ann1 = AnnotatedElementUtils.findMergedAnnotation(element, MessageMapping.class);
314                if (ann1 != null && ann1.value().length > 0) {
315                        return new CompositeMessageCondition(
316                                        RSocketFrameTypeMessageCondition.EMPTY_CONDITION,
317                                        new DestinationPatternsMessageCondition(processDestinations(ann1.value()), obtainRouteMatcher()));
318                }
319                ConnectMapping ann2 = AnnotatedElementUtils.findMergedAnnotation(element, ConnectMapping.class);
320                if (ann2 != null) {
321                        String[] patterns = processDestinations(ann2.value());
322                        return new CompositeMessageCondition(
323                                        RSocketFrameTypeMessageCondition.CONNECT_CONDITION,
324                                        new DestinationPatternsMessageCondition(patterns, obtainRouteMatcher()));
325                }
326                return null;
327        }
328
329        @Override
330        protected CompositeMessageCondition extendMapping(CompositeMessageCondition composite, HandlerMethod handler) {
331
332                List<MessageCondition<?>> conditions = composite.getMessageConditions();
333                Assert.isTrue(conditions.size() == 2 &&
334                                conditions.get(0) instanceof RSocketFrameTypeMessageCondition &&
335                                conditions.get(1) instanceof DestinationPatternsMessageCondition,
336                                "Unexpected message condition types");
337
338                if (conditions.get(0) != RSocketFrameTypeMessageCondition.EMPTY_CONDITION) {
339                        return composite;
340                }
341
342                int responseCardinality = getCardinality(handler.getReturnType());
343                int requestCardinality = 0;
344                for (MethodParameter parameter : handler.getMethodParameters()) {
345                        if (getArgumentResolvers().getArgumentResolver(parameter) instanceof PayloadMethodArgumentResolver) {
346                                requestCardinality = getCardinality(parameter);
347                        }
348                }
349
350                return new CompositeMessageCondition(
351                                RSocketFrameTypeMessageCondition.getCondition(requestCardinality, responseCardinality),
352                                conditions.get(1));
353        }
354
355        private int getCardinality(MethodParameter parameter) {
356                Class<?> clazz = parameter.getParameterType();
357                ReactiveAdapter adapter = getReactiveAdapterRegistry().getAdapter(clazz);
358                if (adapter == null) {
359                        return clazz.equals(void.class) ? 0 : 1;
360                }
361                else if (parameter.nested().getNestedParameterType().equals(Void.class)) {
362                        return 0;
363                }
364                else {
365                        return adapter.isMultiValue() ? 2 : 1;
366                }
367        }
368
369        @Override
370        protected void handleNoMatch(@Nullable RouteMatcher.Route destination, Message<?> message) {
371                FrameType frameType = RSocketFrameTypeMessageCondition.getFrameType(message);
372                if (frameType == FrameType.SETUP || frameType == FrameType.METADATA_PUSH) {
373                        return;  // optional handling
374                }
375                if (frameType == FrameType.REQUEST_FNF) {
376                        // Can't propagate error to client, so just log
377                        logger.warn("No handler for fireAndForget to '" + destination + "'");
378                        return;
379                }
380
381                Set<FrameType> frameTypes = getHandlerMethods().keySet().stream()
382                                .map(CompositeMessageCondition::getMessageConditions)
383                                .filter(conditions -> conditions.get(1).getMatchingCondition(message) != null)
384                                .map(conditions -> (RSocketFrameTypeMessageCondition) conditions.get(0))
385                                .flatMap(condition -> condition.getFrameTypes().stream())
386                                .collect(Collectors.toSet());
387
388                throw new MessageDeliveryException(frameTypes.isEmpty() ?
389                                "No handler for destination '" + destination + "'" :
390                                "Destination '" + destination + "' does not support " + frameType + ". " +
391                                                "Supported interaction(s): " + frameTypes);
392        }
393
394        /**
395         * Return an RSocket {@link SocketAcceptor} backed by this
396         * {@code RSocketMessageHandler} instance that can be plugged in as a
397         * {@link io.rsocket.core.RSocketConnector#acceptor(SocketAcceptor) client} or
398         * {@link io.rsocket.core.RSocketServer#acceptor(SocketAcceptor) server}
399         * RSocket responder.
400         * <p>The initial {@link ConnectionSetupPayload} is handled through
401         * {@link ConnectMapping @ConnectionMapping} methods that can be asynchronous
402         * and return {@code Mono<Void>} with an error signal preventing the
403         * connection. Such a method can also start requests to the client but that
404         * must be done decoupled from handling and from the current thread.
405         * <p>Subsequent requests on the connection can be handled with
406         * {@link MessageMapping MessageMapping} methods.
407         */
408        public SocketAcceptor responder() {
409                return (setupPayload, sendingRSocket) -> {
410                        MessagingRSocket responder;
411                        try {
412                                responder = createResponder(setupPayload, sendingRSocket);
413                        }
414                        catch (Throwable ex) {
415                                return Mono.error(ex);
416                        }
417                        return responder.handleConnectionSetupPayload(setupPayload).then(Mono.just(responder));
418                };
419        }
420
421        private MessagingRSocket createResponder(ConnectionSetupPayload setupPayload, RSocket rsocket) {
422                String str = setupPayload.dataMimeType();
423                MimeType dataMimeType = StringUtils.hasText(str) ? MimeTypeUtils.parseMimeType(str) : this.defaultDataMimeType;
424                Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value");
425                Assert.isTrue(isDataMimeTypeSupported(dataMimeType), "Data MimeType '" + dataMimeType + "' not supported");
426
427                str = setupPayload.metadataMimeType();
428                MimeType metaMimeType = StringUtils.hasText(str) ? MimeTypeUtils.parseMimeType(str) : this.defaultMetadataMimeType;
429                Assert.notNull(metaMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value");
430
431                RSocketRequester requester = RSocketRequester.wrap(rsocket, dataMimeType, metaMimeType, this.strategies);
432                return new MessagingRSocket(dataMimeType, metaMimeType, getMetadataExtractor(),
433                                requester, this, obtainRouteMatcher(), this.strategies);
434        }
435
436        private boolean isDataMimeTypeSupported(MimeType dataMimeType) {
437                for (Encoder<?> encoder : getEncoders()) {
438                        for (MimeType encodable : encoder.getEncodableMimeTypes()) {
439                                if (encodable.isCompatibleWith(dataMimeType)) {
440                                        return true;
441                                }
442                        }
443                }
444                return false;
445        }
446
447        /**
448         * Static factory method to create an RSocket {@link SocketAcceptor}
449         * backed by handlers with annotated methods. Effectively a shortcut for:
450         * <pre class="code">
451         * RSocketMessageHandler handler = new RSocketMessageHandler();
452         * handler.setHandlers(handlers);
453         * handler.setRSocketStrategies(strategies);
454         * handler.afterPropertiesSet();
455         *
456         * SocketAcceptor acceptor = handler.responder();
457         * </pre>
458         * <p>This is intended for programmatic creation and registration of a
459         * client-side responder. For example:
460         * <pre class="code">
461         * SocketAcceptor responder =
462         *         RSocketMessageHandler.responder(strategies, new ClientHandler());
463         *
464         * RSocketRequester.builder()
465         *         .rsocketConnector(connector -> connector.acceptor(responder))
466         *         .connectTcp("localhost", server.address().getPort());
467         * </pre>
468         *
469         * <p>Note that the given handlers do not need to have any stereotype
470         * annotations such as {@code @Controller} which helps to avoid overlap with
471         * server side handlers that may be used in the same application. However,
472         * for more advanced scenarios, e.g. discovering handlers through a custom
473         * stereotype annotation, consider declaring {@code RSocketMessageHandler}
474         * as a bean, and then obtain the responder from it.
475         *
476         * @param strategies the strategies to set on the created
477         * {@code RSocketMessageHandler}
478         * @param candidateHandlers a list of Objects and/or Classes with annotated
479         * handler methods; used to call {@link #setHandlers(List)} with
480         * on the created {@code RSocketMessageHandler}
481         * @return a configurer that may be passed into
482         * {@link org.springframework.messaging.rsocket.RSocketRequester.Builder#rsocketConnector}
483         * @since 5.2.6
484         */
485        public static SocketAcceptor responder(RSocketStrategies strategies, Object... candidateHandlers) {
486                Assert.notEmpty(candidateHandlers, "No handlers");
487                List<Object> handlers = new ArrayList<>(candidateHandlers.length);
488                for (Object obj : candidateHandlers) {
489                        handlers.add(obj instanceof Class ? BeanUtils.instantiateClass((Class<?>) obj) : obj);
490                }
491                RSocketMessageHandler handler = new RSocketMessageHandler();
492                handler.setHandlers(handlers);
493                handler.setRSocketStrategies(strategies);
494                handler.afterPropertiesSet();
495                return handler.responder();
496        }
497
498        /**
499         * Static factory method for a configurer of a client side responder with
500         * annotated handler methods. This is intended to be passed into
501         * {@link org.springframework.messaging.rsocket.RSocketRequester.Builder#rsocketFactory}.
502         * <p>In effect a shortcut to create and initialize
503         * {@code RSocketMessageHandler} with the given strategies and handlers,
504         * use {@link #responder()} to obtain the responder, and plug that into
505         * {@link io.rsocket.RSocketFactory.ClientRSocketFactory ClientRSocketFactory}.
506         * For more advanced scenarios, e.g. discovering handlers through a custom
507         * stereotype annotation, consider declaring {@code RSocketMessageHandler}
508         * as a bean, and then obtain the responder from it.
509         * @param strategies the strategies to set on the created
510         * {@code RSocketMessageHandler}
511         * @param candidateHandlers a list of Objects and/or Classes with annotated
512         * handler methods; used to call {@link #setHandlers(List)} with
513         * on the created {@code RSocketMessageHandler}
514         * @return a configurer that may be passed into
515         * {@link org.springframework.messaging.rsocket.RSocketRequester.Builder#rsocketFactory}
516         * @deprecated as of 5.2.6  following the deprecation of
517         * {@link io.rsocket.RSocketFactory.ClientRSocketFactory RSocketFactory.ClientRSocketFactory}
518         * in RSocket 1.0 RC7.
519         */
520        @Deprecated
521        public static org.springframework.messaging.rsocket.ClientRSocketFactoryConfigurer clientResponder(
522                        RSocketStrategies strategies, Object... candidateHandlers) {
523
524                Assert.notEmpty(candidateHandlers, "No handlers");
525                List<Object> handlers = new ArrayList<>(candidateHandlers.length);
526                for (Object obj : candidateHandlers) {
527                        handlers.add(obj instanceof Class ? BeanUtils.instantiateClass((Class<?>) obj) : obj);
528                }
529
530                return factory -> {
531                        RSocketMessageHandler handler = new RSocketMessageHandler();
532                        handler.setHandlers(handlers);
533                        handler.setRSocketStrategies(strategies);
534                        handler.afterPropertiesSet();
535                        factory.acceptor(handler.responder());
536                };
537        }
538}
539
540