001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.rsocket.annotation.support;
018
019import java.util.List;
020
021import io.rsocket.Payload;
022import reactor.core.publisher.Flux;
023import reactor.core.publisher.Mono;
024import reactor.core.publisher.MonoProcessor;
025
026import org.springframework.core.MethodParameter;
027import org.springframework.core.ReactiveAdapterRegistry;
028import org.springframework.core.codec.Encoder;
029import org.springframework.core.io.buffer.DataBuffer;
030import org.springframework.lang.Nullable;
031import org.springframework.messaging.Message;
032import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler;
033import org.springframework.messaging.rsocket.PayloadUtils;
034import org.springframework.util.Assert;
035
036/**
037 * Extension of {@link AbstractEncoderMethodReturnValueHandler} that
038 * {@link #handleEncodedContent handles} encoded content by wrapping data buffers
039 * as RSocket payloads and by passing those to the {@link MonoProcessor}
040 * from the {@link #RESPONSE_HEADER} header.
041 *
042 * @author Rossen Stoyanchev
043 * @since 5.2
044 */
045public class RSocketPayloadReturnValueHandler extends AbstractEncoderMethodReturnValueHandler {
046
047        /**
048         * Message header name that is expected to have a {@link MonoProcessor}
049         * which will receive the {@code Flux<Payload>} that represents the response.
050         */
051        public static final String RESPONSE_HEADER = "rsocketResponse";
052
053
054        public RSocketPayloadReturnValueHandler(List<Encoder<?>> encoders, ReactiveAdapterRegistry registry) {
055                super(encoders, registry);
056        }
057
058
059        @Override
060        @SuppressWarnings("unchecked")
061        protected Mono<Void> handleEncodedContent(
062                        Flux<DataBuffer> encodedContent, MethodParameter returnType, Message<?> message) {
063
064                MonoProcessor<Flux<Payload>> replyMono = getReplyMono(message);
065                Assert.notNull(replyMono, "Missing '" + RESPONSE_HEADER + "'");
066                replyMono.onNext(encodedContent.map(PayloadUtils::createPayload));
067                replyMono.onComplete();
068                return Mono.empty();
069        }
070
071        @Override
072        protected Mono<Void> handleNoContent(MethodParameter returnType, Message<?> message) {
073                MonoProcessor<Flux<Payload>> replyMono = getReplyMono(message);
074                if (replyMono != null) {
075                        replyMono.onComplete();
076                }
077                return Mono.empty();
078        }
079
080        @Nullable
081        @SuppressWarnings("unchecked")
082        private MonoProcessor<Flux<Payload>> getReplyMono(Message<?> message) {
083                Object headerValue = message.getHeaders().get(RESPONSE_HEADER);
084                Assert.state(headerValue == null || headerValue instanceof MonoProcessor, "Expected MonoProcessor");
085                return (MonoProcessor<Flux<Payload>>) headerValue;
086        }
087
088}