Commit 51abdc6e authored by Michal Balazia's avatar Michal Balazia
Browse files

distance matrix

parent 6dff6935
......@@ -3,24 +3,35 @@ package algorithms;
import Jama.Matrix;
import java.io.Serializable;
import java.util.List;
import objects.DistanceMatrix;
import objects.Template;
public abstract class Classifier implements Serializable {
private DistanceFunction distanceFunction;
private DistanceTemplates distanceFunction;
private DistanceMatrix distanceMatrix;
private Matrix transformationMatrix;
public Classifier() {
}
public DistanceFunction getDistanceFunction() {
public DistanceTemplates getDistanceFunction() {
return distanceFunction;
}
public void setDistanceFunction(DistanceFunction distanceFunction) {
public void setDistanceTemplates(DistanceTemplates distanceFunction) {
this.distanceFunction = distanceFunction;
}
public DistanceMatrix getDistanceMatrix() {
return distanceMatrix;
}
public void setDistanceMatrix(DistanceMatrix distanceMatrix) {
this.distanceMatrix = distanceMatrix;
}
public Matrix getTransformationMatrix() {
return transformationMatrix;
}
......
......@@ -5,6 +5,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import objects.DistanceMatrix;
import objects.Template;
public class ClassifierKNNMV extends Classifier {
......@@ -17,13 +18,14 @@ public class ClassifierKNNMV extends Classifier {
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
// get k closest descriptors
Map<Double, Template> kClosestTemplates = new HashMap(k);
DistanceFunction distanceFunction = getDistanceFunction();
double maxDistance = 0.0;
DistanceTemplates distanceFunction = getDistanceFunction();
DistanceMatrix distanceMatrix = getDistanceMatrix();
double maxDistance = 0;
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(templateQuery, template);
double distance = getDistanceMatrix() == null ? distanceFunction.getDistance(template, templateQuery) : distanceMatrix.getDistance(template, templateQuery);
if (kClosestTemplates.size() < k || distance < maxDistance) {
kClosestTemplates.put(distance, template);
if (kClosestTemplates.size() > k) {
......@@ -31,7 +33,7 @@ public class ClassifierKNNMV extends Classifier {
}
}
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
maxDistance = 0.0;
maxDistance = 0;
while (iterator.hasNext()) {
Double key = iterator.next();
if (maxDistance < key) {
......
package algorithms;
import java.io.Serializable;
import objects.Template;
public abstract class DistanceFunction implements Serializable {
public DistanceFunction() {
}
public abstract double getDistance(Template template1, Template template2);
}
package algorithms;
import objects.Template;
public final class DistanceFunctionBaseline extends DistanceFunction {
private PoseDistance poseDistance;
public DistanceFunctionBaseline() {
}
public void setPoseDistance(PoseDistance poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template sample1, Template sample2) {
int length = sample1.getLength();
if (length != sample2.getLength()) {
System.out.println("different lengths");
sample2.adjust(length);
}
double distance = 0f;
for (int l = 0; l < length; l++) {
distance += Math.abs(poseDistance.getDistance(sample1.getPose(l), sample2.getPose(l)));
}
return distance;
}
}
package algorithms;
import objects.Template;
public final class DistanceFunctionDTW extends DistanceFunction {
private PoseDistance poseDistance;
public DistanceFunctionDTW() {
}
public void setPoseDistance(PoseDistance poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template sample1, Template sample2) {
int n = sample1.getLength(), m = sample2.getLength();
double cm[][] = new double[n][m]; // accum matrix
// create the accum matrix
cm[0][0] = poseDistance.getDistance(sample1.getPose(0), sample2.getPose(0));
for (int i = 1; i < n; i++) {
cm[i][0] = poseDistance.getDistance(sample1.getPose(i), sample2.getPose(0)) + cm[i - 1][0];
}
for (int j = 1; j < m; j++) {
cm[0][j] = poseDistance.getDistance(sample1.getPose(0), sample2.getPose(j)) + cm[0][j - 1];
}
// Compute the matrix values
for (int i = 1; i < n; i++) {
for (int j = 1; j < m; j++) {
// Decide on the path with minimum distance so far
cm[i][j] = poseDistance.getDistance(sample1.getPose(i), sample2.getPose(j)) + Math.min(cm[i - 1][j], Math.min(cm[i][j - 1], cm[i - 1][j - 1]));
}
}
return cm[n - 1][m - 1];
}
}
package algorithms;
import java.util.Random;
import objects.Template;
public final class DistanceFunctionRandom extends DistanceFunction {
public DistanceFunctionRandom() {
}
@Override
public double getDistance(Template template1, Template template2) {
return new Random().nextDouble();
}
}
package algorithms;
import java.io.Serializable;
public abstract class PoseDistance implements Serializable {
public PoseDistance() {
}
public abstract double getDistance(double[] pose1, double[] pose2);
}
package algorithms;
public final class PoseDistanceL1 extends PoseDistance {
public PoseDistanceL1() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
float distance = 0f;
for (int d = 0; d < pose1.length; d++) {
distance += Math.abs(pose1[d] - pose2[d]);
}
return distance;
}
}
package algorithms;
import Jama.Matrix;
public final class PoseDistanceMahalanobis extends PoseDistance {
Matrix matrix;
public PoseDistanceMahalanobis() {
}
public void setMatrix(Matrix matrix) {
this.matrix = matrix;
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
int dimension = pose1.length;
if (pose2.length != dimension || matrix.getColumnDimension() != dimension || matrix.getRowDimension() != dimension) {
System.out.println("wrong dimensions");
}
Matrix diff = new Matrix(pose1,pose1.length).minus(new Matrix(pose2,pose2.length));
return Math.sqrt(diff.transpose().times(matrix).times(diff).get(0, 0));
}
}
......@@ -2,22 +2,32 @@ package algorithms;
import java.io.Serializable;
import java.util.List;
import objects.DistanceMatrix;
import objects.Template;
public abstract class Retriever implements Serializable {
private DistanceFunction distanceFunction;
private DistanceTemplates distanceFunction;
private DistanceMatrix distanceMatrix;
public Retriever() {
}
public DistanceFunction getDistanceFunction() {
public DistanceTemplates getDistanceFunction() {
return distanceFunction;
}
public void setDistanceFunction(DistanceFunction distanceFunction) {
public void setDistanceFunction(DistanceTemplates distanceFunction) {
this.distanceFunction = distanceFunction;
}
public DistanceMatrix getDistanceMatrix() {
return distanceMatrix;
}
public void setDistanceMatrix(DistanceMatrix distanceMatrix) {
this.distanceMatrix = distanceMatrix;
}
public abstract List<Template> retrieve(Template sampleTest, List<Template> samplesGallery);
}
......@@ -5,6 +5,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import objects.DistanceMatrix;
import objects.Template;
public class RetrieverKNNQuery extends Retriever {
......@@ -17,11 +18,12 @@ public class RetrieverKNNQuery extends Retriever {
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
DistanceFunction distanceFunction = getDistanceFunction();
DistanceTemplates distanceFunction = getDistanceFunction();
DistanceMatrix distanceMatrix = getDistanceMatrix();
Map<Double, Template> kClosestTemplates = new HashMap(k);
double maxDistance = 0.0;
double maxDistance = 0;
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(template, templateTest);
double distance = getDistanceMatrix() == null ? distanceFunction.getDistance(template, templateTest) : distanceMatrix.getDistance(template, templateTest);
if (kClosestTemplates.size() < k || distance < maxDistance) {
kClosestTemplates.put(distance, template);
if (kClosestTemplates.size() > k) {
......@@ -29,7 +31,7 @@ public class RetrieverKNNQuery extends Retriever {
}
}
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
maxDistance = 0.0;
maxDistance = 0;
while (iterator.hasNext()) {
Double key = iterator.next();
if (maxDistance < key) {
......
......@@ -2,6 +2,7 @@ package algorithms;
import java.util.ArrayList;
import java.util.List;
import objects.DistanceMatrix;
import objects.Template;
public class RetrieverRangeQuery extends Retriever {
......@@ -17,10 +18,11 @@ public class RetrieverRangeQuery extends Retriever {
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
DistanceFunction distanceFunction = getDistanceFunction();
DistanceTemplates distanceFunction = getDistanceFunction();
DistanceMatrix distanceMatrix = getDistanceMatrix();
List<Template> retrievedSamples = new ArrayList();
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(template, templateTest);
for (Template template : templatesGallery) {
double distance = getDistanceMatrix() == null ? distanceFunction.getDistance(template, templateTest) : distanceMatrix.getDistance(template, templateTest);
if (distance <= r) {
retrievedSamples.add(template);
}
......
......@@ -64,7 +64,7 @@ public class Executor {
Select actions
*/
private static final boolean extractDatabase = false;
private static final boolean learnClassifiers = true;
private static final boolean learnClassifiers = false;
private static final boolean performClassifications = false;
private static final boolean evaluateMethods = false;
......@@ -128,15 +128,13 @@ public class Executor {
/*
Extracts classifiers for all methods learned on the sub-database determined by distanceThreshold
Classifiers are saved as separate files {distanceThreshold}-{methodName}.classifier in classifiers/ directory
Classifiers are saved as separate files {methodName}-{distanceThreshold}.classifier in classifiers/ directory
*/
private static void learnClassifiers() throws IOException {
new File("classifiers/").mkdirs();
for (Method method : methods) {
long beginning = System.currentTimeMillis();
method.learnClassifier(method.loadMotions(new File("amc" + distanceThreshold + "/")));
method.saveClassifier(new File("classifiers/" + method.getName() + "-" + distanceThreshold + ".classifier"));
System.out.print("time" + method + "=" + printTime(beginning) + "\r\n");
}
}
......@@ -149,7 +147,7 @@ public class Executor {
method.loadClassifier(customClassifier);
Template templateQuery = method.extractTemplate(method.loadMotion(customQueryFileAMC));
List<Template> templatesGallery = method.extractTemplates(method.loadMotions(customGalleryDirectory));
System.out.print(method.getName() + ": query subject classified as " + method.recognize(templateQuery, templatesGallery) + "\r\n");
System.out.print(method.getName() + ": query subject classified as " + method.classify(templateQuery, templatesGallery) + "\r\n");
}
}
......@@ -169,10 +167,10 @@ public class Executor {
int nClasses = numberOfClasses; // number of learning and evaluation classes
int nFoldsLearning = 3; // number of learning folds
int attempts = 1; // number of evaluation attempts
long beginningHomogeneous = System.currentTimeMillis();
long beginning = System.currentTimeMillis();
// System.out.print("i=" + i + ",");//
method.evaluateHomogeneous(motions, nClasses, nFoldsLearning, attempts);
System.out.print("timeHomogeneous=" + printTime(beginningHomogeneous) + "\r\n\r\n");
System.out.print("timeHomogeneous=" + printTime(System.currentTimeMillis() - beginning) + "\r\n\r\n");
// }//
}
if (evaluateHeterogeneous) {
......@@ -180,14 +178,14 @@ public class Executor {
int nClassesLearning = 2; // number of learning classes
int nClassesEvaluation = numberOfClasses - 2; // number of evaluation classes
int attempts = 1; // number of evaluation attempts
long beginningHeterogeneous = System.currentTimeMillis();
long beginning = System.currentTimeMillis();
// System.out.print("i=" + i + ",");//
method.evaluateHeterogeneous(motions, nClassesLearning, nClassesEvaluation, attempts);
System.out.print("timeHeterogeneous=" + printTime(beginningHeterogeneous) + "\r\n\r\n");
System.out.print("timeHeterogeneous=" + printTime(System.currentTimeMillis() - beginning) + "\r\n\r\n");
// }//
}
}
System.out.print("timeTotal=" + printTime(beginningTotal));
System.out.print("timeTotal=" + printTime(System.currentTimeMillis() - beginningTotal));
}
private static void normalizeDatabase() throws IOException {
......@@ -213,7 +211,7 @@ public class Executor {
MotionJointCoordinates motionProto = convertASFAMCtoMJC(prototypicalFileASF, prototypicalFileAMC);
Feature featureProto = motionProto.extractFeetDistanceSideSensitiveFeature();
int lengthProto = featureProto.getLength();
double sigma = 2.0;
double sigma = 2;
int minLength = (int) Math.round(lengthProto / sigma);
int maxLength = (int) Math.round(lengthProto * sigma);
List<Motion> gaitCycles = new ArrayList();
......@@ -234,7 +232,7 @@ public class Executor {
subFeature.adjust(lengthProto);
double[] subFeatureParameterValues = subFeature.getParameterValues();
double[] featureProtoParameterValues = featureProto.getParameterValues();
double distance = 0.0;
double distance = 0;
for (int l = 0; l < lengthProto; l++) {
distance += Math.abs(subFeatureParameterValues[l] - featureProtoParameterValues[l]);
}
......@@ -395,8 +393,7 @@ public class Executor {
return new MotionBoneRotations(subject, id, frames);
}
private static String printTime(long beginning) {
double time = System.currentTimeMillis() - beginning;
public static String printTime(double time) {
String unit = "ms";
if (time > 1000) {
time /= 1000;
......
This diff is collapsed.
package methods;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionBaseline;
import algorithms.PoseDistanceL1;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
......@@ -26,10 +26,10 @@ public class MethodAhmed extends Method {
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierKNNMV classifier = new ClassifierKNNMV(1);
DistanceFunctionBaseline distanceFunction = new DistanceFunctionBaseline();
distanceFunction.setPoseDistance(new PoseDistanceL1());
classifier.setDistanceFunction(distanceFunction);
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......
package methods;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionBaseline;
import algorithms.PoseDistanceL1;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
......@@ -25,10 +25,10 @@ public class MethodAli extends Method {
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierKNNMV classifier = new ClassifierKNNMV(1);
DistanceFunctionBaseline distanceFunction = new DistanceFunctionBaseline();
distanceFunction.setPoseDistance(new PoseDistanceL1());
classifier.setDistanceFunction(distanceFunction);
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......
package methods;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionBaseline;
import algorithms.PoseDistanceL1;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
......@@ -27,10 +27,10 @@ public class MethodAndersson extends Method {
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierKNNMV classifier = new ClassifierKNNMV(1);
DistanceFunctionBaseline distanceFunction = new DistanceFunctionBaseline();
distanceFunction.setPoseDistance(new PoseDistanceL1());
classifier.setDistanceFunction(distanceFunction);
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......@@ -59,7 +59,7 @@ public class MethodAndersson extends Method {
}
}
List<Parameter> angleLocalMaxs = new ArrayList();
double maxValue = 0.0;
double maxValue = 0;
begin = 0;
Parameter angleLocalMax = angleFeature.getParameter(begin);
for (int l = begin; l < begin + length / 2; l++) {
......
package methods;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionBaseline;
import algorithms.PoseDistanceL1;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
......@@ -26,10 +26,10 @@ public class MethodBall extends Method {
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierKNNMV classifier = new ClassifierKNNMV(1);
DistanceFunctionBaseline distanceFunction = new DistanceFunctionBaseline();
distanceFunction.setPoseDistance(new PoseDistanceL1());
classifier.setDistanceFunction(distanceFunction);
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......
package methods;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionBaseline;
import algorithms.PoseDistanceL1;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
......@@ -26,10 +26,10 @@ public class MethodDikovski extends Method {
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierKNNMV classifier = new ClassifierKNNMV(1);
DistanceFunctionBaseline distanceFunction = new DistanceFunctionBaseline();
distanceFunction.setPoseDistance(new PoseDistanceL1());
classifier.setDistanceFunction(distanceFunction);
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......@@ -63,7 +63,7 @@ public class MethodDikovski extends Method {
features.add(new Feature(new Parameter("steplength", motionMJC.extractInterJointDistanceFeature("ltibia", "rtibia").getMean())));
// 1 height
double height = 0.0;
double height = 0;
height += motionMJC.extractInterJointDistanceFeature("ltibia", "lfemur").getMean() / 2;
height += motionMJC.extractInterJointDistanceFeature("rtibia", "rfemur").getMean() / 2;
height += motionMJC.extractInterJointDistanceFeature("lfemur", "lhipjoint").getMean() / 2;
......
......@@ -2,9 +2,9 @@ package methods;
import algorithms.Classifier;
import algorithms.ClassifierKNNMV;
import algorithms.DistanceFunctionDTW;
import algorithms.PoseDistance;
import algorithms.PoseDistanceL1;
import algorithms.DistanceTemplatesDTW;
import algorithms.DistanceVectors;
import algorithms.DistanceVectorsL1;
import algorithms.RetrieverKNNQuery;
import executor.Executor;
import java.io.File;
......@@ -26,21 +26,20 @@ public class MethodGavrilova extends Method {
return Executor.loadMotionJointCoordinates(fileAMC);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
ClassifierGavrilova classifier = new ClassifierGavrilova();
DistanceFunctionDTW distanceFunctionJRD = new DistanceFunctionDTW();
DistanceFunctionDTW distanceFunctionJRA = new DistanceFunctionDTW();
PoseDistanceGavrilovaJRD poseDistanceJRD = new PoseDistanceGavrilovaJRD();
PoseDistanceGavrilovaJRA poseDistanceJRA = new PoseDistanceGavrilovaJRA();