From d0e726c7334fd7bc4c696c8417349a2513d3bd16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Fri, 19 Feb 2021 09:49:35 +0100 Subject: [PATCH] ADD: more explicit tree building steps --- src/mhtree/BuildTree.java | 90 +++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index 42e1f09..f6f7c97 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -30,6 +30,9 @@ class BuildTree { BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, BucketDispatcher bucketDispatcher) throws BucketStorageException { this.arity = arity; this.leafCapacity = leafCapacity; + this.insertType = insertType; + this.distanceMeasure = distanceMeasure; + this.bucketDispatcher = bucketDispatcher; nodes = new Node[objects.size() / leafCapacity]; @@ -39,9 +42,15 @@ class BuildTree { objectDistances = new PrecomputedDistances(objects); nodeDistances = new float[nodes.length][nodes.length]; - this.insertType = insertType; - this.distanceMeasure = distanceMeasure; - this.bucketDispatcher = bucketDispatcher; + // Every object is stored in the root + if (objectDistances.getObjectCount() < leafCapacity) { + root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, distanceMeasure); + return; + } + + createLeafNodes(); + + precomputeNodeDistances(); buildTree(); } @@ -50,50 +59,41 @@ class BuildTree { return root; } - private void buildTree() throws BucketStorageException { - if (!initHullPoints()) return; - - precomputeHullDistances(); - + private void buildTree() { while (validNodeIndices.cardinality() != 1) { - BitSet notProcessedIndices = (BitSet) validNodeIndices.clone(); + BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone(); - while (!notProcessedIndices.isEmpty()) { - if (notProcessedIndices.cardinality() < arity) { - List<Integer> restOfTheIndices = new ArrayList<>(); + while (!notProcessedNodeIndices.isEmpty()) { + if (notProcessedNodeIndices.cardinality() < arity) { + List<Integer> nodeIndices = new ArrayList<>(); - notProcessedIndices.stream().forEach(restOfTheIndices::add); + notProcessedNodeIndices.stream().forEach(nodeIndices::add); - int mainIndex = restOfTheIndices.remove(0); - mergeHulls(mainIndex, restOfTheIndices); + int mainNodeIndex = nodeIndices.remove(0); + mergeNodes(mainNodeIndex, nodeIndices); break; } - int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedIndices); - notProcessedIndices.clear(furthestNodeIndex); + int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices); + notProcessedNodeIndices.clear(furthestNodeIndex); - List<Integer> nnIndices = new ArrayList<>(); + List<Integer> nnNodeIndices = new ArrayList<>(); for (int i = 0; i < arity - 1; i++) { - int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedIndices, furthestNodeIndex); - notProcessedIndices.clear(index); - nnIndices.add(index); + int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex); + notProcessedNodeIndices.clear(index); + nnNodeIndices.add(index); } - mergeHulls(furthestNodeIndex, nnIndices); + mergeNodes(furthestNodeIndex, nnNodeIndices); } } root = nodes[validNodeIndices.nextSetBit(0)]; } - private boolean initHullPoints() throws BucketStorageException { - if (objectDistances.getObjectCount() < leafCapacity) { - root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, distanceMeasure); - return false; - } - + private void createLeafNodes() throws BucketStorageException { BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount()); notProcessedObjectIndices.set(0, objectDistances.getObjectCount()); @@ -101,7 +101,7 @@ class BuildTree { if (notProcessedObjectIndices.cardinality() < leafCapacity) { for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) addObjectToClosestNode(i); - return true; + return; } List<LocalAbstractObject> objects = new ArrayList<>(); @@ -116,15 +116,13 @@ class BuildTree { nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, distanceMeasure); } - - return true; } private void addObjectToClosestNode(int objectIndex) throws BucketStorageException { LocalAbstractObject object = objectDistances.getObject(objectIndex); Function<Node, Float> getMinHullObjectDistance = node -> node.getHullObjects().stream() - .map(obj -> obj.getDistance(object)) + .map(object::getDistance) .reduce(Float.MAX_VALUE, Math::min); Map<Node, Float> nodeToObjectDistance = Arrays.stream(nodes) @@ -177,7 +175,7 @@ class BuildTree { }).collect(Collectors.toList()); } - private void precomputeHullDistances() { + private void precomputeNodeDistances() { for (int i = 0; i < nodes.length; i++) { for (int j = i + 1; j < nodes.length; j++) { float minDistance = Float.MAX_VALUE; @@ -195,11 +193,11 @@ class BuildTree { } } - private void mergeHulls(int mainHullIndex, List<Integer> otherHullIndices) { - if (otherHullIndices.size() == 0) return; + private void mergeNodes(int mainNodeIndex, List<Integer> nodeIndices) { + if (nodeIndices.size() == 0) return; - List<Integer> indices = new ArrayList<>(otherHullIndices); - indices.add(mainHullIndex); + List<Integer> indices = new ArrayList<>(nodeIndices); + indices.add(mainNodeIndex); List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList()); @@ -207,27 +205,27 @@ class BuildTree { nodes.forEach(node -> node.setParent(parent)); parent.addChildren(nodes); - this.nodes[mainHullIndex] = parent; + this.nodes[mainNodeIndex] = parent; - for (int i : otherHullIndices) { + for (int i : nodeIndices) { validNodeIndices.clear(i); this.nodes[i] = null; } - updateHullDistances(mainHullIndex, otherHullIndices); + updateNodeDistances(mainNodeIndex, nodeIndices); } - private void updateHullDistances(int baseHullIndex, List<Integer> otherHullIndices) { - if (otherHullIndices.size() == 0) return; + private void updateNodeDistances(int baseNodeIndex, List<Integer> nodeIndices) { + if (nodeIndices.size() == 0) return; validNodeIndices.stream().forEach(i -> { - float minDistance = nodeDistances[baseHullIndex][i]; + float minDistance = nodeDistances[baseNodeIndex][i]; - for (int index : otherHullIndices) + for (int index : nodeIndices) minDistance = Math.min(minDistance, nodeDistances[index][i]); - nodeDistances[baseHullIndex][i] = minDistance; - nodeDistances[i][baseHullIndex] = minDistance; + nodeDistances[baseNodeIndex][i] = minDistance; + nodeDistances[i][baseNodeIndex] = minDistance; }); } } -- GitLab