diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index f553cd974d87203908ef6a18733372ebfd885b33..cf0e98cef4259471616fc395f637ecf6199deafc 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -6,6 +6,7 @@ import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import java.util.*; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -24,16 +25,21 @@ class BuildTree { private final InsertType insertType; private final ObjectToNodeDistance objectToNodeDistance; private final BucketDispatcher bucketDispatcher; + private final BiFunction<Float, Float, Float> distanceSelector; private Node root; - BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException { + BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException { this.arity = arity; this.leafCapacity = leafCapacity; this.insertType = insertType; this.objectToNodeDistance = objectToNodeDistance; this.bucketDispatcher = bucketDispatcher; + // Set distance selector function for two hull objects + boolean selectNearest = nodeToNodeDistance == NodeToNodeDistance.NEAREST_HULL_OBJECTS; + distanceSelector = selectNearest ? Math::min : Math::max; + nodes = new Node[objects.size() / leafCapacity]; validNodeIndices = new BitSet(nodes.length); @@ -50,7 +56,7 @@ class BuildTree { createLeafNodes(); - precomputeNodeDistances(); + precomputeLeafNodeDistances(); buildTree(); } @@ -175,20 +181,20 @@ class BuildTree { }).collect(Collectors.toSet()); } - private void precomputeNodeDistances() { + private void precomputeLeafNodeDistances() { for (int i = 0; i < nodes.length; i++) { for (int j = i + 1; j < nodes.length; j++) { - float minDistance = Float.MAX_VALUE; + float distance = Float.MAX_VALUE; - for (LocalAbstractObject firstPoint : nodes[i].getObjects()) - for (LocalAbstractObject secondPoint : nodes[j].getObjects()) - minDistance = Math.min(minDistance, objectDistances.getDistance(firstPoint, secondPoint)); + for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects()) + for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects()) + distance = distanceSelector.apply(distance, objectDistances.getDistance(firstHullObject, secondHullObject)); - if (minDistance == 0f) + if (distance == 0f) throw new RuntimeException("Zero distance between " + nodes[i].toString() + " and " + nodes[j].toString()); - nodeDistances[i][j] = minDistance; - nodeDistances[j][i] = minDistance; + nodeDistances[i][j] = distance; + nodeDistances[j][i] = distance; } } } @@ -219,13 +225,13 @@ class BuildTree { if (nodeIndices.size() == 0) return; validNodeIndices.stream().forEach(i -> { - float minDistance = nodeDistances[baseNodeIndex][i]; + float distance = nodeDistances[baseNodeIndex][i]; for (int index : nodeIndices) - minDistance = Math.min(minDistance, nodeDistances[index][i]); + distance = distanceSelector.apply(distance, nodeDistances[index][i]); - nodeDistances[baseNodeIndex][i] = minDistance; - nodeDistances[i][baseNodeIndex] = minDistance; + nodeDistances[baseNodeIndex][i] = distance; + nodeDistances[i][baseNodeIndex] = distance; }); } } diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index a053329a68b35a07dcfec58ae4d923589613e235..00a43ca9ae3f3399bbad21a1ba7ffc1b857e3db6 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -29,6 +29,7 @@ public class MHTree extends Algorithm implements Serializable { private final BucketDispatcher bucketDispatcher; private final InsertType insertType; private final ObjectToNodeDistance objectToNodeDistance; + private final NodeToNodeDistance nodeToNodeDistance; private final StatisticCounter statVisitedLeaves = StatisticCounter.getStatistics("Node.Leaf.Visited"); @@ -42,17 +43,18 @@ public class MHTree extends Algorithm implements Serializable { "storage class for buckets", "storage class parameters" }) - public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { + public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { super("MH-Tree"); this.leafCapacity = leafCapacity; this.arity = arity; this.insertType = insertType; this.objectToNodeDistance = objectToNodeDistance; + this.nodeToNodeDistance = nodeToNodeDistance; bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - root = new BuildTree(objects, leafCapacity, arity, insertType, objectToNodeDistance, bucketDispatcher).getRoot(); + root = new BuildTree(objects, leafCapacity, arity, insertType, objectToNodeDistance, nodeToNodeDistance, bucketDispatcher).getRoot(); } public void approxKNN(ApproxKNNQueryOperation operation, double coefficient) { @@ -175,6 +177,7 @@ public class MHTree extends Algorithm implements Serializable { ", root=" + root + ", insertType" + insertType + ", objectToNodeDistance" + objectToNodeDistance + + ", nodeToNodeDistance" + nodeToNodeDistance + '}'; } } diff --git a/src/mhtree/NodeToNodeDistance.java b/src/mhtree/NodeToNodeDistance.java new file mode 100644 index 0000000000000000000000000000000000000000..391d9a1778d4a065182111699d9f279bcbccdaf4 --- /dev/null +++ b/src/mhtree/NodeToNodeDistance.java @@ -0,0 +1,6 @@ +package mhtree; + +public enum NodeToNodeDistance { + NEAREST_HULL_OBJECTS, + FURTHEST_HULL_OBJECTS, +}