001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.List;
022import javax.servlet.AsyncContext;
023import javax.servlet.AsyncEvent;
024import javax.servlet.AsyncListener;
025import javax.servlet.ServletContext;
026import javax.servlet.ServletException;
027import javax.servlet.ServletRequest;
028import javax.servlet.ServletResponse;
029import javax.servlet.http.HttpServletRequest;
030import javax.servlet.http.HttpServletResponse;
031
032import org.springframework.beans.BeanUtils;
033import org.springframework.util.Assert;
034import org.springframework.web.util.WebUtils;
035
036/**
037 * Mock implementation of the {@link AsyncContext} interface.
038 *
039 * @author Rossen Stoyanchev
040 * @since 3.2
041 */
042public class MockAsyncContext implements AsyncContext {
043
044        private final HttpServletRequest request;
045
046        private final HttpServletResponse response;
047
048        private final List<AsyncListener> listeners = new ArrayList<AsyncListener>();
049
050        private String dispatchedPath;
051
052        private long timeout = 10 * 1000L;      // 10 seconds is Tomcat's default
053
054        private final List<Runnable> dispatchHandlers = new ArrayList<Runnable>();
055
056
057        public MockAsyncContext(ServletRequest request, ServletResponse response) {
058                this.request = (HttpServletRequest) request;
059                this.response = (HttpServletResponse) response;
060        }
061
062
063        public void addDispatchHandler(Runnable handler) {
064                Assert.notNull(handler, "Dispatch handler must not be null");
065                synchronized (this) {
066                        if (this.dispatchedPath == null) {
067                                this.dispatchHandlers.add(handler);
068                        }
069                        else {
070                                handler.run();
071                        }
072                }
073        }
074
075        @Override
076        public ServletRequest getRequest() {
077                return this.request;
078        }
079
080        @Override
081        public ServletResponse getResponse() {
082                return this.response;
083        }
084
085        @Override
086        public boolean hasOriginalRequestAndResponse() {
087                return (this.request instanceof MockHttpServletRequest && this.response instanceof MockHttpServletResponse);
088        }
089
090        @Override
091        public void dispatch() {
092                dispatch(this.request.getRequestURI());
093        }
094
095        @Override
096        public void dispatch(String path) {
097                dispatch(null, path);
098        }
099
100        @Override
101        public void dispatch(ServletContext context, String path) {
102                synchronized (this) {
103                        this.dispatchedPath = path;
104                        for (Runnable r : this.dispatchHandlers) {
105                                r.run();
106                        }
107                }
108        }
109
110        public String getDispatchedPath() {
111                return this.dispatchedPath;
112        }
113
114        @Override
115        public void complete() {
116                MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
117                if (mockRequest != null) {
118                        mockRequest.setAsyncStarted(false);
119                }
120                for (AsyncListener listener : this.listeners) {
121                        try {
122                                listener.onComplete(new AsyncEvent(this, this.request, this.response));
123                        }
124                        catch (IOException ex) {
125                                throw new IllegalStateException("AsyncListener failure", ex);
126                        }
127                }
128        }
129
130        @Override
131        public void start(Runnable runnable) {
132                runnable.run();
133        }
134
135        @Override
136        public void addListener(AsyncListener listener) {
137                this.listeners.add(listener);
138        }
139
140        @Override
141        public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) {
142                this.listeners.add(listener);
143        }
144
145        public List<AsyncListener> getListeners() {
146                return this.listeners;
147        }
148
149        @Override
150        public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
151                return BeanUtils.instantiateClass(clazz);
152        }
153
154        @Override
155        public void setTimeout(long timeout) {
156                this.timeout = timeout;
157        }
158
159        @Override
160        public long getTimeout() {
161                return this.timeout;
162        }
163
164}