001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.File;
020import java.io.IOException;
021import java.io.InputStream;
022import java.net.MalformedURLException;
023import java.net.URL;
024import java.nio.file.InvalidPathException;
025import java.util.Collections;
026import java.util.Enumeration;
027import java.util.EventListener;
028import java.util.HashMap;
029import java.util.LinkedHashMap;
030import java.util.LinkedHashSet;
031import java.util.Map;
032import java.util.Set;
033
034import javax.servlet.Filter;
035import javax.servlet.FilterRegistration;
036import javax.servlet.RequestDispatcher;
037import javax.servlet.Servlet;
038import javax.servlet.ServletContext;
039import javax.servlet.ServletException;
040import javax.servlet.ServletRegistration;
041import javax.servlet.SessionCookieConfig;
042import javax.servlet.SessionTrackingMode;
043import javax.servlet.descriptor.JspConfigDescriptor;
044
045import org.apache.commons.logging.Log;
046import org.apache.commons.logging.LogFactory;
047
048import org.springframework.core.io.DefaultResourceLoader;
049import org.springframework.core.io.Resource;
050import org.springframework.core.io.ResourceLoader;
051import org.springframework.http.MediaType;
052import org.springframework.http.MediaTypeFactory;
053import org.springframework.lang.Nullable;
054import org.springframework.util.Assert;
055import org.springframework.util.ClassUtils;
056import org.springframework.util.MimeType;
057import org.springframework.util.ObjectUtils;
058import org.springframework.util.StringUtils;
059import org.springframework.web.util.WebUtils;
060
061/**
062 * Mock implementation of the {@link javax.servlet.ServletContext} interface.
063 *
064 * <p>As of Spring 5.0, this set of mocks is designed on a Servlet 4.0 baseline.
065 *
066 * <p>Compatible with Servlet 3.1 but can be configured to expose a specific version
067 * through {@link #setMajorVersion}/{@link #setMinorVersion}; default is 3.1.
068 * Note that Servlet 3.1 support is limited: servlet, filter and listener
069 * registration methods are not supported; neither is JSP configuration.
070 * We generally do not recommend to unit test your ServletContainerInitializers and
071 * WebApplicationInitializers which is where those registration methods would be used.
072 *
073 * <p>For setting up a full {@code WebApplicationContext} in a test environment, you can
074 * use {@code AnnotationConfigWebApplicationContext}, {@code XmlWebApplicationContext},
075 * or {@code GenericWebApplicationContext}, passing in a corresponding
076 * {@code MockServletContext} instance. Consider configuring your
077 * {@code MockServletContext} with a {@code FileSystemResourceLoader} in order to
078 * interpret resource paths as relative filesystem locations.
079 *
080 * @author Rod Johnson
081 * @author Juergen Hoeller
082 * @author Sam Brannen
083 * @since 1.0.2
084 * @see #MockServletContext(org.springframework.core.io.ResourceLoader)
085 * @see org.springframework.web.context.support.AnnotationConfigWebApplicationContext
086 * @see org.springframework.web.context.support.XmlWebApplicationContext
087 * @see org.springframework.web.context.support.GenericWebApplicationContext
088 */
089public class MockServletContext implements ServletContext {
090
091        /** Default Servlet name used by Tomcat, Jetty, JBoss, and GlassFish: {@value}. */
092        private static final String COMMON_DEFAULT_SERVLET_NAME = "default";
093
094        private static final String TEMP_DIR_SYSTEM_PROPERTY = "java.io.tmpdir";
095
096        private static final Set<SessionTrackingMode> DEFAULT_SESSION_TRACKING_MODES = new LinkedHashSet<>(4);
097
098        static {
099                DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.COOKIE);
100                DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.URL);
101                DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.SSL);
102        }
103
104
105        private final Log logger = LogFactory.getLog(getClass());
106
107        private final ResourceLoader resourceLoader;
108
109        private final String resourceBasePath;
110
111        private String contextPath = "";
112
113        private final Map<String, ServletContext> contexts = new HashMap<>();
114
115        private int majorVersion = 3;
116
117        private int minorVersion = 1;
118
119        private int effectiveMajorVersion = 3;
120
121        private int effectiveMinorVersion = 1;
122
123        private final Map<String, RequestDispatcher> namedRequestDispatchers = new HashMap<>();
124
125        private String defaultServletName = COMMON_DEFAULT_SERVLET_NAME;
126
127        private final Map<String, String> initParameters = new LinkedHashMap<>();
128
129        private final Map<String, Object> attributes = new LinkedHashMap<>();
130
131        private String servletContextName = "MockServletContext";
132
133        private final Set<String> declaredRoles = new LinkedHashSet<>();
134
135        @Nullable
136        private Set<SessionTrackingMode> sessionTrackingModes;
137
138        private final SessionCookieConfig sessionCookieConfig = new MockSessionCookieConfig();
139
140        private int sessionTimeout;
141
142        @Nullable
143        private String requestCharacterEncoding;
144
145        @Nullable
146        private String responseCharacterEncoding;
147
148        private final Map<String, MediaType> mimeTypes = new LinkedHashMap<>();
149
150
151        /**
152         * Create a new {@code MockServletContext}, using no base path and a
153         * {@link DefaultResourceLoader} (i.e. the classpath root as WAR root).
154         * @see org.springframework.core.io.DefaultResourceLoader
155         */
156        public MockServletContext() {
157                this("", null);
158        }
159
160        /**
161         * Create a new {@code MockServletContext}, using a {@link DefaultResourceLoader}.
162         * @param resourceBasePath the root directory of the WAR (should not end with a slash)
163         * @see org.springframework.core.io.DefaultResourceLoader
164         */
165        public MockServletContext(String resourceBasePath) {
166                this(resourceBasePath, null);
167        }
168
169        /**
170         * Create a new {@code MockServletContext}, using the specified {@link ResourceLoader}
171         * and no base path.
172         * @param resourceLoader the ResourceLoader to use (or null for the default)
173         */
174        public MockServletContext(@Nullable ResourceLoader resourceLoader) {
175                this("", resourceLoader);
176        }
177
178        /**
179         * Create a new {@code MockServletContext} using the supplied resource base
180         * path and resource loader.
181         * <p>Registers a {@link MockRequestDispatcher} for the Servlet named
182         * {@literal 'default'}.
183         * @param resourceBasePath the root directory of the WAR (should not end with a slash)
184         * @param resourceLoader the ResourceLoader to use (or null for the default)
185         * @see #registerNamedDispatcher
186         */
187        public MockServletContext(String resourceBasePath, @Nullable ResourceLoader resourceLoader) {
188                this.resourceLoader = (resourceLoader != null ? resourceLoader : new DefaultResourceLoader());
189                this.resourceBasePath = resourceBasePath;
190
191                // Use JVM temp dir as ServletContext temp dir.
192                String tempDir = System.getProperty(TEMP_DIR_SYSTEM_PROPERTY);
193                if (tempDir != null) {
194                        this.attributes.put(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File(tempDir));
195                }
196
197                registerNamedDispatcher(this.defaultServletName, new MockRequestDispatcher(this.defaultServletName));
198        }
199
200        /**
201         * Build a full resource location for the given path, prepending the resource
202         * base path of this {@code MockServletContext}.
203         * @param path the path as specified
204         * @return the full resource path
205         */
206        protected String getResourceLocation(String path) {
207                if (!path.startsWith("/")) {
208                        path = "/" + path;
209                }
210                return this.resourceBasePath + path;
211        }
212
213        public void setContextPath(String contextPath) {
214                this.contextPath = contextPath;
215        }
216
217        @Override
218        public String getContextPath() {
219                return this.contextPath;
220        }
221
222        public void registerContext(String contextPath, ServletContext context) {
223                this.contexts.put(contextPath, context);
224        }
225
226        @Override
227        public ServletContext getContext(String contextPath) {
228                if (this.contextPath.equals(contextPath)) {
229                        return this;
230                }
231                return this.contexts.get(contextPath);
232        }
233
234        public void setMajorVersion(int majorVersion) {
235                this.majorVersion = majorVersion;
236        }
237
238        @Override
239        public int getMajorVersion() {
240                return this.majorVersion;
241        }
242
243        public void setMinorVersion(int minorVersion) {
244                this.minorVersion = minorVersion;
245        }
246
247        @Override
248        public int getMinorVersion() {
249                return this.minorVersion;
250        }
251
252        public void setEffectiveMajorVersion(int effectiveMajorVersion) {
253                this.effectiveMajorVersion = effectiveMajorVersion;
254        }
255
256        @Override
257        public int getEffectiveMajorVersion() {
258                return this.effectiveMajorVersion;
259        }
260
261        public void setEffectiveMinorVersion(int effectiveMinorVersion) {
262                this.effectiveMinorVersion = effectiveMinorVersion;
263        }
264
265        @Override
266        public int getEffectiveMinorVersion() {
267                return this.effectiveMinorVersion;
268        }
269
270        @Override
271        @Nullable
272        public String getMimeType(String filePath) {
273                String extension = StringUtils.getFilenameExtension(filePath);
274                if (this.mimeTypes.containsKey(extension)) {
275                        return this.mimeTypes.get(extension).toString();
276                }
277                else {
278                        return MediaTypeFactory.getMediaType(filePath).
279                                        map(MimeType::toString)
280                                        .orElse(null);
281                }
282        }
283
284        /**
285         * Adds a mime type mapping for use by {@link #getMimeType(String)}.
286         * @param fileExtension a file extension, such as {@code txt}, {@code gif}
287         * @param mimeType the mime type
288         */
289        public void addMimeType(String fileExtension, MediaType mimeType) {
290                Assert.notNull(fileExtension, "'fileExtension' must not be null");
291                this.mimeTypes.put(fileExtension, mimeType);
292        }
293
294        @Override
295        @Nullable
296        public Set<String> getResourcePaths(String path) {
297                String actualPath = (path.endsWith("/") ? path : path + "/");
298                String resourceLocation = getResourceLocation(actualPath);
299                Resource resource = null;
300                try {
301                        resource = this.resourceLoader.getResource(resourceLocation);
302                        File file = resource.getFile();
303                        String[] fileList = file.list();
304                        if (ObjectUtils.isEmpty(fileList)) {
305                                return null;
306                        }
307                        Set<String> resourcePaths = new LinkedHashSet<>(fileList.length);
308                        for (String fileEntry : fileList) {
309                                String resultPath = actualPath + fileEntry;
310                                if (resource.createRelative(fileEntry).getFile().isDirectory()) {
311                                        resultPath += "/";
312                                }
313                                resourcePaths.add(resultPath);
314                        }
315                        return resourcePaths;
316                }
317                catch (InvalidPathException | IOException ex ) {
318                        if (logger.isWarnEnabled()) {
319                                logger.warn("Could not get resource paths for " +
320                                                (resource != null ? resource : resourceLocation), ex);
321                        }
322                        return null;
323                }
324        }
325
326        @Override
327        @Nullable
328        public URL getResource(String path) throws MalformedURLException {
329                String resourceLocation = getResourceLocation(path);
330                Resource resource = null;
331                try {
332                        resource = this.resourceLoader.getResource(resourceLocation);
333                        if (!resource.exists()) {
334                                return null;
335                        }
336                        return resource.getURL();
337                }
338                catch (MalformedURLException ex) {
339                        throw ex;
340                }
341                catch (InvalidPathException | IOException ex) {
342                        if (logger.isWarnEnabled()) {
343                                logger.warn("Could not get URL for resource " +
344                                                (resource != null ? resource : resourceLocation), ex);
345                        }
346                        return null;
347                }
348        }
349
350        @Override
351        @Nullable
352        public InputStream getResourceAsStream(String path) {
353                String resourceLocation = getResourceLocation(path);
354                Resource resource = null;
355                try {
356                        resource = this.resourceLoader.getResource(resourceLocation);
357                        if (!resource.exists()) {
358                                return null;
359                        }
360                        return resource.getInputStream();
361                }
362                catch (InvalidPathException | IOException ex) {
363                        if (logger.isWarnEnabled()) {
364                                logger.warn("Could not open InputStream for resource " +
365                                                (resource != null ? resource : resourceLocation), ex);
366                        }
367                        return null;
368                }
369        }
370
371        @Override
372        public RequestDispatcher getRequestDispatcher(String path) {
373                Assert.isTrue(path.startsWith("/"),
374                                () -> "RequestDispatcher path [" + path + "] at ServletContext level must start with '/'");
375                return new MockRequestDispatcher(path);
376        }
377
378        @Override
379        public RequestDispatcher getNamedDispatcher(String path) {
380                return this.namedRequestDispatchers.get(path);
381        }
382
383        /**
384         * Register a {@link RequestDispatcher} (typically a {@link MockRequestDispatcher})
385         * that acts as a wrapper for the named Servlet.
386         * @param name the name of the wrapped Servlet
387         * @param requestDispatcher the dispatcher that wraps the named Servlet
388         * @see #getNamedDispatcher
389         * @see #unregisterNamedDispatcher
390         */
391        public void registerNamedDispatcher(String name, RequestDispatcher requestDispatcher) {
392                Assert.notNull(name, "RequestDispatcher name must not be null");
393                Assert.notNull(requestDispatcher, "RequestDispatcher must not be null");
394                this.namedRequestDispatchers.put(name, requestDispatcher);
395        }
396
397        /**
398         * Unregister the {@link RequestDispatcher} with the given name.
399         * @param name the name of the dispatcher to unregister
400         * @see #getNamedDispatcher
401         * @see #registerNamedDispatcher
402         */
403        public void unregisterNamedDispatcher(String name) {
404                Assert.notNull(name, "RequestDispatcher name must not be null");
405                this.namedRequestDispatchers.remove(name);
406        }
407
408        /**
409         * Get the name of the <em>default</em> {@code Servlet}.
410         * <p>Defaults to {@literal 'default'}.
411         * @see #setDefaultServletName
412         */
413        public String getDefaultServletName() {
414                return this.defaultServletName;
415        }
416
417        /**
418         * Set the name of the <em>default</em> {@code Servlet}.
419         * <p>Also {@link #unregisterNamedDispatcher unregisters} the current default
420         * {@link RequestDispatcher} and {@link #registerNamedDispatcher replaces}
421         * it with a {@link MockRequestDispatcher} for the provided
422         * {@code defaultServletName}.
423         * @param defaultServletName the name of the <em>default</em> {@code Servlet};
424         * never {@code null} or empty
425         * @see #getDefaultServletName
426         */
427        public void setDefaultServletName(String defaultServletName) {
428                Assert.hasText(defaultServletName, "defaultServletName must not be null or empty");
429                unregisterNamedDispatcher(this.defaultServletName);
430                this.defaultServletName = defaultServletName;
431                registerNamedDispatcher(this.defaultServletName, new MockRequestDispatcher(this.defaultServletName));
432        }
433
434        @Deprecated
435        @Override
436        @Nullable
437        public Servlet getServlet(String name) {
438                return null;
439        }
440
441        @Override
442        @Deprecated
443        public Enumeration<Servlet> getServlets() {
444                return Collections.enumeration(Collections.emptySet());
445        }
446
447        @Override
448        @Deprecated
449        public Enumeration<String> getServletNames() {
450                return Collections.enumeration(Collections.emptySet());
451        }
452
453        @Override
454        public void log(String message) {
455                logger.info(message);
456        }
457
458        @Override
459        @Deprecated
460        public void log(Exception ex, String message) {
461                logger.info(message, ex);
462        }
463
464        @Override
465        public void log(String message, Throwable ex) {
466                logger.info(message, ex);
467        }
468
469        @Override
470        @Nullable
471        public String getRealPath(String path) {
472                String resourceLocation = getResourceLocation(path);
473                Resource resource = null;
474                try {
475                        resource = this.resourceLoader.getResource(resourceLocation);
476                        return resource.getFile().getAbsolutePath();
477                }
478                catch (InvalidPathException | IOException ex) {
479                        if (logger.isWarnEnabled()) {
480                                logger.warn("Could not determine real path of resource " +
481                                                (resource != null ? resource : resourceLocation), ex);
482                        }
483                        return null;
484                }
485        }
486
487        @Override
488        public String getServerInfo() {
489                return "MockServletContext";
490        }
491
492        @Override
493        public String getInitParameter(String name) {
494                Assert.notNull(name, "Parameter name must not be null");
495                return this.initParameters.get(name);
496        }
497
498        @Override
499        public Enumeration<String> getInitParameterNames() {
500                return Collections.enumeration(this.initParameters.keySet());
501        }
502
503        @Override
504        public boolean setInitParameter(String name, String value) {
505                Assert.notNull(name, "Parameter name must not be null");
506                if (this.initParameters.containsKey(name)) {
507                        return false;
508                }
509                this.initParameters.put(name, value);
510                return true;
511        }
512
513        public void addInitParameter(String name, String value) {
514                Assert.notNull(name, "Parameter name must not be null");
515                this.initParameters.put(name, value);
516        }
517
518        @Override
519        @Nullable
520        public Object getAttribute(String name) {
521                Assert.notNull(name, "Attribute name must not be null");
522                return this.attributes.get(name);
523        }
524
525        @Override
526        public Enumeration<String> getAttributeNames() {
527                return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet()));
528        }
529
530        @Override
531        public void setAttribute(String name, @Nullable Object value) {
532                Assert.notNull(name, "Attribute name must not be null");
533                if (value != null) {
534                        this.attributes.put(name, value);
535                }
536                else {
537                        this.attributes.remove(name);
538                }
539        }
540
541        @Override
542        public void removeAttribute(String name) {
543                Assert.notNull(name, "Attribute name must not be null");
544                this.attributes.remove(name);
545        }
546
547        public void setServletContextName(String servletContextName) {
548                this.servletContextName = servletContextName;
549        }
550
551        @Override
552        public String getServletContextName() {
553                return this.servletContextName;
554        }
555
556        @Override
557        @Nullable
558        public ClassLoader getClassLoader() {
559                return ClassUtils.getDefaultClassLoader();
560        }
561
562        @Override
563        public void declareRoles(String... roleNames) {
564                Assert.notNull(roleNames, "Role names array must not be null");
565                for (String roleName : roleNames) {
566                        Assert.hasLength(roleName, "Role name must not be empty");
567                        this.declaredRoles.add(roleName);
568                }
569        }
570
571        public Set<String> getDeclaredRoles() {
572                return Collections.unmodifiableSet(this.declaredRoles);
573        }
574
575        @Override
576        public void setSessionTrackingModes(Set<SessionTrackingMode> sessionTrackingModes)
577                        throws IllegalStateException, IllegalArgumentException {
578                this.sessionTrackingModes = sessionTrackingModes;
579        }
580
581        @Override
582        public Set<SessionTrackingMode> getDefaultSessionTrackingModes() {
583                return DEFAULT_SESSION_TRACKING_MODES;
584        }
585
586        @Override
587        public Set<SessionTrackingMode> getEffectiveSessionTrackingModes() {
588                return (this.sessionTrackingModes != null ?
589                                Collections.unmodifiableSet(this.sessionTrackingModes) : DEFAULT_SESSION_TRACKING_MODES);
590        }
591
592        @Override
593        public SessionCookieConfig getSessionCookieConfig() {
594                return this.sessionCookieConfig;
595        }
596
597        @Override  // on Servlet 4.0
598        public void setSessionTimeout(int sessionTimeout) {
599                this.sessionTimeout = sessionTimeout;
600        }
601
602        @Override  // on Servlet 4.0
603        public int getSessionTimeout() {
604                return this.sessionTimeout;
605        }
606
607        @Override  // on Servlet 4.0
608        public void setRequestCharacterEncoding(@Nullable String requestCharacterEncoding) {
609                this.requestCharacterEncoding = requestCharacterEncoding;
610        }
611
612        @Override  // on Servlet 4.0
613        @Nullable
614        public String getRequestCharacterEncoding() {
615                return this.requestCharacterEncoding;
616        }
617
618        @Override  // on Servlet 4.0
619        public void setResponseCharacterEncoding(@Nullable String responseCharacterEncoding) {
620                this.responseCharacterEncoding = responseCharacterEncoding;
621        }
622
623        @Override  // on Servlet 4.0
624        @Nullable
625        public String getResponseCharacterEncoding() {
626                return this.responseCharacterEncoding;
627        }
628
629
630        //---------------------------------------------------------------------
631        // Unsupported Servlet 3.0 registration methods
632        //---------------------------------------------------------------------
633
634        @Override
635        public JspConfigDescriptor getJspConfigDescriptor() {
636                throw new UnsupportedOperationException();
637        }
638
639        @Override  // on Servlet 4.0
640        public ServletRegistration.Dynamic addJspFile(String servletName, String jspFile) {
641                throw new UnsupportedOperationException();
642        }
643
644        @Override
645        public ServletRegistration.Dynamic addServlet(String servletName, String className) {
646                throw new UnsupportedOperationException();
647        }
648
649        @Override
650        public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) {
651                throw new UnsupportedOperationException();
652        }
653
654        @Override
655        public ServletRegistration.Dynamic addServlet(String servletName, Class<? extends Servlet> servletClass) {
656                throw new UnsupportedOperationException();
657        }
658
659        @Override
660        public <T extends Servlet> T createServlet(Class<T> c) throws ServletException {
661                throw new UnsupportedOperationException();
662        }
663
664        /**
665         * This method always returns {@code null}.
666         * @see javax.servlet.ServletContext#getServletRegistration(java.lang.String)
667         */
668        @Override
669        @Nullable
670        public ServletRegistration getServletRegistration(String servletName) {
671                return null;
672        }
673
674        /**
675         * This method always returns an {@linkplain Collections#emptyMap empty map}.
676         * @see javax.servlet.ServletContext#getServletRegistrations()
677         */
678        @Override
679        public Map<String, ? extends ServletRegistration> getServletRegistrations() {
680                return Collections.emptyMap();
681        }
682
683        @Override
684        public FilterRegistration.Dynamic addFilter(String filterName, String className) {
685                throw new UnsupportedOperationException();
686        }
687
688        @Override
689        public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) {
690                throw new UnsupportedOperationException();
691        }
692
693        @Override
694        public FilterRegistration.Dynamic addFilter(String filterName, Class<? extends Filter> filterClass) {
695                throw new UnsupportedOperationException();
696        }
697
698        @Override
699        public <T extends Filter> T createFilter(Class<T> c) throws ServletException {
700                throw new UnsupportedOperationException();
701        }
702
703        /**
704         * This method always returns {@code null}.
705         * @see javax.servlet.ServletContext#getFilterRegistration(java.lang.String)
706         */
707        @Override
708        @Nullable
709        public FilterRegistration getFilterRegistration(String filterName) {
710                return null;
711        }
712
713        /**
714         * This method always returns an {@linkplain Collections#emptyMap empty map}.
715         * @see javax.servlet.ServletContext#getFilterRegistrations()
716         */
717        @Override
718        public Map<String, ? extends FilterRegistration> getFilterRegistrations() {
719                return Collections.emptyMap();
720        }
721
722        @Override
723        public void addListener(Class<? extends EventListener> listenerClass) {
724                throw new UnsupportedOperationException();
725        }
726
727        @Override
728        public void addListener(String className) {
729                throw new UnsupportedOperationException();
730        }
731
732        @Override
733        public <T extends EventListener> void addListener(T t) {
734                throw new UnsupportedOperationException();
735        }
736
737        @Override
738        public <T extends EventListener> T createListener(Class<T> c) throws ServletException {
739                throw new UnsupportedOperationException();
740        }
741
742        @Override
743        public String getVirtualServerName() {
744                throw new UnsupportedOperationException();
745        }
746
747}