From 74f6858c653952b3835d70eff5502bbf99268d4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Mon, 15 Mar 2021 17:24:29 +0100 Subject: [PATCH] ADD: precomputed node distance javadoc, builder pattern in MH-Tree --- src/mhtree/MHTree.java | 522 ++++++++++++++++++++++++++++++++++------- src/mhtree/Node.java | 76 +++--- src/mhtree/Utils.java | 95 +++++--- 3 files changed, 536 insertions(+), 157 deletions(-) diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 5bf4ae3..80db3a7 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -1,19 +1,25 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; import messif.buckets.BucketErrorCode; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; +import messif.buckets.impl.MemoryStorageBucket; import messif.objects.LocalAbstractObject; -import messif.operations.Approximate; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; -import messif.operations.query.KNNQueryOperation; -import messif.statistics.StatisticCounter; import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.IntSummaryStatistics; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.function.BiFunction; +import java.util.stream.Collectors; public class MHTree extends Algorithm implements Serializable { @@ -22,64 +28,56 @@ public class MHTree extends Algorithm implements Serializable { */ private static final long serialVersionUID = 42L; - private final int leafCapacity; - private final int arity; + /** + * Minimal number of objects in leaf node's bucket. + */ + private final int LEAF_CAPACITY; + + /** + * Maximal degree of internal node. + */ + private final int NODE_DEGREE; private final Node root; - private final BucketDispatcher bucketDispatcher; private final InsertType insertType; private final ObjectToNodeDistance objectToNodeDistance; - private final NodeToNodeDistance nodeToNodeDistance; - - private final StatisticCounter statVisitedLeaves = StatisticCounter.getStatistics("Node.Leaf.Visited"); + private final BucketDispatcher bucketDispatcher; @AlgorithmConstructor(description = "MH-Tree", arguments = { - "list of objects", - "number of objects in a leaf node", - "arity", - "insert type", - "object to node distance type", - "node to node distance type", - "storage class for buckets", - "storage class parameters" + "MH-Tree builder object", }) - public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { + private MHTree(Builder builder) { super("MH-Tree"); - this.leafCapacity = leafCapacity; - this.arity = arity; - this.insertType = insertType; - this.objectToNodeDistance = objectToNodeDistance; - this.nodeToNodeDistance = nodeToNodeDistance; + LEAF_CAPACITY = builder.LEAF_CAPACITY; + NODE_DEGREE = builder.NODE_DEGREE; - bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); + bucketDispatcher = builder.bucketDispatcher; + insertType = builder.insertType; + objectToNodeDistance = builder.objectToNodeDistance; - root = new BuildTree(objects, leafCapacity, arity, insertType, objectToNodeDistance, nodeToNodeDistance, bucketDispatcher).getRoot(); + root = builder.root; } - public void approxKNN(ApproxKNNQueryOperation operation, double coefficient) { + public void approxKNN(ApproxKNNQueryOperation operation) { LocalAbstractObject queryObject = operation.getQueryObject(); - boolean limitVisitedLeaves = operation.getLocalSearchType() == Approximate.LocalSearchType.DATA_PARTITIONS; - PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>(); queue.add(new ObjectToNodeDistanceRank(queryObject, root)); while (!queue.isEmpty()) { Node node = queue.poll().getNode(); - if (node.isLeaf()) { - if (limitVisitedLeaves && statVisitedLeaves.get() == operation.getLocalSearchParam()) - break; - - statVisitedLeaves.add(); + if (operation.isAnswerFull() && isPrunable(node, queryObject, operation)) + continue; + if (node.isLeaf()) { for (LocalAbstractObject object : node.getObjects()) if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) operation.addToAnswer(object); } else { for (Node child : ((InternalNode) node).getChildren()) - if (!operation.isAnswerFull() || !isPrunable(child, queryObject, operation, coefficient)) + if (!operation.isAnswerFull() || !isPrunable(child, queryObject, operation)) queue.add(new ObjectToNodeDistanceRank(queryObject, child)); } } @@ -87,8 +85,8 @@ public class MHTree extends Algorithm implements Serializable { operation.endOperation(); } - private boolean isPrunable(Node child, LocalAbstractObject queryObject, ApproxKNNQueryOperation operation, double coefficient) { - return operation.getAnswerDistance() * coefficient < child.getNearestDistance(queryObject); + private boolean isPrunable(Node child, LocalAbstractObject queryObject, ApproxKNNQueryOperation operation) { + return operation.getAnswerDistance() < child.getDistanceToNearest(queryObject); } public void insert(InsertOperation operation) throws BucketStorageException { @@ -107,76 +105,430 @@ public class MHTree extends Algorithm implements Serializable { operation.endOperation(BucketErrorCode.OBJECT_INSERTED); } + private List<Node> getNodes() { + List<Node> nodes = new ArrayList<>(); + root.gatherNodes(nodes); + return nodes; + } + + private int getNumberOfInternalNodes() { + if (root.isLeaf()) return 0; + + return ((InternalNode) root).getInternalNodesCount(); + } + + /** + * Returns a list of leaf nodes. + * + * @return a list of leaf nodes + */ + private List<LeafNode> getLeafNodes() { + List<LeafNode> leafNodes = new ArrayList<>(); + root.gatherLeafNodes(leafNodes); + return leafNodes; + } + public void printStatistics() { - IntSummaryStatistics leafStatistics = bucketDispatcher - .getAllBuckets() + IntSummaryStatistics nodeHullObjects = getNodes() .stream() - .mapToInt(LocalBucket::getObjectCount) + .mapToInt(Node::getHullObjectCount) .summaryStatistics(); - System.out.println("--- STATISTICS ---"); + IntSummaryStatistics leafNodeObjects = getLeafNodes() + .stream() + .mapToInt(LeafNode::getObjectCount) + .summaryStatistics(); - System.out.println("- MH-Tree -"); + int numberOfObjects = bucketDispatcher + .getAllBuckets() + .stream() + .mapToInt(LocalBucket::getObjectCount) + .sum(); System.out.println("Insert type: " + insertType); System.out.println("Height: " + root.getHeight()); - System.out.println("Arity: " + arity); - System.out.println("Leaf capacity: " + leafCapacity + " objects"); + System.out.println("Node degree: " + NODE_DEGREE); + System.out.println("Leaf object capacity: " + LEAF_CAPACITY); + + System.out.println("Number of objects: " + numberOfObjects); + System.out.println("Number of nodes: " + nodeHullObjects.getCount()); + System.out.println("Number of internal nodes: " + getNumberOfInternalNodes()); + System.out.println("Number of leaf nodes: " + leafNodeObjects.getCount()); + + System.out.printf("Number of hull objects per node - min: %d, avg: %.2f, max: %d, sum: %d\n", + nodeHullObjects.getMin(), + nodeHullObjects.getAverage(), + nodeHullObjects.getMax(), + nodeHullObjects.getSum()); + + System.out.printf("Number of stored objects per leaf node - min: %d, avg: %.2f, max: %d\n", + leafNodeObjects.getMin(), + leafNodeObjects.getAverage(), + leafNodeObjects.getMax()); + } + + @Override + public String toString() { + return "MHTree{" + + "leafCapacity=" + LEAF_CAPACITY + + ", nodeDegree=" + NODE_DEGREE + + ", insertType=" + insertType + + ", objectToNodeDistance=" + objectToNodeDistance + + '}'; + } - System.out.println("Number of nodes in each level:"); - for (int level = 1; level <= root.getHeight() + 1; level++) { - System.out.println("- Level " + level + " -> " + root.getNodesOnLevel(level).size()); + public static class Builder { + + /** + * List of object used during the building of MH-Tree. + */ + private final List<LocalAbstractObject> objects; + + /** + * Minimal number of objects in leaf node's bucket. + */ + private final int LEAF_CAPACITY; + + /** + * Maximal degree of internal node. + */ + private final int NODE_DEGREE; + + /** + * Specifies which method to use when adding a new object. + */ + private InsertType insertType; + + /** + * Specifies how to measure distance between an object and a node. + */ + private ObjectToNodeDistance objectToNodeDistance; + + /** + * A dispatcher for maintaining a set of local buckets. + */ + private BucketDispatcher bucketDispatcher; + + /** + * Precomputed objects distances. + */ + private AbstractRepresentation.PrecomputedDistances objectDistances; + + /** + * Stores leaf nodes and subsequently internal nodes. + */ + private Node[] nodes; + + /** + * Identifies which indices in {@code nodes} are valid. + */ + private BitSet validNodeIndices; + + /** + * Precomputed node distances. + */ + private PrecomputedNodeDistances nodeDistances; + + /** + * Root of MH-Tree. + */ + private Node root; + + public Builder(List<LocalAbstractObject> objects, int leafCapacity, int nodeDegree) { + this.objects = objects; + this.LEAF_CAPACITY = leafCapacity; + this.NODE_DEGREE = nodeDegree; + + this.insertType = InsertType.GREEDY; + this.objectToNodeDistance = ObjectToNodeDistance.NEAREST; + this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, LEAF_CAPACITY, 0, false, MemoryStorageBucket.class, null); } - System.out.println("Number of hull objects in each level:"); - for (int level = 1; level <= root.getHeight() + 1; level++) { - List<Node> levelNodes = root.getNodesOnLevel(level); + public Builder insertType(InsertType insertType) { + this.insertType = insertType; + return this; + } - System.out.println("- Level " + level + " -> " + levelNodes.stream().mapToInt(n -> n.getHullObjects().size()).summaryStatistics()); + public Builder objectToNodeDistance(ObjectToNodeDistance objectToNodeDistance) { + this.objectToNodeDistance = objectToNodeDistance; + return this; } - System.out.println("\nHistogram of covered objects per level: "); - System.out.println(Histogram.generate(root)); + public Builder bucketDispatcher(BucketDispatcher bucketDispatcher) { + this.bucketDispatcher = bucketDispatcher; + return this; + } - System.out.println("- LeafNodes -"); + public Builder bucketDispatcher(Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) { + this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, LEAF_CAPACITY, 0, false, defaultBucketClass, bucketClassParams); + return this; + } - System.out.println("Count: " + leafStatistics.getCount()); - System.out.println("Minimum number of objects: " + leafStatistics.getMin()); - System.out.println("Average number of objects: " + String.format("%.2f", leafStatistics.getAverage())); - System.out.println("Maximum number of objects: " + leafStatistics.getMax()); - } + public MHTree build() throws BucketStorageException { + nodes = new Node[objects.size() / LEAF_CAPACITY]; - public long getVisitedLeaves() { - return statVisitedLeaves.get(); - } + validNodeIndices = new BitSet(nodes.length); + validNodeIndices.set(0, nodes.length); - public void reset() { - statVisitedLeaves.reset(); - } + objectDistances = new AbstractRepresentation.PrecomputedDistances(objects); - public double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation) { - LocalAbstractObject queryObject = approxKNNQueryOperation.getQueryObject(); - int k = approxKNNQueryOperation.getK(); + // Every object is stored in the root + if (objectDistances.getObjectCount() <= LEAF_CAPACITY) { + root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance); + return new MHTree(this); + } - KNNQueryOperation knnQueryOperation = new KNNQueryOperation(queryObject, k); - KNN(knnQueryOperation); + createLeafNodes(LEAF_CAPACITY); - return Utils.measureRecall(approxKNNQueryOperation, knnQueryOperation); - } + nodeDistances = new PrecomputedNodeDistances(); - private void KNN(KNNQueryOperation operation) { - root.getObjects().forEach(operation::addToAnswer); - operation.endOperation(); - } + root = createRoot(NODE_DEGREE); - @Override - public String toString() { - return "MHTree{" + - "leafCapacity=" + leafCapacity + - ", arity=" + arity + - ", insertType=" + insertType + - ", objectToNodeDistance=" + objectToNodeDistance + - ", nodeToNodeDistance=" + nodeToNodeDistance + - '}'; + return new MHTree(this); + } + + private Node createRoot(int arity) { + while (validNodeIndices.cardinality() != 1) { + BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone(); + + while (!notProcessedNodeIndices.isEmpty()) { + if (notProcessedNodeIndices.cardinality() <= arity) { + mergeNodes(notProcessedNodeIndices); + break; + } + + int furthestNodeIndex = nodeDistances.getFurthestIndex(notProcessedNodeIndices); + notProcessedNodeIndices.clear(furthestNodeIndex); + + mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices)); + } + } + + return nodes[validNodeIndices.nextSetBit(0)]; + } + + private void createLeafNodes(int leafCapacity) throws BucketStorageException { + BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount()); + notProcessedObjectIndices.set(0, objectDistances.getObjectCount()); + + for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) { + if (notProcessedObjectIndices.cardinality() < leafCapacity) { + for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) { + LocalAbstractObject object = objectDistances.getObject(i); + + nodes[getClosestNodeIndex(object)].addObject(object); + } + + return; + } + + List<Integer> objectIndices = new ArrayList<>(leafCapacity); + + // Select a base object + int furthestIndex = Utils.maxDistanceIndex(objectDistances.getDistances(), notProcessedObjectIndices); + notProcessedObjectIndices.clear(furthestIndex); + objectIndices.add(furthestIndex); + + // Select the rest of the objects up to the total of leafCapacity + objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices)); + + List<LocalAbstractObject> objects = objectIndices + .stream() + .map(objectDistances::getObject) + .collect(Collectors.toList()); + + nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance); + } + } + + private int getClosestNodeIndex(LocalAbstractObject object) { + double minDistance = Double.MAX_VALUE; + int closestNodeIndex = -1; + + for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) { + double distance = nodes[candidateIndex].getDistance(object); + + if (distance < minDistance) { + minDistance = distance; + closestNodeIndex = candidateIndex; + } + } + + return closestNodeIndex; + } + + private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) { + List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems); + itemIndices.add(itemIndex); + + List<Integer> resultItemsIndices = new ArrayList<>(numberOfItems); + + while (resultItemsIndices.size() != numberOfItems) { + int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices); + + itemIndices.add(index); + resultItemsIndices.add(index); + + notProcessedItemIndices.clear(index); + } + + return resultItemsIndices; + } + + private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) { + double minDistance = Double.MAX_VALUE; + int closestNodeIndex = -1; + + for (int index : indices) { + int candidateIndex = nodeDistances.getClosestIndex(index, validNodeIndices); + float distance = nodeDistances.getDistance(index, candidateIndex); + + if (distance < minDistance) { + minDistance = distance; + closestNodeIndex = candidateIndex; + } + } + + return closestNodeIndex; + } + + private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) { + double minDistance = Double.MAX_VALUE; + int closestObjectIndex = -1; + + for (int index : indices) { + int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices); + double distance = indices + .stream() + .mapToDouble(i -> objectDistances.getDistance(i, candidateIndex)) + .sum(); + + if (distance < minDistance) { + minDistance = distance; + closestObjectIndex = candidateIndex; + } + } + + return closestObjectIndex; + } + + /** + * Merges nodes specified by indices in set state in {@code nodeIndices}. + * The new node is placed on the first set index in {@code nodeIndices}. + * + * @param nodeIndices + */ + private void mergeNodes(BitSet nodeIndices) { + List<Integer> indices = nodeIndices + .stream() + .boxed() + .collect(Collectors.toList()); + int parentNodeIndex = indices.remove(0); + mergeNodes(parentNodeIndex, indices); + } + + /** + * Merges specified nodes into one and places the new node on the {@code parentNodeIndex} in {@code nodes}. + * + * @param parentNodeIndex an index where the new node is placed + * @param nodeIndices specifies a list of indices which are merge with {@code parentNodeIndex} into a new node + */ + private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) { + if (nodeIndices.size() == 0) return; + + nodeIndices.add(parentNodeIndex); + + List<Node> children = nodeIndices + .stream() + .map(i -> this.nodes[i]) + .collect(Collectors.toList()); + + InternalNode parent = Node.createParent(children, objectDistances, insertType, objectToNodeDistance); + + nodeIndices.forEach(index -> { + validNodeIndices.clear(index); + this.nodes[index] = null; + }); + + this.nodes[parentNodeIndex] = parent; + validNodeIndices.set(parentNodeIndex); + + nodeDistances.updateNodeDistances(parentNodeIndex); + } + + /** + * {@code PrecomputedNodeDistances} contains methods for computing, updating, + * and retrieving distance between nodes stored in {@code nodes}. + */ + private class PrecomputedNodeDistances { + private final float[][] distances; + + PrecomputedNodeDistances() { + distances = new float[nodes.length][nodes.length]; + + computeNodeDistances(); + } + + /** + * Returns precomputed distance between nodes on indices i and j in {@code nodes}. + * + * @param i an index of node in {@code nodes} + * @param j an index of node in {@code nodes} + * @return the distance between nodes on indices i and j in {@code nodes} + */ + private float getDistance(int i, int j) { + return distances[i][j]; + } + + private void updateNodeDistances(int nodeIndex) { + validNodeIndices + .stream() + .forEach(index -> { + float distance = computeDistanceBetweenNodes(nodeIndex, index); + + distances[nodeIndex][index] = distance; + distances[index][nodeIndex] = distance; + }); + } + + private int getClosestIndex(int nodeIndex, BitSet notUsedIndexes) { + return Utils.minDistanceIndex(distances[nodeIndex], notUsedIndexes); + } + + private int getFurthestIndex(BitSet validIndices) { + return Utils.maxDistanceIndex(distances, validIndices); + } + + /** + * Computes distances between nodes in {@code nodes}, storing the result in {@code distances}. + */ + private void computeNodeDistances() { + for (int i = 0; i < nodes.length; i++) { + for (int j = i + 1; j < nodes.length; j++) { + float distance = computeDistanceBetweenNodes(i, j); + + distances[i][j] = distance; + distances[j][i] = distance; + } + } + } + + /** + * Computes and returns the distance between nodes on indices i and j in {@code nodes}. + * + * @param i an index of node in {@code nodes} + * @param j an index of node in {@code nodes} + * @return the distance between nodes on indices i and j in {@code nodes}. + */ + private float computeDistanceBetweenNodes(int i, int j) { + float distance = Float.MAX_VALUE; + + for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects()) + for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects()) + distance = Math.min(distance, objectDistances.getDistance(firstHullObject, secondHullObject)); + + return distance; + } + } } } diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 95904be..b89d9bc 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -9,78 +9,78 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Set; import java.util.stream.Collectors; -public abstract class Node implements Serializable { +abstract class Node implements Serializable { /** * Serialization ID */ private static final long serialVersionUID = 420L; - protected final InsertType insertType; - private final ObjectToNodeDistance objectToNodeDistance; - protected HullOptimizedRepresentationV3 hull; - protected Node parent; + + private final InsertType INSERT_TYPE; + private final ObjectToNodeDistance OBJECT_TO_NODE_DISTANCE; + + private HullOptimizedRepresentationV3 hull; Node(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { this.hull = new HullOptimizedRepresentationV3(distances); this.hull.build(); - this.insertType = insertType; - this.objectToNodeDistance = objectToNodeDistance; + + this.INSERT_TYPE = insertType; + this.OBJECT_TO_NODE_DISTANCE = objectToNodeDistance; } - public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { - List<LocalAbstractObject> objects = nodes.stream() + @Override + public String toString() { + return "Node{hull=" + hull + '}'; + } + + static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { + List<LocalAbstractObject> objects = nodes + .stream() .map(Node::getObjects) .flatMap(Collection::stream) .collect(Collectors.toList()); - return new InternalNode(distances.getSubset(objects), insertType, objectToNodeDistance); - } - - public double getDistance(LocalAbstractObject object) { - return objectToNodeDistance.getDistance(object, this); + return new InternalNode(distances.getSubset(objects), insertType, objectToNodeDistance, nodes); } - public double getNearestDistance(LocalAbstractObject object) { - return ObjectToNodeDistance.NEAREST_HULL_OBJECT.getDistance(object, this); + double getDistance(LocalAbstractObject object) { + return OBJECT_TO_NODE_DISTANCE.getDistance(object, this); } - public boolean isCovered(LocalAbstractObject object) { - return hull.isExternalCovered(object); + double getDistanceToNearest(LocalAbstractObject object) { + return ObjectToNodeDistance.NEAREST.getDistance(object, this); } - public boolean isLeaf() { + boolean isLeaf() { return (this instanceof LeafNode); } - public void setParent(Node parent) { - this.parent = parent; + boolean isInternal() { + return !isLeaf(); } - public List<LocalAbstractObject> getHullObjects() { + List<LocalAbstractObject> getHullObjects() { return hull.getHull(); } - public int getLevel() { - return parent == null ? 1 : parent.getLevel() + 1; + int getHullObjectCount() { + return hull.getRepresentativesCount(); } - @Override - public String toString() { - return "Node{hull=" + hull + '}'; - } + abstract void addObject(LocalAbstractObject object) throws BucketStorageException; - public abstract void addObject(LocalAbstractObject object) throws BucketStorageException; + abstract List<LocalAbstractObject> getObjects(); - public abstract Set<LocalAbstractObject> getObjects(); + abstract int getHeight(); - public abstract int getHeight(); + abstract void gatherNodes(List<Node> nodes); - public abstract List<Node> getNodesOnLevel(int level); + abstract void gatherLeafNodes(List<LeafNode> leafNodes); - protected void rebuildHull(LocalAbstractObject object) { + private void rebuildHull(LocalAbstractObject object) { List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); objects.add(object); @@ -88,14 +88,18 @@ public abstract class Node implements Serializable { hull.build(); } - protected void addNewObject(LocalAbstractObject object) { + void addObjectIntoHull(LocalAbstractObject object) { if (isCovered(object)) return; - if (insertType == InsertType.INCREMENTAL) { + if (INSERT_TYPE == InsertType.INCREMENTAL) { hull.addHullObject(object); return; } rebuildHull(object); } + + private boolean isCovered(LocalAbstractObject object) { + return hull.isExternalCovered(object); + } } diff --git a/src/mhtree/Utils.java b/src/mhtree/Utils.java index b453edb..2cd8b8b 100644 --- a/src/mhtree/Utils.java +++ b/src/mhtree/Utils.java @@ -1,46 +1,69 @@ package mhtree; -import messif.objects.util.DistanceRankedObject; -import messif.objects.util.RankedAbstractObject; -import messif.operations.query.ApproxKNNQueryOperation; -import messif.operations.query.KNNQueryOperation; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.stream.Collectors; +import java.util.BitSet; +import java.util.function.BiPredicate; public class Utils { - // comparing done based on distances, counts how many of the same distances of KNNQueryOperation were presents in the answer of ApproxKNNQueryOperation - public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, KNNQueryOperation knnQueryOperation) { - if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d; - if (knnQueryOperation.getAnswerCount() == 0) return -1d; - - List<RankedAbstractObject> objects = new ArrayList<>(knnQueryOperation.getAnswerCount()); - for (RankedAbstractObject object : knnQueryOperation) - objects.add(object); - - Map<Float, Long> frequencyMap = objects.parallelStream() - .map(DistanceRankedObject::getDistance) - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); - - long trueCount = 0; - - for (RankedAbstractObject approxObject : approxKNNQueryOperation) { - float distance = approxObject.getDistance(); - if (frequencyMap.containsKey(distance)) { - long count = frequencyMap.get(distance); - if (count == 1) { - frequencyMap.remove(distance); - } else { - frequencyMap.replace(distance, count - 1); - } - trueCount++; + /** + * Returns the index from {@code validIndices} with the maximum value in {@code distanceMatrix}. + * + * @param distanceMatrix a distance matrix + * @param validIndices valid indices in {@code distanceMatrix} + * @return the index from {@code validIndices} + */ + public static int maxDistanceIndex(float[][] distanceMatrix, BitSet validIndices) { + float maxDistance = Float.MIN_VALUE; + int furthestIndex = validIndices.nextSetBit(0); + + while (true) { + float[] distances = distanceMatrix[furthestIndex]; + int candidateIndex = maxDistanceIndex(distances, validIndices); + + if (!(distances[candidateIndex] > maxDistance)) { + return furthestIndex; + } + + maxDistance = distances[candidateIndex]; + furthestIndex = candidateIndex; + } + } + + /** + * @param distances + * @param validIndices + * @return an index from validIndices with minimal distance in distances + */ + public static int minDistanceIndex(float[] distances, BitSet validIndices) { + return getDistanceIndex(distances, validIndices, (minDistance, newDistance) -> minDistance > newDistance); + } + + /** + * @param distances + * @param validIndices + * @return an index from validIndices with maximal distance in distances + */ + private static int maxDistanceIndex(float[] distances, BitSet validIndices) { + return getDistanceIndex(distances, validIndices, (maxDistance, newDistance) -> maxDistance < newDistance); + } + + /** + * @param distances + * @param validIndices + * @param comparator specifies when to update the value of distance and index + * @return + */ + private static int getDistanceIndex(float[] distances, BitSet validIndices, BiPredicate<Float, Float> comparator) { + int index = -1; + float minDistance = Float.MAX_VALUE; + + for (int i = validIndices.nextSetBit(0); i >= 0; i = validIndices.nextSetBit(i + 1)) { + if (index == -1 || comparator.test(minDistance, distances[i])) { + minDistance = distances[i]; + index = i; } } - return trueCount / (double) knnQueryOperation.getAnswerCount(); + return index; } } -- GitLab