001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.util;
018
019import java.lang.reflect.Method;
020import java.util.Collection;
021import java.util.Optional;
022import java.util.function.Consumer;
023import java.util.function.Function;
024import java.util.function.Predicate;
025import java.util.function.Supplier;
026import java.util.stream.Stream;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.core.ResolvableType;
032import org.springframework.util.Assert;
033import org.springframework.util.ClassUtils;
034import org.springframework.util.ReflectionUtils;
035
036/**
037 * Utility that can be used to invoke lambdas in a safe way. Primarily designed to help
038 * support generically typed callbacks where {@link ClassCastException class cast
039 * exceptions} need to be dealt with due to class erasure.
040 *
041 * @author Phillip Webb
042 * @since 2.0.0
043 */
044public final class LambdaSafe {
045
046        private static final Method CLASS_GET_MODULE;
047
048        private static final Method MODULE_GET_NAME;
049
050        static {
051                CLASS_GET_MODULE = ReflectionUtils.findMethod(Class.class, "getModule");
052                MODULE_GET_NAME = (CLASS_GET_MODULE != null)
053                                ? ReflectionUtils.findMethod(CLASS_GET_MODULE.getReturnType(), "getName")
054                                : null;
055        }
056
057        private LambdaSafe() {
058        }
059
060        /**
061         * Start a call to a single callback instance, dealing with common generic type
062         * concerns and exceptions.
063         * @param callbackType the callback type (a {@link FunctionalInterface functional
064         * interface})
065         * @param callbackInstance the callback instance (may be a lambda)
066         * @param argument the primary argument passed to the callback
067         * @param additionalArguments any additional arguments passed to the callback
068         * @param <C> the callback type
069         * @param <A> the primary argument type
070         * @return a {@link Callback} instance that can be invoked.
071         */
072        public static <C, A> Callback<C, A> callback(Class<C> callbackType,
073                        C callbackInstance, A argument, Object... additionalArguments) {
074                Assert.notNull(callbackType, "CallbackType must not be null");
075                Assert.notNull(callbackInstance, "CallbackInstance must not be null");
076                return new Callback<>(callbackType, callbackInstance, argument,
077                                additionalArguments);
078        }
079
080        /**
081         * Start a call to callback instances, dealing with common generic type concerns and
082         * exceptions.
083         * @param callbackType the callback type (a {@link FunctionalInterface functional
084         * interface})
085         * @param callbackInstances the callback instances (elements may be lambdas)
086         * @param argument the primary argument passed to the callbacks
087         * @param additionalArguments any additional arguments passed to the callbacks
088         * @param <C> the callback type
089         * @param <A> the primary argument type
090         * @return a {@link Callbacks} instance that can be invoked.
091         */
092        public static <C, A> Callbacks<C, A> callbacks(Class<C> callbackType,
093                        Collection<? extends C> callbackInstances, A argument,
094                        Object... additionalArguments) {
095                Assert.notNull(callbackType, "CallbackType must not be null");
096                Assert.notNull(callbackInstances, "CallbackInstances must not be null");
097                return new Callbacks<>(callbackType, callbackInstances, argument,
098                                additionalArguments);
099        }
100
101        /**
102         * Abstract base class for lambda safe callbacks.
103         */
104        private abstract static class LambdaSafeCallback<C, A, SELF extends LambdaSafeCallback<C, A, SELF>> {
105
106                private final Class<C> callbackType;
107
108                private final A argument;
109
110                private final Object[] additionalArguments;
111
112                private Log logger;
113
114                private Filter<C, A> filter = new GenericTypeFilter<>();
115
116                protected LambdaSafeCallback(Class<C> callbackType, A argument,
117                                Object[] additionalArguments) {
118                        this.callbackType = callbackType;
119                        this.argument = argument;
120                        this.additionalArguments = additionalArguments;
121                        this.logger = LogFactory.getLog(callbackType);
122                }
123
124                /**
125                 * Use the specified logger source to report any lambda failures.
126                 * @param loggerSource the logger source to use
127                 * @return this instance
128                 */
129                public SELF withLogger(Class<?> loggerSource) {
130                        return withLogger(LogFactory.getLog(loggerSource));
131                }
132
133                /**
134                 * Use the specified logger to report any lambda failures.
135                 * @param logger the logger to use
136                 * @return this instance
137                 */
138                public SELF withLogger(Log logger) {
139                        Assert.notNull(logger, "Logger must not be null");
140                        this.logger = logger;
141                        return self();
142                }
143
144                /**
145                 * Use a specific filter to determine when a callback should apply. If no explicit
146                 * filter is set filter will be attempted using the generic type on the callback
147                 * type.
148                 * @param filter the filter to use
149                 * @return this instance
150                 */
151                public SELF withFilter(Filter<C, A> filter) {
152                        Assert.notNull(filter, "Filter must not be null");
153                        this.filter = filter;
154                        return self();
155                }
156
157                protected final <R> InvocationResult<R> invoke(C callbackInstance,
158                                Supplier<R> supplier) {
159                        if (this.filter.match(this.callbackType, callbackInstance, this.argument,
160                                        this.additionalArguments)) {
161                                try {
162                                        return InvocationResult.of(supplier.get());
163                                }
164                                catch (ClassCastException ex) {
165                                        if (!isLambdaGenericProblem(ex)) {
166                                                throw ex;
167                                        }
168                                        logNonMatchingType(callbackInstance, ex);
169                                }
170                        }
171                        return InvocationResult.noResult();
172                }
173
174                private boolean isLambdaGenericProblem(ClassCastException ex) {
175                        return (ex.getMessage() == null
176                                        || startsWithArgumentClassName(ex.getMessage()));
177                }
178
179                private boolean startsWithArgumentClassName(String message) {
180                        Predicate<Object> startsWith = (argument) -> startsWithArgumentClassName(
181                                        message, argument);
182                        return startsWith.test(this.argument)
183                                        || Stream.of(this.additionalArguments).anyMatch(startsWith);
184                }
185
186                private boolean startsWithArgumentClassName(String message, Object argument) {
187                        if (argument == null) {
188                                return false;
189                        }
190                        Class<?> argumentType = argument.getClass();
191                        // On Java 8, the message starts with the class name: "java.lang.String cannot
192                        // be cast..."
193                        if (message.startsWith(argumentType.getName())) {
194                                return true;
195                        }
196                        // On Java 11, the message starts with "class ..." a.k.a. Class.toString()
197                        if (message.startsWith(argumentType.toString())) {
198                                return true;
199                        }
200                        // On Java 9, the message used to contain the module name:
201                        // "java.base/java.lang.String cannot be cast..."
202                        int moduleSeparatorIndex = message.indexOf('/');
203                        if (moduleSeparatorIndex != -1 && message.startsWith(argumentType.getName(),
204                                        moduleSeparatorIndex + 1)) {
205                                return true;
206                        }
207                        if (CLASS_GET_MODULE != null) {
208                                Object module = ReflectionUtils.invokeMethod(CLASS_GET_MODULE,
209                                                argumentType);
210                                Object moduleName = ReflectionUtils.invokeMethod(MODULE_GET_NAME, module);
211                                return message.startsWith(moduleName + "/" + argumentType.getName());
212                        }
213                        return false;
214                }
215
216                private void logNonMatchingType(C callback, ClassCastException ex) {
217                        if (this.logger.isDebugEnabled()) {
218                                Class<?> expectedType = ResolvableType.forClass(this.callbackType)
219                                                .resolveGeneric();
220                                String expectedTypeName = (expectedType != null)
221                                                ? ClassUtils.getShortName(expectedType) + " type" : "type";
222                                String message = "Non-matching " + expectedTypeName + " for callback "
223                                                + ClassUtils.getShortName(this.callbackType) + ": " + callback;
224                                this.logger.debug(message, ex);
225                        }
226                }
227
228                @SuppressWarnings("unchecked")
229                private SELF self() {
230                        return (SELF) this;
231                }
232
233        }
234
235        /**
236         * Represents a single callback that can be invoked in a lambda safe way.
237         *
238         * @param <C> the callback type
239         * @param <A> the primary argument type
240         */
241        public static final class Callback<C, A>
242                        extends LambdaSafeCallback<C, A, Callback<C, A>> {
243
244                private final C callbackInstance;
245
246                private Callback(Class<C> callbackType, C callbackInstance, A argument,
247                                Object[] additionalArguments) {
248                        super(callbackType, argument, additionalArguments);
249                        this.callbackInstance = callbackInstance;
250                }
251
252                /**
253                 * Invoke the callback instance where the callback method returns void.
254                 * @param invoker the invoker used to invoke the callback
255                 */
256                public void invoke(Consumer<C> invoker) {
257                        invoke(this.callbackInstance, () -> {
258                                invoker.accept(this.callbackInstance);
259                                return null;
260                        });
261                }
262
263                /**
264                 * Invoke the callback instance where the callback method returns a result.
265                 * @param invoker the invoker used to invoke the callback
266                 * @param <R> the result type
267                 * @return the result of the invocation (may be {@link InvocationResult#noResult}
268                 * if the callback was not invoked)
269                 */
270                public <R> InvocationResult<R> invokeAnd(Function<C, R> invoker) {
271                        return invoke(this.callbackInstance,
272                                        () -> invoker.apply(this.callbackInstance));
273                }
274
275        }
276
277        /**
278         * Represents a collection of callbacks that can be invoked in a lambda safe way.
279         *
280         * @param <C> the callback type
281         * @param <A> the primary argument type
282         */
283        public static final class Callbacks<C, A>
284                        extends LambdaSafeCallback<C, A, Callbacks<C, A>> {
285
286                private final Collection<? extends C> callbackInstances;
287
288                private Callbacks(Class<C> callbackType,
289                                Collection<? extends C> callbackInstances, A argument,
290                                Object[] additionalArguments) {
291                        super(callbackType, argument, additionalArguments);
292                        this.callbackInstances = callbackInstances;
293                }
294
295                /**
296                 * Invoke the callback instances where the callback method returns void.
297                 * @param invoker the invoker used to invoke the callback
298                 */
299                public void invoke(Consumer<C> invoker) {
300                        this.callbackInstances.forEach((callbackInstance) -> {
301                                invoke(callbackInstance, () -> {
302                                        invoker.accept(callbackInstance);
303                                        return null;
304                                });
305                        });
306                }
307
308                /**
309                 * Invoke the callback instances where the callback method returns a result.
310                 * @param invoker the invoker used to invoke the callback
311                 * @param <R> the result type
312                 * @return the results of the invocation (may be an empty stream if no callbacks
313                 * could be called)
314                 */
315                public <R> Stream<R> invokeAnd(Function<C, R> invoker) {
316                        Function<C, InvocationResult<R>> mapper = (callbackInstance) -> invoke(
317                                        callbackInstance, () -> invoker.apply(callbackInstance));
318                        return this.callbackInstances.stream().map(mapper)
319                                        .filter(InvocationResult::hasResult).map(InvocationResult::get);
320                }
321
322        }
323
324        /**
325         * A filter that can be used to restrict when a callback is used.
326         *
327         * @param <C> the callback type
328         * @param <A> the primary argument type
329         */
330        @FunctionalInterface
331        interface Filter<C, A> {
332
333                /**
334                 * Determine if the given callback matches and should be invoked.
335                 * @param callbackType the callback type (the functional interface)
336                 * @param callbackInstance the callback instance (the implementation)
337                 * @param argument the primary argument
338                 * @param additionalArguments any additional arguments
339                 * @return if the callback matches and should be invoked
340                 */
341                boolean match(Class<C> callbackType, C callbackInstance, A argument,
342                                Object[] additionalArguments);
343
344                /**
345                 * Return a {@link Filter} that allows all callbacks to be invoked.
346                 * @param <C> the callback type
347                 * @param <A> the primary argument type
348                 * @return an "allow all" filter
349                 */
350                static <C, A> Filter<C, A> allowAll() {
351                        return (callbackType, callbackInstance, argument,
352                                        additionalArguments) -> true;
353                }
354
355        }
356
357        /**
358         * {@link Filter} that matches when the callback has a single generic and primary
359         * argument is an instance of it.
360         */
361        private static class GenericTypeFilter<C, A> implements Filter<C, A> {
362
363                @Override
364                public boolean match(Class<C> callbackType, C callbackInstance, A argument,
365                                Object[] additionalArguments) {
366                        ResolvableType type = ResolvableType.forClass(callbackType,
367                                        callbackInstance.getClass());
368                        if (type.getGenerics().length == 1 && type.resolveGeneric() != null) {
369                                return type.resolveGeneric().isInstance(argument);
370                        }
371
372                        return true;
373                }
374
375        }
376
377        /**
378         * The result of a callback which may be a value, {@code null} or absent entirely if
379         * the callback wasn't suitable. Similar in design to {@link Optional} but allows for
380         * {@code null} as a valid value.
381         *
382         * @param <R> the result type
383         */
384        public static final class InvocationResult<R> {
385
386                private static final InvocationResult<?> NONE = new InvocationResult<>(null);
387
388                private final R value;
389
390                private InvocationResult(R value) {
391                        this.value = value;
392                }
393
394                /**
395                 * Return true if a result in present.
396                 * @return if a result is present
397                 */
398                public boolean hasResult() {
399                        return this != NONE;
400                }
401
402                /**
403                 * Return the result of the invocation or {@code null} if the callback wasn't
404                 * suitable.
405                 * @return the result of the invocation or {@code null}
406                 */
407                public R get() {
408                        return this.value;
409                }
410
411                /**
412                 * Return the result of the invocation or the given fallback if the callback
413                 * wasn't suitable.
414                 * @param fallback the fallback to use when there is no result
415                 * @return the result of the invocation or the fallback
416                 */
417                public R get(R fallback) {
418                        return (this != NONE) ? this.value : fallback;
419                }
420
421                /**
422                 * Create a new {@link InvocationResult} instance with the specified value.
423                 * @param value the value (may be {@code null})
424                 * @param <R> the result type
425                 * @return an {@link InvocationResult}
426                 */
427                public static <R> InvocationResult<R> of(R value) {
428                        return new InvocationResult<>(value);
429                }
430
431                /**
432                 * Return an {@link InvocationResult} instance representing no result.
433                 * @param <R> the result type
434                 * @return an {@link InvocationResult}
435                 */
436                @SuppressWarnings("unchecked")
437                public static <R> InvocationResult<R> noResult() {
438                        return (InvocationResult<R>) NONE;
439                }
440
441        }
442
443}