diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index 16fa9890267ff7347811637d85b82674f16f150c..e1be9461815b15c4cf949509dcdc75f1ac4a5204 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -1,31 +1,267 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.objects.LocalAbstractObject; -import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; -import messif.objects.util.AbstractObjectList; -import messif.objects.util.AbstractStreamObjectIterator; -import messif.objects.util.StreamGenericAbstractObjectIterator; -import java.io.IOException; -import java.util.logging.Level; -import java.util.logging.Logger; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; -public class BuildTree { +class BuildTree { - public static void main(String[] args) { - if (args.length == 0) args = new String[]{"./data/mix_m.txt"}; + private final int leafCapacity; + private final int numberOfChildren; - try { - for (String arg : args) { - System.out.println("Processing argument " + arg); + private final Node[] nodes; + private final BitSet validNodeIndices; - AbstractStreamObjectIterator<ObjectFloatVectorNeuralNetworkL2> iter = new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, arg); - AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter); + private final AbstractRepresentation.PrecomputedDistances objectDistances; + private final float[][] nodeDistances; - MHTree tree = new MHTree(objects, 3, 7); + private Node root; + + + BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren) { + this.numberOfChildren = numberOfChildren; + this.leafCapacity = leafCapacity; + + nodes = new Node[objects.size() / leafCapacity]; + + validNodeIndices = new BitSet(nodes.length); + validNodeIndices.set(0, nodes.length); + + objectDistances = new AbstractRepresentation.PrecomputedDistances(objects); + nodeDistances = new float[nodes.length][nodes.length]; + + buildTree(); + } + + public Node getRoot() { + return root; + } + + private void buildTree() { + if (!initHullPoints()) return; + + precomputeHullDistances(); + + while (validNodeIndices.cardinality() != 1) { + BitSet notProcessedIndices = (BitSet) validNodeIndices.clone(); + + while (!notProcessedIndices.isEmpty()) { + if (notProcessedIndices.cardinality() < numberOfChildren) { + List<Integer> restOfTheIndices = new ArrayList<>(); + + notProcessedIndices.stream().forEach(restOfTheIndices::add); + + int mainIndex = restOfTheIndices.remove(0); + mergeHulls(mainIndex, restOfTheIndices); + + break; + } + + int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedIndices); + notProcessedIndices.clear(furthestNodeIndex); + + List<Integer> nnIndices = new ArrayList<>(); + + for (int i = 0; i < numberOfChildren - 1; i++) { + int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedIndices, furthestNodeIndex); + notProcessedIndices.clear(index); + nnIndices.add(index); + } + + mergeHulls(furthestNodeIndex, nnIndices); } - } catch (IOException ex) { - Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); } + + root = nodes[validNodeIndices.nextSetBit(0)]; + } + + private boolean initHullPoints() { + if (objectDistances.getObjectCount() < leafCapacity) { + root = new LeafNode(objectDistances); + return false; + } + + BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount()); + notProcessedObjectIndices.set(0, objectDistances.getObjectCount()); + + for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) { + if (notProcessedObjectIndices.cardinality() < leafCapacity) { + notProcessedObjectIndices.stream().forEach(this::addObjectToClosestNode); + return true; + } + + List<LocalAbstractObject> objects = new ArrayList<>(); + + // Select a base object + int furthestIndex = getFurthestIndex(objectDistances.getDistances(), notProcessedObjectIndices); + notProcessedObjectIndices.clear(furthestIndex); + objects.add(objectDistances.getObject(furthestIndex)); + + // Select the rest of the points up to the total of leafCapacity + findClosestObjectIndices(furthestIndex, leafCapacity - 1, notProcessedObjectIndices) + .forEach(i -> objects.add(objectDistances.getObject(i))); + + nodes[nodeIndex] = new LeafNode(objects); + } + + return true; + } + + private void addObjectToClosestNode(int objectIndex) { + int[] hullObjectIndices = getEveryHullObjectIndex(); + + int closestHullObjectIndex = hullObjectIndices[getNNIndex(objectIndex, hullObjectIndices)]; + int nodeIndex = findCorrespondingHullIndex(objectDistances.getObject(closestHullObjectIndex)); + + nodes[nodeIndex] = LeafNode.addObject((LeafNode) nodes[nodeIndex], objectDistances.getObject(objectIndex)); + } + + private int getNNIndex(int centerIndex, int[] dataIndices) { + AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIndices.length]; + + for (int i = 0; i < dataIndices.length; ++i) { + int objIdx = dataIndices[i]; + ranked[i] = new AbstractRepresentation.Ranked(objectDistances.getDistances(centerIndex)[objIdx], i); + } + + Arrays.sort(ranked); + + return ranked[0].index; + } + + private int getNNIndex(int centerIdx, BitSet dataIndices) { + AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIndices.cardinality()]; + int i = 0; + + for (int index = dataIndices.nextSetBit(0); index >= 0; index = dataIndices.nextSetBit(index + 1)) { + ranked[i] = new AbstractRepresentation.Ranked(objectDistances.getDistances(centerIdx)[index], index); + ++i; + } + + Arrays.sort(ranked); + + return ranked[0].index; + } + + private int[] getEveryHullObjectIndex() { + return objectDistances.getIndexes( + Arrays.stream(nodes) + .map(Node::getHullObjects) + .flatMap(Collection::stream) + .collect(Collectors.toList())); + } + + private int findCorrespondingHullIndex(LocalAbstractObject object) { + return IntStream.range(0, nodes.length) + .filter(i -> nodes[i].contains(object)) + .findFirst() + .getAsInt(); + } + + private int getFurthestIndex(float[][] distMatrix, BitSet notUsedIndices) { + float max = Float.MIN_VALUE; + int maxIndex = notUsedIndices.nextSetBit(0); + + while (true) { + float[] distances = distMatrix[maxIndex]; + int candidateMaxIdx = this.objectDistances.maxDistInArray(distances, notUsedIndices); + if (!(distances[candidateMaxIdx] > max)) { + return maxIndex; + } + + max = distances[candidateMaxIdx]; + maxIndex = candidateMaxIdx; + } + } + + private List<Integer> findClosestObjectIndices(int basePointIndex, int numberOfIndices, BitSet notProcessedIndices) { + List<Integer> closestIndices = new ArrayList<>(); + List<Integer> hullPointIndices = new ArrayList<>(); + + hullPointIndices.add(basePointIndex); + + for (int i = 0; i < numberOfIndices; i++) { + List<Integer> candidateIndices = new ArrayList<>(); + List<Float> candidateDistances = new ArrayList<>(); + + for (int index : hullPointIndices) { + int nnIndex = getNNIndex(index, notProcessedIndices); + + float distanceSum = hullPointIndices.stream() + .map(pointIndex -> objectDistances.getDistance(pointIndex, nnIndex)) + .reduce(0f, Float::sum); + + candidateIndices.add(nnIndex); + candidateDistances.add(distanceSum); + } + + int smallestDistanceIndex = candidateDistances.indexOf(Collections.min(candidateDistances)); + int closestPointIndex = candidateIndices.get(smallestDistanceIndex); + + notProcessedIndices.clear(closestPointIndex); + + hullPointIndices.add(closestPointIndex); + closestIndices.add(closestPointIndex); + } + + return closestIndices; + } + + private void precomputeHullDistances() { + for (int i = 0; i < nodes.length; i++) { + for (int j = i + 1; j < nodes.length; j++) { + float minDistance = Float.MAX_VALUE; + + for (LocalAbstractObject firstPoint : nodes[i].getObjects()) + for (LocalAbstractObject secondPoint : nodes[j].getObjects()) + if (objectDistances.getDistance(firstPoint, secondPoint) < minDistance) + minDistance = objectDistances.getDistance(firstPoint, secondPoint); + + if (minDistance == 0f) + throw new RuntimeException("Zero distance between " + nodes[i].toString() + " and " + nodes[j].toString()); + + nodeDistances[i][j] = minDistance; + nodeDistances[j][i] = minDistance; + } + } + } + + private void mergeHulls(int mainHullIndex, List<Integer> otherHullIndices) { + if (otherHullIndices.size() == 0) return; + + List<Integer> indices = new ArrayList<>(otherHullIndices); + indices.add(mainHullIndex); + + List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList()); + + InternalNode parent = Node.createParent(nodes); + nodes.forEach(node -> node.setParent(parent)); + parent.addChildren(nodes); + + this.nodes[mainHullIndex] = parent; + + for (int i : otherHullIndices) { + validNodeIndices.clear(i); + this.nodes[i] = null; + } + + updateHullDistances(mainHullIndex, otherHullIndices); + } + + private void updateHullDistances(int baseHullIndex, List<Integer> otherHullIndices) { + if (otherHullIndices.size() == 0) return; + + validNodeIndices.stream().forEach(i -> { + float minDistance = nodeDistances[baseHullIndex][i]; + + for (int index : otherHullIndices) + minDistance = Math.min(minDistance, nodeDistances[index][i]); + + nodeDistances[baseHullIndex][i] = minDistance; + nodeDistances[i][baseHullIndex] = minDistance; + }); } } diff --git a/src/mhtree/BuildTreeApp.java b/src/mhtree/BuildTreeApp.java new file mode 100644 index 0000000000000000000000000000000000000000..d9eeda797bd225f78a24c0b3a7949d104b8303d5 --- /dev/null +++ b/src/mhtree/BuildTreeApp.java @@ -0,0 +1,54 @@ +package mhtree; + +import messif.objects.AbstractObject; +import messif.objects.LocalAbstractObject; +import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; +import messif.objects.util.AbstractObjectList; +import messif.objects.util.AbstractStreamObjectIterator; +import messif.objects.util.StreamGenericAbstractObjectIterator; +import messif.operations.query.ApproxKNNQueryOperation; + +import java.io.IOException; +import java.util.Iterator; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class BuildTreeApp { + + public static void main(String[] args) { + if (args.length == 0) + args = new String[]{"./data/mix_m.txt", "./data/intersection.txt", "./data/middle_c.txt", "./data/r1.txt"}; + + try { + for (String arg : args) { + System.out.println("Processing argument " + arg); + + AbstractStreamObjectIterator<ObjectFloatVectorNeuralNetworkL2> iter = new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, arg); + AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter); + + MHTree tree = new MHTree(objects, 10, 5); + + for (LocalAbstractObject object : objects) { + + ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, 1); + tree.approxKNN(op); + + Iterator<AbstractObject> answerObjects = op.getAnswerObjects(); + + if (op.getAnswerCount() == 0) throw new RuntimeException("no result"); + if (op.getAnswerCount() != 1) throw new RuntimeException("too many results"); + + while (answerObjects.hasNext()) { + AbstractObject answerObject = answerObjects.next(); + + if (!answerObject.getLocatorURI().equals(object.getLocatorURI())) + throw new RuntimeException("returned different object"); + } + } + + } + } catch (IOException ex) { + Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); + } + } +} diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 1f0898370a0547eb4f31422b56a5e62a9d837948..d30976ecfce47392dd5fdf5a4e511191daedddd9 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -1,8 +1,11 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import messif.objects.LocalAbstractObject; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; public class InternalNode extends Node implements Serializable { /** @@ -10,7 +13,18 @@ public class InternalNode extends Node implements Serializable { */ private static final long serialVersionUID = 2L; - InternalNode(HullOptimizedRepresentationV3 hull) { - super(hull); + private List<Node> children; + + InternalNode(Stream<LocalAbstractObject> objects) { + super(objects); + children = new ArrayList<>(); + } + + public void addChildren(List<Node> children) { + this.children.addAll(children); + } + + public List<Node> getChildren() { + return children; } } diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index 041277f982e049690afcd4a5097241700698a5df..072de8e71d09fd075321cb3db49cbe0cd3e2e117 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -1,8 +1,11 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import messif.objects.LocalAbstractObject; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; public class LeafNode extends Node implements Serializable { /** @@ -10,7 +13,18 @@ public class LeafNode extends Node implements Serializable { */ private static final long serialVersionUID = 1L; - LeafNode(HullOptimizedRepresentationV3 hull) { - super(hull); + LeafNode(List<LocalAbstractObject> objects) { + super(objects); + } + + LeafNode(AbstractRepresentation.PrecomputedDistances distances) { + super(distances); + } + + public static LeafNode addObject(LeafNode node, LocalAbstractObject object) { + List<LocalAbstractObject> objects = new ArrayList<>(node.getObjects()); + objects.add(object); + + return new LeafNode(objects); } } diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index a034f3fc63456a8d56b97949fe187730ab856cf2..742285aae5e9bd094b2d16c404008e621c1f4193 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -1,16 +1,13 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; -import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; -import cz.muni.fi.disa.similarityoperators.cover.HullRepresentation; import messif.algorithms.Algorithm; import messif.objects.LocalAbstractObject; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import java.io.Serializable; -import java.util.*; -import java.util.stream.Collectors; +import java.util.List; +import java.util.PriorityQueue; public class MHTree extends Algorithm implements Serializable { @@ -19,416 +16,66 @@ public class MHTree extends Algorithm implements Serializable { */ private static final long serialVersionUID = 42L; - private final Node[] nodes; - private final HullOptimizedRepresentationV3[] hulls; - private final BitSet validHullIndices; + private final int leafCapacity; + private final int numberOfChildren; - private final AbstractRepresentation.PrecomputedDistances shared; - private final float[][] hullMinDistances; - - private final int initialHullsSize; - private final int hullsMergedIntoGroupsOf; - - private Node root; + private final Node root; @AlgorithmConstructor(description = "MH-Tree", arguments = { "list of objects", - "number of objects in one hull in leaf node", - "how many hulls should be merge into one", + "number of objects in leaf node", + "number of children in internal node" }) - MHTree(List<LocalAbstractObject> objects, int initialHullsSize, int mergeHullsIntoGroupsOf) { + MHTree(List<LocalAbstractObject> objects, int leafCapacity, int numberOfChildren) { super("MH-Tree"); - hullsMergedIntoGroupsOf = mergeHullsIntoGroupsOf; - this.initialHullsSize = initialHullsSize; - - hulls = new HullOptimizedRepresentationV3[objects.size() / initialHullsSize]; - nodes = new Node[objects.size() / initialHullsSize]; - - validHullIndices = new BitSet(hulls.length); - validHullIndices.set(0, hulls.length); - - shared = new AbstractRepresentation.PrecomputedDistances(objects); - hullMinDistances = new float[hulls.length][hulls.length]; + this.leafCapacity = leafCapacity; + this.numberOfChildren = numberOfChildren; - buildTree(); + root = new BuildTree(objects, leafCapacity, numberOfChildren).getRoot(); } public void approxKNN(ApproxKNNQueryOperation operation) { LocalAbstractObject object = operation.getQueryObject(); - int k = operation.getK(); - // TODO: - operation.endOperation(); - } - - - public boolean insert(InsertOperation operation) { - LocalAbstractObject object = operation.getInsertedObject(); - // TODO: - operation.endOperation(); - return true; - } - - - private void buildTree() { - if (!initHullPoints()) return; + // int k = operation.getK(); - precomputeHullDistances(); + PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>(); + queue.add(new ObjectToNodeDistanceRank(root, object)); - while (validHullIndices.cardinality() != 1) { - BitSet notProcessedIndices = (BitSet) validHullIndices.clone(); + while (!queue.isEmpty()) { + Node currentNode = queue.poll().getNode(); - while (!notProcessedIndices.isEmpty()) { - if (notProcessedIndices.cardinality() < hullsMergedIntoGroupsOf) { - List<Integer> restOfTheIndices = new ArrayList<>(); - - for (int j = notProcessedIndices.nextSetBit(0); j >= 0; j = notProcessedIndices.nextSetBit(j + 1)) - restOfTheIndices.add(j); - - int baseIndex = restOfTheIndices.remove(0); - mergeHulls(baseIndex, restOfTheIndices); - - break; - } - - int furthestHullIndex = getFurthestIdx(hullMinDistances, notProcessedIndices); - notProcessedIndices.clear(furthestHullIndex); - - List<Integer> nnIndices = new ArrayList<>(); - - for (int i = 0; i < hullsMergedIntoGroupsOf - 1; i++) { - int index = shared.minDistInArrayExceptIdx(hullMinDistances[furthestHullIndex], notProcessedIndices, furthestHullIndex); - notProcessedIndices.clear(index); - nnIndices.add(index); - } - - mergeHulls(furthestHullIndex, nnIndices); + if (currentNode.isLeaf()) { + for (LocalAbstractObject obj : currentNode.getObjects()) + if (obj.getLocatorURI().equals(object.getLocatorURI())) + operation.addToAnswer(obj); } - } - - root = nodes[validHullIndices.nextSetBit(0)]; - } - - private boolean initHullPoints() { - if (shared.getObjectCount() < initialHullsSize) { - root = new LeafNode(new HullOptimizedRepresentationV3(shared)); - - return false; - } - - BitSet notProcessedIndices = new BitSet(shared.getObjectCount()); - notProcessedIndices.set(0, shared.getObjectCount()); - int hullIndex = 0; - while (!notProcessedIndices.isEmpty()) { - if (notProcessedIndices.cardinality() < initialHullsSize) { - for (int i = notProcessedIndices.nextSetBit(0); i >= 0; i = notProcessedIndices.nextSetBit(i + 1)) - addObjectToClosestHull(i); + if (!currentNode.isLeaf()) { + InternalNode node = (InternalNode) currentNode; - return true; + for (Node child : node.getChildren()) + if (child.isCovered(object)) + queue.add(new ObjectToNodeDistanceRank(child, object)); } - - List<LocalAbstractObject> hullPoints = new ArrayList<>(); - - // Select a base point - int furthestIndex = getFurthestIdx(notProcessedIndices); - notProcessedIndices.clear(furthestIndex); - hullPoints.add(shared.getObject(furthestIndex)); - - // Select the rest of the `initialHullsSize` number of points - findClosestIndices(furthestIndex, initialHullsSize - 1, notProcessedIndices) - .forEach(i -> hullPoints.add(shared.getObject(i))); - - HullOptimizedRepresentationV3 newHull = new HullOptimizedRepresentationV3(hullPoints); - newHull.getHull().addAll(hullPoints); - - hulls[hullIndex] = newHull; - nodes[hullIndex] = new LeafNode(newHull); - - hullIndex++; - } - - return true; - } - - private void addObjectToClosestHull(int index) { - int[] hullIndices = getAllCurrentHullIndexes(); - - // returns closest index within the data - LocalAbstractObject object = shared.getObject(index); - int closestIndex = hullIndices[getNNIndex(index, hullIndices)]; - HullRepresentation closestHull = hulls[findCorrespondingHullIdx(shared.getObject(closestIndex))]; - - closestHull.getObjects().add(object); // add just to data - - if (isCoveredByHull(index, shared.getIndexes(closestHull.getHull()))) { - return; // don't add to hulls } - closestHull.getHull().add(object); //otherwise add to hull points - } - - private int getNNIndex(int centerIdx, int[] dataIdxs) { - return getKNNIndexesShared(1, dataIdxs, shared.getDistances(centerIdx))[0]; - } - - private int[] getKNNIndexesShared(int k, int[] dataIdxs, float[] dists) { - AbstractRepresentation.Ranked[] ranked = getRankedIdxs(dataIdxs, dists); - int length = Math.min(k, ranked.length); - int[] result = new int[length]; - - for (int i = 0; i < length; ++i) { - result[i] = ranked[i].index; - } - - return result; - } - - private AbstractRepresentation.Ranked[] getRankedIdxs(int[] dataIdxs, float[] dists) { - AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIdxs.length]; - - for (int i = 0; i < dataIdxs.length; ++i) { - int objIdx = dataIdxs[i]; - ranked[i] = new AbstractRepresentation.Ranked(dists[objIdx], i); - } - - Arrays.sort(ranked); - return ranked; - } - - private int[] getAllCurrentHullIndexes() { - return shared.getIndexes(Arrays - .stream(hulls) - .map(hull -> hull.getHull()) - .flatMap(Collection::stream) - .collect(Collectors.toList())); - } - - private int findCorrespondingHullIdx(LocalAbstractObject object) { - for (int i = 0; i < hulls.length; i++) - if (hulls[i].getHull().contains(object)) - return i; - - return -1; // some hull has to be closest, does not make sense otherwise - } - - public boolean isCoveredByHull(int objIdx, int[] hullIdxs) { - return this.isCoveredByDistSum(objIdx, hullIdxs[getNNIndex(objIdx, hullIdxs)], hullIdxs); - } - - private boolean isCoveredByDistSum(int objIdx, int nnIdx, int[] nnIdxs) { - float[] distsObj = shared.getDistances(objIdx); - float sumObj = shared.sumArray(distsObj, nnIdxs) - distsObj[nnIdx]; - float[] distsNN = shared.getDistances(nnIdx); - float sumNN = shared.sumArray(distsNN, nnIdxs); - return sumObj <= sumNN; - } - - private int getFurthestIdx(BitSet notUsedIdxs) { - return this.getFurthestIdx(shared.getDistances(), notUsedIdxs); - } - - private int getFurthestIdx(float[][] distMatrix, BitSet notUsedIdxs) { - float max = 1.4E-45F; - int maxIdx = notUsedIdxs.nextSetBit(0); - - while (true) { - float[] dists = distMatrix[maxIdx]; - int candidateMaxIdx = this.shared.maxDistInArray(dists, notUsedIdxs); - if (!(dists[candidateMaxIdx] > max)) { - return maxIdx; - } - - max = dists[candidateMaxIdx]; - maxIdx = candidateMaxIdx; - } - } - - private List<Integer> findClosestIndices(int basePointIndex, int numberOfIndices, BitSet notProcessedIndices) { - List<Integer> closestIndices = new ArrayList<>(); - List<Integer> hullPointIndices = new ArrayList<>(); - - hullPointIndices.add(basePointIndex); - - for (int i = 0; i < numberOfIndices; i++) { - List<Integer> candidateIndices = new ArrayList<>(); - List<Float> candidateDistances = new ArrayList<>(); - - for (int index : hullPointIndices) { - int nnIndex = getNNIndex(index, notProcessedIndices); - - float distanceSum = hullPointIndices.stream() - .map(pointIndex -> shared.getDistance(pointIndex, nnIndex)) - .reduce(0f, Float::sum); - - candidateIndices.add(nnIndex); - candidateDistances.add(distanceSum); - } - - int smallestDistanceIndex = candidateDistances.indexOf(Collections.min(candidateDistances)); - int closestPointIndex = candidateIndices.get(smallestDistanceIndex); - - notProcessedIndices.clear(closestPointIndex); - - hullPointIndices.add(closestPointIndex); - closestIndices.add(closestPointIndex); - } - - return closestIndices; - } - - private int getNNIndex(int centerIdx, BitSet idxs) { - return getRankedIdxs(idxs, shared.getDistances(centerIdx))[0].index; - } - - private AbstractRepresentation.Ranked[] getRankedIdxs(BitSet dataIdxs, float[] dists) { - AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIdxs.cardinality()]; - int i = 0; - - for (int index = dataIdxs.nextSetBit(0); index >= 0; index = dataIdxs.nextSetBit(index + 1)) { - ranked[i] = new AbstractRepresentation.Ranked(dists[index], index); - ++i; - } - - Arrays.sort(ranked); - return ranked; - } - - private void precomputeHullDistances() { - for (int i = 0; i < hulls.length; i++) { // pre vsetky hullsIdxs - for (int j = i + 1; j < hulls.length; j++) { // prejdi zvysne - float minDist = Float.MAX_VALUE;// udrzi min vzdialenost medzi vsetkymi komb hullPointov - for (LocalAbstractObject firstPoint : hulls[i].getObjects()) { // prejdi vsetky hull points prveho a najdi najmensiu vzial k hociakemu druhemu - for (LocalAbstractObject secondPoint : hulls[j].getObjects()) { - if (shared.getDistance(firstPoint, secondPoint) < minDist) { - minDist = shared.getDistance(firstPoint, secondPoint); - } - } - } - if (minDist == 0f) - throw new RuntimeException("Zero distance between " + hulls[i].toString() + - " and " + hulls[j].toString()); - - - hullMinDistances[i][j] = minDist; - hullMinDistances[j][i] = minDist; - } - } - } - - private void mergeHulls(int baseHullIndex, List<Integer> otherHullIndices) { - if (otherHullIndices.size() == 0) return; - - // Create list of all the hulls - List<HullOptimizedRepresentationV3> hulls = new ArrayList<>(); - hulls.add(this.hulls[baseHullIndex]); - otherHullIndices.forEach(i -> hulls.add(this.hulls[i])); - - // Create a new node - HullOptimizedRepresentationV3 newHull = mergeHulls(hulls); - InternalNode newNode = new InternalNode(newHull); - - // Set parent for every corresponding input node - nodes[baseHullIndex].setParent(newNode); - otherHullIndices.forEach(i -> nodes[i].setParent(newNode)); - - // Set children for the new node - newNode.addChildren(nodes[baseHullIndex]); - newNode.addChildren(otherHullIndices.stream().map(i -> nodes[i])); - - // Cleanup the nodes array and place the new node onto the first index - nodes[baseHullIndex] = newNode; - otherHullIndices.forEach(i -> nodes[i] = null); - - // Cleanup the merged hulls - this.hulls[baseHullIndex] = newHull; - otherHullIndices.forEach(i -> { - validHullIndices.clear(i); - this.hulls[i] = null; - }); - - updateHullDistances(baseHullIndex, otherHullIndices); - } - - private HullOptimizedRepresentationV3 mergeHulls(List<HullOptimizedRepresentationV3> hulls) { - if (hulls.size() == 0) return null; - if (hulls.size() == 1) return hulls.get(0); - - List<LocalAbstractObject> points = new ArrayList<>(); - - // Minimal number of points in a hull is 3 - int maxCoveredPoints = -3; - int maxIndex = hulls.get(0).getHull().size(); - - for (HullRepresentation hull : hulls) { - points.addAll(hull.getObjects()); - maxCoveredPoints += hull.getHull().size(); - maxIndex = Math.max(maxIndex, hull.getHull().size()); - } - - HashSet<LocalAbstractObject> coveredObjects = new HashSet<>(); - HullOptimizedRepresentationV3 newHull = new HullOptimizedRepresentationV3(points); - - for (int i = 0; i < maxIndex; i++) - for (HullOptimizedRepresentationV3 hull : hulls) - if (i < hull.getHull().size()) - checkCoveredObjects(hull, hulls, hull.getHull().get(i), coveredObjects, maxCoveredPoints); - - hulls.forEach(hull -> addNotCoveredHullPoints(hull, newHull, coveredObjects)); - - return newHull; - } - - private void checkCoveredObjects(HullOptimizedRepresentationV3 firstHull, - List<HullOptimizedRepresentationV3> hulls, - LocalAbstractObject excludedPoint, - HashSet<LocalAbstractObject> coveredPoints, - int maxCoveredPoints) { - List<HullOptimizedRepresentationV3> everyHull = new ArrayList<>(); - everyHull.add(firstHull); - everyHull.addAll(hulls); - - List<LocalAbstractObject> pointsForCoverageCheck = getPointsToComputeCoverage(everyHull, excludedPoint, coveredPoints); - if (isCoveredByHull(shared.getIndex(excludedPoint), this.shared.getIndexes(pointsForCoverageCheck)) && coveredPoints.size() < maxCoveredPoints) { - coveredPoints.add(excludedPoint); - } - } - - private List<LocalAbstractObject> getPointsToComputeCoverage(List<HullOptimizedRepresentationV3> hulls, - LocalAbstractObject excludedPoint, - HashSet<LocalAbstractObject> covered) { - List<LocalAbstractObject> hullPoints = new ArrayList<>(); - hulls.forEach(hull -> hullPoints.addAll(hull.getHull())); - - hullPoints.remove(excludedPoint); - hullPoints.removeAll(covered); - - return hullPoints; + operation.endOperation(); } - /** - * TODO: Vlasta 2020/06/23 -- should not we try to replace an existing hull object in newHull instead of "blindly" adding the hull object? - */ - private void addNotCoveredHullPoints(HullOptimizedRepresentationV3 hull, HullOptimizedRepresentationV3 newHull, - HashSet<LocalAbstractObject> coveredObjects) { - for (LocalAbstractObject hullPoint : hull.getHull()) { - if (coveredObjects.contains(hullPoint)) - continue; - newHull.getHull().add(hullPoint); - } + public boolean insert(InsertOperation operation) { + LocalAbstractObject object = operation.getInsertedObject(); + // TODO: + operation.endOperation(); + return true; } - private void updateHullDistances(int baseHullIndex, List<Integer> otherHullIndices) { - if (otherHullIndices.size() == 0) return; - - for (int i = validHullIndices.nextSetBit(0); i >= 0; i = validHullIndices.nextSetBit(i + 1)) { - float minDistance = hullMinDistances[baseHullIndex][i]; - - for (int index : otherHullIndices) - minDistance = Math.min(minDistance, hullMinDistances[index][i]); - - hullMinDistances[baseHullIndex][i] = minDistance; - hullMinDistances[i][baseHullIndex] = minDistance; - } + @Override + public String toString() { + return "MHTree{" + + "leafCapacity=" + leafCapacity + + ", numberOfChildren=" + numberOfChildren + + ", root=" + root + + '}'; } } diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 49fd7513c281907fd1dfd19a1e4fda0b78dc83c3..ee08758791690529cbbf04e694b9e86a207c77ca 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -1,10 +1,12 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import messif.objects.LocalAbstractObject; import java.io.Serializable; -import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; public abstract class Node implements Serializable { @@ -14,23 +16,61 @@ public abstract class Node implements Serializable { private static final long serialVersionUID = 420L; private Node parent; - private List<Node> children; private HullOptimizedRepresentationV3 hull; Node(HullOptimizedRepresentationV3 hull) { - children = new ArrayList<>(); + hull.build(); this.hull = hull; } + Node(List<LocalAbstractObject> objects) { + this(new HullOptimizedRepresentationV3(objects)); + } + + Node(Stream<LocalAbstractObject> objects) { + this(new HullOptimizedRepresentationV3(objects.collect(Collectors.toList()))); + } + + Node(AbstractRepresentation.PrecomputedDistances distances) { + this(new HullOptimizedRepresentationV3(distances)); + } + + public static InternalNode createParent(List<Node> nodes) { + return new InternalNode(nodes.stream().flatMap(node -> node.getObjects().stream())); + } + + public float getDistance(LocalAbstractObject object) { + return hull.getHull().stream() + .map(h -> h.getDistance(object)) + .reduce(Float.MAX_VALUE, Math::min); + } + + public boolean isCovered(LocalAbstractObject object) { + return hull.isExternalCovered(object); + } + + public boolean isLeaf() { + return (this instanceof LeafNode); + } + + public boolean contains(LocalAbstractObject object) { + return hull.getObjects().contains(object); + } + public void setParent(Node parent) { this.parent = parent; } - public void addChildren(Node child) { - children.add(child); + public List<LocalAbstractObject> getHullObjects() { + return hull.getHull(); + } + + public List<LocalAbstractObject> getObjects() { + return hull.getObjects(); } - public void addChildren(Stream<Node> children) { - children.forEach(this.children::add); + @Override + public String toString() { + return "Node{hull=" + hull + '}'; } } diff --git a/src/mhtree/ObjectToNodeDistanceRank.java b/src/mhtree/ObjectToNodeDistanceRank.java new file mode 100644 index 0000000000000000000000000000000000000000..ea83379c7ab14fd9fc3537ee0b0ac6f53838541d --- /dev/null +++ b/src/mhtree/ObjectToNodeDistanceRank.java @@ -0,0 +1,21 @@ +package mhtree; + +import messif.objects.LocalAbstractObject; + +public class ObjectToNodeDistanceRank implements Comparable<ObjectToNodeDistanceRank> { + private final Node node; + private final LocalAbstractObject object; + + public ObjectToNodeDistanceRank(Node node, LocalAbstractObject object) { + this.node = node; + this.object = object; + } + + public int compareTo(ObjectToNodeDistanceRank rank) { + return Float.compare(node.getDistance(object), rank.node.getDistance(object)); + } + + public Node getNode() { + return node; + } +}