001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.servlet.mvc.method.annotation;
018
019import java.io.InputStream;
020import java.io.Reader;
021import java.security.Principal;
022import java.time.ZoneId;
023import java.util.Locale;
024import java.util.TimeZone;
025import javax.servlet.ServletRequest;
026import javax.servlet.http.HttpServletRequest;
027import javax.servlet.http.HttpSession;
028
029import org.springframework.core.MethodParameter;
030import org.springframework.http.HttpMethod;
031import org.springframework.lang.UsesJava8;
032import org.springframework.web.bind.support.WebDataBinderFactory;
033import org.springframework.web.context.request.NativeWebRequest;
034import org.springframework.web.context.request.WebRequest;
035import org.springframework.web.method.support.HandlerMethodArgumentResolver;
036import org.springframework.web.method.support.ModelAndViewContainer;
037import org.springframework.web.multipart.MultipartRequest;
038import org.springframework.web.servlet.support.RequestContextUtils;
039
040/**
041 * Resolves request-related method argument values of the following types:
042 * <ul>
043 * <li>{@link WebRequest}
044 * <li>{@link ServletRequest}
045 * <li>{@link MultipartRequest}
046 * <li>{@link HttpSession}
047 * <li>{@link Principal}
048 * <li>{@link InputStream}
049 * <li>{@link Reader}
050 * <li>{@link HttpMethod} (as of Spring 4.0)
051 * <li>{@link Locale}
052 * <li>{@link TimeZone} (as of Spring 4.0)
053 * <li>{@link java.time.ZoneId} (as of Spring 4.0 and Java 8)
054 * </ul>
055 *
056 * @author Arjen Poutsma
057 * @author Rossen Stoyanchev
058 * @author Juergen Hoeller
059 * @since 3.1
060 */
061public class ServletRequestMethodArgumentResolver implements HandlerMethodArgumentResolver {
062
063        @Override
064        public boolean supportsParameter(MethodParameter parameter) {
065                Class<?> paramType = parameter.getParameterType();
066                return (WebRequest.class.isAssignableFrom(paramType) ||
067                                ServletRequest.class.isAssignableFrom(paramType) ||
068                                MultipartRequest.class.isAssignableFrom(paramType) ||
069                                HttpSession.class.isAssignableFrom(paramType) ||
070                                Principal.class.isAssignableFrom(paramType) ||
071                                InputStream.class.isAssignableFrom(paramType) ||
072                                Reader.class.isAssignableFrom(paramType) ||
073                                HttpMethod.class == paramType ||
074                                Locale.class == paramType ||
075                                TimeZone.class == paramType ||
076                                "java.time.ZoneId".equals(paramType.getName()));
077        }
078
079        @Override
080        public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer,
081                        NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
082
083                Class<?> paramType = parameter.getParameterType();
084                if (WebRequest.class.isAssignableFrom(paramType)) {
085                        if (!paramType.isInstance(webRequest)) {
086                                throw new IllegalStateException(
087                                                "Current request is not of type [" + paramType.getName() + "]: " + webRequest);
088                        }
089                        return webRequest;
090                }
091
092                HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);
093                if (ServletRequest.class.isAssignableFrom(paramType) || MultipartRequest.class.isAssignableFrom(paramType)) {
094                        Object nativeRequest = webRequest.getNativeRequest(paramType);
095                        if (nativeRequest == null) {
096                                throw new IllegalStateException(
097                                                "Current request is not of type [" + paramType.getName() + "]: " + request);
098                        }
099                        return nativeRequest;
100                }
101                else if (HttpSession.class.isAssignableFrom(paramType)) {
102                        HttpSession session = request.getSession();
103                        if (session != null && !paramType.isInstance(session)) {
104                                throw new IllegalStateException(
105                                                "Current session is not of type [" + paramType.getName() + "]: " + session);
106                        }
107                        return session;
108                }
109                else if (InputStream.class.isAssignableFrom(paramType)) {
110                        InputStream inputStream = request.getInputStream();
111                        if (inputStream != null && !paramType.isInstance(inputStream)) {
112                                throw new IllegalStateException(
113                                                "Request input stream is not of type [" + paramType.getName() + "]: " + inputStream);
114                        }
115                        return inputStream;
116                }
117                else if (Reader.class.isAssignableFrom(paramType)) {
118                        Reader reader = request.getReader();
119                        if (reader != null && !paramType.isInstance(reader)) {
120                                throw new IllegalStateException(
121                                                "Request body reader is not of type [" + paramType.getName() + "]: " + reader);
122                        }
123                        return reader;
124                }
125                else if (Principal.class.isAssignableFrom(paramType)) {
126                        Principal userPrincipal = request.getUserPrincipal();
127                        if (userPrincipal != null && !paramType.isInstance(userPrincipal)) {
128                                throw new IllegalStateException(
129                                                "Current user principal is not of type [" + paramType.getName() + "]: " + userPrincipal);
130                        }
131                        return userPrincipal;
132                }
133                else if (HttpMethod.class == paramType) {
134                        return HttpMethod.resolve(request.getMethod());
135                }
136                else if (Locale.class == paramType) {
137                        return RequestContextUtils.getLocale(request);
138                }
139                else if (TimeZone.class == paramType) {
140                        TimeZone timeZone = RequestContextUtils.getTimeZone(request);
141                        return (timeZone != null ? timeZone : TimeZone.getDefault());
142                }
143                else if ("java.time.ZoneId".equals(paramType.getName())) {
144                        return ZoneIdResolver.resolveZoneId(request);
145                }
146                else {
147                        // Should never happen...
148                        throw new UnsupportedOperationException(
149                                        "Unknown parameter type [" + paramType.getName() + "] in " + parameter.getMethod());
150                }
151        }
152
153
154        /**
155         * Inner class to avoid a hard-coded dependency on Java 8's {@link java.time.ZoneId}.
156         */
157        @UsesJava8
158        private static class ZoneIdResolver {
159
160                public static Object resolveZoneId(HttpServletRequest request) {
161                        TimeZone timeZone = RequestContextUtils.getTimeZone(request);
162                        return (timeZone != null ? timeZone.toZoneId() : ZoneId.systemDefault());
163                }
164        }
165
166}