From 70ee6dc18d5b8adde987741e1da45a9fdb4ee624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 24 Feb 2021 13:59:25 +0100 Subject: [PATCH] FIX: number of leaf node objects is taken from buckets --- src/mhtree/BuildTree.java | 12 ++++++------ src/mhtree/InternalNode.java | 11 ++--------- src/mhtree/LeafNode.java | 12 +++++------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index c32fcd0..f553cd9 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -22,16 +22,16 @@ class BuildTree { private final float[][] nodeDistances; private final InsertType insertType; - private final DistanceMeasure distanceMeasure; + private final ObjectToNodeDistance objectToNodeDistance; private final BucketDispatcher bucketDispatcher; private Node root; - BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, BucketDispatcher bucketDispatcher) throws BucketStorageException { + BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException { this.arity = arity; this.leafCapacity = leafCapacity; this.insertType = insertType; - this.distanceMeasure = distanceMeasure; + this.objectToNodeDistance = objectToNodeDistance; this.bucketDispatcher = bucketDispatcher; nodes = new Node[objects.size() / leafCapacity]; @@ -44,7 +44,7 @@ class BuildTree { // Every object is stored in the root if (objectDistances.getObjectCount() < leafCapacity) { - root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, distanceMeasure); + root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance); return; } @@ -114,7 +114,7 @@ class BuildTree { // Select the rest of the objects up to the total of leafCapacity objects.addAll(findClosestObjects(furthestIndex, leafCapacity - 1, notProcessedObjectIndices)); - nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, distanceMeasure); + nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance); } } @@ -201,7 +201,7 @@ class BuildTree { Set<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toSet()); - InternalNode parent = Node.createParent(nodes, objectDistances, insertType, distanceMeasure); + InternalNode parent = Node.createParent(nodes, objectDistances, insertType, objectToNodeDistance); nodes.forEach(node -> node.setParent(parent)); parent.addChildren(nodes); diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 3a383a1..e628e86 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -17,8 +17,8 @@ public class InternalNode extends Node implements Serializable { private final Set<Node> children; - InternalNode(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { - super(distances, insertType, distanceMeasure); + InternalNode(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { + super(distances, insertType, objectToNodeDistance); children = new HashSet<>(); } @@ -52,13 +52,6 @@ public class InternalNode extends Node implements Serializable { return children.stream().anyMatch(child -> child.contains(object)); } - public List<Integer> getLeafNodesObjectCounts() { - return children.stream() - .map(Node::getLeafNodesObjectCounts) - .flatMap(Collection::stream) - .collect(Collectors.toList()); - } - public int getHeight() { return children.stream() .mapToInt(Node::getHeight) diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index bc36cbd..07a610a 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -7,7 +7,9 @@ import messif.objects.LocalAbstractObject; import messif.objects.util.AbstractObjectIterator; import java.io.Serializable; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; public class LeafNode extends Node implements Serializable { @@ -17,8 +19,8 @@ public class LeafNode extends Node implements Serializable { private static final long serialVersionUID = 1L; private LocalBucket bucket; - LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType, DistanceMeasure distanceMeasure) throws BucketStorageException { - super(distances, insertType, distanceMeasure); + LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) throws BucketStorageException { + super(distances, insertType, objectToNodeDistance); this.bucket = bucket; this.bucket.addObjects(distances.getObjects()); @@ -47,10 +49,6 @@ public class LeafNode extends Node implements Serializable { return false; } - public List<Integer> getLeafNodesObjectCounts() { - return Collections.singletonList(bucket.getObjectCount()); - } - public int getHeight() { return 0; } -- GitLab