Commit acab277e authored by Michal Balazia's avatar Michal Balazia
Browse files

distance map

parent 07d6abac
package algorithms; package algorithms;
import Jama.EigenvalueDecomposition; public class Classifier1NNNareshKumarMS extends Classifier {
import Jama.Matrix;
import objects.Template; public Classifier1NNNareshKumarMS() {
DecisionKNN decision = new DecisionKNN(1);
public class Classifier1NNNareshKumarMS extends Classifier { decision.setDistanceTemplates(new DistanceTemplatesNareshKumarMS());
setDecision(decision);
public Classifier1NNNareshKumarMS() { }
DecisionKNN decision = new DecisionKNN(1); }
decision.setDistanceTemplates(new DistanceTemplatesNareshKumarMS());
setDecision(decision);
}
public final class DistanceTemplatesNareshKumarMS extends DistanceTemplates {
public DistanceTemplatesNareshKumarMS() {
}
@Override
public String getDescription() {
return "NareshKumarMS";
}
@Override
public double getDistance(Template template1, Template template2) {
double sum = 0.0;
double[] generalizedEigenvalues = new EigenvalueDecomposition(getCovarianceMatrix(template2).inverse().times(getCovarianceMatrix(template1))).getRealEigenvalues();
for (int l = 0; l < generalizedEigenvalues.length; l++) {
sum += Math.pow(Math.log(Math.abs(generalizedEigenvalues[l])), 2);
}
return Math.sqrt(sum);
}
private Matrix getCovarianceMatrix(Template template) {
Matrix matrix = template.getMatrix();
int rows = matrix.getRowDimension();
int columns = matrix.getColumnDimension();
Matrix covarianceMatrix = new Matrix(columns, columns);
for (int r = 0; r < rows; r++) {
Matrix d = matrix.getMatrix(r, r, 0, columns - 1);
Matrix u = getDistanceTemplatesMatrix().getMatrix(r, r, 0, columns - 1);
Matrix diff = d.minus(u);
covarianceMatrix.plusEquals((diff.transpose().times(diff)).times((double) 1 / (rows - 1)));
}
return covarianceMatrix;
}
}
}
...@@ -39,8 +39,15 @@ public class ClassifierTransform1NNMMCMahalanobis extends ClassifierTransform { ...@@ -39,8 +39,15 @@ public class ClassifierTransform1NNMMCMahalanobis extends ClassifierTransform {
double[] column = (getMeanMatrix(samplesOfClass).minus(meanSampleMatrix)).times(Math.sqrt((double) samplesOfClass.size() / numberOfSamples)).getColumnPackedCopy(); double[] column = (getMeanMatrix(samplesOfClass).minus(meanSampleMatrix)).times(Math.sqrt((double) samplesOfClass.size() / numberOfSamples)).getColumnPackedCopy();
Upsilon.setMatrix(0, dimension - 1, c, c, new Matrix(column, column.length)); Upsilon.setMatrix(0, dimension - 1, c, c, new Matrix(column, column.length));
} }
SingularValueDecomposition svdChi = new SingularValueDecomposition(Chi); SingularValueDecomposition svdChi;
Matrix Omega = svdChi.getU(); Matrix Omega;
if (dimension < numberOfSamples) {
svdChi = new SingularValueDecomposition(Chi.transpose());
Omega = svdChi.getV().transpose();
} else {
svdChi = new SingularValueDecomposition(Chi);
Omega = svdChi.getU();
}
Matrix ThetaInverseSquareRoot = svdChi.getS(); Matrix ThetaInverseSquareRoot = svdChi.getS();
int dim = Math.min(ThetaInverseSquareRoot.getRowDimension(), ThetaInverseSquareRoot.getColumnDimension()); int dim = Math.min(ThetaInverseSquareRoot.getRowDimension(), ThetaInverseSquareRoot.getColumnDimension());
for (int d = 0; d < dim; d++) { for (int d = 0; d < dim; d++) {
......
package algorithms; package algorithms;
import java.util.List; import java.util.List;
import objects.Template; import objects.Template;
public abstract class Clustering extends Retriever { public abstract class Clusterer extends Retriever {
public Clustering() { public Clusterer() {
} }
public abstract List<List<Template>> cluster(List<Template> templatesGallery); public abstract List<List<Template>> cluster(List<Template> templatesGallery);
public Template getCentroid(List<Template> templates) { public Template getCentroid(List<Template> templates) {
Template centroid = new Template(null, null); Template centroid = new Template(null, null);
if (templates.isEmpty()) { if (templates.isEmpty()) {
System.out.println("empty class"); System.out.println("empty class");
} else { } else {
double minSum = Double.MAX_VALUE; double minSum = Double.MAX_VALUE;
for (Template centroidCandidate : templates) { for (Template centroidCandidate : templates) {
double sum = 0; double sum = 0;
for (Template template : templates) { for (Template template : templates) {
sum += getDistance(template, centroidCandidate); sum += getDistance(template, centroidCandidate);
} }
if (minSum > sum) { if (minSum > sum) {
minSum = sum; minSum = sum;
centroid = centroidCandidate; centroid = centroidCandidate;
} }
} }
} }
return centroid; return centroid;
} }
} }
package algorithms; package algorithms;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import objects.Template; import objects.Template;
public class ClusteringAgglomerativeHierarchical extends Clustering { public class ClustererAgglomerativeHierarchical extends Clusterer {
private final int k; private final int k;
private final double d; private final double d;
public ClusteringAgglomerativeHierarchical(int k) { public ClustererAgglomerativeHierarchical(int k) {
this.k = k; this.k = k;
this.d = Double.MAX_VALUE; this.d = Double.MAX_VALUE;
} }
public ClusteringAgglomerativeHierarchical(double d) { public ClustererAgglomerativeHierarchical(double d) {
this.k = Integer.MAX_VALUE; this.k = Integer.MAX_VALUE;
this.d = d; this.d = d;
} }
@Override @Override
public List<List<Template>> cluster(List<Template> templatesGallery) { public List<List<Template>> cluster(List<Template> templatesGallery) {
int numberOfTemplates = templatesGallery.size(); int numberOfTemplates = templatesGallery.size();
if (numberOfTemplates < k) { if (numberOfTemplates < k) {
System.out.println(numberOfTemplates + " < " + k); System.out.println(numberOfTemplates + " < " + k);
} }
List<List<Template>> clusters = new ArrayList(); List<List<Template>> clusters = new ArrayList();
List<Template> centroids = new ArrayList(); List<Template> centroids = new ArrayList();
for (Template template : templatesGallery) { for (Template template : templatesGallery) {
List<Template> cluster = new ArrayList(); List<Template> cluster = new ArrayList();
cluster.add(template); cluster.add(template);
clusters.add(cluster); clusters.add(cluster);
centroids.add(template); centroids.add(template);
} }
double minDistance = 0; double minDistance = 0;
while (clusters.size() > k && minDistance < d) { while (clusters.size() > k && minDistance < d) {
int numberOfClusters = clusters.size(); int numberOfClusters = clusters.size();
List<Template> clusterI = clusters.get(0); List<Template> clusterI = clusters.get(0);
List<Template> clusterJ = clusters.get(0); List<Template> clusterJ = clusters.get(0);
Template centroidI = centroids.get(0); Template centroidI = centroids.get(0);
Template centroidJ = centroids.get(0); Template centroidJ = centroids.get(0);
minDistance = Double.MAX_VALUE; minDistance = Double.MAX_VALUE;
for (int i = 0; i < numberOfClusters - 1; i++) { for (int i = 0; i < numberOfClusters - 1; i++) {
for (int j = i + 1; j < numberOfClusters; j++) { for (int j = i + 1; j < numberOfClusters; j++) {
double distance; double distance;
distance = getDistance(centroids.get(i), centroids.get(j)); distance = getDistance(centroids.get(i), centroids.get(j));
if (minDistance > distance) { if (minDistance > distance) {
minDistance = distance; minDistance = distance;
clusterI = clusters.get(i); clusterI = clusters.get(i);
clusterJ = clusters.get(j); clusterJ = clusters.get(j);
centroidI = centroids.get(i); centroidI = centroids.get(i);
centroidJ = centroids.get(j); centroidJ = centroids.get(j);
} }
} }
} }
clusters.remove(clusterI); clusters.remove(clusterI);
clusters.remove(clusterJ); clusters.remove(clusterJ);
centroids.remove(centroidI); centroids.remove(centroidI);
centroids.remove(centroidJ); centroids.remove(centroidJ);
clusterI.addAll(clusterJ); clusterI.addAll(clusterJ);
clusters.add(clusterI); clusters.add(clusterI);
centroids.add(getCentroid(clusterI)); centroids.add(getCentroid(clusterI));
} }
return clusters; return clusters;
} }
} }
package algorithms; package algorithms;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import objects.Template; import objects.Template;
public class ClusteringKmeans extends Clustering { public class ClustererKmeans extends Clusterer {
private final int k; private final int k;
public ClusteringKmeans(int k) { public ClustererKmeans(int k) {
this.k = k; this.k = k;
} }
@Override @Override
public List<List<Template>> cluster(List<Template> templatesGallery) { public List<List<Template>> cluster(List<Template> templatesGallery) {
int numberOfTemplates = templatesGallery.size(); int numberOfTemplates = templatesGallery.size();
if (numberOfTemplates < k) { if (numberOfTemplates < k) {
System.out.println(numberOfTemplates + " < " + k); System.out.println(numberOfTemplates + " < " + k);
return null; return null;
} }
Random random = new Random(); Random random = new Random();
List<Integer> seeds = new ArrayList(); List<Integer> seeds = new ArrayList();
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
int seed = random.nextInt(numberOfTemplates); int seed = random.nextInt(numberOfTemplates);
while (seeds.contains(seed)) { while (seeds.contains(seed)) {
seed = random.nextInt(numberOfTemplates); seed = random.nextInt(numberOfTemplates);
} }
seeds.add(seed); seeds.add(seed);
} }
List<Template> centroids = new ArrayList(); List<Template> centroids = new ArrayList();
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
centroids.add(templatesGallery.get(seeds.get(i))); centroids.add(templatesGallery.get(seeds.get(i)));
} }
List<List<Template>> clusters = new ArrayList(); List<List<Template>> clusters = new ArrayList();
boolean done = false; boolean done = false;
while (!done) {//CHANGE while (!done) {//CHANGE
done = true; done = true;
clusters = new ArrayList(); clusters = new ArrayList();
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
clusters.add(new ArrayList()); clusters.add(new ArrayList());
} }
for (Template template : templatesGallery) { for (Template template : templatesGallery) {
double minDistance = Double.MAX_VALUE; double minDistance = Double.MAX_VALUE;
int closestCentroidIndex = 0; int closestCentroidIndex = 0;
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
double distance = getDistance(template, centroids.get(i)); double distance = getDistance(template, centroids.get(i));
if (minDistance > distance) { if (minDistance > distance) {
minDistance = distance; minDistance = distance;
closestCentroidIndex = i; closestCentroidIndex = i;
} }
} }
clusters.get(closestCentroidIndex).add(template); clusters.get(closestCentroidIndex).add(template);
} }
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
Template centroidOld = centroids.get(i); Template centroidOld = centroids.get(i);
Template centroidNew = getCentroid(clusters.get(i)); Template centroidNew = getCentroid(clusters.get(i));
centroids.set(i, centroidNew); centroids.set(i, centroidNew);
if(!centroidOld.equals(centroidNew)){ if(!centroidOld.equals(centroidNew)){
done = false; done = false;
} }
} }
} }
return clusters; return clusters;
} }
} }
package algorithms; package algorithms;
import java.util.List; import java.util.List;
import objects.Template; import objects.Template;
public abstract class Decision extends Retriever { public abstract class Decision extends Retriever {
private List gallery; private List gallery;
public Decision() { public List getGallery() {
} return gallery;
}
public List getGallery() {
return gallery; public void setGallery(List gallery) {
} this.gallery = gallery;
}
public void setGallery(List gallery) {
this.gallery = gallery; public abstract String getDescription();
}
public abstract void importGallery(List<Template> templatesGallery);
public abstract String getDescription();
public abstract String decide(Template template);
public abstract void importGallery(List<Template> templatesGallery); }
public abstract String decide(Template template);
}
package algorithms; package algorithms;
import java.util.ArrayList; import java.util.List;
import java.util.List; import java.util.Random;
import java.util.Random; import objects.Template;
import objects.Template;
public class DecisionRandom extends Decision {
public class DecisionRandom extends Decision {
@Override
@Override public String getDescription() {
public String getDescription() { return "random";
return "random"; }
}
@Override
@Override public void importGallery(List<Template> templatesGallery) {
public void importGallery(List<Template> templatesGallery) { setGallery(templatesGallery);
List<String> gallery = new ArrayList(); }
for (Template template : templatesGallery) {
gallery.add(template.getSubject()); @Override
} public String decide(Template templateProbe) {
setGallery(gallery); List<Template> gallery = getGallery();
} return gallery.get(new Random().nextInt(gallery.size())).getSubject();
}
@Override }
public String decide(Template templateProbe) {
List<String> gallery = getGallery();
return gallery.get(new Random().nextInt(gallery.size()));
}
}
package algorithms;
import java.util.List;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import objects.Template;
public class DecisionSVM extends Decision {
@Override
public String getDescription() {
return "SVM";
}
@Override
public void importGallery(List<Template> templatesGallery) {
setGallery(templatesGallery);
}
@Override
public String decide(Template templateProbe) {
svm svm = new svm();
svm_model svm_model = new svm_model();
svm_problem svm_problem = new svm_problem();
return "";
}
double[][] train = new double[1000][];
double[][] test = new double[10][];
private svm_model svmTrain() {
svm_problem prob = new svm_problem();
int dataCount = train.length;
prob.y = new double[dataCount];
prob.l = dataCount;
prob.x = new svm_node[dataCount][];
for (int i = 0; i < dataCount; i++) {
double[] features = train[i];
prob.x[i] = new svm_node[features.length - 1];
for (int j = 1; j < features.length; j++) {
svm_node node = new svm_node();
node.index = j;
node.value = features[j];
prob.x[i][j - 1] = node;
}
prob.y[i] = features[0];
}
svm_parameter param = new svm_parameter();
param.probability = 1;
param.gamma = 0.5;
param.nu = 0.5;
param.C = 1;
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 20000;
param.eps = 0.001;
svm_model model = svm.svm_train(prob, param);
return model;
}
}
package algorithms; package algorithms;
import Jama.Matrix; import java.io.Serializable;
import java.io.Serializable; import objects.Template;
import objects.Template;
public abstract class DistanceTemplates implements Serializable {
public abstract class DistanceTemplates implements Serializable {
public abstract String getDescription();
private Matrix distanceTemplatesMatrix;
public abstract double getDistance(Template template1, Template template2);
public DistanceTemplates() { }
}
public Matrix getDistanceTemplatesMatrix() {
return distanceTemplatesMatrix;
}
public void setDistanceTemplatesMatrix(Matrix distanceTemplatesMatrix) {
this.distanceTemplatesMatrix = distanceTemplatesMatrix;
}
public abstract String getDescription();