Skip to content
Snippets Groups Projects
MHTree.java 20.2 KiB
Newer Older
import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
import messif.algorithms.Algorithm;
import messif.buckets.BucketDispatcher;
import messif.buckets.BucketStorageException;
import messif.buckets.LocalBucket;
import messif.buckets.impl.MemoryStorageBucket;
import messif.objects.LocalAbstractObject;
import messif.operations.data.InsertOperation;
import messif.operations.query.ApproxKNNQueryOperation;
import messif.operations.query.KNNQueryOperation;
import messif.statistics.Statistics;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.IntSummaryStatistics;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

public class MHTree extends Algorithm implements Serializable {

    /**
     * Serialization ID
     */
    private static final long serialVersionUID = 42L;

    /**
     * Minimal number of objects in leaf node's bucket.
     */
     * Maximal degree of an internal node.
    private final InsertType insertType;
    private final ObjectToNodeDistance objectToNodeDistance;
    private final BucketDispatcher bucketDispatcher;

    @AlgorithmConstructor(description = "MH-Tree", arguments = {
    private MHTree(Builder builder) {
        leafCapacity = builder.leafCapacity;
        nodeDegree = builder.nodeDegree;
        bucketDispatcher = builder.bucketDispatcher;
        insertType = builder.insertType;
        objectToNodeDistance = builder.objectToNodeDistance;
    public void approxKNN(ApproxKNNQueryOperation operation) {
        SearchState state = (SearchState) operation.suppData;
        long distanceComputations = (long) Statistics.getStatistics("DistanceComputations").getValue();
        LocalAbstractObject queryObject = operation.getQueryObject();
        long toThisIterationDistanceComputations = 0;
        if (state.queue == null || state.approxState == null) {
            state.queue = new PriorityQueue<>();
            state.queue.add(new ObjectToNodeDistanceRank(queryObject, root));
            state.approxState = ApproxState.create(operation, this);
        } else {
            toThisIterationDistanceComputations = state.approxState.getComputedDistances() - 1000;
            state.approxState.setComputedDistances(0);
        }
        while (!state.queue.isEmpty()) {
            if (state.approxState.stop()) {
                return;
            }
            Node node = state.queue.remove().getNode();
                for (LocalAbstractObject object : node.getObjects()) {
                    if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) {
                        operation.addToAnswer(object);
                long changeInDistanceComputations = (long) Statistics.getStatistics("DistanceComputations").getValue() - distanceComputations;

                state.approxState.update(
                        (LeafNode) node,
                        toThisIterationDistanceComputations + changeInDistanceComputations);
                for (Node child : ((InternalNode) node).getChildren()) {
                    state.queue.add(new ObjectToNodeDistanceRank(queryObject, child));
                }
    public void kNN(KNNQueryOperation operation) {
        root.getObjects().forEach(operation::addToAnswer);
        operation.endOperation();
    }

    public int getObjectCount() {
        return bucketDispatcher.getObjectCount();
    }

    public List<LocalAbstractObject> getObjects() {
        return root.getObjects();
    }

    public void insert(InsertOperation operation) throws BucketStorageException {
        LocalAbstractObject object = operation.getInsertedObject();
        Node node = root;
        while (!node.isLeaf()) {
            node.addObject(object);
            node = ((InternalNode) node).getNearestChild(object);
        node.addObject(object);
    private List<Node> getNodes() {
        List<Node> nodes = new ArrayList<>();
        root.gatherNodes(nodes);
        return nodes;
    }

    /**
     * Returns a list of leaf nodes.
     *
     * @return a list of leaf nodes
     */
    public List<LeafNode> getLeafNodes() {
        List<LeafNode> leafNodes = new ArrayList<>();
        root.gatherLeafNodes(leafNodes);
        return leafNodes;
    }

    public void printStatistics() {
        IntSummaryStatistics nodeHullObjects = getNodes()
                .stream()
                .mapToInt(Node::getHullObjectCount)
                .summaryStatistics();
        IntSummaryStatistics leafNodeObjects = getLeafNodes()
                .stream()
                .mapToInt(LeafNode::getObjectCount)
                .summaryStatistics();
        int numberOfObjects = bucketDispatcher
                .getAllBuckets()
                .stream()
                .mapToInt(LocalBucket::getObjectCount)
                .sum();
        int numberOfNodes = getNodes().size();

        System.out.println("Insert type: " + insertType);
        System.out.println("Height: " + root.getHeight());
        System.out.println("Node degree: " + nodeDegree);
        System.out.println("Leaf object capacity: " + leafCapacity);

        System.out.println("Number of objects: " + numberOfObjects);
        System.out.println("Number of nodes: " + numberOfNodes);
        System.out.println("Number of internal nodes: " + (numberOfNodes - leafNodeObjects.getCount()));
        System.out.println("Number of leaf nodes: " + leafNodeObjects.getCount());

        System.out.printf("Number of hull objects per node - min: %d, avg: %.2f, max: %d, sum: %d\n",
                nodeHullObjects.getMin(),
                nodeHullObjects.getAverage(),
                nodeHullObjects.getMax(),
                nodeHullObjects.getSum());

        System.out.printf("Number of stored objects per leaf node - min: %d, avg: %.2f, max: %d\n",
                leafNodeObjects.getMin(),
                leafNodeObjects.getAverage(),
                leafNodeObjects.getMax());
    }

    @Override
    public String toString() {
        return "MHTree{" +
                "leafCapacity=" + leafCapacity +
                ", nodeDegree=" + nodeDegree +
                ", insertType=" + insertType +
                ", objectToNodeDistance=" + objectToNodeDistance +
                '}';
    }
    public static class Builder {

        /**
         * List of object used during the building of MH-Tree.
         */
        private final List<LocalAbstractObject> objects;

        /**
         * Minimal number of objects in leaf node's bucket.
         */

        /**
         * Specifies which method to use when adding a new object.
         */
        private InsertType insertType;

        /**
         * Specifies how to measure distance between an object and a node.
         */
        private ObjectToNodeDistance objectToNodeDistance;

        /**
         * A dispatcher for maintaining a set of local buckets.
         */
        private BucketDispatcher bucketDispatcher;

        /**
         * Precomputed objects distances.
         */
        private AbstractRepresentation.PrecomputedDistances objectDistances;

        /**
         * Stores leaf nodes and subsequently internal nodes.
         */
        private Node[] nodes;

        /**
         * Identifies which indices in {@code nodes} are valid.
         */
        private BitSet validNodeIndices;

        /**
         * Precomputed node distances.
         */
        private PrecomputedNodeDistances nodeDistances;

        /**
         * Root of MH-Tree.
         */
        private Node root;

        public Builder(List<LocalAbstractObject> objects, int leafCapacity, int nodeDegree) {
            this.objects = objects;
            this.leafCapacity = leafCapacity;
            this.nodeDegree = nodeDegree;

            this.insertType = InsertType.GREEDY;
            this.objectToNodeDistance = ObjectToNodeDistance.NEAREST;
            this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, MemoryStorageBucket.class, null);
        public Builder insertType(InsertType insertType) {
            this.insertType = insertType;
            return this;
        }
        public Builder objectToNodeDistance(ObjectToNodeDistance objectToNodeDistance) {
            this.objectToNodeDistance = objectToNodeDistance;
            return this;
        public Builder bucketDispatcher(BucketDispatcher bucketDispatcher) {
            this.bucketDispatcher = bucketDispatcher;
            return this;
        }
        public Builder bucketDispatcher(Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) {
            this.bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams);
        public MHTree build() throws BucketStorageException {
            nodes = new Node[objects.size() / leafCapacity];
            validNodeIndices = new BitSet(nodes.length);
            validNodeIndices.set(0, nodes.length);
            objectDistances = new AbstractRepresentation.PrecomputedDistances(objects);
            // Every object is stored in the root
            if (objectDistances.getObjectCount() <= leafCapacity) {
                root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
                return new MHTree(this);
            }
            nodeDistances = new PrecomputedNodeDistances();
            return new MHTree(this);
        }

        private Node createRoot(int arity) {
            while (validNodeIndices.cardinality() != 1) {
                BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone();

                while (!notProcessedNodeIndices.isEmpty()) {
                    if (notProcessedNodeIndices.cardinality() <= arity) {
                        mergeNodes(notProcessedNodeIndices);
                        break;
                    }

                    int furthestNodeIndex = nodeDistances.getFurthestIndex(notProcessedNodeIndices);
                    notProcessedNodeIndices.clear(furthestNodeIndex);

                    mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices));
                }
            }

            return nodes[validNodeIndices.nextSetBit(0)];
        }

        private void createLeafNodes(int leafCapacity) throws BucketStorageException {
            BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount());
            notProcessedObjectIndices.set(0, objectDistances.getObjectCount());

            for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) {
                if (notProcessedObjectIndices.cardinality() < leafCapacity) {
                    for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) {
                        LocalAbstractObject object = objectDistances.getObject(i);
                        ((LeafNode) nodes[getClosestNodeIndex(object)]).addObject(object, objectDistances);
                    }

                    return;
                }

                List<Integer> objectIndices = new ArrayList<>(leafCapacity);

                // Select a base object
                int furthestIndex = Utils.maxDistanceIndex(objectDistances.getDistances(), notProcessedObjectIndices);
                notProcessedObjectIndices.clear(furthestIndex);
                objectIndices.add(furthestIndex);

                // Select the rest of the objects up to the total of leafCapacity
                objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices));

                List<LocalAbstractObject> objects = objectIndices
                        .stream()
                        .map(objectDistances::getObject)
                        .collect(Collectors.toList());

                nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
            }
        }

        private int getClosestNodeIndex(LocalAbstractObject object) {
            double minDistance = Double.MAX_VALUE;
            int closestNodeIndex = -1;

            for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) {
                double distance = nodes[candidateIndex].getDistance(object, objectDistances);

                if (distance < minDistance) {
                    minDistance = distance;
                    closestNodeIndex = candidateIndex;
                }
            }

            return closestNodeIndex;
        }

        private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) {
            List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems);
            itemIndices.add(itemIndex);

            List<Integer> resultItemsIndices = new ArrayList<>(numberOfItems);

            while (resultItemsIndices.size() != numberOfItems) {
                int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices);

                itemIndices.add(index);
                resultItemsIndices.add(index);

                notProcessedItemIndices.clear(index);
            }

