Commit d4c74ee6 authored by Michal Balazia's avatar Michal Balazia
Browse files

change evaluation for class separability and classification metrics

parent 0ac9b0dd
......@@ -39,7 +39,9 @@ public abstract class Classifier implements Serializable {
this.transformationMatrix = transformationMatrix;
}
public abstract int classify(Template templateQuery, List<Template> templatesGallery);
public abstract void importTemplatesGallery(List<Template> templatesGallery);
public abstract String classify(Template templateProbe);
public double getDistance(Template template1, Template template2) {
return distanceMatrix == null ? distanceTemplates.getDistance(template1, template2) : distanceMatrix.getDistance(template1, template2);
......
......@@ -4,13 +4,20 @@ import java.util.List;
import objects.Template;
public class Classifier1NN extends Classifier {
private List<Template> templatesGallery;
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
Template closestTemplate = new Template(0, "", null);
public void importTemplatesGallery(List<Template> templatesGallery) {
this.templatesGallery = templatesGallery;
}
@Override
public String classify(Template templateProbe) {
Template closestTemplate = templatesGallery.get(0);
double minDistance = Double.MAX_VALUE;
for (Template template : templatesGallery) {
double distance = getDistance(template, templateQuery);
double distance = getDistance(template, templateProbe);
if (minDistance > distance) {
minDistance = distance;
closestTemplate = template;
......
......@@ -7,15 +7,17 @@ import objects.Template;
public class ClassifierLinear extends Classifier {
private List<Template> centroids;
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
List<Template> centroids = new ArrayList();
List<List<Template>> templatesByClass = new ClassifiableList(templatesGallery).splitClasses();
public void importTemplatesGallery(List<Template> templatesGallery) {
centroids = new ArrayList();
List<List<Template>> templatesByClass = new ClassifiableList(templatesGallery).splitBySubject();
for (List<Template> templatesOfClass : templatesByClass) {
Template centroid = new Template(0, "", null);
if (templatesOfClass.isEmpty()) {
System.out.println("empty class");
} else {
Template centroid = templatesOfClass.get(0);
double minSum = Double.MAX_VALUE;
for (Template centroidCandidate : templatesOfClass) {
double sum = 0;
......@@ -27,13 +29,17 @@ public class ClassifierLinear extends Classifier {
centroid = centroidCandidate;
}
}
centroids.add(centroid);
}
centroids.add(centroid);
}
Template closestCentroid = new Template(0, "", null);
}
@Override
public String classify(Template templateProbe) {
Template closestCentroid = centroids.get(0);
double minDistance = Double.MAX_VALUE;
for (Template centroid : centroids) {
double distance = getDistance(centroid, templateQuery);
double distance = getDistance(centroid, templateProbe);
if (minDistance > distance) {
minDistance = distance;
closestCentroid = centroid;
......
......@@ -5,12 +5,16 @@ import java.util.Random;
import objects.Template;
public class ClassifierRandom extends Classifier {
private List<Template> templatesGallery;
public ClassifierRandom() {
@Override
public void importTemplatesGallery(List<Template> templatesGallery) {
this.templatesGallery = templatesGallery;
}
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
public String classify(Template templateProbe) {
return templatesGallery.get(new Random().nextInt(templatesGallery.size())).getSubject();
}
}
......@@ -2,9 +2,9 @@ package algorithms;
import java.io.Serializable;
public abstract class DistanceVectors implements Serializable {
public abstract class DistancePoses implements Serializable {
public DistanceVectors() {
public DistancePoses() {
}
public abstract double getDistance(double[] pose1, double[] pose2);
......
package algorithms;
public final class DistanceVectorsL1 extends DistanceVectors {
public final class DistancePosesL1 extends DistancePoses {
public DistanceVectorsL1() {
public DistancePosesL1() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
if (pose1.length != pose2.length) {
System.out.println("wrong dimensions");
return 0;
}
double sum = 0;
for (int d = 0; d < pose1.length; d++) {
......
......@@ -4,13 +4,13 @@ import objects.Template;
public final class DistanceTemplatesDTW extends DistanceTemplates {
private DistanceVectors poseDistance;
private DistancePoses distancePoses;
public DistanceTemplatesDTW() {
}
public void setDistanceVectors(DistanceVectors poseDistance) {
this.poseDistance = poseDistance;
public void setDistancePoses(DistancePoses distancePoses) {
this.distancePoses = distancePoses;
}
@Override
......@@ -18,18 +18,18 @@ public final class DistanceTemplatesDTW extends DistanceTemplates {
int n = template1.getLength(), m = template2.getLength();
double cm[][] = new double[n][m]; // accum matrix
// create the accum matrix
cm[0][0] = poseDistance.getDistance(template1.getPose(0), template2.getPose(0));
cm[0][0] = distancePoses.getDistance(template1.getPose(0), template2.getPose(0));
for (int i = 1; i < n; i++) {
cm[i][0] = poseDistance.getDistance(template1.getPose(i), template2.getPose(0)) + cm[i - 1][0];
cm[i][0] = distancePoses.getDistance(template1.getPose(i), template2.getPose(0)) + cm[i - 1][0];
}
for (int j = 1; j < m; j++) {
cm[0][j] = poseDistance.getDistance(template1.getPose(0), template2.getPose(j)) + cm[0][j - 1];
cm[0][j] = distancePoses.getDistance(template1.getPose(0), template2.getPose(j)) + cm[0][j - 1];
}
// Compute the matrix values
for (int i = 1; i < n; i++) {
for (int j = 1; j < m; j++) {
// Decide on the path with minimum distance so far
cm[i][j] = poseDistance.getDistance(template1.getPose(i), template2.getPose(j)) + Math.min(cm[i - 1][j], Math.min(cm[i][j - 1], cm[i - 1][j - 1]));
cm[i][j] = distancePoses.getDistance(template1.getPose(i), template2.getPose(j)) + Math.min(cm[i - 1][j], Math.min(cm[i][j - 1], cm[i - 1][j - 1]));
}
}
return cm[n - 1][m - 1];
......
......@@ -2,15 +2,15 @@ package algorithms;
import objects.Template;
public final class DistanceTemplatesBaseline extends DistanceTemplates {
public final class DistanceTemplatesL1 extends DistanceTemplates {
private DistanceVectors poseDistance;
private DistancePoses distancePoses;
public DistanceTemplatesBaseline() {
public DistanceTemplatesL1() {
}
public void setDistanceVectors(DistanceVectors poseDistance) {
this.poseDistance = poseDistance;
public void setDistancePoses(DistancePoses distancePoses) {
this.distancePoses = distancePoses;
}
@Override
......@@ -18,11 +18,11 @@ public final class DistanceTemplatesBaseline extends DistanceTemplates {
int length = template1.getLength();
if (length != template2.getLength()) {
System.out.println("different lengths");
template2.adjust(length);
return 0;
}
double distance = 0;
for (int l = 0; l < length; l++) {
distance += Math.abs(poseDistance.getDistance(template1.getPose(l), template2.getPose(l)));
distance += Math.abs(distancePoses.getDistance(template1.getPose(l), template2.getPose(l)));
}
return distance;
}
......
package algorithms;
import Jama.Matrix;
import objects.Template;
public final class DistanceVectorsMahalanobis extends DistanceVectors {
public final class DistanceTemplatesMahalanobis extends DistanceTemplates {
Matrix matrix;
public DistanceVectorsMahalanobis() {
public DistanceTemplatesMahalanobis() {
}
public void setMatrix(Matrix matrix) {
......@@ -14,12 +15,8 @@ public final class DistanceVectorsMahalanobis extends DistanceVectors {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
int dimension = pose1.length;
if (pose2.length != dimension || matrix.getColumnDimension() != dimension || matrix.getRowDimension() != dimension) {
System.out.println("wrong dimensions");
}
Matrix diff = new Matrix(pose1,pose1.length).minus(new Matrix(pose2,pose2.length));
public double getDistance(Template template1, Template template2) {
Matrix diff = template1.getMatrix().minus(template2.getMatrix());
return Math.sqrt(diff.transpose().times(matrix).times(diff).get(0, 0));
}
}
package algorithms;
import java.util.Random;
import objects.Template;
public final class DistanceTemplatesRandom extends DistanceTemplates {
public DistanceTemplatesRandom() {
}
@Override
public double getDistance(Template template1, Template template2) {
return new Random().nextDouble();
}
}
package algorithms;
public final class DistanceVectorsL2 extends DistanceVectors {
public DistanceVectorsL2() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
if (pose1.length != pose2.length) {
System.out.println("wrong dimensions");
}
double sum = 0;
for (int d = 0; d < pose1.length; d++) {
sum += Math.pow(pose1[d] - pose2[d], 2);
}
return Math.sqrt(sum);
}
}
package algorithms;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import objects.Motion;
public abstract class MotionLoader implements Serializable {
public MotionLoader() {
}
public abstract Motion loadMotion(File fileAMC) throws IOException;
}
package algorithms;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import objects.PoseBoneRotations;
import objects.Motion;
import objects.MotionBoneRotations;
public class MotionLoaderBoneRotations extends MotionLoader {
@Override
public Motion loadMotion(File fileAMC) throws IOException {
String[] excludedBones = new String[]{"root"};
List<PoseBoneRotations> poses = new ArrayList();
BufferedReader brAMC = new BufferedReader(new FileReader(fileAMC));
String lineAMC = brAMC.readLine();
while (!lineAMC.equals("1")) {
lineAMC = brAMC.readLine();
}
lineAMC = brAMC.readLine();
while (lineAMC != null) {
List<Double> boneRotationsList = new ArrayList();
while (lineAMC != null && lineAMC.split(" ").length != 1) {
String[] split = lineAMC.split(" ");
boolean notExcludedBone = true;
for (String excludedBone : excludedBones) {
if (split[0].equals(excludedBone)) {
notExcludedBone = false;
}
}
if (notExcludedBone) {
for (int i = 1; i < split.length; i++) {
boneRotationsList.add(Double.parseDouble(split[i]));
}
}
lineAMC = brAMC.readLine();
}
double[] boneRotations = new double[boneRotationsList.size()];
for (int i = 0; i < boneRotationsList.size(); i++) {
boneRotations[i] = boneRotationsList.get(i);
}
poses.add(new PoseBoneRotations(boneRotations));
lineAMC = brAMC.readLine();
}
return new MotionBoneRotations(fileAMC.getName().split("\\.")[0], poses);
}
}
package algorithms;
import executor.Constants;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import objects.Bone;
import objects.PoseJointCoordinates;
import objects.Joint;
import objects.MotionJointCoordinates;
import objects.Skeleton;
import objects.Trajectory;
import objects.Trajectory3d;
import objects.Triple;
public class MotionLoaderJointCoordinates extends MotionLoader {
@Override
public MotionJointCoordinates loadMotion(File fileAMC) throws IOException {
Skeleton skeleton = new Skeleton();
skeleton.readASF(Constants.prototypicalFileASF.getPath());
skeleton.readAMC(fileAMC.getPath());
List<Trajectory> trajectories = new ArrayList();
for (Bone bone : skeleton.getBones()) {
Trajectory3d trajectory3d = skeleton.getTrajectory3d(bone.getName());
List<Triple> positions3d = trajectory3d.getPositions();
List<Triple> positions = new ArrayList();
for (int l = 0; l < trajectory3d.getPositions().size(); l++) {
positions.add(positions3d.get(l));
}
trajectories.add(new Trajectory(trajectory3d.getName(), positions));
}
List<PoseJointCoordinates> poses = new ArrayList();
for (int l = 0; l < skeleton.getNumberOfPoses() - 1; l++) {
List<Joint> joints = new ArrayList();
joints.add(new Joint("root", skeleton.getRoot().getPositions().get(l)));
for (Bone bone : skeleton.getBones()) {
Trajectory trajectory = trajectories.get(0);
for (Trajectory traj : trajectories) {
if (traj.getName().equals(bone.getName())) {
trajectory = traj;
}
}
joints.add(new Joint(bone.getName(), trajectory.getPosition(l)));
}
poses.add(new PoseJointCoordinates(joints));
}
return new MotionJointCoordinates(fileAMC.getName().split("\\.")[0], poses);
}
}
package executor;
import java.io.File;
import objects.Axis;
public class Constants {
public static final File prototypicalFileASF = new File("skeleton.asf");
public static final File prototypicalFileAMC = new File("gaitcycle.amc");
public static final File customProbeFileAMC = new File("customProbe.amc");
public static final File customGalleryDirectory = new File("customGallery/");
public static final File customClassifier = new File("customClassifier.classifier");
public static final Axis[] axes = new Axis[]{new Axis("X"), new Axis("Y"), new Axis("Z")};
}
This diff is collapsed.
This diff is collapsed.
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import algorithms.DistanceTemplatesL1;
import algorithms.DistancePosesL1;
import algorithms.MotionLoader;
import java.util.ArrayList;
import java.util.List;
import objects.Axis;
......@@ -16,19 +14,15 @@ import objects.Parameter;
public class MethodAhmed extends Method {
public MethodAhmed() {
public MethodAhmed(MotionLoader motionLoader) {
super(motionLoader);
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
DistanceTemplatesL1 distanceTemplates = new DistanceTemplatesL1();
distanceTemplates.setDistancePoses(new DistancePosesL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import algorithms.DistanceTemplatesL1;
import algorithms.DistancePosesL1;
import algorithms.MotionLoader;
import java.util.ArrayList;
import java.util.List;
import objects.Feature;
......@@ -15,19 +13,15 @@ import objects.Parameter;
public class MethodAli extends Method {
public MethodAli() {
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
public MethodAli(MotionLoader motionLoader) {
super(motionLoader);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
DistanceTemplatesL1 distanceTemplates = new DistanceTemplatesL1();
distanceTemplates.setDistancePoses(new DistancePosesL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......
package methods;
import algorithms.Classifier1NN;
import algorithms.DistanceTemplatesBaseline;
import algorithms.DistanceVectorsL1;
import executor.Executor;
import java.io.File;
import java.io.IOException;
import algorithms.DistanceTemplatesL1;
import algorithms.DistancePosesL1;
import algorithms.MotionLoader;
import java.util.ArrayList;
import java.util.List;
import objects.Axis;
import objects.Feature;
import objects.FrameJointCoordinates;
import objects.PoseJointCoordinates;
import objects.Motion;
import objects.MotionJointCoordinates;
import objects.Parameter;
public class MethodAndersson extends Method {
public MethodAndersson() {
}
@Override
public Motion loadMotion(File fileAMC) throws IOException {
return Executor.loadMotionJointCoordinates(fileAMC);
public MethodAndersson(MotionLoader motionLoader) {
super(motionLoader);
}
@Override
public void learnClassifier(List<Motion> motionsLearning) {
Classifier1NN classifier = new Classifier1NN();
DistanceTemplatesBaseline distanceTemplates = new DistanceTemplatesBaseline();
distanceTemplates.setDistanceVectors(new DistanceVectorsL1());
DistanceTemplatesL1 distanceTemplates = new DistanceTemplatesL1();
distanceTemplates.setDistancePoses(new DistancePosesL1());
classifier.setDistanceTemplates(distanceTemplates);
setClassifier(classifier);
}
......@@ -43,7 +37,7 @@ public class MethodAndersson extends Method {
for (String[] angleString : angleStrings) {
Parameter[] angleParameters = new Parameter[length];
for (int l = 0; l < length; l++) {
angleParameters[l] = ((FrameJointCoordinates) motionMJC.getFrame(l)).extractBoneAxisAngleParameter(new Axis(angleString[0]), angleString[1], angleString[2]);
angleParameters[l] = ((PoseJointCoordinates) motionMJC.getPose(l)).extractBoneAxisAngleParameter(new Axis(angleString[0]), angleString[1], angleString[2]);
}
Feature angleFeature = new Feature(angleParameters);
List<Parameter> angleLocalMins = new ArrayList();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment