Skip to content
Snippets Groups Projects
Verified Commit 56035f59 authored by David Procházka's avatar David Procházka
Browse files

ADD: new type of recall measurement, statistics calculation, mtree

parent 3820ccde
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@
<element id="module-output" name="mh-tree" />
<element id="extracted-dir" path="$PROJECT_DIR$/jars/messif.jar" path-in-jar="/" />
<element id="extracted-dir" path="$PROJECT_DIR$/jars/similarityoperators.jar" path-in-jar="/" />
<element id="extracted-dir" path="$PROJECT_DIR$/jars/mtree.jar" path-in-jar="/" />
</root>
</artifact>
</component>
\ No newline at end of file
......@@ -9,5 +9,6 @@
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="messif" level="project" />
<orderEntry type="library" name="similarityoperators" level="project" />
<orderEntry type="library" name="mtree" level="project" />
</component>
</module>
\ No newline at end of file
......@@ -8,14 +8,14 @@ import java.util.List;
public class MHTreeConfig {
public final int leafCapacity;
public final int numberOfChildren;
public final int nodeDegree;
public final InsertType insertType;
public final ObjectToNodeDistance objectToNodeDistance;
public List<LocalAbstractObject> objects;
MHTreeConfig(int leafCapacity, int numberOfChildren, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) {
MHTreeConfig(int leafCapacity, int nodeDegree, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) {
this.leafCapacity = leafCapacity;
this.numberOfChildren = numberOfChildren;
this.nodeDegree = nodeDegree;
this.insertType = insertType;
this.objectToNodeDistance = objectToNodeDistance;
}
......@@ -28,8 +28,9 @@ public class MHTreeConfig {
public String toString() {
return "MHTreeConfig{" +
"leafCapacity=" + leafCapacity +
", numberOfChildren=" + numberOfChildren +
", numberOfChildren=" + nodeDegree +
", insertType=" + insertType +
", objectToNodeDistance=" + objectToNodeDistance +
'}';
}
}
......@@ -71,4 +71,31 @@ public class PerformanceMeasures {
return trueCount / (double) knnQueryOperation.getAnswerCount();
}
// comparing done based on distances, counts how many of the same distances of KNNQueryOperation were presents in the answer of ApproxKNNQueryOperation
public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, HashMap<String, List<RankedAbstractObject>> kNNs) throws NoSuchMethodException, AlgorithmMethodException {
if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d;
Map<Float, Long> frequencyMap = kNNs.get(approxKNNQueryOperation.getQueryObject().getLocatorURI()).subList(0, approxKNNQueryOperation.getK())
.parallelStream()
.map(DistanceRankedObject::getDistance)
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
long trueCount = 0;
for (RankedAbstractObject approxObject : approxKNNQueryOperation) {
float distance = approxObject.getDistance();
if (frequencyMap.containsKey(distance)) {
long count = frequencyMap.get(distance);
if (count == 1) {
frequencyMap.remove(distance);
} else {
frequencyMap.replace(distance, count - 1);
}
trueCount++;
}
}
return trueCount / (double) approxKNNQueryOperation.getK();
}
}
......@@ -3,102 +3,181 @@ package mhtree.benchmarking;
import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
import messif.algorithms.AlgorithmMethodException;
import messif.buckets.BucketStorageException;
import messif.buckets.impl.MemoryStorageBucket;
import messif.objects.LocalAbstractObject;
import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
import messif.objects.util.AbstractObjectList;
import messif.objects.util.RankedAbstractObject;
import messif.objects.util.StreamGenericAbstractObjectIterator;
import messif.operations.data.BulkInsertOperation;
import messif.operations.Approximate;
import messif.operations.query.ApproxKNNQueryOperation;
import messif.operations.query.KNNQueryOperation;
import mhtree.InsertType;
import mhtree.MHTree;
import mhtree.ObjectToNodeDistance;
import mtree.MTree;
import mhtree.SearchState;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public class RunBenchmark {
private static final String dataset = "./data/middle_c.txt";
public static void main(String[] args) throws IOException {
List<LocalAbstractObject> objects = loadDataset(args.length == 1 ? args[0] : RunBenchmark.dataset);
public static void main(String[] args) throws IOException, NoSuchMethodException, AlgorithmMethodException, BucketStorageException, ClassNotFoundException {
if (!(args.length == 0 || args.length == 5)) {
throw new IllegalArgumentException("unexpected number of params");
}
AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16;
new BenchmarkConfig(
Arrays.asList(
new MHTreeConfig(50, 10, InsertType.GREEDY, ObjectToNodeDistance.NEAREST)
List<LocalAbstractObject> objects;
if (args.length != 5) {
objects = loadDataset(RunBenchmark.dataset);
DCtoRecall(
new MHTreeConfig(
10,
2,
InsertType.GREEDY,
ObjectToNodeDistance.NEAREST
),
objects,
new int[]{1, 50, 100}
);
return;
}
objects = loadDataset(args[0]);
int leafCapacity = Integer.parseInt(args[1]);
int nodeDegree = Integer.parseInt(args[2]);
InsertType insertType = args[3].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
ObjectToNodeDistance objectToNodeDistance;
switch (args[4]) {
case "FURTHEST":
objectToNodeDistance = ObjectToNodeDistance.FURTHEST;
break;
case "AVERAGE":
objectToNodeDistance = ObjectToNodeDistance.AVERAGE;
break;
default:
objectToNodeDistance = ObjectToNodeDistance.NEAREST;
break;
}
// Collections.shuffle(objects);
DCtoRecall(new MHTreeConfig(
leafCapacity,
nodeDegree,
insertType,
objectToNodeDistance
),
Arrays.asList(
(MHTreeConfig config) -> finalBench(config, new int[]{10}, 2)
)).execute(objects);
objects,
new int[]{1, 50}
);
}
private static void finalBench(MHTreeConfig config, int[] ks, int buildInsertRatio) {
List<LocalAbstractObject> objects = config.objects;
// int buildObjectsCount = objects.size() - objects.size() / buildInsertRatio;
//
// List<LocalAbstractObject> buildObjects = objects.subList(0, buildObjectsCount);
// List<LocalAbstractObject> insertObjects = objects.subList(buildObjectsCount, objects.size());
//
// System.out.println(objects.size() + " objects, build on " + buildObjects.size() + ", inserted " + insertObjects.size());
MHTree mhtree = measure("MH-Tree build", (Supplier<MHTree>) () -> {
try {
return new MHTree.Builder(objects, config.leafCapacity, config.numberOfChildren)
.insertType(config.insertType)
.objectToNodeDistance(config.objectToNodeDistance)
.bucketDispatcher(MemoryStorageBucket.class, null)
.build();
} catch (BucketStorageException e) {
e.printStackTrace();
return null;
}
});
mhtree.printStatistics();
System.out.println();
BulkInsertOperation bulkInsertOperation = new BulkInsertOperation(objects);
MTree mtree = measure("M-Tree build", (Supplier<MTree>) () -> {
try {
MTree mTree = new MTree(config.numberOfChildren,
2L * config.leafCapacity,
MemoryStorageBucket.class,
null);
mTree.insert(bulkInsertOperation);
return mTree;
} catch (AlgorithmMethodException | InstantiationException e) {
e.printStackTrace();
return null;
}
});
private static void DCtoRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, BucketStorageException, RuntimeException {
MHTree mhTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree)
.insertType(config.insertType)
.objectToNodeDistance(config.objectToNodeDistance)
.build();
mtree.printStatistics();
}
mhTree.printStatistics();
private static <T extends Supplier<U>, U> U measure(String what, T f) {
U result;
int maxK = Arrays.stream(ks)
.max()
.orElseThrow(() -> new RuntimeException("bad max k"));
long timeStart = System.currentTimeMillis();
result = f.get();
long timeEnd = System.currentTimeMillis();
// Precomputed NNs for each object
HashMap<String, List<RankedAbstractObject>> URIToKNNs = new HashMap<>(objects.size());
System.out.println(what + " took " + (timeEnd - timeStart) + " ms");
for (LocalAbstractObject object : objects) {
KNNQueryOperation knnQueryOperation = new KNNQueryOperation(object, maxK);
mhTree.executeOperation(knnQueryOperation);
return result;
}
List<RankedAbstractObject> rankedKNNObjects = new ArrayList<>(knnQueryOperation.getAnswerCount());
for (RankedAbstractObject knnObj : knnQueryOperation) {
rankedKNNObjects.add(knnObj);
}
URIToKNNs.put(object.getLocatorURI(), rankedKNNObjects);
}
System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,distanceComputations,recall (min),recall (avg),recall (med),recall (max)");
double minRecall = 0;
long distanceComputations = 0;
for (int k : ks) {
List<ApproxKNNQueryOperation> operations = objects
.stream()
.map(object -> new ApproxKNNQueryOperation(
object,
k,
1000,
Approximate.LocalSearchType.ABS_DC_COUNT,
LocalAbstractObject.UNKNOWN_DISTANCE))
.collect(Collectors.toList());
operations.forEach(operation -> operation.suppData = new SearchState(null, null, false));
System.out.println(
String.join(",",
String.valueOf(config.leafCapacity),
String.valueOf(config.nodeDegree),
String.valueOf(config.objectToNodeDistance),
String.valueOf(k),
String.valueOf(distanceComputations),
String.format("%.2f,%.2f,%.2f,%.2f", 0f, 0f, 0f, 0f)
)
);
distanceComputations += 1000;
List<Double> recalls = new ArrayList<>(operations.size());
for (int i = 0; i < operations.size(); i++) {
recalls.add(0d);
}
private static <T extends Runnable> void measure(String what, T f) {
long timeStart = System.currentTimeMillis();
f.run();
long timeEnd = System.currentTimeMillis();
while (minRecall != 1) {
for (int i = 0; i < operations.size(); i++) {
if (recalls.get(i) != 1d) {
ApproxKNNQueryOperation operation = operations.get(i);
mhTree.executeOperation(operation);
recalls.set(i, PerformanceMeasures.measureRecall(operation, URIToKNNs));
}
}
Stats recallStats = new Stats(new ArrayList<>(recalls));
System.out.println(String.join(",",
String.valueOf(config.leafCapacity),
String.valueOf(config.nodeDegree),
String.valueOf(config.objectToNodeDistance),
String.valueOf(k),
String.valueOf(distanceComputations),
String.format("%.2f,%.2f,%.2f,%.2f",
recallStats.getMin(),
recallStats.getAverage(),
recallStats.getMedian(),
recallStats.getMax())));
minRecall = recallStats.getMin();
distanceComputations += 1000;
}
System.out.println(what + " took " + (timeEnd - timeStart) + " ms");
minRecall = 0;
distanceComputations = 0;
}
}
private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
......
package mhtree.benchmarking;
import java.util.Collections;
import java.util.List;
public class Stats {
List<Double> sortedData;
Stats(List<Double> data) {
sortedData = data;
Collections.sort(sortedData);
}
public double getMin() {
return sortedData.get(0);
}
public double getAverage() {
return sortedData.stream().mapToDouble(x -> x).sum() / sortedData.size();
}
/**
* @return Lower median
*/
public double getMedian() {
return sortedData.get((sortedData.size() + 1) / 2 - 1);
}
public double getMax() {
return sortedData.get(sortedData.size() - 1);
}
@Override
public String toString() {
return String.format("%.2f %.2f(%.2f) %.2f", getMin(), getAverage(), getMedian(), getMax());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment