Commit 07d6abac authored by Michal Balazia's avatar Michal Balazia
Browse files

closed/open set evaluation

parent 7ea2c48d
package algorithms;
import Jama.Matrix;
import java.io.Serializable;
import java.util.List;
import objects.DistanceMatrix;
import objects.Template;
public abstract class Classifier implements Serializable {
private DistanceTemplates distanceTemplates;
private DistanceMatrix distanceMatrix;
private Matrix transformationMatrix;
public Classifier() {
}
public DistanceTemplates getDistanceTemplates() {
return distanceTemplates;
}
public void setDistanceTemplates(DistanceTemplates distanceTemplates) {
this.distanceTemplates = distanceTemplates;
}
public DistanceMatrix getDistanceMatrix() {
return distanceMatrix;
}
public void setDistanceMatrix(DistanceMatrix distanceMatrix) {
this.distanceMatrix = distanceMatrix;
}
public Matrix getTransformationMatrix() {
return transformationMatrix;
}
public void setTransformationMatrix(Matrix transformationMatrix) {
this.transformationMatrix = transformationMatrix;
}
public abstract int classify(Template templateQuery, List<Template> templatesGallery);
public double getDistance(Template template1, Template template2) {
return distanceMatrix == null ? distanceTemplates.getDistance(template1, template2) : distanceMatrix.getDistance(template1, template2);
}
}
package algorithms;
import java.util.List;
import objects.Template;
public class Classifier1NN extends Classifier {
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
Template closestTemplate = new Template(0, "", null);
double minDistance = Double.MAX_VALUE;
for (Template template : templatesGallery) {
double distance = getDistance(template, templateQuery);
if (minDistance > distance) {
minDistance = distance;
closestTemplate = template;
}
}
return closestTemplate.getSubject();
}
}
package algorithms;
import java.util.ArrayList;
import java.util.List;
import objects.ClassifiableList;
import objects.Template;
public class ClassifierLinear extends Classifier {
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
List<Template> centroids = new ArrayList();
List<List<Template>> templatesByClass = new ClassifiableList(templatesGallery).splitClasses();
for (List<Template> templatesOfClass : templatesByClass) {
Template centroid = new Template(0, "", null);
if (templatesOfClass.isEmpty()) {
System.out.println("empty class");
} else {
double minSum = Double.MAX_VALUE;
for (Template centroidCandidate : templatesOfClass) {
double sum = 0;
for (Template template : templatesOfClass) {
sum += getDistance(template, centroidCandidate);
}
if (minSum > sum) {
minSum = sum;
centroid = centroidCandidate;
}
}
}
centroids.add(centroid);
}
Template closestCentroid = new Template(0, "", null);
double minDistance = Double.MAX_VALUE;
for (Template centroid : centroids) {
double distance = getDistance(centroid, templateQuery);
if (minDistance > distance) {
minDistance = distance;
closestCentroid = centroid;
}
}
return closestCentroid.getSubject();
}
}
package algorithms;
import java.util.List;
import java.util.Random;
import objects.Template;
public class ClassifierRandom extends Classifier {
public ClassifierRandom() {
}
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
return templatesGallery.get(new Random().nextInt(templatesGallery.size())).getSubject();
}
}
package algorithms;
import java.io.Serializable;
import objects.Template;
public abstract class DistanceTemplates implements Serializable {
public DistanceTemplates() {
}
public abstract double getDistance(Template template1, Template template2);
}
package algorithms;
import objects.Template;
public final class DistanceTemplatesBaseline extends DistanceTemplates {
private DistanceVectors poseDistance;
public DistanceTemplatesBaseline() {
}
public void setDistanceVectors(DistanceVectors poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template template1, Template template2) {
int length = template1.getLength();
if (length != template2.getLength()) {
System.out.println("different lengths");
template2.adjust(length);
}
double distance = 0;
for (int l = 0; l < length; l++) {
distance += Math.abs(poseDistance.getDistance(template1.getPose(l), template2.getPose(l)));
}
return distance;
}
}
package algorithms;
import objects.Template;
public final class DistanceTemplatesDTW extends DistanceTemplates {
private DistanceVectors poseDistance;
public DistanceTemplatesDTW() {
}
public void setDistanceVectors(DistanceVectors poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template template1, Template template2) {
int n = template1.getLength(), m = template2.getLength();
double cm[][] = new double[n][m]; // accum matrix
// create the accum matrix
cm[0][0] = poseDistance.getDistance(template1.getPose(0), template2.getPose(0));
for (int i = 1; i < n; i++) {
cm[i][0] = poseDistance.getDistance(template1.getPose(i), template2.getPose(0)) + cm[i - 1][0];
}
for (int j = 1; j < m; j++) {
cm[0][j] = poseDistance.getDistance(template1.getPose(0), template2.getPose(j)) + cm[0][j - 1];
}
// Compute the matrix values
for (int i = 1; i < n; i++) {
for (int j = 1; j < m; j++) {
// Decide on the path with minimum distance so far
cm[i][j] = poseDistance.getDistance(template1.getPose(i), template2.getPose(j)) + Math.min(cm[i - 1][j], Math.min(cm[i][j - 1], cm[i - 1][j - 1]));
}
}
return cm[n - 1][m - 1];
}
}
package algorithms;
import java.util.Random;
import objects.Template;
public final class DistanceTemplatesRandom extends DistanceTemplates {
public DistanceTemplatesRandom() {
}
@Override
public double getDistance(Template template1, Template template2) {
return new Random().nextDouble();
}
}
package algorithms;
import java.io.Serializable;
public abstract class DistanceVectors implements Serializable {
public DistanceVectors() {
}
public abstract double getDistance(double[] pose1, double[] pose2);
}
package algorithms;
public final class DistanceVectorsL1 extends DistanceVectors {
public DistanceVectorsL1() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
if (pose1.length != pose2.length) {
System.out.println("wrong dimensions");
}
double sum = 0;
for (int d = 0; d < pose1.length; d++) {
sum += Math.abs(pose1[d] - pose2[d]);
}
return sum;
}
}
package algorithms;
public final class DistanceVectorsL2 extends DistanceVectors {
public DistanceVectorsL2() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
if (pose1.length != pose2.length) {
System.out.println("wrong dimensions");
}
double sum = 0;
for (int d = 0; d < pose1.length; d++) {
sum += Math.pow(pose1[d] - pose2[d], 2);
}
return Math.sqrt(sum);
}
}
package algorithms;
import Jama.Matrix;
public final class DistanceVectorsMahalanobis extends DistanceVectors {
Matrix matrix;
public DistanceVectorsMahalanobis() {
}
public void setMatrix(Matrix matrix) {
this.matrix = matrix;
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
int dimension = pose1.length;
if (pose2.length != dimension || matrix.getColumnDimension() != dimension || matrix.getRowDimension() != dimension) {
System.out.println("wrong dimensions");
}
Matrix diff = new Matrix(pose1,pose1.length).minus(new Matrix(pose2,pose2.length));
return Math.sqrt(diff.transpose().times(matrix).times(diff).get(0, 0));
}
}
package algorithms;
import java.io.Serializable;
import java.util.List;
import objects.DistanceMatrix;
import objects.Template;
public abstract class Retriever implements Serializable {
private DistanceTemplates distanceTemplates;
private DistanceMatrix distanceMatrix;
public Retriever() {
}
public void setDistanceTemplates(DistanceTemplates distanceTemplates) {
this.distanceTemplates = distanceTemplates;
}
public void setDistanceMatrix(DistanceMatrix distanceMatrix) {
this.distanceMatrix = distanceMatrix;
}
public abstract List<Template> retrieve(Template sampleTest, List<Template> samplesGallery);
public double getDistance(Template template1, Template template2) {
return distanceMatrix == null ? distanceTemplates.getDistance(template1, template2) : distanceMatrix.getDistance(template1, template2);
}
}
package algorithms;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import objects.Template;
public class RetrieverKNNQuery extends Retriever {
private final int k;
public RetrieverKNNQuery(int k) {
this.k = k;
}
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
Map<Double, Template> kClosestTemplates = new HashMap(k);
double maxDistance = 0;
for (Template template : templatesGallery) {
double distance = getDistance(template, templateTest);
if (kClosestTemplates.size() < k || distance < maxDistance) {
kClosestTemplates.put(distance, template);
if (kClosestTemplates.size() > k) {
kClosestTemplates.remove(maxDistance);
}
}
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
maxDistance = 0;
while (iterator.hasNext()) {
Double key = iterator.next();
if (maxDistance < key) {
maxDistance = key;
}
}
}
List<Template> templatesRetrieved = new ArrayList();
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
while (iterator.hasNext()) {
templatesRetrieved.add(kClosestTemplates.get(iterator.next()));
}
return templatesRetrieved;
}
}
package algorithms;
import java.util.ArrayList;
import java.util.List;
import objects.Template;
public class RetrieverRangeQuery extends Retriever {
private double r;
public RetrieverRangeQuery() {
}
public void setR(double r) {
this.r = r;
}
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
List<Template> retrievedSamples = new ArrayList();
for (Template template : templatesGallery) {
if (getDistance(template, templateTest) <= r) {
retrievedSamples.add(template);
}
}
return retrievedSamples;
}
}
This diff is collapsed.
This diff is collapsed.
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import objects.Axis;
import objects.Feature;
import objects.Motion;
import objects.MotionJointCoordinates;
import objects.Parameter;
public class MethodAhmed extends Method {
public MethodAhmed() {
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
@Override
public List<Feature> extractFeatures(Motion motion) { // HDF and VDF
MotionJointCoordinates motionMJC = (MotionJointCoordinates) motion;
List<Feature> features = new ArrayList();
features.add(new Feature(new Parameter("meanHDF1", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "ltibia", "rtibia").getMean())));
features.add(new Feature(new Parameter("meanHDF2", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lfemur", "rfemur").getMean())));
features.add(new Feature(new Parameter("meanHDF3", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lwrist", "rwrist").getMean())));
features.add(new Feature(new Parameter("meanHDF4", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lclavicle", "rclavicle").getMean())));
features.add(new Feature(new Parameter("stdHDF1", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "ltibia", "rtibia").getMean())));
features.add(new Feature(new Parameter("stdHDF2", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lfemur", "rfemur").getMean())));
features.add(new Feature(new Parameter("stdHDF3", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lwrist", "rwrist").getMean())));
features.add(new Feature(new Parameter("stdHDF4", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lclavicle", "rclavicle").getMean())));
features.add(new Feature(new Parameter("skewHDF1", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "ltibia", "rtibia").getMean())));
features.add(new Feature(new Parameter("skewHDF2", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lfemur", "rfemur").getMean())));
features.add(new Feature(new Parameter("skewHDF3", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lwrist", "rwrist").getMean())));
features.add(new Feature(new Parameter("skewHDF4", motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lclavicle", "rclavicle").getMean())));
features.add(new Feature(new Parameter("meanVDF1", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "head").getMean())));
features.add(new Feature(new Parameter("meanVDF2", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rwrist").getMean())));
features.add(new Feature(new Parameter("meanVDF3", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rclavicle").getMean())));
features.add(new Feature(new Parameter("meanVDF4", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rtibia").getMean())));
features.add(new Feature(new Parameter("meanVDF5", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "ltibia").getMean())));
features.add(new Feature(new Parameter("meanVDF6", 0.5 * motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lfoot", "rfoot").getMean() * motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "root").getMean())));
features.add(new Feature(new Parameter("stdVDF1", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "head").getStD())));
features.add(new Feature(new Parameter("stdVDF2", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rwrist").getStD())));
features.add(new Feature(new Parameter("stdVDF3", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rclavicle").getStD())));
features.add(new Feature(new Parameter("stdVDF4", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "rtibia").getStD())));
features.add(new Feature(new Parameter("stdVDF5", motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "ltibia").getStD())));
features.add(new Feature(new Parameter("stdVDF6", 0.5 * motionMJC.extractJointAxisDistanceFeature(new Axis("Z"), "lfoot", "rfoot").getStD() * motionMJC.extractJointAxisCoordinateFeature(new Axis("Y"), "root").getStD())));
return features;
}
}
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import objects.Feature;
import objects.Motion;
import objects.MotionJointCoordinates;
import objects.Parameter;
public class MethodAli extends Method {
public MethodAli() {
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
@Override
public List<Feature> extractFeatures(Motion motion) { // mean of the hip-knee-ankle triangle areas
MotionJointCoordinates motionMJC = (MotionJointCoordinates) motion;
List<Feature> features = new ArrayList();
features.add(new Feature(new Parameter("left", motionMJC.extractTriangleAreaFeature("lhipjoint", "lfemur", "ltibia").getMean())));
features.add(new Feature(new Parameter("right", motionMJC.extractTriangleAreaFeature("rhipjoint", "rfemur", "rtibia").getMean())));
return features;
}
}
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import objects.Axis;
import objects.Feature;
import objects.FrameJointCoordinates;
import objects.Motion;
import objects.MotionJointCoordinates;
import objects.Parameter;
public class MethodAndersson extends Method {
public MethodAndersson() {
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
@Override
public List<Feature> extractFeatures(Motion motion) { //36 gait attributes: (mean and std) of ((min and max) of (4 pairs of (lower body angles))), step length (max of ltibia-rtibia), stride length (2x step length), cycle time (length / 120), velocity (stride length / cycle time); 32 anthropometric attributes: (mean and std) of bone lengths and height
MotionJointCoordinates motionMJC = (MotionJointCoordinates) motion;
int length = motionMJC.getLength();
List<Feature> features = new ArrayList();
String[][] angleStrings = new String[][]{{"Y", "lhipjoint", "lfemur"}, {"Y", "rhipjoint", "rfemur"}, {"Y", "lfemur", "ltibia"}, {"Y", "rfemur", "rtibia"}, {"Y", "ltibia", "lfoot"}, {"Y", "rtibia", "rfoot"}, {"Z", "ltibia", "lfoot"}, {"Z", "rtibia", "rfoot"}};
for (String[] angleString : angleStrings) {
Parameter[] angleParameters = new Parameter[length];
for (int l = 0; l < length; l++) {
angleParameters[l] = ((FrameJointCoordinates) motionMJC.getFrame(l)).extractBoneAxisAngleParameter(new Axis(angleString[0]), angleString[1], angleString[2]);
}
Feature angleFeature = new Feature(angleParameters);
List<Parameter> angleLocalMins = new ArrayList();
angleLocalMins.add(angleFeature.getParameter(0)); // 0%