diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index f6f7c975baa17c67d5dc4f0029cbb7ad33a2a215..961126ee5513ae9819e36720be28043dccc0bcb6 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -65,11 +65,11 @@ class BuildTree { while (!notProcessedNodeIndices.isEmpty()) { if (notProcessedNodeIndices.cardinality() < arity) { - List<Integer> nodeIndices = new ArrayList<>(); + Set<Integer> nodeIndices = new HashSet<>(); - notProcessedNodeIndices.stream().forEach(nodeIndices::add); + int mainNodeIndex = notProcessedNodeIndices.nextSetBit(0); + notProcessedNodeIndices.stream().skip(1).forEach(nodeIndices::add); - int mainNodeIndex = nodeIndices.remove(0); mergeNodes(mainNodeIndex, nodeIndices); break; @@ -78,7 +78,7 @@ class BuildTree { int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices); notProcessedNodeIndices.clear(furthestNodeIndex); - List<Integer> nnNodeIndices = new ArrayList<>(); + Set<Integer> nnNodeIndices = new HashSet<>(); for (int i = 0; i < arity - 1; i++) { int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex); @@ -148,7 +148,7 @@ class BuildTree { } } - private List<LocalAbstractObject> findClosestObjects(int baseObjectIndex, int numberOfObjects, BitSet notProcessedIndices) { + private Set<LocalAbstractObject> findClosestObjects(int baseObjectIndex, int numberOfObjects, BitSet notProcessedIndices) { List<Integer> objectIndices = new ArrayList<>(); objectIndices.add(baseObjectIndex); @@ -160,7 +160,7 @@ class BuildTree { int nnIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), notProcessedIndices); float distanceSum = objectIndices.stream() - .map(pointIndex -> objectDistances.getDistance(pointIndex, nnIndex)) + .map(objectIndex -> objectDistances.getDistance(objectIndex, nnIndex)) .reduce(0f, Float::sum); indexToDistance.put(nnIndex, distanceSum); @@ -172,7 +172,7 @@ class BuildTree { objectIndices.add(closestPointIndex); return objectDistances.getObject(closestPointIndex); - }).collect(Collectors.toList()); + }).collect(Collectors.toSet()); } private void precomputeNodeDistances() { @@ -193,13 +193,13 @@ class BuildTree { } } - private void mergeNodes(int mainNodeIndex, List<Integer> nodeIndices) { + private void mergeNodes(int mainNodeIndex, Set<Integer> nodeIndices) { if (nodeIndices.size() == 0) return; - List<Integer> indices = new ArrayList<>(nodeIndices); + Set<Integer> indices = new HashSet<>(nodeIndices); indices.add(mainNodeIndex); - List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList()); + Set<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toSet()); InternalNode parent = Node.createParent(nodes, objectDistances, insertType, distanceMeasure); nodes.forEach(node -> node.setParent(parent)); @@ -215,7 +215,7 @@ class BuildTree { updateNodeDistances(mainNodeIndex, nodeIndices); } - private void updateNodeDistances(int baseNodeIndex, List<Integer> nodeIndices) { + private void updateNodeDistances(int baseNodeIndex, Set<Integer> nodeIndices) { if (nodeIndices.size() == 0) return; validNodeIndices.stream().forEach(i -> { diff --git a/src/mhtree/Histogram.java b/src/mhtree/Histogram.java index 470c6f681109ba1e129afb7ca38724f03c5da136..3f96f9daff19736b283b0b4a9ecc57b6a20ff521 100644 --- a/src/mhtree/Histogram.java +++ b/src/mhtree/Histogram.java @@ -5,6 +5,7 @@ import messif.objects.LocalAbstractObject; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Set; public class Histogram { private final HashMap<Integer, HashMap<Long, Integer>> levelToHullsCoveredToObjectCount = new HashMap<>(); @@ -16,7 +17,7 @@ public class Histogram { Histogram histogram = new Histogram(); for (int level = 1; level <= x.getHeight() + 1; level++) { - List<Node> levelNodes = x.getNodesOnLevel(level); + Set<Node> levelNodes = x.getNodesOnLevel(level); for (LocalAbstractObject object : x.getObjects()) { long coveredObjectCount = levelNodes.stream() diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 4b961df514c9ad55bbc11669590eed719ca46b0a..3a383a1c7dd04bd6d6ea8661b300e530ec879852 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -15,18 +15,18 @@ public class InternalNode extends Node implements Serializable { */ private static final long serialVersionUID = 2L; - private final List<Node> children; + private final Set<Node> children; InternalNode(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { super(distances, insertType, distanceMeasure); - children = new ArrayList<>(); + children = new HashSet<>(); } - public void addChildren(List<Node> children) { + public void addChildren(Set<Node> children) { this.children.addAll(children); } - public List<Node> getChildren() { + public Set<Node> getChildren() { return children; } @@ -41,11 +41,11 @@ public class InternalNode extends Node implements Serializable { addNewObject(object); } - public List<LocalAbstractObject> getObjects() { + public Set<LocalAbstractObject> getObjects() { return children.stream() .map(Node::getObjects) .flatMap(Collection::stream) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); } public boolean contains(LocalAbstractObject object) { @@ -66,12 +66,12 @@ public class InternalNode extends Node implements Serializable { .getMax() + 1; } - public List<Node> getNodesOnLevel(int level) { - if (getLevel() == level) return Collections.singletonList(this); + public Set<Node> getNodesOnLevel(int level) { + if (getLevel() == level) return Collections.singleton(this); return children.stream() .map(child -> child.getNodesOnLevel(level)) .flatMap(Collection::stream) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); } } diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index d720c984167e4a7b53c034f9174afd19409a86a5..d4c92824a2423dca306c1cb400a8bc4299a5160c 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -7,9 +7,7 @@ import messif.objects.LocalAbstractObject; import messif.objects.util.AbstractObjectIterator; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; +import java.util.*; public class LeafNode extends Node implements Serializable { @@ -32,8 +30,8 @@ public class LeafNode extends Node implements Serializable { addNewObject(object); } - public List<LocalAbstractObject> getObjects() { - List<LocalAbstractObject> objects = new ArrayList<>(); + public Set<LocalAbstractObject> getObjects() { + Set<LocalAbstractObject> objects = new HashSet<>(); for (AbstractObjectIterator<LocalAbstractObject> it = bucket.getAllObjects(); it.hasNext(); ) objects.add(it.next()); @@ -57,7 +55,7 @@ public class LeafNode extends Node implements Serializable { return 0; } - public List<Node> getNodesOnLevel(int level) { - return Collections.singletonList(this); + public Set<Node> getNodesOnLevel(int level) { + return Collections.singleton(this); } } diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 05344b492c0ae200655578ffc4955c4c8fa89edc..731b6887c778a1ade08e752590a627681be9335c 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -16,8 +16,8 @@ public abstract class Node implements Serializable { * Serialization ID */ private static final long serialVersionUID = 420L; - protected final InsertType insertType; private final DistanceMeasure distanceMeasure; + protected final InsertType insertType; protected HullOptimizedRepresentationV3 hull; protected Node parent; @@ -28,7 +28,7 @@ public abstract class Node implements Serializable { this.distanceMeasure = distanceMeasure; } - public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { + public static InternalNode createParent(Set<Node> nodes, PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { List<LocalAbstractObject> objects = nodes.stream() .map(Node::getObjects) .flatMap(Collection::stream) @@ -77,7 +77,7 @@ public abstract class Node implements Serializable { public abstract void addObject(LocalAbstractObject object) throws BucketStorageException; - public abstract List<LocalAbstractObject> getObjects(); + public abstract Set<LocalAbstractObject> getObjects(); public abstract boolean contains(LocalAbstractObject object); @@ -85,7 +85,7 @@ public abstract class Node implements Serializable { public abstract int getHeight(); - public abstract List<Node> getNodesOnLevel(int level); + public abstract Set<Node> getNodesOnLevel(int level); protected void rebuildHull(LocalAbstractObject object) { List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); @@ -106,6 +106,12 @@ public abstract class Node implements Serializable { rebuildHull(object); } + private static <T extends Collection<LocalAbstractObject>> double sumOfDistanceToObject(LocalAbstractObject object, T objects) { + return objects.stream() + .mapToDouble(object::getDistance) + .sum(); + } + private double getDistanceToNearestHullObject(LocalAbstractObject object) { return hull.getHull().stream() .mapToDouble(object::getDistance) @@ -121,17 +127,14 @@ public abstract class Node implements Serializable { } private double getSumOfDistancesToHullObjects(LocalAbstractObject object) { - return hull.getHull().stream() - .mapToDouble(object::getDistance) - .sum(); + return sumOfDistanceToObject(object, hull.getHull()); } private double getDistanceToMedoid(LocalAbstractObject object) { - List<LocalAbstractObject> objects = getObjects(); - Function<LocalAbstractObject, Double> sumOfDistanceToObject = obj -> objects.stream().mapToDouble(obj::getDistance).sum(); + Set<LocalAbstractObject> objects = getObjects(); Map<LocalAbstractObject, Double> objectToObjectDistance = objects.stream() - .collect(Collectors.toMap(Function.identity(), sumOfDistanceToObject)); + .collect(Collectors.toMap(Function.identity(), o -> sumOfDistanceToObject(o, objects))); LocalAbstractObject medoid = Collections.min(objectToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey();