001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.servlet.support;
018
019import java.util.Collections;
020import java.util.LinkedList;
021import java.util.List;
022import java.util.concurrent.CopyOnWriteArrayList;
023import javax.servlet.http.HttpServletRequest;
024import javax.servlet.http.HttpServletResponse;
025
026import org.apache.commons.logging.Log;
027import org.apache.commons.logging.LogFactory;
028
029import org.springframework.util.Assert;
030import org.springframework.util.CollectionUtils;
031import org.springframework.util.MultiValueMap;
032import org.springframework.util.StringUtils;
033import org.springframework.web.servlet.FlashMap;
034import org.springframework.web.servlet.FlashMapManager;
035import org.springframework.web.util.UrlPathHelper;
036
037/**
038 * A base class for {@link FlashMapManager} implementations.
039 *
040 * @author Rossen Stoyanchev
041 * @author Juergen Hoeller
042 * @author Sam Brannen
043 * @since 3.1.1
044 */
045public abstract class AbstractFlashMapManager implements FlashMapManager {
046
047        private static final Object DEFAULT_FLASH_MAPS_MUTEX = new Object();
048
049
050        protected final Log logger = LogFactory.getLog(getClass());
051
052        private int flashMapTimeout = 180;
053
054        private UrlPathHelper urlPathHelper = UrlPathHelper.defaultInstance;
055
056
057        /**
058         * Set the amount of time in seconds after a {@link FlashMap} is saved
059         * (at request completion) and before it expires.
060         * <p>The default value is 180 seconds.
061         */
062        public void setFlashMapTimeout(int flashMapTimeout) {
063                this.flashMapTimeout = flashMapTimeout;
064        }
065
066        /**
067         * Return the amount of time in seconds before a FlashMap expires.
068         */
069        public int getFlashMapTimeout() {
070                return this.flashMapTimeout;
071        }
072
073        /**
074         * Set the UrlPathHelper to use to match FlashMap instances to requests.
075         */
076        public void setUrlPathHelper(UrlPathHelper urlPathHelper) {
077                Assert.notNull(urlPathHelper, "UrlPathHelper must not be null");
078                this.urlPathHelper = urlPathHelper;
079        }
080
081        /**
082         * Return the UrlPathHelper implementation to use.
083         */
084        public UrlPathHelper getUrlPathHelper() {
085                return this.urlPathHelper;
086        }
087
088
089        @Override
090        public final FlashMap retrieveAndUpdate(HttpServletRequest request, HttpServletResponse response) {
091                List<FlashMap> allFlashMaps = retrieveFlashMaps(request);
092                if (CollectionUtils.isEmpty(allFlashMaps)) {
093                        return null;
094                }
095
096                if (logger.isDebugEnabled()) {
097                        logger.debug("Retrieved FlashMap(s): " + allFlashMaps);
098                }
099                List<FlashMap> mapsToRemove = getExpiredFlashMaps(allFlashMaps);
100                FlashMap match = getMatchingFlashMap(allFlashMaps, request);
101                if (match != null) {
102                        mapsToRemove.add(match);
103                }
104
105                if (!mapsToRemove.isEmpty()) {
106                        if (logger.isDebugEnabled()) {
107                                logger.debug("Removing FlashMap(s): " + mapsToRemove);
108                        }
109                        Object mutex = getFlashMapsMutex(request);
110                        if (mutex != null) {
111                                synchronized (mutex) {
112                                        allFlashMaps = retrieveFlashMaps(request);
113                                        if (allFlashMaps != null) {
114                                                allFlashMaps.removeAll(mapsToRemove);
115                                                updateFlashMaps(allFlashMaps, request, response);
116                                        }
117                                }
118                        }
119                        else {
120                                allFlashMaps.removeAll(mapsToRemove);
121                                updateFlashMaps(allFlashMaps, request, response);
122                        }
123                }
124
125                return match;
126        }
127
128        /**
129         * Return a list of expired FlashMap instances contained in the given list.
130         */
131        private List<FlashMap> getExpiredFlashMaps(List<FlashMap> allMaps) {
132                List<FlashMap> result = new LinkedList<FlashMap>();
133                for (FlashMap map : allMaps) {
134                        if (map.isExpired()) {
135                                result.add(map);
136                        }
137                }
138                return result;
139        }
140
141        /**
142         * Return a FlashMap contained in the given list that matches the request.
143         * @return a matching FlashMap or {@code null}
144         */
145        private FlashMap getMatchingFlashMap(List<FlashMap> allMaps, HttpServletRequest request) {
146                List<FlashMap> result = new LinkedList<FlashMap>();
147                for (FlashMap flashMap : allMaps) {
148                        if (isFlashMapForRequest(flashMap, request)) {
149                                result.add(flashMap);
150                        }
151                }
152                if (!result.isEmpty()) {
153                        Collections.sort(result);
154                        if (logger.isDebugEnabled()) {
155                                logger.debug("Found matching FlashMap(s): " + result);
156                        }
157                        return result.get(0);
158                }
159                return null;
160        }
161
162        /**
163         * Whether the given FlashMap matches the current request.
164         * Uses the expected request path and query parameters saved in the FlashMap.
165         */
166        protected boolean isFlashMapForRequest(FlashMap flashMap, HttpServletRequest request) {
167                String expectedPath = flashMap.getTargetRequestPath();
168                if (expectedPath != null) {
169                        String requestUri = getUrlPathHelper().getOriginatingRequestUri(request);
170                        if (!requestUri.equals(expectedPath) && !requestUri.equals(expectedPath + "/")) {
171                                return false;
172                        }
173                }
174                MultiValueMap<String, String> actualParams = getOriginatingRequestParams(request);
175                MultiValueMap<String, String> expectedParams = flashMap.getTargetRequestParams();
176                for (String expectedName : expectedParams.keySet()) {
177                        List<String> actualValues = actualParams.get(expectedName);
178                        if (actualValues == null) {
179                                return false;
180                        }
181                        for (String expectedValue : expectedParams.get(expectedName)) {
182                                if (!actualValues.contains(expectedValue)) {
183                                        return false;
184                                }
185                        }
186                }
187                return true;
188        }
189
190        private MultiValueMap<String, String> getOriginatingRequestParams(HttpServletRequest request) {
191                String query = getUrlPathHelper().getOriginatingQueryString(request);
192                return ServletUriComponentsBuilder.fromPath("/").query(query).build().getQueryParams();
193        }
194
195        @Override
196        public final void saveOutputFlashMap(FlashMap flashMap, HttpServletRequest request, HttpServletResponse response) {
197                if (CollectionUtils.isEmpty(flashMap)) {
198                        return;
199                }
200
201                String path = decodeAndNormalizePath(flashMap.getTargetRequestPath(), request);
202                flashMap.setTargetRequestPath(path);
203
204                if (logger.isDebugEnabled()) {
205                        logger.debug("Saving FlashMap=" + flashMap);
206                }
207                flashMap.startExpirationPeriod(getFlashMapTimeout());
208
209                Object mutex = getFlashMapsMutex(request);
210                if (mutex != null) {
211                        synchronized (mutex) {
212                                List<FlashMap> allFlashMaps = retrieveFlashMaps(request);
213                                allFlashMaps = (allFlashMaps != null ? allFlashMaps : new CopyOnWriteArrayList<FlashMap>());
214                                allFlashMaps.add(flashMap);
215                                updateFlashMaps(allFlashMaps, request, response);
216                        }
217                }
218                else {
219                        List<FlashMap> allFlashMaps = retrieveFlashMaps(request);
220                        allFlashMaps = (allFlashMaps != null ? allFlashMaps : new LinkedList<FlashMap>());
221                        allFlashMaps.add(flashMap);
222                        updateFlashMaps(allFlashMaps, request, response);
223                }
224        }
225
226        private String decodeAndNormalizePath(String path, HttpServletRequest request) {
227                if (path != null && !path.isEmpty()) {
228                        path = getUrlPathHelper().decodeRequestString(request, path);
229                        if (path.charAt(0) != '/') {
230                                String requestUri = getUrlPathHelper().getRequestUri(request);
231                                path = requestUri.substring(0, requestUri.lastIndexOf('/') + 1) + path;
232                                path = StringUtils.cleanPath(path);
233                        }
234                }
235                return path;
236        }
237
238        /**
239         * Retrieve saved FlashMap instances from the underlying storage.
240         * @param request the current request
241         * @return a List with FlashMap instances, or {@code null} if none found
242         */
243        protected abstract List<FlashMap> retrieveFlashMaps(HttpServletRequest request);
244
245        /**
246         * Update the FlashMap instances in the underlying storage.
247         * @param flashMaps a (potentially empty) list of FlashMap instances to save
248         * @param request the current request
249         * @param response the current response
250         */
251        protected abstract void updateFlashMaps(
252                        List<FlashMap> flashMaps, HttpServletRequest request, HttpServletResponse response);
253
254        /**
255         * Obtain a mutex for modifying the FlashMap List as handled by
256         * {@link #retrieveFlashMaps} and {@link #updateFlashMaps},
257         * <p>The default implementation returns a shared static mutex.
258         * Subclasses are encouraged to return a more specific mutex, or
259         * {@code null} to indicate that no synchronization is necessary.
260         * @param request the current request
261         * @return the mutex to use (may be {@code null} if none applicable)
262         * @since 4.0.3
263         */
264        protected Object getFlashMapsMutex(HttpServletRequest request) {
265                return DEFAULT_FLASH_MAPS_MUTEX;
266        }
267
268}