001package org.junit.experimental.theories.internal;
002
003import java.lang.reflect.Array;
004import java.lang.reflect.Field;
005import java.util.ArrayList;
006import java.util.Collection;
007import java.util.Iterator;
008import java.util.List;
009
010import org.junit.Assume;
011import org.junit.experimental.theories.DataPoint;
012import org.junit.experimental.theories.DataPoints;
013import org.junit.experimental.theories.ParameterSignature;
014import org.junit.experimental.theories.ParameterSupplier;
015import org.junit.experimental.theories.PotentialAssignment;
016import org.junit.runners.model.FrameworkField;
017import org.junit.runners.model.FrameworkMethod;
018import org.junit.runners.model.TestClass;
019
020/**
021 * Supplies Theory parameters based on all public members of the target class.
022 */
023public class AllMembersSupplier extends ParameterSupplier {
024    static class MethodParameterValue extends PotentialAssignment {
025        private final FrameworkMethod method;
026
027        private MethodParameterValue(FrameworkMethod dataPointMethod) {
028            method = dataPointMethod;
029        }
030
031        @Override
032        public Object getValue() throws CouldNotGenerateValueException {
033            try {
034                return method.invokeExplosively(null);
035            } catch (IllegalArgumentException e) {
036                throw new RuntimeException(
037                        "unexpected: argument length is checked");
038            } catch (IllegalAccessException e) {
039                throw new RuntimeException(
040                        "unexpected: getMethods returned an inaccessible method");
041            } catch (Throwable throwable) {
042                DataPoint annotation = method.getAnnotation(DataPoint.class);
043                Assume.assumeTrue(annotation == null || !isAssignableToAnyOf(annotation.ignoredExceptions(), throwable));
044                
045                throw new CouldNotGenerateValueException(throwable);
046            }
047        }
048
049        @Override
050        public String getDescription() throws CouldNotGenerateValueException {
051            return method.getName();
052        }
053    }
054    
055    private final TestClass clazz;
056
057    /**
058     * Constructs a new supplier for {@code type}
059     */
060    public AllMembersSupplier(TestClass type) {
061        clazz = type;
062    }
063
064    @Override
065    public List<PotentialAssignment> getValueSources(ParameterSignature sig) throws Throwable {
066        List<PotentialAssignment> list = new ArrayList<PotentialAssignment>();
067
068        addSinglePointFields(sig, list);
069        addMultiPointFields(sig, list);
070        addSinglePointMethods(sig, list);
071        addMultiPointMethods(sig, list);
072
073        return list;
074    }
075
076    private void addMultiPointMethods(ParameterSignature sig, List<PotentialAssignment> list) throws Throwable {
077        for (FrameworkMethod dataPointsMethod : getDataPointsMethods(sig)) {
078            Class<?> returnType = dataPointsMethod.getReturnType();
079            
080            if ((returnType.isArray() && sig.canPotentiallyAcceptType(returnType.getComponentType())) ||
081                    Iterable.class.isAssignableFrom(returnType)) {
082                try {
083                    addDataPointsValues(returnType, sig, dataPointsMethod.getName(), list, 
084                            dataPointsMethod.invokeExplosively(null));
085                } catch (Throwable throwable) {
086                    DataPoints annotation = dataPointsMethod.getAnnotation(DataPoints.class);
087                    if (annotation != null && isAssignableToAnyOf(annotation.ignoredExceptions(), throwable)) {
088                        return;
089                    } else {
090                        throw throwable;
091                    }
092                }
093            }
094        }
095    }
096
097    private void addSinglePointMethods(ParameterSignature sig, List<PotentialAssignment> list) {
098        for (FrameworkMethod dataPointMethod : getSingleDataPointMethods(sig)) {
099            if (sig.canAcceptType(dataPointMethod.getType())) {
100                list.add(new MethodParameterValue(dataPointMethod));
101            }
102        }
103    }
104    
105    private void addMultiPointFields(ParameterSignature sig, List<PotentialAssignment> list) {
106        for (final Field field : getDataPointsFields(sig)) {
107            Class<?> type = field.getType();
108            addDataPointsValues(type, sig, field.getName(), list, getStaticFieldValue(field));
109        }
110    }
111
112    private void addSinglePointFields(ParameterSignature sig, List<PotentialAssignment> list) {
113        for (final Field field : getSingleDataPointFields(sig)) {
114            Object value = getStaticFieldValue(field);
115            
116            if (sig.canAcceptValue(value)) {
117                list.add(PotentialAssignment.forValue(field.getName(), value));
118            }
119        }
120    }
121    
122    private void addDataPointsValues(Class<?> type, ParameterSignature sig, String name, 
123            List<PotentialAssignment> list, Object value) {
124        if (type.isArray()) {
125            addArrayValues(sig, name, list, value);
126        }
127        else if (Iterable.class.isAssignableFrom(type)) {
128            addIterableValues(sig, name, list, (Iterable<?>) value);
129        }
130    }
131
132    private void addArrayValues(ParameterSignature sig, String name, List<PotentialAssignment> list, Object array) {
133        for (int i = 0; i < Array.getLength(array); i++) {
134            Object value = Array.get(array, i);
135            if (sig.canAcceptValue(value)) {
136                list.add(PotentialAssignment.forValue(name + "[" + i + "]", value));
137            }
138        }
139    }
140    
141    private void addIterableValues(ParameterSignature sig, String name, List<PotentialAssignment> list, Iterable<?> iterable) {
142        Iterator<?> iterator = iterable.iterator();
143        int i = 0;
144        while (iterator.hasNext()) {
145            Object value = iterator.next();
146            if (sig.canAcceptValue(value)) {
147                list.add(PotentialAssignment.forValue(name + "[" + i + "]", value));
148            }
149            i += 1;
150        }
151    }
152
153    private Object getStaticFieldValue(final Field field) {
154        try {
155            return field.get(null);
156        } catch (IllegalArgumentException e) {
157            throw new RuntimeException(
158                    "unexpected: field from getClass doesn't exist on object");
159        } catch (IllegalAccessException e) {
160            throw new RuntimeException(
161                    "unexpected: getFields returned an inaccessible field");
162        }
163    }
164    
165    private static boolean isAssignableToAnyOf(Class<?>[] typeArray, Object target) {
166        for (Class<?> type : typeArray) {
167            if (type.isAssignableFrom(target.getClass())) {
168                return true;
169            }
170        }
171        return false;
172    }
173
174    protected Collection<FrameworkMethod> getDataPointsMethods(ParameterSignature sig) {
175        return clazz.getAnnotatedMethods(DataPoints.class);
176    }
177    
178    protected Collection<Field> getSingleDataPointFields(ParameterSignature sig) {
179        List<FrameworkField> fields = clazz.getAnnotatedFields(DataPoint.class);
180        Collection<Field> validFields = new ArrayList<Field>();
181
182        for (FrameworkField frameworkField : fields) {
183            validFields.add(frameworkField.getField());
184        }
185
186        return validFields;
187    }
188    
189    protected Collection<Field> getDataPointsFields(ParameterSignature sig) {
190        List<FrameworkField> fields = clazz.getAnnotatedFields(DataPoints.class);
191        Collection<Field> validFields = new ArrayList<Field>();
192
193        for (FrameworkField frameworkField : fields) {
194            validFields.add(frameworkField.getField());
195        }
196
197        return validFields;
198    }
199    
200    protected Collection<FrameworkMethod> getSingleDataPointMethods(ParameterSignature sig) {
201        return clazz.getAnnotatedMethods(DataPoint.class);
202    }
203
204}