            return resultItemsIndices;
        }

        private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) {
            double minDistance = Double.MAX_VALUE;
            int closestNodeIndex = -1;

            for (int index : indices) {
                int candidateIndex = nodeDistances.getClosestIndex(index, validNodeIndices);
                float distance = nodeDistances.getDistance(index, candidateIndex);

                if (distance < minDistance) {
                    minDistance = distance;
                    closestNodeIndex = candidateIndex;
                }
            }

            return closestNodeIndex;
        }

        private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) {
            double minDistance = Double.MAX_VALUE;
            int closestObjectIndex = -1;

            for (int index : indices) {
                int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices);
                double distance = indices
                        .stream()
                        .mapToDouble(i -> objectDistances.getDistance(i, candidateIndex))
                        .sum();

                if (distance < minDistance) {
                    minDistance = distance;
                    closestObjectIndex = candidateIndex;
                }
            }

            return closestObjectIndex;
        }

        /**
         * Merges nodes specified by indices in set state in {@code nodeIndices}.
         * The new node is placed on the first set index in {@code nodeIndices}.
         *
         * @param nodeIndices the bitset of nodes to be merged
         */
        private void mergeNodes(BitSet nodeIndices) {
            List<Integer> indices = nodeIndices
                    .stream()
                    .boxed()
                    .collect(Collectors.toList());
            int parentNodeIndex = indices.remove(0);
            mergeNodes(parentNodeIndex, indices);
        }

        /**
         * Merges specified nodes into one and places the new node on the {@code parentNodeIndex} in {@code nodes}.
         *
         * @param parentNodeIndex an index where the new node is placed
         * @param nodeIndices     specifies a list of indices which are merge with {@code parentNodeIndex} into a new node
         */
        private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) {
            if (nodeIndices.size() == 0) return;

            nodeIndices.add(parentNodeIndex);

            List<Node> children = nodeIndices
                    .stream()
                    .map(i -> this.nodes[i])
                    .collect(Collectors.toList());

            InternalNode parent = Node.createParent(children, objectDistances, insertType, objectToNodeDistance);

            nodeIndices.forEach(index -> {
                validNodeIndices.clear(index);
                this.nodes[index] = null;
            });

            this.nodes[parentNodeIndex] = parent;
            validNodeIndices.set(parentNodeIndex);

            nodeDistances.updateNodeDistances(parentNodeIndex);
        }

        /**
         * {@code PrecomputedNodeDistances} contains methods for computing, updating,
         * and retrieving distance between nodes stored in {@code nodes}.
         */
        private class PrecomputedNodeDistances {
            private final float[][] distances;

            PrecomputedNodeDistances() {
                distances = new float[nodes.length][nodes.length];

                computeNodeDistances();
            }

            /**
             * Returns precomputed distance between nodes on indices i and j in {@code nodes}.
             *
             * @param i an index of node in {@code nodes}
             * @param j an index of node in {@code nodes}
             * @return the distance between nodes on indices i and j in {@code nodes}
             */
            private float getDistance(int i, int j) {
                return distances[i][j];
            }

            private void updateNodeDistances(int nodeIndex) {
                validNodeIndices
                        .stream()
                        .forEach(index -> {
                            float distance = computeDistanceBetweenNodes(nodeIndex, index);

                            distances[nodeIndex][index] = distance;
                            distances[index][nodeIndex] = distance;
                        });
            }

            private int getClosestIndex(int nodeIndex, BitSet notUsedIndexes) {
                return Utils.minDistanceIndex(distances[nodeIndex], notUsedIndexes);
            }

            private int getFurthestIndex(BitSet validIndices) {
                return Utils.maxDistanceIndex(distances, validIndices);
            }

            /**
             * Computes distances between nodes in {@code nodes}, storing the result in {@code distances}.
             */
            private void computeNodeDistances() {
                for (int i = 0; i < nodes.length; i++) {
                    for (int j = i + 1; j < nodes.length; j++) {
                        float distance = computeDistanceBetweenNodes(i, j);

                        distances[i][j] = distance;
                        distances[j][i] = distance;
                    }
                }
            }

            /**
             * Computes and returns the distance between nodes on indices i and j in {@code nodes}.
             *
             * @param i an index of node in {@code nodes}
             * @param j an index of node in {@code nodes}
             * @return the distance between nodes on indices i and j in {@code nodes}.
             */
            private float computeDistanceBetweenNodes(int i, int j) {
                float distance = Float.MAX_VALUE;

                for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects())
                    for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects())
                        distance = Math.min(distance, objectDistances.getDistance(firstHullObject, secondHullObject));

                return distance;
            }
        }