Commit 6dff6935 authored by Michal Balazia's avatar Michal Balazia
Browse files

NEW Added all source code with libs

parents
package algorithms;
import Jama.Matrix;
import java.io.Serializable;
import java.util.List;
import objects.Template;
public abstract class Classifier implements Serializable {
private DistanceFunction distanceFunction;
private Matrix transformationMatrix;
public Classifier() {
}
public DistanceFunction getDistanceFunction() {
return distanceFunction;
}
public void setDistanceFunction(DistanceFunction distanceFunction) {
this.distanceFunction = distanceFunction;
}
public Matrix getTransformationMatrix() {
return transformationMatrix;
}
public void setTransformationMatrix(Matrix transformationMatrix) {
this.transformationMatrix = transformationMatrix;
}
public abstract int classify(Template templateQuery, List<Template> templatesGallery);
}
package algorithms;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import objects.Template;
public class ClassifierKNNMV extends Classifier {
private final int k;
public ClassifierKNNMV(int k) {
this.k = k;
}
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
// get k closest descriptors
Map<Double, Template> kClosestTemplates = new HashMap(k);
DistanceFunction distanceFunction = getDistanceFunction();
double maxDistance = 0.0;
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(templateQuery, template);
if (kClosestTemplates.size() < k || distance < maxDistance) {
kClosestTemplates.put(distance, template);
if (kClosestTemplates.size() > k) {
kClosestTemplates.remove(maxDistance);
}
}
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
maxDistance = 0.0;
while (iterator.hasNext()) {
Double key = iterator.next();
if (maxDistance < key) {
maxDistance = key;
}
}
}
// majority voting
List<List<Integer>> votes = new ArrayList();
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
while (iterator.hasNext()) {
Template template = kClosestTemplates.get(iterator.next());
boolean voteIsInVotes = false;
for (List<Integer> vote : votes) {
if (vote.get(0) == template.getSubject()) {
voteIsInVotes = true;
vote.set(1, vote.get(1) + 1);
break;
}
}
if (!voteIsInVotes) {
List<Integer> vote = new ArrayList();
vote.add(template.getSubject());
vote.add(1);
votes.add(vote);
}
}
int maxNumberVotes = 0;
int winSubject = 0;
for (List<Integer> vote : votes) {
int numberVotes = vote.get(1);
if (maxNumberVotes < numberVotes) {
maxNumberVotes = numberVotes;
winSubject = vote.get(0);
}
}
return winSubject;
}
}
package algorithms;
import java.util.List;
import java.util.Random;
import objects.Template;
public class ClassifierRandom extends Classifier {
public ClassifierRandom() {
}
@Override
public int classify(Template templateQuery, List<Template> templatesGallery) {
return templatesGallery.get(new Random().nextInt(templatesGallery.size())).getSubject();
}
}
package algorithms;
import java.io.Serializable;
import objects.Template;
public abstract class DistanceFunction implements Serializable {
public DistanceFunction() {
}
public abstract double getDistance(Template template1, Template template2);
}
package algorithms;
import objects.Template;
public final class DistanceFunctionBaseline extends DistanceFunction {
private PoseDistance poseDistance;
public DistanceFunctionBaseline() {
}
public void setPoseDistance(PoseDistance poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template sample1, Template sample2) {
int length = sample1.getLength();
if (length != sample2.getLength()) {
System.out.println("different lengths");
sample2.adjust(length);
}
double distance = 0f;
for (int l = 0; l < length; l++) {
distance += Math.abs(poseDistance.getDistance(sample1.getPose(l), sample2.getPose(l)));
}
return distance;
}
}
package algorithms;
import objects.Template;
public final class DistanceFunctionDTW extends DistanceFunction {
private PoseDistance poseDistance;
public DistanceFunctionDTW() {
}
public void setPoseDistance(PoseDistance poseDistance) {
this.poseDistance = poseDistance;
}
@Override
public double getDistance(Template sample1, Template sample2) {
int n = sample1.getLength(), m = sample2.getLength();
double cm[][] = new double[n][m]; // accum matrix
// create the accum matrix
cm[0][0] = poseDistance.getDistance(sample1.getPose(0), sample2.getPose(0));
for (int i = 1; i < n; i++) {
cm[i][0] = poseDistance.getDistance(sample1.getPose(i), sample2.getPose(0)) + cm[i - 1][0];
}
for (int j = 1; j < m; j++) {
cm[0][j] = poseDistance.getDistance(sample1.getPose(0), sample2.getPose(j)) + cm[0][j - 1];
}
// Compute the matrix values
for (int i = 1; i < n; i++) {
for (int j = 1; j < m; j++) {
// Decide on the path with minimum distance so far
cm[i][j] = poseDistance.getDistance(sample1.getPose(i), sample2.getPose(j)) + Math.min(cm[i - 1][j], Math.min(cm[i][j - 1], cm[i - 1][j - 1]));
}
}
return cm[n - 1][m - 1];
}
}
package algorithms;
import java.util.Random;
import objects.Template;
public final class DistanceFunctionRandom extends DistanceFunction {
public DistanceFunctionRandom() {
}
@Override
public double getDistance(Template template1, Template template2) {
return new Random().nextDouble();
}
}
package algorithms;
import java.io.Serializable;
public abstract class PoseDistance implements Serializable {
public PoseDistance() {
}
public abstract double getDistance(double[] pose1, double[] pose2);
}
package algorithms;
public final class PoseDistanceL1 extends PoseDistance {
public PoseDistanceL1() {
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
float distance = 0f;
for (int d = 0; d < pose1.length; d++) {
distance += Math.abs(pose1[d] - pose2[d]);
}
return distance;
}
}
package algorithms;
import Jama.Matrix;
public final class PoseDistanceMahalanobis extends PoseDistance {
Matrix matrix;
public PoseDistanceMahalanobis() {
}
public void setMatrix(Matrix matrix) {
this.matrix = matrix;
}
@Override
public double getDistance(double[] pose1, double[] pose2) {
int dimension = pose1.length;
if (pose2.length != dimension || matrix.getColumnDimension() != dimension || matrix.getRowDimension() != dimension) {
System.out.println("wrong dimensions");
}
Matrix diff = new Matrix(pose1,pose1.length).minus(new Matrix(pose2,pose2.length));
return Math.sqrt(diff.transpose().times(matrix).times(diff).get(0, 0));
}
}
package algorithms;
import java.io.Serializable;
import java.util.List;
import objects.Template;
public abstract class Retriever implements Serializable {
private DistanceFunction distanceFunction;
public Retriever() {
}
public DistanceFunction getDistanceFunction() {
return distanceFunction;
}
public void setDistanceFunction(DistanceFunction distanceFunction) {
this.distanceFunction = distanceFunction;
}
public abstract List<Template> retrieve(Template sampleTest, List<Template> samplesGallery);
}
package algorithms;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import objects.Template;
public class RetrieverKNNQuery extends Retriever {
private final int k;
public RetrieverKNNQuery(int k) {
this.k = k;
}
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
DistanceFunction distanceFunction = getDistanceFunction();
Map<Double, Template> kClosestTemplates = new HashMap(k);
double maxDistance = 0.0;
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(template, templateTest);
if (kClosestTemplates.size() < k || distance < maxDistance) {
kClosestTemplates.put(distance, template);
if (kClosestTemplates.size() > k) {
kClosestTemplates.remove(maxDistance);
}
}
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
maxDistance = 0.0;
while (iterator.hasNext()) {
Double key = iterator.next();
if (maxDistance < key) {
maxDistance = key;
}
}
}
List<Template> templatesRetrieved = new ArrayList();
Iterator<Double> iterator = kClosestTemplates.keySet().iterator();
while (iterator.hasNext()) {
templatesRetrieved.add(kClosestTemplates.get(iterator.next()));
}
return templatesRetrieved;
}
}
package algorithms;
import java.util.ArrayList;
import java.util.List;
import objects.Template;
public class RetrieverRangeQuery extends Retriever {
private double r;
public RetrieverRangeQuery() {
}
public void setR(double r) {
this.r = r;
}
@Override
public List<Template> retrieve(Template templateTest, List<Template> templatesGallery) {
DistanceFunction distanceFunction = getDistanceFunction();
List<Template> retrievedSamples = new ArrayList();
for (Template template : templatesGallery) {
double distance = distanceFunction.getDistance(template, templateTest);
if (distance <= r) {
retrievedSamples.add(template);
}
}
return retrievedSamples;
}
}
package executor;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import objects.FrameJointCoordinates;
import objects.MotionJointCoordinates;
import objects.FrameBoneRotations;
import objects.MotionBoneRotations;
import objects.Joint;
import objects.Trajectory;
import objects.Bone;
import objects.Feature;
import objects.Motion;
import objects.Skeleton;
import objects.Trajectory3d;
import objects.Triple;
import methods.Method;
import methods.MethodAhmed;
import methods.MethodAli;
import methods.MethodAndersson;
import methods.MethodBall;
import methods.MethodDikovski;
import methods.MethodGavrilova;
import methods.MethodJiang;
import methods.MethodKrzeszowski;
import methods.MethodKumar;
import methods.MethodKwolek;
import methods.MethodPreis;
import methods.MethodSedmidubsky;
import methods.MethodSinha;
import methods.Method_MMC;
import methods.Method_PCALDA;
import methods.Method_Random;
import methods.Method_Raw;
import objects.ClassifiableList;
import objects.Parameter;
import objects.Template;
public class Executor {
/*
File locations
*/
private static final File prototypicalFileASF = new File("skeleton.asf");
private static final File prototypicalFileAMC = new File("gaitcycle.amc");
private static final File customQueryFileAMC = new File("query.amc");
private static final File customGalleryDirectory = new File("gallery/");
private static final File customClassifier = new File("classifiers/_MMC-302.0.classifier");
/*
Select distance threshold
Each value refers to a particular sub-database
*/
private static final double distanceThreshold = 302.0; // 56.3(2,35), 59.4(4,67), 63.3(8,130), 73.7(16,302), 173.3(32,2047), 302.0(54,3843), 495.3(64,5923)
/*
Select actions
*/
private static final boolean extractDatabase = false;
private static final boolean learnClassifiers = true;
private static final boolean performClassifications = false;
private static final boolean evaluateMethods = false;
/*
Select evaluation set-ups at evaluateMethods()
*/
private static final boolean evaluateHomogeneous = true;
private static final boolean evaluateHeterogeneous = false;
/*
All implemented MoCap-based gait recognition methods
*/
private static final Method[] methods = new Method[]{
new MethodAhmed(),//0 OK
new MethodAli(),//1 OK
new MethodAndersson(),//2 OK
new MethodBall(),//3 OK
new MethodDikovski(),//4 OK
new MethodGavrilova(),//5 SLOW
new MethodJiang(),//6 SLOW
new MethodKrzeszowski(),//7 SLOW
new MethodKumar(),//8 SLOW
new MethodKwolek(),//9 OK
new MethodPreis(),//10 OK
new MethodSedmidubsky(),//11 SLOW
new MethodSinha(),//12 OK
new Method_MMC(),//13 OK
new Method_PCALDA(),//14 OK
new Method_Random(),//15 OK
new Method_Raw(),//16 SLOW
};
/*
Main method
Executes the actions selected above
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
if (extractDatabase) {
extractDatabase();
}
if (learnClassifiers) {
learnClassifiers();
}
if (performClassifications) {
performClassifications();
}
if (evaluateMethods) {
evaluateMethods();
}
}
/*
Normalize the whole original CMU database in amcOriginal/ directory with respect to person’s position and walk direction and saves it to amc/ directory
Select gait cycles as sub-motions with similarity under selected distanceThreshold to prototypicalFileAMC and save them to separate AMC files {subject}_{id}[{from}-{to}].amc in amc{distanceThreshold}/ directory
Only subjects of at least 10 gait cycles are kept.
*/
private static void extractDatabase() throws IOException {
normalizeDatabase();
extractGaitCycles();
}
/*
Extracts classifiers for all methods learned on the sub-database determined by distanceThreshold
Classifiers are saved as separate files {distanceThreshold}-{methodName}.classifier in classifiers/ directory
*/
private static void learnClassifiers() throws IOException {
new File("classifiers/").mkdirs();
for (Method method : methods) {
long beginning = System.currentTimeMillis();
method.learnClassifier(method.loadMotions(new File("amc" + distanceThreshold + "/")));
method.saveClassifier(new File("classifiers/" + method.getName() + "-" + distanceThreshold + ".classifier"));
System.out.print("time" + method + "=" + printTime(beginning) + "\r\n");
}
}
/*
Performs classifications of customQueryFileAMC over customGalleryDirectory for all methods
Instead of learning classifiers methods load their extracted classifiers {distanceThreshold}-{methodName}.classifier in classifiers/ directory that are learned on the sub-database determined by distanceThreshold
*/
private static void performClassifications() throws IOException, ClassNotFoundException {
for (Method method : methods) {
method.loadClassifier(customClassifier);
Template templateQuery = method.extractTemplate(method.loadMotion(customQueryFileAMC));
List<Template> templatesGallery = method.extractTemplates(method.loadMotions(customGalleryDirectory));
System.out.print(method.getName() + ": query subject classified as " + method.recognize(templateQuery, templatesGallery) + "\r\n");
}
}
/*
Performs evaluations for all methods
*/
private static void evaluateMethods() throws IOException {
long beginningTotal = System.currentTimeMillis();
for (Method method : methods) {
System.out.print(method.getName() + "," + distanceThreshold + "\r\n");
List<Motion> motions = method.loadMotions(new File("amc" + distanceThreshold + "/"));
int numberOfClasses = new ClassifiableList(motions).splitClasses().size();
method.setNFoldsEvaluation(10); // number of folds in evaluation data
method.setFineness(30); // fineness of FAR/FRR, GAR/IAR, RCL/PCN
if (evaluateHomogeneous) {
// for (int i = 2; i <= 27; i++) {//method.getNumberOfClasses()
int nClasses = numberOfClasses; // number of learning and evaluation classes
int nFoldsLearning = 3; // number of learning folds
int attempts = 1; // number of evaluation attempts
long beginningHomogeneous = System.currentTimeMillis();
// System.out.print("i=" + i + ",");//
method.evaluateHomogeneous(motions, nClasses, nFoldsLearning, attempts);
System.out.print("timeHomogeneous=" + printTime(beginningHomogeneous) + "\r\n\r\n");
// }//
}
if (evaluateHeterogeneous) {
// for (int i = 2; i <= method.getNumberOfClasses() - 2; i++) {//method.getNumberOfClasses() - 2
int nClassesLearning = 2; // number of learning classes
int nClassesEvaluation = numberOfClasses - 2; // number of evaluation classes
int attempts = 1; // number of evaluation attempts
long beginningHeterogeneous = System.currentTimeMillis();
// System.out.print("i=" + i + ",");//
method.evaluateHeterogeneous(motions, nClassesLearning, nClassesEvaluation, attempts);
System.out.print("timeHeterogeneous=" + printTime(beginningHeterogeneous) + "\r\n\r\n");
// }//
}
}
System.out.print("timeTotal=" + printTime(beginningTotal));
}
private static void normalizeDatabase() throws IOException {
new File("amc/").mkdirs();
for (File fileAMC : new File("amcOriginal/").listFiles()) {
BufferedReader brAMC = new BufferedReader(new FileReader(fileAMC));
BufferedWriter wrAMC = new BufferedWriter(new FileWriter("amc/" + fileAMC.getName()));
String lineAMC = brAMC.readLine();
while (lineAMC != null) {
if (lineAMC.split(" ")[0].equals("root")) {
wrAMC.write("root 0 0 0 0 0 0\r\n");
} else {
wrAMC.write(lineAMC + "\r\n");
}
lineAMC = brAMC.readLine();
}
wrAMC.flush();
wrAMC.close();