Commit 36c21afe authored by sedmidubsky's avatar sedmidubsky
Browse files

* Testing search/classification applications

parent a1092782
Loading
Loading
Loading
Loading
+176 −0
Original line number Diff line number Diff line
package mcdr.test;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import mcdr.objects.classification.impl.ObjectMultiCategoryClassifier;
import mcdr.objects.classification.impl.ObjectTrainingSampleRatioClassifier;
import mcdr.objects.impl.ObjectFloatVectorNeuralNetworkL2WeightedLength;
import mcdr.objects.impl.ObjectMotionWordNMatches;
import mcdr.objects.impl.ObjectMotionWordSoftAssignment;
import mcdr.sequence.SequenceMocap;
import mcdr.sequence.impl.SequenceMocapPoseCoordsL2DTW;
import mcdr.sequence.impl.SequenceMocapPoseCoordsL2DTWSegments;
import mcdr.sequence.impl.SequenceMocapTrajectoryDistL1;
import mcdr.sequence.impl.SequenceMotionWordsDTW;
import mcdr.sequence.impl.SequenceMotionWordsNGramsJaccard;
import mcdr.sequence.impl.SequenceMotionWordsNMatchesDTW;
import mcdr.sequence.impl.SequenceMotionWordsSoftAssignmentDTW;
import mcdr.sequence.impl.SequenceSegmentFeatureDTW;
import mcdr.test.utils.ObjectCategoryMgmt;
import mcdr.test.utils.ObjectMgmt;
import messif.objects.LocalAbstractObject;
import messif.objects.impl.ObjectFloatVectorL1;
import messif.objects.impl.ObjectFloatVectorL2;
import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
import messif.objects.util.RankedAbstractObject;
import messif.operations.RankingQueryOperation;
import messif.operations.query.KNNQueryOperation;

/**
 *
 * @author Jan Sedmidubsky, xsedmid@fi.muni.cz, FI MU Brno, Czech Republic
 */
public class DescriptorTester {

    /**
     * @param args the command line arguments
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {

//        Class<? extends SequenceMocap<?>> objectClass = SequenceMocapPoseCoordsL2DTW.class; // coords
//        Class<? extends LocalAbstractObject> objectClass = SequenceSegmentFeatureDTW.class;
//        Class<? extends SequenceMocap<?>> objectClass = SequenceMocapPoseCoordsL2DTWSegments.class;
//        Class<? extends LocalAbstractObject> objectClass = SequenceMotionWordsDTW.class;
//        Class<? extends LocalAbstractObject> objectClass = SequenceMotionWordsNMatchesDTW.class;
        Class<? extends LocalAbstractObject> objectClass = SequenceMotionWordsSoftAssignmentDTW.class;
//        Class<? extends LocalAbstractObject> objectClass = SequenceMotionWordsNGramsJaccard.class;

        ObjectMotionWordNMatches.nMatches = 1;
        ObjectMotionWordNMatches.maxPartsToMatch = 4;
        ObjectMotionWordSoftAssignment.maxPartsToMatch = 6;
        SequenceMotionWordsNGramsJaccard.nGramSize = 1;
        SequenceMocapPoseCoordsL2DTWSegments.INDEPENDENT_SEGMENT_SHIFT_LEVELS = false;
        SequenceMocapPoseCoordsL2DTWSegments.TRIM_LAST_SEGMENT = false;

        // retrieval params
        final boolean includeExactMatchInResult = false;
        // values of k to be evaluated
        int[] fixedKsToEvaluate = new int[]{};
//        int[] fixedKsToEvaluate = new int[]{4};
//        int[] fixedKsToEvaluate = IntStream.range(1, 21).toArray();
//        int[] fixedKsToEvaluate = new int[]{1, 3, 5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100};
        final String[] ignoredCategoryIds = null;
//        final String[] ignoredCategoryIds = new String[]{"56", "57", "58", "59", "60", "61", "138", "139"}; // HDM05-122
        final boolean includeMatchFromTheSameSequenceInResult = true;
        final boolean evaluateQueriesIndependently = true; // indicates whether each query is evaluated independently, or only one multi-object query is constructed for each category
        final boolean restrictDataObjectsByQueries = false;
        final boolean parseDataCategoriesFromOverlappingQueries = false;

        // Printing global variables
        System.out.println("===== GLOBAL PARAMS =====");
        printStaticClassVariables(objectClass);
        System.out.println();

//        final String batchNamePrefix = "e:/datasets/mocap/hdm05/";
//        final String batchNamePrefix = "y:/datasets/mocap/PKU-MMD/skeleton2D/actions/";
//        final String batchNamePrefix = "y:/datasets/mocap/PKU-MMD/skeleton3D/actions/single-subject/messif/";
        final String batchNamePrefix = "y:/datasets/mocap/hdm05/motion_words/hulls/hull-optimized-center-on-kmedoids100/";
//        final String batchNamePrefix = "y:/datasets/mocap/hdm05/motion_words/quantized/";
        for (String batchNameFile : new String[]{
            "class130-actions-coords_normPOS-fps12.data"
//            "hdm05-annotations_specific-segment80_shift16-coords_normPOS-fps12-quantized-pivots-kmedoids-350.data"
//            "hdm05-annotations_specific-segment80_shift16-coords_normPOS-fps12-quantized-overlays5-pivots-kmedoids-350.data"
//            "hdm05-annotations_specific-segment80_shift16-coords_normPOS-fps12-quantized-pivots-kmedoids-350-softassign-D20K6.data"
//                "hdm05-annotations_specific-segment80_shift16-coords_normPOS-fps12-quantized-pivots-kmedoids-350-softassign-D20K6.data"
//                "hdm05-annotations_specific-segment80_shift16-coords_normPOS-fps12-quantized-pivots1000-maxlvl1-leaf240-random-filtered.data"
        }) {
            final String batchName = batchNamePrefix + batchNameFile;
            System.out.println("===== NEW EXPERIMENT: " + batchName + " =====");

            final String queryFile = batchName;
            final String dataFile = batchName;
//            final String queryFile = batchName + "actions-single-subject-test-CV-P.data";
//            final String dataFile = batchName + "actions-single-subject-train-CV-P.data";
//            final String dataFile = "d:/temp/2D-CS-P-randomActionClassSelection10.data";

            // structures
            ObjectCategoryMgmt categoryMgmt = new ObjectCategoryMgmt("y:/datasets/mocap/hdm05/meta/category_description_short.txt");
            ObjectMgmt queryMgmt = new ObjectMgmt(categoryMgmt);
            ObjectMgmt dataMgmt = new ObjectMgmt(categoryMgmt);

            // queries
            System.out.println("Queries:");
            queryMgmt.read(objectClass, queryFile, null, ignoredCategoryIds, null, null, true);

            // data
            System.out.println("Data:");
            dataMgmt.read(objectClass, dataFile, null, ignoredCategoryIds, (!restrictDataObjectsByQueries) ? null : queryMgmt.getParentSequenceIds(), (parseDataCategoriesFromOverlappingQueries) ? queryMgmt : null, true);
//            dataMgmt.storeRandomObjects("d:/temp/2D-CS-P-randomActionClassSelection10.data", 10);

            // Querying
            Integer maxK = (fixedKsToEvaluate.length == 0) ? null : Arrays.stream(fixedKsToEvaluate).summaryStatistics().getMax();
            System.out.println("maxK = " + maxK);
            long startTime = System.currentTimeMillis();
            Map<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> origCategoryOperationsMap = dataMgmt.executeKNNQueries(queryMgmt, maxK, includeExactMatchInResult, includeMatchFromTheSameSequenceInResult);
            System.out.println("Querying time: " + ((System.currentTimeMillis() - startTime) / 1000f) + " s");

            // Classifier
            ObjectMultiCategoryClassifier objectClassifier = new ObjectMultiCategoryClassifier(true);
//              ObjectMultiCategoryClassifier objectClassifier = new ObjectTrainingSampleRatioClassifier(dataMgmt);

            // Evaluation
            if (maxK == null) {
                dataMgmt.evaluateRetrieval(origCategoryOperationsMap, evaluateQueriesIndependently, false, false);
                dataMgmt.evaluateClassification(objectClassifier, origCategoryOperationsMap, 1, false, false);
            } else {
                for (int k : fixedKsToEvaluate) {
                    System.out.println("Search evaluation (k=" + k + "):");

                    Map<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> categoryOperationsMap = ObjectMgmt.cloneCategorizedRankingOperations(origCategoryOperationsMap, k);
                    dataMgmt.evaluateRetrieval(categoryOperationsMap, evaluateQueriesIndependently, false, false);

                    // evaluation of classification
                    float[][] confMatrix = dataMgmt.evaluateClassification(objectClassifier, categoryOperationsMap, 1, false, false);
//                dataMgmt.saveConfusionMatrixToFile(confMatrix, confMatrixFile, true);
                }
            }
        }
    }

    public static void printStaticClassVariables(Class<?> clazz) throws IllegalArgumentException, IllegalAccessException {
        System.out.println("Static params of class " + clazz.getName() + ":");
        for (Field field : clazz.getFields()) {
            if (java.lang.reflect.Modifier.isStatic(field.getModifiers())) {
                System.out.println("  " + field.getName() + " = " + field.get(null).toString());
            }
        }
    }

    private static Map<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> mergeAnswers(Map<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> answerFrom, Map<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> answerTo) {
        for (Map.Entry<ObjectCategoryMgmt.Category, List<RankingQueryOperation>> opsEntry : answerTo.entrySet()) {
            for (RankingQueryOperation op : opsEntry.getValue()) {
                RankingQueryOperation opFrom = null;
                String locTo = ((KNNQueryOperation) op).getQueryObject().getLocatorURI();
                for (RankingQueryOperation opTmp : answerFrom.get(opsEntry.getKey())) {
                    String locFrom = ((KNNQueryOperation) opTmp).getQueryObject().getLocatorURI();
                    if (locTo.equals(locFrom)) {
                        opFrom = opTmp;
                        break;
                    }
                }
                Iterator<RankedAbstractObject> answerIt = opFrom.getAnswer();
                while (answerIt.hasNext()) {
                    RankedAbstractObject rao = answerIt.next();
                    op.addToAnswer(rao.getObject(), rao.getDistance(), null);
                }
            }
        }
        return answerTo;
    }

}