From 42fb9595f0788decf2c4462f8b5f798bfd87950e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 30 Dec 2020 12:38:04 +0100 Subject: [PATCH] ADD: LocalBucket into LeafNode --- src/mhtree/BuildTree.java | 26 ++++++++++++++++-------- src/mhtree/BuildTreeApp.java | 6 ++++-- src/mhtree/InternalNode.java | 1 + src/mhtree/LeafNode.java | 25 ++++++++++++++++------- src/mhtree/MHTree.java | 14 ++++++++++--- src/mhtree/Node.java | 3 ++- src/mhtree/ObjectToNodeDistanceRank.java | 1 + 7 files changed, 54 insertions(+), 22 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index e1be946..fbfd69f 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -1,6 +1,9 @@ package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import messif.buckets.BucketDispatcher; +import messif.buckets.BucketStorageException; +import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; import java.util.*; @@ -18,10 +21,11 @@ class BuildTree { private final AbstractRepresentation.PrecomputedDistances objectDistances; private final float[][] nodeDistances; - private Node root; + private final BucketDispatcher bucketDispatcher; + private Node root; - BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren) { + BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren, BucketDispatcher bucketDispatcher) throws BucketStorageException { this.numberOfChildren = numberOfChildren; this.leafCapacity = leafCapacity; @@ -33,6 +37,8 @@ class BuildTree { objectDistances = new AbstractRepresentation.PrecomputedDistances(objects); nodeDistances = new float[nodes.length][nodes.length]; + this.bucketDispatcher = bucketDispatcher; + buildTree(); } @@ -40,7 +46,7 @@ class BuildTree { return root; } - private void buildTree() { + private void buildTree() throws BucketStorageException { if (!initHullPoints()) return; precomputeHullDistances(); @@ -78,9 +84,9 @@ class BuildTree { root = nodes[validNodeIndices.nextSetBit(0)]; } - private boolean initHullPoints() { + private boolean initHullPoints() throws BucketStorageException { if (objectDistances.getObjectCount() < leafCapacity) { - root = new LeafNode(objectDistances); + root = new LeafNode(objectDistances, bucketDispatcher.createBucket()); return false; } @@ -89,7 +95,8 @@ class BuildTree { for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) { if (notProcessedObjectIndices.cardinality() < leafCapacity) { - notProcessedObjectIndices.stream().forEach(this::addObjectToClosestNode); + for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) + addObjectToClosestNode(i); return true; } @@ -104,19 +111,20 @@ class BuildTree { findClosestObjectIndices(furthestIndex, leafCapacity - 1, notProcessedObjectIndices) .forEach(i -> objects.add(objectDistances.getObject(i))); - nodes[nodeIndex] = new LeafNode(objects); + nodes[nodeIndex] = new LeafNode(objects, bucketDispatcher.createBucket()); } return true; } - private void addObjectToClosestNode(int objectIndex) { + private void addObjectToClosestNode(int objectIndex) throws BucketStorageException { int[] hullObjectIndices = getEveryHullObjectIndex(); int closestHullObjectIndex = hullObjectIndices[getNNIndex(objectIndex, hullObjectIndices)]; int nodeIndex = findCorrespondingHullIndex(objectDistances.getObject(closestHullObjectIndex)); - nodes[nodeIndex] = LeafNode.addObject((LeafNode) nodes[nodeIndex], objectDistances.getObject(objectIndex)); + LeafNode node = (LeafNode) nodes[nodeIndex]; + node.addObject(objectDistances.getObject(objectIndex)); } private int getNNIndex(int centerIndex, int[] dataIndices) { diff --git a/src/mhtree/BuildTreeApp.java b/src/mhtree/BuildTreeApp.java index d9eeda7..98c4281 100644 --- a/src/mhtree/BuildTreeApp.java +++ b/src/mhtree/BuildTreeApp.java @@ -1,5 +1,7 @@ package mhtree; +import messif.buckets.BucketStorageException; +import messif.buckets.impl.MemoryStorageBucket; import messif.objects.AbstractObject; import messif.objects.LocalAbstractObject; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; @@ -26,7 +28,7 @@ public class BuildTreeApp { AbstractStreamObjectIterator<ObjectFloatVectorNeuralNetworkL2> iter = new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, arg); AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter); - MHTree tree = new MHTree(objects, 10, 5); + MHTree tree = new MHTree(objects, 10, 5, MemoryStorageBucket.class, null); for (LocalAbstractObject object : objects) { @@ -47,7 +49,7 @@ public class BuildTreeApp { } } - } catch (IOException ex) { + } catch (IOException | BucketStorageException ex) { Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); } } diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index d30976e..762bee3 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.stream.Stream; public class InternalNode extends Node implements Serializable { + /** * Serialization ID */ diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index 072de8e..f3dd645 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -1,6 +1,9 @@ package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import messif.buckets.BucketStorageException; +import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; import java.io.Serializable; @@ -8,23 +11,31 @@ import java.util.ArrayList; import java.util.List; public class LeafNode extends Node implements Serializable { + /** * Serialization ID */ private static final long serialVersionUID = 1L; + private LocalBucket bucket; - LeafNode(List<LocalAbstractObject> objects) { + LeafNode(List<LocalAbstractObject> objects, LocalBucket bucket) throws BucketStorageException { super(objects); + + this.bucket = bucket; + this.bucket.addObjects(objects); } - LeafNode(AbstractRepresentation.PrecomputedDistances distances) { - super(distances); + LeafNode(AbstractRepresentation.PrecomputedDistances distances, LocalBucket bucket) throws BucketStorageException { + this(distances.getObjects(), bucket); } - public static LeafNode addObject(LeafNode node, LocalAbstractObject object) { - List<LocalAbstractObject> objects = new ArrayList<>(node.getObjects()); - objects.add(object); + public void addObject(LocalAbstractObject object) throws BucketStorageException { + bucket.addObject(object); - return new LeafNode(objects); + if (isCovered(object)) return; + + List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); + objects.add(object); + hull = new HullOptimizedRepresentationV3(objects); } } diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 742285a..094410a 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -1,12 +1,16 @@ package mhtree; import messif.algorithms.Algorithm; +import messif.buckets.BucketDispatcher; +import messif.buckets.BucketStorageException; +import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import java.io.Serializable; import java.util.List; +import java.util.Map; import java.util.PriorityQueue; public class MHTree extends Algorithm implements Serializable { @@ -20,19 +24,23 @@ public class MHTree extends Algorithm implements Serializable { private final int numberOfChildren; private final Node root; + private final BucketDispatcher bucketDispatcher; @AlgorithmConstructor(description = "MH-Tree", arguments = { "list of objects", "number of objects in leaf node", - "number of children in internal node" + "number of children in internal node", + "storage class for buckets", + "storage class parameters" }) - MHTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren) { + MHTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { super("MH-Tree"); this.leafCapacity = leafCapacity; this.numberOfChildren = numberOfChildren; + bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - root = new BuildTree(objects, leafCapacity, numberOfChildren).getRoot(); + root = new BuildTree(objects, leafCapacity, numberOfChildren, bucketDispatcher).getRoot(); } public void approxKNN(ApproxKNNQueryOperation operation) { diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index ee08758..518023e 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -10,13 +10,14 @@ import java.util.stream.Collectors; import java.util.stream.Stream; public abstract class Node implements Serializable { + /** * Serialization ID */ private static final long serialVersionUID = 420L; private Node parent; - private HullOptimizedRepresentationV3 hull; + protected HullOptimizedRepresentationV3 hull; Node(HullOptimizedRepresentationV3 hull) { hull.build(); diff --git a/src/mhtree/ObjectToNodeDistanceRank.java b/src/mhtree/ObjectToNodeDistanceRank.java index 6da2fa9..fa48091 100644 --- a/src/mhtree/ObjectToNodeDistanceRank.java +++ b/src/mhtree/ObjectToNodeDistanceRank.java @@ -3,6 +3,7 @@ package mhtree; import messif.objects.LocalAbstractObject; public class ObjectToNodeDistanceRank implements Comparable<ObjectToNodeDistanceRank> { + private final Node node; private final LocalAbstractObject object; private final float distance; -- GitLab