001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.List;
022
023import javax.servlet.AsyncContext;
024import javax.servlet.AsyncEvent;
025import javax.servlet.AsyncListener;
026import javax.servlet.ServletContext;
027import javax.servlet.ServletException;
028import javax.servlet.ServletRequest;
029import javax.servlet.ServletResponse;
030import javax.servlet.http.HttpServletRequest;
031import javax.servlet.http.HttpServletResponse;
032
033import org.springframework.beans.BeanUtils;
034import org.springframework.lang.Nullable;
035import org.springframework.util.Assert;
036import org.springframework.web.util.WebUtils;
037
038/**
039 * Mock implementation of the {@link AsyncContext} interface.
040 *
041 * @author Rossen Stoyanchev
042 * @since 3.2
043 */
044public class MockAsyncContext implements AsyncContext {
045
046        private final HttpServletRequest request;
047
048        @Nullable
049        private final HttpServletResponse response;
050
051        private final List<AsyncListener> listeners = new ArrayList<>();
052
053        @Nullable
054        private String dispatchedPath;
055
056        private long timeout = 10 * 1000L;
057
058        private final List<Runnable> dispatchHandlers = new ArrayList<>();
059
060
061        public MockAsyncContext(ServletRequest request, @Nullable ServletResponse response) {
062                this.request = (HttpServletRequest) request;
063                this.response = (HttpServletResponse) response;
064        }
065
066
067        public void addDispatchHandler(Runnable handler) {
068                Assert.notNull(handler, "Dispatch handler must not be null");
069                synchronized (this) {
070                        if (this.dispatchedPath == null) {
071                                this.dispatchHandlers.add(handler);
072                        }
073                        else {
074                                handler.run();
075                        }
076                }
077        }
078
079        @Override
080        public ServletRequest getRequest() {
081                return this.request;
082        }
083
084        @Override
085        @Nullable
086        public ServletResponse getResponse() {
087                return this.response;
088        }
089
090        @Override
091        public boolean hasOriginalRequestAndResponse() {
092                return (this.request instanceof MockHttpServletRequest && this.response instanceof MockHttpServletResponse);
093        }
094
095        @Override
096        public void dispatch() {
097                dispatch(this.request.getRequestURI());
098        }
099
100        @Override
101        public void dispatch(String path) {
102                dispatch(null, path);
103        }
104
105        @Override
106        public void dispatch(@Nullable ServletContext context, String path) {
107                synchronized (this) {
108                        this.dispatchedPath = path;
109                        this.dispatchHandlers.forEach(Runnable::run);
110                }
111        }
112
113        @Nullable
114        public String getDispatchedPath() {
115                return this.dispatchedPath;
116        }
117
118        @Override
119        public void complete() {
120                MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(this.request, MockHttpServletRequest.class);
121                if (mockRequest != null) {
122                        mockRequest.setAsyncStarted(false);
123                }
124                for (AsyncListener listener : this.listeners) {
125                        try {
126                                listener.onComplete(new AsyncEvent(this, this.request, this.response));
127                        }
128                        catch (IOException ex) {
129                                throw new IllegalStateException("AsyncListener failure", ex);
130                        }
131                }
132        }
133
134        @Override
135        public void start(Runnable runnable) {
136                runnable.run();
137        }
138
139        @Override
140        public void addListener(AsyncListener listener) {
141                this.listeners.add(listener);
142        }
143
144        @Override
145        public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) {
146                this.listeners.add(listener);
147        }
148
149        public List<AsyncListener> getListeners() {
150                return this.listeners;
151        }
152
153        @Override
154        public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
155                return BeanUtils.instantiateClass(clazz);
156        }
157
158        /**
159         * By default this is set to 10000 (10 seconds) even though the Servlet API
160         * specifies a default async request timeout of 30 seconds. Keep in mind the
161         * timeout could further be impacted by global configuration through the MVC
162         * Java config or the XML namespace, as well as be overridden per request on
163         * {@link org.springframework.web.context.request.async.DeferredResult DeferredResult}
164         * or on
165         * {@link org.springframework.web.servlet.mvc.method.annotation.SseEmitter SseEmitter}.
166         * @param timeout the timeout value to use.
167         * @see AsyncContext#setTimeout(long)
168         */
169        @Override
170        public void setTimeout(long timeout) {
171                this.timeout = timeout;
172        }
173
174        @Override
175        public long getTimeout() {
176                return this.timeout;
177        }
178
179}