package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.buckets.impl.MemoryStorageBucket; import messif.objects.LocalAbstractObject; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import messif.operations.query.KNNQueryOperation; import messif.statistics.Statistics; import java.io.Serializable; import java.util.ArrayList; import java.util.BitSet; import java.util.IntSummaryStatistics; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.function.BiFunction; import java.util.stream.Collectors; public class MHTree extends Algorithm implements Serializable { /** * Serialization ID */ private static final long serialVersionUID = 42L; /** * Minimal number of objects in leaf node's bucket. */ private final int leafCapacity; /** * Maximal degree of an internal node. */ private final int nodeDegree; private final Node root; private final InsertType insertType; private final ObjectToNodeDistance objectToNodeDistance; private final BucketDispatcher bucketDispatcher; @AlgorithmConstructor(description = "MH-Tree", arguments = { "MH-Tree builder object", }) private MHTree(Builder builder) { super("MH-Tree"); leafCapacity = builder.leafCapacity; nodeDegree = builder.nodeDegree; bucketDispatcher = builder.bucketDispatcher; insertType = builder.insertType; objectToNodeDistance = builder.objectToNodeDistance; root = builder.root; } public void approxKNN(ApproxKNNQueryOperation operation) { SearchState state = (SearchState) operation.suppData; if (state.done) { return; } long distanceComputations = (long) Statistics.getStatistics("DistanceComputations").getValue(); LocalAbstractObject queryObject = operation.getQueryObject(); long toThisIterationDistanceComputations = 0; if (state.queue == null || state.approxState == null) { state.queue = new PriorityQueue<>(); state.queue.add(new ObjectToNodeDistanceRank(queryObject, root)); state.approxState = ApproxState.create(operation, this); } else { toThisIterationDistanceComputations = state.approxState.getComputedDistances() - 1000; state.approxState.setComputedDistances(0); } while (!state.queue.isEmpty()) { if (state.approxState.stop()) { return; } Node node = state.queue.remove().getNode(); if (node.isLeaf()) { for (LocalAbstractObject object : node.getObjects()) { if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) { operation.addToAnswer(object); } } long changeInDistanceComputations = (long) Statistics.getStatistics("DistanceComputations").getValue() - distanceComputations; state.approxState.update( (LeafNode) node, toThisIterationDistanceComputations + changeInDistanceComputations); } else { for (Node child : ((InternalNode) node).getChildren()) { state.queue.add(new ObjectToNodeDistanceRank(queryObject, child)); } } } state.done = true; operation.endOperation(); } public void kNN(KNNQueryOperation operation) { root.getObjects().forEach(operation::addToAnswer); operation.endOperation(); } public int getObjectCount() { return bucketDispatcher.getObjectCount(); } public List<LocalAbstractObject> getObjects() { return root.getObjects(); } public void insert(InsertOperation operation) throws BucketStorageException { LocalAbstractObject object = operation.getInsertedObject(); Node node = root; while (!node.isLeaf()) { node.addObject(object); node = ((InternalNode) node).getNearestChild(object); } node.addObject(object); operation.endOperation(); } private List<Node> getNodes() { List<Node> nodes = new ArrayList<>(); root.gatherNodes(nodes); return nodes; } /** * Returns a list of leaf nodes. * * @return a list of leaf nodes */ public List<LeafNode> getLeafNodes() { List<LeafNode> leafNodes = new ArrayList<>(); root.gatherLeafNodes(leafNodes); return leafNodes; } public void printStatistics() { IntSummaryStatistics nodeHullObjects = getNodes() .stream() .mapToInt(Node::getHullObjectCount) .summaryStatistics(); IntSummaryStatistics leafNodeObjects = getLeafNodes() .stream() .mapToInt(LeafNode::getObjectCount) .summaryStatistics(); int numberOfObjects = bucketDispatcher .getAllBuckets() .stream() .mapToInt(LocalBucket::getObjectCount) .sum(); int numberOfNodes = getNodes().size(); System.out.println("Insert type: " + insertType); System.out.println("Height: " + root.getHeight()); System.out.println("Node degree: " + nodeDegree); System.out.println("Leaf object capacity: " + leafCapacity); System.out.println("Number of objects: " + numberOfObjects); System.out.println("Number of nodes: " + numberOfNodes); System.out.println("Number of internal nodes: " + (numberOfNodes - leafNodeObjects.getCount())); System.out.println("Number of leaf nodes: " + leafNodeObjects.getCount()); System.out.printf("Number of hull objects per node - min: %d, avg: %.2f, max: %d, sum: %d\n", nodeHullObjects.getMin(), nodeHullObjects.getAverage(), nodeHullObjects.getMax(), nodeHullObjects.getSum()); System.out.printf("Number of stored objects per leaf node - min: %d, avg: %.2f, max: %d\n", leafNodeObjects.getMin(), leafNodeObjects.getAverage(), leafNodeObjects.getMax()); } @Override public String toString() { return "MHTree{" + "leafCapacity=" + leafCapacity + ", nodeDegree=" + nodeDegree + ", insertType=" + insertType + ", objectToNodeDistance=" + objectToNodeDistance + '}'; } public static class Builder { /** * List of object used during the building of MH-Tree. */ private final List<LocalAbstractObject> objects; /** * Minimal number of objects in leaf node's bucket. */ private final int leafCapacity; /** * Maximal degree of internal node. */ private final int nodeDegree; /** * Specifies which method to use when adding a new object. */ private InsertType insertType; /** * Specifies how to measure distance between an object and a node. */ private ObjectToNodeDistance objectToNodeDistance; /** * A dispatcher for maintaining a set of local buckets. */ private BucketDispatcher bucketDispatcher; /** * Precomputed objects distances. */ private AbstractRepresentation.PrecomputedDistances objectDistances; /** * Stores leaf nodes and subsequently internal nodes. */ private Node[] nodes; /** * Identifies which indices in {@code nodes} are valid. */ private BitSet validNodeIndices; /** * Precomputed node distances. */ private PrecomputedNodeDistances nodeDistances; /** * Root of MH-Tree. */ private Node root; public Builder(List<LocalAbstractObject> objects, int leafCapacity, int nodeDegree) { this.objects = objects; this.leafCapacity = leafCapacity; this.nodeDegree = nodeDegree; this.insertType = InsertType.GREEDY; this.objectToNodeDistance = ObjectToNodeDistance.NEAREST; this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, MemoryStorageBucket.class, null); } public Builder insertType(InsertType insertType) { this.insertType = insertType; return this; } public Builder objectToNodeDistance(ObjectToNodeDistance objectToNodeDistance) { this.objectToNodeDistance = objectToNodeDistance; return this; } public Builder bucketDispatcher(BucketDispatcher bucketDispatcher) { this.bucketDispatcher = bucketDispatcher; return this; } public Builder bucketDispatcher(Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) { this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); return this; } public MHTree build() throws BucketStorageException { nodes = new Node[objects.size() / leafCapacity]; validNodeIndices = new BitSet(nodes.length); validNodeIndices.set(0, nodes.length); objectDistances = new AbstractRepresentation.PrecomputedDistances(objects); // Every object is stored in the root if (objectDistances.getObjectCount() <= leafCapacity) { root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance); return new MHTree(this); } createLeafNodes(leafCapacity); nodeDistances = new PrecomputedNodeDistances(); root = createRoot(nodeDegree); return new MHTree(this); } private Node createRoot(int arity) { while (validNodeIndices.cardinality() != 1) { BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone(); while (!notProcessedNodeIndices.isEmpty()) { if (notProcessedNodeIndices.cardinality() <= arity) { mergeNodes(notProcessedNodeIndices); break; } int furthestNodeIndex = nodeDistances.getFurthestIndex(notProcessedNodeIndices); notProcessedNodeIndices.clear(furthestNodeIndex); mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices)); } } return nodes[validNodeIndices.nextSetBit(0)]; } private void createLeafNodes(int leafCapacity) throws BucketStorageException { BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount()); notProcessedObjectIndices.set(0, objectDistances.getObjectCount()); for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) { if (notProcessedObjectIndices.cardinality() < leafCapacity) { for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) { LocalAbstractObject object = objectDistances.getObject(i); ((LeafNode) nodes[getClosestNodeIndex(object)]).addObject(object, objectDistances); } return; } List<Integer> objectIndices = new ArrayList<>(leafCapacity); // Select a base object int furthestIndex = Utils.maxDistanceIndex(objectDistances.getDistances(), notProcessedObjectIndices); notProcessedObjectIndices.clear(furthestIndex); objectIndices.add(furthestIndex); // Select the rest of the objects up to the total of leafCapacity objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices)); List<LocalAbstractObject> objects = objectIndices .stream() .map(objectDistances::getObject) .collect(Collectors.toList()); nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance); } } private int getClosestNodeIndex(LocalAbstractObject object) { double minDistance = Double.MAX_VALUE; int closestNodeIndex = -1; for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) { double distance = nodes[candidateIndex].getDistance(object, objectDistances); if (distance < minDistance) { minDistance = distance; closestNodeIndex = candidateIndex; } } return closestNodeIndex; } private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) { List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems); itemIndices.add(itemIndex); List<Integer> resultItemsIndices = new ArrayList<>(numberOfItems); while (resultItemsIndices.size() != numberOfItems) { int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices); itemIndices.add(index); resultItemsIndices.add(index); notProcessedItemIndices.clear(index); } return resultItemsIndices; } private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) { double minDistance = Double.MAX_VALUE; int closestNodeIndex = -1; for (int index : indices) { int candidateIndex = nodeDistances.getClosestIndex(index, validNodeIndices); float distance = nodeDistances.getDistance(index, candidateIndex); if (distance < minDistance) { minDistance = distance; closestNodeIndex = candidateIndex; } } return closestNodeIndex; } private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) { double minDistance = Double.MAX_VALUE; int closestObjectIndex = -1; for (int index : indices) { int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices); double distance = indices .stream() .mapToDouble(i -> objectDistances.getDistance(i, candidateIndex)) .sum(); if (distance < minDistance) { minDistance = distance; closestObjectIndex = candidateIndex; } } return closestObjectIndex; } /** * Merges nodes specified by indices in set state in {@code nodeIndices}. * The new node is placed on the first set index in {@code nodeIndices}. * * @param nodeIndices the bitset of nodes to be merged */ private void mergeNodes(BitSet nodeIndices) { List<Integer> indices = nodeIndices .stream() .boxed() .collect(Collectors.toList()); int parentNodeIndex = indices.remove(0); mergeNodes(parentNodeIndex, indices); } /** * Merges specified nodes into one and places the new node on the {@code parentNodeIndex} in {@code nodes}. * * @param parentNodeIndex an index where the new node is placed * @param nodeIndices specifies a list of indices which are merge with {@code parentNodeIndex} into a new node */ private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) { if (nodeIndices.size() == 0) return; nodeIndices.add(parentNodeIndex); List<Node> children = nodeIndices .stream() .map(i -> this.nodes[i]) .collect(Collectors.toList()); InternalNode parent = Node.createParent(children, objectDistances, insertType, objectToNodeDistance); nodeIndices.forEach(index -> { validNodeIndices.clear(index); this.nodes[index] = null; }); this.nodes[parentNodeIndex] = parent; validNodeIndices.set(parentNodeIndex); nodeDistances.updateNodeDistances(parentNodeIndex); } /** * {@code PrecomputedNodeDistances} contains methods for computing, updating, * and retrieving distance between nodes stored in {@code nodes}. */ private class PrecomputedNodeDistances { private final float[][] distances; PrecomputedNodeDistances() { distances = new float[nodes.length][nodes.length]; computeNodeDistances(); } /** * Returns precomputed distance between nodes on indices i and j in {@code nodes}. * * @param i an index of node in {@code nodes} * @param j an index of node in {@code nodes} * @return the distance between nodes on indices i and j in {@code nodes} */ private float getDistance(int i, int j) { return distances[i][j]; } private void updateNodeDistances(int nodeIndex) { validNodeIndices .stream() .forEach(index -> { float distance = computeDistanceBetweenNodes(nodeIndex, index); distances[nodeIndex][index] = distance; distances[index][nodeIndex] = distance; }); } private int getClosestIndex(int nodeIndex, BitSet notUsedIndexes) { return Utils.minDistanceIndex(distances[nodeIndex], notUsedIndexes); } private int getFurthestIndex(BitSet validIndices) { return Utils.maxDistanceIndex(distances, validIndices); } /** * Computes distances between nodes in {@code nodes}, storing the result in {@code distances}. */ private void computeNodeDistances() { for (int i = 0; i < nodes.length; i++) { for (int j = i + 1; j < nodes.length; j++) { float distance = computeDistanceBetweenNodes(i, j); distances[i][j] = distance; distances[j][i] = distance; } } } /** * Computes and returns the distance between nodes on indices i and j in {@code nodes}. * * @param i an index of node in {@code nodes} * @param j an index of node in {@code nodes} * @return the distance between nodes on indices i and j in {@code nodes}. */ private float computeDistanceBetweenNodes(int i, int j) { float distance = Float.MAX_VALUE; for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects()) for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects()) distance = Math.min(distance, objectDistances.getDistance(firstHullObject, secondHullObject)); return distance; } } } }