Skip to content
Snippets Groups Projects
Verified Commit dd40600b authored by David Procházka's avatar David Procházka
Browse files

ADD: M-Tree benchmark logic

parent 9d974aac
No related branches found
No related tags found
No related merge requests found
...@@ -18,8 +18,8 @@ public abstract class Node implements Serializable { ...@@ -18,8 +18,8 @@ public abstract class Node implements Serializable {
*/ */
private static final long serialVersionUID = 420L; private static final long serialVersionUID = 420L;
private final InsertType INSERT_TYPE; private final InsertType insertType;
private final ObjectToNodeDistance OBJECT_TO_NODE_DISTANCE; private final ObjectToNodeDistance objectToNodeDistance;
private HullOptimizedRepresentationV3 hull; private HullOptimizedRepresentationV3 hull;
...@@ -27,8 +27,8 @@ public abstract class Node implements Serializable { ...@@ -27,8 +27,8 @@ public abstract class Node implements Serializable {
this.hull = new HullOptimizedRepresentationV3(distances); this.hull = new HullOptimizedRepresentationV3(distances);
this.hull.build(); this.hull.build();
this.INSERT_TYPE = insertType; this.insertType = insertType;
this.OBJECT_TO_NODE_DISTANCE = objectToNodeDistance; this.objectToNodeDistance = objectToNodeDistance;
} }
protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergeType mergeType) { protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergeType mergeType) {
...@@ -70,11 +70,11 @@ public abstract class Node implements Serializable { ...@@ -70,11 +70,11 @@ public abstract class Node implements Serializable {
} }
protected double getDistance(LocalAbstractObject object) { protected double getDistance(LocalAbstractObject object) {
return OBJECT_TO_NODE_DISTANCE.getDistance(object, this); return objectToNodeDistance.getDistance(object, this);
} }
protected double getDistance(LocalAbstractObject object, PrecomputedDistances distances) { protected double getDistance(LocalAbstractObject object, PrecomputedDistances distances) {
return OBJECT_TO_NODE_DISTANCE.getDistance(object, this, distances); return objectToNodeDistance.getDistance(object, this, distances);
} }
protected double getDistanceToNearest(LocalAbstractObject object) { protected double getDistanceToNearest(LocalAbstractObject object) {
...@@ -100,7 +100,7 @@ public abstract class Node implements Serializable { ...@@ -100,7 +100,7 @@ public abstract class Node implements Serializable {
protected void addObjectIntoHull(LocalAbstractObject object, PrecomputedDistances distances) { protected void addObjectIntoHull(LocalAbstractObject object, PrecomputedDistances distances) {
if (isCovered(object, distances)) return; if (isCovered(object, distances)) return;
if (INSERT_TYPE == InsertType.INCREMENTAL) { if (insertType == InsertType.INCREMENTAL) {
hull.addHullObject(object); hull.addHullObject(object);
return; return;
} }
......
package mhtree.benchmarking; package mhtree.benchmarking;
import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
import messif.algorithms.AlgorithmMethodException;
import messif.buckets.BucketStorageException; import messif.buckets.BucketStorageException;
import messif.objects.LocalAbstractObject; import messif.objects.LocalAbstractObject;
import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
...@@ -8,6 +9,7 @@ import messif.objects.util.AbstractObjectList; ...@@ -8,6 +9,7 @@ import messif.objects.util.AbstractObjectList;
import messif.objects.util.RankedAbstractObject; import messif.objects.util.RankedAbstractObject;
import messif.objects.util.StreamGenericAbstractObjectIterator; import messif.objects.util.StreamGenericAbstractObjectIterator;
import messif.operations.Approximate; import messif.operations.Approximate;
import messif.operations.data.BulkInsertOperation;
import messif.operations.query.ApproxKNNQueryOperation; import messif.operations.query.ApproxKNNQueryOperation;
import messif.operations.query.KNNQueryOperation; import messif.operations.query.KNNQueryOperation;
import messif.statistics.Statistics; import messif.statistics.Statistics;
...@@ -15,16 +17,18 @@ import mhtree.InsertType; ...@@ -15,16 +17,18 @@ import mhtree.InsertType;
import mhtree.MHTree; import mhtree.MHTree;
import mhtree.MergeType; import mhtree.MergeType;
import mhtree.ObjectToNodeDistance; import mhtree.ObjectToNodeDistance;
import mtree.MTree;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class RunBenchmark { public class RunBenchmark {
public static void main(String[] args) throws IOException, BucketStorageException { public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException {
if (args.length != 5) { if (args.length != 5) {
throw new IllegalArgumentException("Unexpected number of params"); throw new IllegalArgumentException("Unexpected number of params");
} }
...@@ -50,7 +54,7 @@ public class RunBenchmark { ...@@ -50,7 +54,7 @@ public class RunBenchmark {
break; break;
} }
percentageToRecall(new MHTreeConfig( percentageToRecallMHTree(new MHTreeConfig(
leafCapacity, leafCapacity,
nodeDegree, nodeDegree,
insertType, insertType,
...@@ -61,7 +65,7 @@ public class RunBenchmark { ...@@ -61,7 +65,7 @@ public class RunBenchmark {
); );
} }
private static void percentageToRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException { private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException {
MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree) MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree)
.objectToNodeDistance(config.objectToNodeDistance) .objectToNodeDistance(config.objectToNodeDistance)
.mergeType(MergeType.REPRESENTATION_BASED) .mergeType(MergeType.REPRESENTATION_BASED)
...@@ -145,6 +149,72 @@ public class RunBenchmark { ...@@ -145,6 +149,72 @@ public class RunBenchmark {
} }
} }
private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException {
int numberOfObjects = objects.size();
MTree mTree = new MTree(config.nodeDegree, config.leafCapacity);
Collections.shuffle(objects);
BulkInsertOperation op = new BulkInsertOperation(objects);
mTree.insert(op);
mTree.printStatistics();
System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)");
double minimalRecall = 0;
int percentage = 0;
int percentageStep = 5;
for (int k : ks) {
List<Double> recalls = new ArrayList<>(numberOfObjects);
for (int i = 0; i < numberOfObjects; i++) {
recalls.add(0.0);
}
while (minimalRecall != 1.0) {
for (int i = 0; i < numberOfObjects; i++) {
if (recalls.get(i) != 1.0) {
ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation(
objects.get(i),
k,
percentage,
Approximate.LocalSearchType.PERCENTAGE,
LocalAbstractObject.UNKNOWN_DISTANCE
);
mTree.executeOperation(operation);
recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree));
}
}
Stats recallStats = new Stats(new ArrayList<>(recalls));
System.out.println(String.join(",",
String.valueOf(config.leafCapacity),
String.valueOf(config.nodeDegree),
String.valueOf(config.objectToNodeDistance),
String.valueOf(k),
String.valueOf(percentage),
String.format("%.2f,%.2f,%.2f,%.2f",
recallStats.getMin(),
recallStats.getAverage(),
recallStats.getMedian(),
recallStats.getMax())));
minimalRecall = recallStats.getMin();
percentage += percentageStep;
}
minimalRecall = 0;
percentage = 0;
}
}
private static List<LocalAbstractObject> loadDataset(String path) throws IOException { private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path)); return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path));
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment