From 239f0ad6bb5b739931ac56e7deaf1f82d1bf9461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Mon, 14 Dec 2020 10:50:09 +0100 Subject: [PATCH] ADD: standalone MH-Tree based on MESSIF features --- mh-tree.iml | 3 +- src/mhtree/BuildTree.java | 10 +- src/mhtree/HullNode.java | 47 ---- src/mhtree/HullTree.java | 14 -- src/mhtree/InternalNode.java | 16 ++ src/mhtree/LeafNode.java | 16 ++ src/mhtree/MHTree.java | 434 +++++++++++++++++++++++++++++++++++ src/mhtree/Node.java | 36 +++ 8 files changed, 506 insertions(+), 70 deletions(-) delete mode 100644 src/mhtree/HullNode.java delete mode 100644 src/mhtree/HullTree.java create mode 100644 src/mhtree/InternalNode.java create mode 100644 src/mhtree/LeafNode.java create mode 100644 src/mhtree/MHTree.java create mode 100644 src/mhtree/Node.java diff --git a/mh-tree.iml b/mh-tree.iml index 32fa139..02c6c02 100644 --- a/mh-tree.iml +++ b/mh-tree.iml @@ -7,6 +7,7 @@ </content> <orderEntry type="inheritedJdk" /> <orderEntry type="sourceFolder" forTests="false" /> - <orderEntry type="library" name="MESSIF" level="project" /> + <orderEntry type="library" name="messif" level="project" /> + <orderEntry type="library" name="similarityoperators" level="project" /> </component> </module> \ No newline at end of file diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index 57dbf0f..16fa989 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -1,6 +1,5 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.HullIncrementalRepresentationV2; import messif.objects.LocalAbstractObject; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; import messif.objects.util.AbstractObjectList; @@ -21,14 +20,9 @@ public class BuildTree { System.out.println("Processing argument " + arg); AbstractStreamObjectIterator<ObjectFloatVectorNeuralNetworkL2> iter = new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, arg); - AbstractObjectList<LocalAbstractObject> lst = new AbstractObjectList<>(iter); + AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter); - HullIncrementalRepresentationV2 hull = (HullIncrementalRepresentationV2) new HullIncrementalRepresentationV2(lst, true, 10) - .mergeHullsIntoGroupsOfSize(5) - .build(); - - System.out.println("MH-Tree:"); - System.out.println(hull.getMHTree()); + MHTree tree = new MHTree(objects, 3, 7); } } catch (IOException ex) { Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); diff --git a/src/mhtree/HullNode.java b/src/mhtree/HullNode.java deleted file mode 100644 index af5170c..0000000 --- a/src/mhtree/HullNode.java +++ /dev/null @@ -1,47 +0,0 @@ -package mhtree; - -import cz.muni.fi.disa.similarityoperators.cover.HullRepresentation; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Stream; - -public class HullNode { - private final HullRepresentation hull; - private final List<HullNode> children; - private HullNode parent; - - public HullNode(HullRepresentation hull) { - children = new ArrayList<>(); - this.hull = hull; - } - - public void addChildren(HullNode child) { - children.add(child); - } - - public void addChildren(Stream<HullNode> children) { - children.forEach(this.children::add); - } - - public void addChildren(HullNode first, HullNode second) { - children.add(first); - children.add(second); - } - - public void setParent(HullNode parent) { - this.parent = parent; - } - - public HullRepresentation getHull() { - return hull; - } - - @Override - public String toString() { - return "HullNode{" + - "hullSize=" + hull.getObjects().size() + - ", children=" + children + - '}'; - } -} diff --git a/src/mhtree/HullTree.java b/src/mhtree/HullTree.java deleted file mode 100644 index f9476cd..0000000 --- a/src/mhtree/HullTree.java +++ /dev/null @@ -1,14 +0,0 @@ -package mhtree; - -public class HullTree { - private final HullNode root; - - public HullTree(HullNode root) { - this.root = root; - } - - @Override - public String toString() { - return "HullTree{root=" + root + '}'; - } -} diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java new file mode 100644 index 0000000..1f08983 --- /dev/null +++ b/src/mhtree/InternalNode.java @@ -0,0 +1,16 @@ +package mhtree; + +import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; + +import java.io.Serializable; + +public class InternalNode extends Node implements Serializable { + /** + * Serialization ID + */ + private static final long serialVersionUID = 2L; + + InternalNode(HullOptimizedRepresentationV3 hull) { + super(hull); + } +} diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java new file mode 100644 index 0000000..041277f --- /dev/null +++ b/src/mhtree/LeafNode.java @@ -0,0 +1,16 @@ +package mhtree; + +import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; + +import java.io.Serializable; + +public class LeafNode extends Node implements Serializable { + /** + * Serialization ID + */ + private static final long serialVersionUID = 1L; + + LeafNode(HullOptimizedRepresentationV3 hull) { + super(hull); + } +} diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java new file mode 100644 index 0000000..a034f3f --- /dev/null +++ b/src/mhtree/MHTree.java @@ -0,0 +1,434 @@ +package mhtree; + +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import cz.muni.fi.disa.similarityoperators.cover.HullRepresentation; +import messif.algorithms.Algorithm; +import messif.objects.LocalAbstractObject; +import messif.operations.data.InsertOperation; +import messif.operations.query.ApproxKNNQueryOperation; + +import java.io.Serializable; +import java.util.*; +import java.util.stream.Collectors; + +public class MHTree extends Algorithm implements Serializable { + + /** + * Serialization ID + */ + private static final long serialVersionUID = 42L; + + private final Node[] nodes; + private final HullOptimizedRepresentationV3[] hulls; + private final BitSet validHullIndices; + + private final AbstractRepresentation.PrecomputedDistances shared; + private final float[][] hullMinDistances; + + private final int initialHullsSize; + private final int hullsMergedIntoGroupsOf; + + private Node root; + + @AlgorithmConstructor(description = "MH-Tree", arguments = { + "list of objects", + "number of objects in one hull in leaf node", + "how many hulls should be merge into one", + }) + MHTree(List<LocalAbstractObject> objects, int initialHullsSize, int mergeHullsIntoGroupsOf) { + super("MH-Tree"); + + hullsMergedIntoGroupsOf = mergeHullsIntoGroupsOf; + this.initialHullsSize = initialHullsSize; + + hulls = new HullOptimizedRepresentationV3[objects.size() / initialHullsSize]; + nodes = new Node[objects.size() / initialHullsSize]; + + validHullIndices = new BitSet(hulls.length); + validHullIndices.set(0, hulls.length); + + shared = new AbstractRepresentation.PrecomputedDistances(objects); + hullMinDistances = new float[hulls.length][hulls.length]; + + buildTree(); + } + + public void approxKNN(ApproxKNNQueryOperation operation) { + LocalAbstractObject object = operation.getQueryObject(); + int k = operation.getK(); + // TODO: + operation.endOperation(); + } + + + public boolean insert(InsertOperation operation) { + LocalAbstractObject object = operation.getInsertedObject(); + // TODO: + operation.endOperation(); + return true; + } + + + private void buildTree() { + if (!initHullPoints()) return; + + precomputeHullDistances(); + + while (validHullIndices.cardinality() != 1) { + BitSet notProcessedIndices = (BitSet) validHullIndices.clone(); + + while (!notProcessedIndices.isEmpty()) { + if (notProcessedIndices.cardinality() < hullsMergedIntoGroupsOf) { + List<Integer> restOfTheIndices = new ArrayList<>(); + + for (int j = notProcessedIndices.nextSetBit(0); j >= 0; j = notProcessedIndices.nextSetBit(j + 1)) + restOfTheIndices.add(j); + + int baseIndex = restOfTheIndices.remove(0); + mergeHulls(baseIndex, restOfTheIndices); + + break; + } + + int furthestHullIndex = getFurthestIdx(hullMinDistances, notProcessedIndices); + notProcessedIndices.clear(furthestHullIndex); + + List<Integer> nnIndices = new ArrayList<>(); + + for (int i = 0; i < hullsMergedIntoGroupsOf - 1; i++) { + int index = shared.minDistInArrayExceptIdx(hullMinDistances[furthestHullIndex], notProcessedIndices, furthestHullIndex); + notProcessedIndices.clear(index); + nnIndices.add(index); + } + + mergeHulls(furthestHullIndex, nnIndices); + } + } + + root = nodes[validHullIndices.nextSetBit(0)]; + } + + private boolean initHullPoints() { + if (shared.getObjectCount() < initialHullsSize) { + root = new LeafNode(new HullOptimizedRepresentationV3(shared)); + + return false; + } + + BitSet notProcessedIndices = new BitSet(shared.getObjectCount()); + notProcessedIndices.set(0, shared.getObjectCount()); + int hullIndex = 0; + + while (!notProcessedIndices.isEmpty()) { + if (notProcessedIndices.cardinality() < initialHullsSize) { + for (int i = notProcessedIndices.nextSetBit(0); i >= 0; i = notProcessedIndices.nextSetBit(i + 1)) + addObjectToClosestHull(i); + + return true; + } + + List<LocalAbstractObject> hullPoints = new ArrayList<>(); + + // Select a base point + int furthestIndex = getFurthestIdx(notProcessedIndices); + notProcessedIndices.clear(furthestIndex); + hullPoints.add(shared.getObject(furthestIndex)); + + // Select the rest of the `initialHullsSize` number of points + findClosestIndices(furthestIndex, initialHullsSize - 1, notProcessedIndices) + .forEach(i -> hullPoints.add(shared.getObject(i))); + + HullOptimizedRepresentationV3 newHull = new HullOptimizedRepresentationV3(hullPoints); + newHull.getHull().addAll(hullPoints); + + hulls[hullIndex] = newHull; + nodes[hullIndex] = new LeafNode(newHull); + + hullIndex++; + } + + return true; + } + + private void addObjectToClosestHull(int index) { + int[] hullIndices = getAllCurrentHullIndexes(); + + // returns closest index within the data + LocalAbstractObject object = shared.getObject(index); + int closestIndex = hullIndices[getNNIndex(index, hullIndices)]; + HullRepresentation closestHull = hulls[findCorrespondingHullIdx(shared.getObject(closestIndex))]; + + closestHull.getObjects().add(object); // add just to data + + if (isCoveredByHull(index, shared.getIndexes(closestHull.getHull()))) { + return; // don't add to hulls + } + + closestHull.getHull().add(object); //otherwise add to hull points + } + + private int getNNIndex(int centerIdx, int[] dataIdxs) { + return getKNNIndexesShared(1, dataIdxs, shared.getDistances(centerIdx))[0]; + } + + private int[] getKNNIndexesShared(int k, int[] dataIdxs, float[] dists) { + AbstractRepresentation.Ranked[] ranked = getRankedIdxs(dataIdxs, dists); + int length = Math.min(k, ranked.length); + int[] result = new int[length]; + + for (int i = 0; i < length; ++i) { + result[i] = ranked[i].index; + } + + return result; + } + + private AbstractRepresentation.Ranked[] getRankedIdxs(int[] dataIdxs, float[] dists) { + AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIdxs.length]; + + for (int i = 0; i < dataIdxs.length; ++i) { + int objIdx = dataIdxs[i]; + ranked[i] = new AbstractRepresentation.Ranked(dists[objIdx], i); + } + + Arrays.sort(ranked); + return ranked; + } + + private int[] getAllCurrentHullIndexes() { + return shared.getIndexes(Arrays + .stream(hulls) + .map(hull -> hull.getHull()) + .flatMap(Collection::stream) + .collect(Collectors.toList())); + } + + private int findCorrespondingHullIdx(LocalAbstractObject object) { + for (int i = 0; i < hulls.length; i++) + if (hulls[i].getHull().contains(object)) + return i; + + return -1; // some hull has to be closest, does not make sense otherwise + } + + public boolean isCoveredByHull(int objIdx, int[] hullIdxs) { + return this.isCoveredByDistSum(objIdx, hullIdxs[getNNIndex(objIdx, hullIdxs)], hullIdxs); + } + + private boolean isCoveredByDistSum(int objIdx, int nnIdx, int[] nnIdxs) { + float[] distsObj = shared.getDistances(objIdx); + float sumObj = shared.sumArray(distsObj, nnIdxs) - distsObj[nnIdx]; + float[] distsNN = shared.getDistances(nnIdx); + float sumNN = shared.sumArray(distsNN, nnIdxs); + return sumObj <= sumNN; + } + + private int getFurthestIdx(BitSet notUsedIdxs) { + return this.getFurthestIdx(shared.getDistances(), notUsedIdxs); + } + + private int getFurthestIdx(float[][] distMatrix, BitSet notUsedIdxs) { + float max = 1.4E-45F; + int maxIdx = notUsedIdxs.nextSetBit(0); + + while (true) { + float[] dists = distMatrix[maxIdx]; + int candidateMaxIdx = this.shared.maxDistInArray(dists, notUsedIdxs); + if (!(dists[candidateMaxIdx] > max)) { + return maxIdx; + } + + max = dists[candidateMaxIdx]; + maxIdx = candidateMaxIdx; + } + } + + private List<Integer> findClosestIndices(int basePointIndex, int numberOfIndices, BitSet notProcessedIndices) { + List<Integer> closestIndices = new ArrayList<>(); + List<Integer> hullPointIndices = new ArrayList<>(); + + hullPointIndices.add(basePointIndex); + + for (int i = 0; i < numberOfIndices; i++) { + List<Integer> candidateIndices = new ArrayList<>(); + List<Float> candidateDistances = new ArrayList<>(); + + for (int index : hullPointIndices) { + int nnIndex = getNNIndex(index, notProcessedIndices); + + float distanceSum = hullPointIndices.stream() + .map(pointIndex -> shared.getDistance(pointIndex, nnIndex)) + .reduce(0f, Float::sum); + + candidateIndices.add(nnIndex); + candidateDistances.add(distanceSum); + } + + int smallestDistanceIndex = candidateDistances.indexOf(Collections.min(candidateDistances)); + int closestPointIndex = candidateIndices.get(smallestDistanceIndex); + + notProcessedIndices.clear(closestPointIndex); + + hullPointIndices.add(closestPointIndex); + closestIndices.add(closestPointIndex); + } + + return closestIndices; + } + + private int getNNIndex(int centerIdx, BitSet idxs) { + return getRankedIdxs(idxs, shared.getDistances(centerIdx))[0].index; + } + + private AbstractRepresentation.Ranked[] getRankedIdxs(BitSet dataIdxs, float[] dists) { + AbstractRepresentation.Ranked[] ranked = new AbstractRepresentation.Ranked[dataIdxs.cardinality()]; + int i = 0; + + for (int index = dataIdxs.nextSetBit(0); index >= 0; index = dataIdxs.nextSetBit(index + 1)) { + ranked[i] = new AbstractRepresentation.Ranked(dists[index], index); + ++i; + } + + Arrays.sort(ranked); + return ranked; + } + + private void precomputeHullDistances() { + for (int i = 0; i < hulls.length; i++) { // pre vsetky hullsIdxs + for (int j = i + 1; j < hulls.length; j++) { // prejdi zvysne + float minDist = Float.MAX_VALUE;// udrzi min vzdialenost medzi vsetkymi komb hullPointov + for (LocalAbstractObject firstPoint : hulls[i].getObjects()) { // prejdi vsetky hull points prveho a najdi najmensiu vzial k hociakemu druhemu + for (LocalAbstractObject secondPoint : hulls[j].getObjects()) { + if (shared.getDistance(firstPoint, secondPoint) < minDist) { + minDist = shared.getDistance(firstPoint, secondPoint); + } + } + } + if (minDist == 0f) + throw new RuntimeException("Zero distance between " + hulls[i].toString() + + " and " + hulls[j].toString()); + + + hullMinDistances[i][j] = minDist; + hullMinDistances[j][i] = minDist; + } + } + } + + private void mergeHulls(int baseHullIndex, List<Integer> otherHullIndices) { + if (otherHullIndices.size() == 0) return; + + // Create list of all the hulls + List<HullOptimizedRepresentationV3> hulls = new ArrayList<>(); + hulls.add(this.hulls[baseHullIndex]); + otherHullIndices.forEach(i -> hulls.add(this.hulls[i])); + + // Create a new node + HullOptimizedRepresentationV3 newHull = mergeHulls(hulls); + InternalNode newNode = new InternalNode(newHull); + + // Set parent for every corresponding input node + nodes[baseHullIndex].setParent(newNode); + otherHullIndices.forEach(i -> nodes[i].setParent(newNode)); + + // Set children for the new node + newNode.addChildren(nodes[baseHullIndex]); + newNode.addChildren(otherHullIndices.stream().map(i -> nodes[i])); + + // Cleanup the nodes array and place the new node onto the first index + nodes[baseHullIndex] = newNode; + otherHullIndices.forEach(i -> nodes[i] = null); + + // Cleanup the merged hulls + this.hulls[baseHullIndex] = newHull; + otherHullIndices.forEach(i -> { + validHullIndices.clear(i); + this.hulls[i] = null; + }); + + updateHullDistances(baseHullIndex, otherHullIndices); + } + + private HullOptimizedRepresentationV3 mergeHulls(List<HullOptimizedRepresentationV3> hulls) { + if (hulls.size() == 0) return null; + if (hulls.size() == 1) return hulls.get(0); + + List<LocalAbstractObject> points = new ArrayList<>(); + + // Minimal number of points in a hull is 3 + int maxCoveredPoints = -3; + int maxIndex = hulls.get(0).getHull().size(); + + for (HullRepresentation hull : hulls) { + points.addAll(hull.getObjects()); + maxCoveredPoints += hull.getHull().size(); + maxIndex = Math.max(maxIndex, hull.getHull().size()); + } + + HashSet<LocalAbstractObject> coveredObjects = new HashSet<>(); + HullOptimizedRepresentationV3 newHull = new HullOptimizedRepresentationV3(points); + + for (int i = 0; i < maxIndex; i++) + for (HullOptimizedRepresentationV3 hull : hulls) + if (i < hull.getHull().size()) + checkCoveredObjects(hull, hulls, hull.getHull().get(i), coveredObjects, maxCoveredPoints); + + hulls.forEach(hull -> addNotCoveredHullPoints(hull, newHull, coveredObjects)); + + return newHull; + } + + private void checkCoveredObjects(HullOptimizedRepresentationV3 firstHull, + List<HullOptimizedRepresentationV3> hulls, + LocalAbstractObject excludedPoint, + HashSet<LocalAbstractObject> coveredPoints, + int maxCoveredPoints) { + List<HullOptimizedRepresentationV3> everyHull = new ArrayList<>(); + everyHull.add(firstHull); + everyHull.addAll(hulls); + + List<LocalAbstractObject> pointsForCoverageCheck = getPointsToComputeCoverage(everyHull, excludedPoint, coveredPoints); + if (isCoveredByHull(shared.getIndex(excludedPoint), this.shared.getIndexes(pointsForCoverageCheck)) && coveredPoints.size() < maxCoveredPoints) { + coveredPoints.add(excludedPoint); + } + } + + private List<LocalAbstractObject> getPointsToComputeCoverage(List<HullOptimizedRepresentationV3> hulls, + LocalAbstractObject excludedPoint, + HashSet<LocalAbstractObject> covered) { + List<LocalAbstractObject> hullPoints = new ArrayList<>(); + hulls.forEach(hull -> hullPoints.addAll(hull.getHull())); + + hullPoints.remove(excludedPoint); + hullPoints.removeAll(covered); + + return hullPoints; + } + + /** + * TODO: Vlasta 2020/06/23 -- should not we try to replace an existing hull object in newHull instead of "blindly" adding the hull object? + */ + private void addNotCoveredHullPoints(HullOptimizedRepresentationV3 hull, HullOptimizedRepresentationV3 newHull, + HashSet<LocalAbstractObject> coveredObjects) { + for (LocalAbstractObject hullPoint : hull.getHull()) { + if (coveredObjects.contains(hullPoint)) + continue; + newHull.getHull().add(hullPoint); + } + } + + private void updateHullDistances(int baseHullIndex, List<Integer> otherHullIndices) { + if (otherHullIndices.size() == 0) return; + + for (int i = validHullIndices.nextSetBit(0); i >= 0; i = validHullIndices.nextSetBit(i + 1)) { + float minDistance = hullMinDistances[baseHullIndex][i]; + + for (int index : otherHullIndices) + minDistance = Math.min(minDistance, hullMinDistances[index][i]); + + hullMinDistances[baseHullIndex][i] = minDistance; + hullMinDistances[i][baseHullIndex] = minDistance; + } + } +} diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java new file mode 100644 index 0000000..49fd751 --- /dev/null +++ b/src/mhtree/Node.java @@ -0,0 +1,36 @@ +package mhtree; + +import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +public abstract class Node implements Serializable { + /** + * Serialization ID + */ + private static final long serialVersionUID = 420L; + + private Node parent; + private List<Node> children; + private HullOptimizedRepresentationV3 hull; + + Node(HullOptimizedRepresentationV3 hull) { + children = new ArrayList<>(); + this.hull = hull; + } + + public void setParent(Node parent) { + this.parent = parent; + } + + public void addChildren(Node child) { + children.add(child); + } + + public void addChildren(Stream<Node> children) { + children.forEach(this.children::add); + } +} -- GitLab