Skip to content
Snippets Groups Projects
Verified Commit 7f002a6a authored by David Procházka's avatar David Procházka
Browse files

FIX: simpified createRoot method, simplified computation of node distances,...

FIX: simpified createRoot method, simplified computation of node distances, generalized findClosestItem
parent c8f5e882
No related branches found
No related tags found
No related merge requests found
......@@ -5,41 +5,36 @@ import messif.buckets.BucketDispatcher;
import messif.buckets.BucketStorageException;
import messif.objects.LocalAbstractObject;
import java.util.*;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
class BuildTree {
private final int leafCapacity;
private final int arity;
private final Node[] nodes;
private final BitSet validNodeIndices;
private final int leafCapacity;
private final InsertType insertType;
private final ObjectToNodeDistance objectToNodeDistance;
private final NodeToNodeDistance nodeToNodeDistance;
private final BucketDispatcher bucketDispatcher;
private final PrecomputedDistances objectDistances;
private final float[][] nodeDistances;
private final InsertType insertType;
private final ObjectToNodeDistance objectToNodeDistance;
private final BucketDispatcher bucketDispatcher;
private final BiFunction<Float, Float, Float> distanceSelector;
private final Node[] nodes;
private final BitSet validNodeIndices;
private Node root;
private final Node root;
BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException {
this.arity = arity;
this.leafCapacity = leafCapacity;
this.arity = arity;
this.insertType = insertType;
this.objectToNodeDistance = objectToNodeDistance;
this.nodeToNodeDistance = nodeToNodeDistance;
this.bucketDispatcher = bucketDispatcher;
// Set distance selector function for two hull objects
boolean selectNearest = nodeToNodeDistance == NodeToNodeDistance.NEAREST_HULL_OBJECTS;
distanceSelector = selectNearest ? Math::min : Math::max;
nodes = new Node[objects.size() / leafCapacity];
validNodeIndices = new BitSet(nodes.length);
......@@ -48,55 +43,38 @@ class BuildTree {
objectDistances = new PrecomputedDistances(objects);
nodeDistances = new float[nodes.length][nodes.length];
// Every object is stored in the root
if (objectDistances.getObjectCount() < leafCapacity) {
// Every object is stored in the root in the case of small number of objects
if (objectDistances.getObjectCount() <= leafCapacity) {
root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
return;
}
createLeafNodes();
precomputeLeafNodeDistances();
buildTree();
root = createRoot();
}
public Node getRoot() {
return root;
}
private void buildTree() {
private Node createRoot() {
while (validNodeIndices.cardinality() != 1) {
BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone();
while (!notProcessedNodeIndices.isEmpty()) {
if (notProcessedNodeIndices.cardinality() < arity) {
Set<Integer> nodeIndices = new HashSet<>(notProcessedNodeIndices.cardinality() - 1);
int mainNodeIndex = notProcessedNodeIndices.nextSetBit(0);
notProcessedNodeIndices.stream().skip(1).forEach(nodeIndices::add);
mergeNodes(mainNodeIndex, nodeIndices);
if (notProcessedNodeIndices.cardinality() <= arity) {
mergeNodes(notProcessedNodeIndices.stream().boxed().collect(Collectors.toList()));
break;
}
int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices);
notProcessedNodeIndices.clear(furthestNodeIndex);
Set<Integer> nnNodeIndices = new HashSet<>(arity - 1);
for (int i = 0; i < arity - 1; i++) {
int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex);
notProcessedNodeIndices.clear(index);
nnNodeIndices.add(index);
}
mergeNodes(furthestNodeIndex, nnNodeIndices);
mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices));
}
}
root = nodes[validNodeIndices.nextSetBit(0)];
return nodes[validNodeIndices.nextSetBit(0)];
}
private void createLeafNodes() throws BucketStorageException {
......@@ -105,128 +83,148 @@ class BuildTree {
for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) {
if (notProcessedObjectIndices.cardinality() < leafCapacity) {
for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1))
addObjectToClosestNode(i);
for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) {
LocalAbstractObject object = objectDistances.getObject(i);
nodes[getClosestNodeIndex(object)].addObject(object);
}
return;
}
List<LocalAbstractObject> objects = new ArrayList<>(leafCapacity);
List<Integer> objectIndices = new ArrayList<>(leafCapacity);
// Select a base object
int furthestIndex = getFurthestIndex(objectDistances.getDistances(), notProcessedObjectIndices);
notProcessedObjectIndices.clear(furthestIndex);
objects.add(objectDistances.getObject(furthestIndex));
objectIndices.add(furthestIndex);
// Select the rest of the objects up to the total of leafCapacity
objects.addAll(findClosestObjects(furthestIndex, leafCapacity - 1, notProcessedObjectIndices));
objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices));
List<LocalAbstractObject> objects = objectIndices.stream().map(objectDistances::getObject).collect(Collectors.toList());
nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
}
}
private void addObjectToClosestNode(int objectIndex) throws BucketStorageException {
LocalAbstractObject object = objectDistances.getObject(objectIndex);
private int getClosestNodeIndex(LocalAbstractObject object) {
double minDistance = Double.MAX_VALUE;
int closestNodeIndex = -1;
for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) {
double distance = nodes[candidateIndex].getDistance(object);
Map<Node, Double> nodeToObjectDistance = Arrays.stream(nodes)
.collect(Collectors.toMap(Function.identity(), node -> node.getDistance(object)));
if (distance < minDistance) {
minDistance = distance;
closestNodeIndex = candidateIndex;
}
}
Collections.min(nodeToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey().addObject(object);
return closestNodeIndex;
}
private int getFurthestIndex(float[][] distMatrix, BitSet notUsedIndices) {
float max = Float.MIN_VALUE;
int maxIndex = notUsedIndices.nextSetBit(0);
private int getFurthestIndex(float[][] distanceMatrix, BitSet validIndices) {
float maxDistance = Float.MIN_VALUE;
int furthestIndex = validIndices.nextSetBit(0);
while (true) {
float[] distances = distMatrix[maxIndex];
int candidateMaxIdx = this.objectDistances.maxDistInArray(distances, notUsedIndices);
if (!(distances[candidateMaxIdx] > max)) {
return maxIndex;
float[] distances = distanceMatrix[furthestIndex];
int candidateIndex = this.objectDistances.maxDistInArray(distances, validIndices);
if (!(distances[candidateIndex] > maxDistance)) {
return furthestIndex;
}
max = distances[candidateMaxIdx];
maxIndex = candidateMaxIdx;
maxDistance = distances[candidateIndex];
furthestIndex = candidateIndex;
}
}
private Set<LocalAbstractObject> findClosestObjects(int baseObjectIndex, int numberOfObjects, BitSet notProcessedIndices) {
List<Integer> objectIndices = new ArrayList<>(1 + numberOfObjects);
private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) {
List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems);
itemIndices.add(itemIndex);
objectIndices.add(baseObjectIndex);
List<Integer> resultItemsIndices = new ArrayList<>();
return IntStream.range(0, numberOfObjects).mapToObj(i -> {
HashMap<Integer, Float> indexToDistance = new HashMap<>(objectIndices.size());
while (resultItemsIndices.size() != numberOfItems) {
int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices);
for (int index : objectIndices) {
int nnIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), notProcessedIndices);
itemIndices.add(index);
notProcessedItemIndices.clear(index);
resultItemsIndices.add(index);
}
float distanceSum = objectIndices.stream()
.map(objectIndex -> objectDistances.getDistance(objectIndex, nnIndex))
.reduce(0f, Float::sum);
return resultItemsIndices;
}
indexToDistance.put(nnIndex, distanceSum);
}
private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) {
double minDistance = Double.MAX_VALUE;
int closestNodeIndex = -1;
int closestPointIndex = Collections.min(indexToDistance.entrySet(), Map.Entry.comparingByValue()).getKey();
for (int candidateIndex = validNodeIndices.nextSetBit(0); candidateIndex >= 0; candidateIndex = validNodeIndices.nextSetBit(candidateIndex + 1)) {
double sum = 0;
for (int index : indices)
sum += nodeDistances[index][candidateIndex];
notProcessedIndices.clear(closestPointIndex);
objectIndices.add(closestPointIndex);
if (sum < minDistance) {
minDistance = sum;
closestNodeIndex = candidateIndex;
}
}
return objectDistances.getObject(closestPointIndex);
}).collect(Collectors.toSet());
return closestNodeIndex;
}
private void precomputeLeafNodeDistances() {
for (int i = 0; i < nodes.length; i++) {
for (int j = i + 1; j < nodes.length; j++) {
float distance = Float.MAX_VALUE;
for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects())
for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects())
distance = distanceSelector.apply(distance, objectDistances.getDistance(firstHullObject, secondHullObject));
private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) {
double minDistance = Double.MAX_VALUE;
int closestObjectIndex = -1;
if (distance == 0f)
throw new RuntimeException("Zero distance between " + nodes[i].toString() + " and " + nodes[j].toString());
for (int index : indices) {
int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices);
double distance = indices.stream().mapToDouble(i -> objectDistances.getDistance(i, candidateIndex)).sum();
nodeDistances[i][j] = distance;
nodeDistances[j][i] = distance;
if (distance < minDistance) {
minDistance = distance;
closestObjectIndex = candidateIndex;
}
}
return closestObjectIndex;
}
private void mergeNodes(List<Integer> nodeIndices) {
int parentNodeIndex = nodeIndices.remove(0);
mergeNodes(parentNodeIndex, nodeIndices);
}
private void mergeNodes(int mainNodeIndex, Set<Integer> nodeIndices) {
private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) {
if (nodeIndices.size() == 0) return;
Set<Integer> indices = new HashSet<>(nodeIndices);
indices.add(mainNodeIndex);
nodeIndices.add(parentNodeIndex);
Set<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toSet());
List<Node> nodes = nodeIndices.stream().map(i -> this.nodes[i]).collect(Collectors.toList());
InternalNode parent = Node.createParent(nodes, objectDistances, insertType, objectToNodeDistance);
nodes.forEach(node -> node.setParent(parent));
parent.addChildren(nodes);
this.nodes[mainNodeIndex] = parent;
nodeIndices.forEach(index -> {
validNodeIndices.clear(index);
this.nodes[index] = null;
});
for (int i : nodeIndices) {
validNodeIndices.clear(i);
this.nodes[i] = null;
}
this.nodes[parentNodeIndex] = parent;
validNodeIndices.set(parentNodeIndex);
updateNodeDistances(mainNodeIndex, nodeIndices);
// Update node distances
validNodeIndices.stream().forEach(index -> computeNodeDistances(parentNodeIndex, index));
}
private void updateNodeDistances(int baseNodeIndex, Set<Integer> nodeIndices) {
if (nodeIndices.size() == 0) return;
validNodeIndices.stream().forEach(i -> {
float distance = nodeDistances[baseNodeIndex][i];
for (int index : nodeIndices)
distance = distanceSelector.apply(distance, nodeDistances[index][i]);
private void computeNodeDistances(int i, int j) {
float distance = nodeToNodeDistance.getDistance(nodes[i], nodes[j], objectDistances::getDistance);
nodeDistances[baseNodeIndex][i] = distance;
nodeDistances[i][baseNodeIndex] = distance;
});
nodeDistances[i][j] = distance;
nodeDistances[j][i] = distance;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment