001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.test.autoconfigure.web.servlet;
018
019import java.io.PrintStream;
020import java.io.PrintWriter;
021import java.io.StringWriter;
022import java.util.ArrayList;
023import java.util.Collection;
024import java.util.List;
025
026import javax.servlet.Filter;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.beans.factory.ListableBeanFactory;
032import org.springframework.beans.factory.NoSuchBeanDefinitionException;
033import org.springframework.boot.web.servlet.AbstractFilterRegistrationBean;
034import org.springframework.boot.web.servlet.DelegatingFilterProxyRegistrationBean;
035import org.springframework.boot.web.servlet.FilterRegistrationBean;
036import org.springframework.boot.web.servlet.RegistrationBean;
037import org.springframework.boot.web.servlet.ServletContextInitializerBeans;
038import org.springframework.context.ApplicationContext;
039import org.springframework.context.ConfigurableApplicationContext;
040import org.springframework.test.web.servlet.MvcResult;
041import org.springframework.test.web.servlet.ResultHandler;
042import org.springframework.test.web.servlet.result.PrintingResultHandler;
043import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
044import org.springframework.util.Assert;
045import org.springframework.util.CollectionUtils;
046import org.springframework.util.StringUtils;
047import org.springframework.web.context.WebApplicationContext;
048
049/**
050 * {@link MockMvcBuilderCustomizer} for a typical Spring Boot application. Usually applied
051 * automatically via {@link AutoConfigureMockMvc @AutoConfigureMockMvc}, but may also be
052 * used directly.
053 *
054 * @author Phillip Webb
055 * @author Andy Wilkinson
056 * @since 1.4.0
057 */
058public class SpringBootMockMvcBuilderCustomizer implements MockMvcBuilderCustomizer {
059
060        private final WebApplicationContext context;
061
062        private boolean addFilters = true;
063
064        private MockMvcPrint print = MockMvcPrint.DEFAULT;
065
066        private boolean printOnlyOnFailure = true;
067
068        /**
069         * Create a new {@link SpringBootMockMvcBuilderCustomizer} instance.
070         * @param context the source application context
071         */
072        public SpringBootMockMvcBuilderCustomizer(WebApplicationContext context) {
073                Assert.notNull(context, "Context must not be null");
074                this.context = context;
075        }
076
077        @Override
078        public void customize(ConfigurableMockMvcBuilder<?> builder) {
079                if (this.addFilters) {
080                        addFilters(builder);
081                }
082                ResultHandler printHandler = getPrintHandler();
083                if (printHandler != null) {
084                        builder.alwaysDo(printHandler);
085                }
086        }
087
088        private ResultHandler getPrintHandler() {
089                LinesWriter writer = getLinesWriter();
090                if (writer == null) {
091                        return null;
092                }
093                if (this.printOnlyOnFailure) {
094                        writer = new DeferredLinesWriter(this.context, writer);
095                }
096                return new LinesWritingResultHandler(writer);
097        }
098
099        private LinesWriter getLinesWriter() {
100                if (this.print == MockMvcPrint.NONE) {
101                        return null;
102                }
103                if (this.print == MockMvcPrint.LOG_DEBUG) {
104                        return new LoggingLinesWriter();
105                }
106                return new SystemLinesWriter(this.print);
107
108        }
109
110        private void addFilters(ConfigurableMockMvcBuilder<?> builder) {
111                FilterRegistrationBeans registrations = new FilterRegistrationBeans(this.context);
112                registrations.stream().map(AbstractFilterRegistrationBean.class::cast)
113                                .filter(AbstractFilterRegistrationBean::isEnabled)
114                                .forEach((registration) -> addFilter(builder, registration));
115        }
116
117        private void addFilter(ConfigurableMockMvcBuilder<?> builder,
118                        AbstractFilterRegistrationBean<?> registration) {
119                Filter filter = registration.getFilter();
120                Collection<String> urls = registration.getUrlPatterns();
121                if (urls.isEmpty()) {
122                        builder.addFilters(filter);
123                }
124                else {
125                        builder.addFilter(filter, StringUtils.toStringArray(urls));
126                }
127        }
128
129        public void setAddFilters(boolean addFilters) {
130                this.addFilters = addFilters;
131        }
132
133        public boolean isAddFilters() {
134                return this.addFilters;
135        }
136
137        public void setPrint(MockMvcPrint print) {
138                this.print = print;
139        }
140
141        public MockMvcPrint getPrint() {
142                return this.print;
143        }
144
145        public void setPrintOnlyOnFailure(boolean printOnlyOnFailure) {
146                this.printOnlyOnFailure = printOnlyOnFailure;
147        }
148
149        public boolean isPrintOnlyOnFailure() {
150                return this.printOnlyOnFailure;
151        }
152
153        /**
154         * {@link ResultHandler} that prints {@link MvcResult} details to a given
155         * {@link LinesWriter}.
156         */
157        private static class LinesWritingResultHandler implements ResultHandler {
158
159                private final LinesWriter writer;
160
161                LinesWritingResultHandler(LinesWriter writer) {
162                        this.writer = writer;
163                }
164
165                @Override
166                public void handle(MvcResult result) throws Exception {
167                        LinesPrintingResultHandler delegate = new LinesPrintingResultHandler();
168                        delegate.handle(result);
169                        delegate.write(this.writer);
170                }
171
172                private static class LinesPrintingResultHandler extends PrintingResultHandler {
173
174                        protected LinesPrintingResultHandler() {
175                                super(new Printer());
176                        }
177
178                        public void write(LinesWriter writer) {
179                                writer.write(((Printer) getPrinter()).getLines());
180                        }
181
182                        private static class Printer implements ResultValuePrinter {
183
184                                private final List<String> lines = new ArrayList<>();
185
186                                @Override
187                                public void printHeading(String heading) {
188                                        this.lines.add("");
189                                        this.lines.add(String.format("%s:", heading));
190                                }
191
192                                @Override
193                                public void printValue(String label, Object value) {
194                                        if (value != null && value.getClass().isArray()) {
195                                                value = CollectionUtils.arrayToList(value);
196                                        }
197                                        this.lines.add(String.format("%17s = %s", label, value));
198                                }
199
200                                public List<String> getLines() {
201                                        return this.lines;
202                                }
203
204                        }
205
206                }
207
208        }
209
210        /**
211         * Strategy interface to write MVC result lines.
212         */
213        interface LinesWriter {
214
215                void write(List<String> lines);
216
217        }
218
219        /**
220         * {@link LinesWriter} used to defer writing until errors are detected.
221         *
222         * @see MockMvcPrintOnlyOnFailureTestExecutionListener
223         */
224        static class DeferredLinesWriter implements LinesWriter {
225
226                private static final String BEAN_NAME = DeferredLinesWriter.class.getName();
227
228                private final LinesWriter delegate;
229
230                private final List<String> lines = new ArrayList<>();
231
232                DeferredLinesWriter(WebApplicationContext context, LinesWriter delegate) {
233                        Assert.state(context instanceof ConfigurableApplicationContext,
234                                        "A ConfigurableApplicationContext is required for printOnlyOnFailure");
235                        ((ConfigurableApplicationContext) context).getBeanFactory()
236                                        .registerSingleton(BEAN_NAME, this);
237                        this.delegate = delegate;
238                }
239
240                @Override
241                public void write(List<String> lines) {
242                        this.lines.addAll(lines);
243                }
244
245                public void writeDeferredResult() {
246                        this.delegate.write(this.lines);
247                }
248
249                public static DeferredLinesWriter get(ApplicationContext applicationContext) {
250                        try {
251                                return applicationContext.getBean(BEAN_NAME, DeferredLinesWriter.class);
252                        }
253                        catch (NoSuchBeanDefinitionException ex) {
254                                return null;
255                        }
256                }
257
258        }
259
260        /**
261         * {@link LinesWriter} to output results to the log.
262         */
263        private static class LoggingLinesWriter implements LinesWriter {
264
265                private static final Log logger = LogFactory
266                                .getLog("org.springframework.test.web.servlet.result");
267
268                @Override
269                public void write(List<String> lines) {
270                        if (logger.isDebugEnabled()) {
271                                StringWriter stringWriter = new StringWriter();
272                                PrintWriter printWriter = new PrintWriter(stringWriter);
273                                for (String line : lines) {
274                                        printWriter.println(line);
275                                }
276                                logger.debug("MvcResult details:\n" + stringWriter);
277                        }
278                }
279
280        }
281
282        /**
283         * {@link LinesWriter} to output results to {@code System.out} or {@code System.err}.
284         */
285        private static class SystemLinesWriter implements LinesWriter {
286
287                private final MockMvcPrint print;
288
289                SystemLinesWriter(MockMvcPrint print) {
290                        this.print = print;
291                }
292
293                @Override
294                public void write(List<String> lines) {
295                        PrintStream printStream = getPrintStream();
296                        for (String line : lines) {
297                                printStream.println(line);
298                        }
299                }
300
301                private PrintStream getPrintStream() {
302                        if (this.print == MockMvcPrint.SYSTEM_ERR) {
303                                return System.err;
304                        }
305                        return System.out;
306                }
307
308        }
309
310        private static class FilterRegistrationBeans extends ServletContextInitializerBeans {
311
312                FilterRegistrationBeans(ListableBeanFactory beanFactory) {
313                        super(beanFactory, FilterRegistrationBean.class,
314                                        DelegatingFilterProxyRegistrationBean.class);
315                }
316
317                @Override
318                protected void addAdaptableBeans(ListableBeanFactory beanFactory) {
319                        addAsRegistrationBean(beanFactory, Filter.class,
320                                        new FilterRegistrationBeanAdapter());
321                }
322
323                private static class FilterRegistrationBeanAdapter
324                                implements RegistrationBeanAdapter<Filter> {
325
326                        @Override
327                        public RegistrationBean createRegistrationBean(String name, Filter source,
328                                        int totalNumberOfSourceBeans) {
329                                FilterRegistrationBean<Filter> bean = new FilterRegistrationBean<>(
330                                                source);
331                                bean.setName(name);
332                                return bean;
333                        }
334
335                }
336
337        }
338
339}