diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index f4c21703cdfaf1c4037d14102f419225a0db8c2b..e4a217fba55a23d3bfbd4f27b1261b2eedc01852 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -87,7 +87,7 @@ public class MHTree extends Algorithm implements Serializable { } private boolean isPrunable(Node child, LocalAbstractObject queryObject, ApproxKNNQueryOperation operation, double coefficient) { - return operation.getAnswerDistance() * coefficient < child.getDistanceToNearestHullObject(queryObject); + return operation.getAnswerDistance() * coefficient < child.getNearestDistance(queryObject); } public void insert(InsertOperation operation) throws BucketStorageException { diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 731b6887c778a1ade08e752590a627681be9335c..a8065c35971d74af980a50d3b31f75f9dd5c55c3 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -6,8 +6,10 @@ import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import java.io.Serializable; -import java.util.*; -import java.util.function.Function; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; import java.util.stream.Collectors; public abstract class Node implements Serializable { @@ -16,37 +18,36 @@ public abstract class Node implements Serializable { * Serialization ID */ private static final long serialVersionUID = 420L; - private final DistanceMeasure distanceMeasure; + private final ObjectToNodeDistance objectToNodeDistance; + protected final InsertType insertType; protected HullOptimizedRepresentationV3 hull; protected Node parent; - Node(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { + Node(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { this.hull = new HullOptimizedRepresentationV3(distances); this.hull.build(); this.insertType = insertType; - this.distanceMeasure = distanceMeasure; + this.objectToNodeDistance = objectToNodeDistance; } - public static InternalNode createParent(Set<Node> nodes, PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { + public static InternalNode createParent(Set<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { List<LocalAbstractObject> objects = nodes.stream() .map(Node::getObjects) .flatMap(Collection::stream) .collect(Collectors.toList()); - return new InternalNode(distances.getSubset(objects), insertType, distanceMeasure); + return new InternalNode(distances.getSubset(objects), insertType, objectToNodeDistance); } public double getDistance(LocalAbstractObject object) { - switch (distanceMeasure) { + switch (objectToNodeDistance) { case FURTHEST_HULL_OBJECT: - return getDistanceToFurthestHullObject(object); - case SUM_OF_DISTANCES_TO_HULL_OBJECTS: - return getSumOfDistancesToHullObjects(object); - case MEDOID: - return getDistanceToMedoid(object); + return getFurthestDistance(object); + case AVERAGE_DISTANCE: + return getAverageDistance(object); default: - return getDistanceToNearestHullObject(object); + return getNearestDistance(object); } } @@ -81,8 +82,6 @@ public abstract class Node implements Serializable { public abstract boolean contains(LocalAbstractObject object); - public abstract List<Integer> getLeafNodesObjectCounts(); - public abstract int getHeight(); public abstract Set<Node> getNodesOnLevel(int level); @@ -106,38 +105,23 @@ public abstract class Node implements Serializable { rebuildHull(object); } - private static <T extends Collection<LocalAbstractObject>> double sumOfDistanceToObject(LocalAbstractObject object, T objects) { - return objects.stream() - .mapToDouble(object::getDistance) - .sum(); - } - - private double getDistanceToNearestHullObject(LocalAbstractObject object) { + public double getNearestDistance(LocalAbstractObject object) { return hull.getHull().stream() .mapToDouble(object::getDistance) .min() .orElse(Double.MAX_VALUE); } - private double getDistanceToFurthestHullObject(LocalAbstractObject object) { + private double getFurthestDistance(LocalAbstractObject object) { return hull.getHull().stream() .mapToDouble(object::getDistance) .max() .orElse(Double.MIN_VALUE); } - private double getSumOfDistancesToHullObjects(LocalAbstractObject object) { - return sumOfDistanceToObject(object, hull.getHull()); - } - - private double getDistanceToMedoid(LocalAbstractObject object) { - Set<LocalAbstractObject> objects = getObjects(); - - Map<LocalAbstractObject, Double> objectToObjectDistance = objects.stream() - .collect(Collectors.toMap(Function.identity(), o -> sumOfDistanceToObject(o, objects))); - - LocalAbstractObject medoid = Collections.min(objectToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey(); - - return medoid.getDistance(object); + private double getAverageDistance(LocalAbstractObject object) { + return hull.getHull().stream() + .mapToDouble(object::getDistance) + .sum() / hull.getHull().size(); } } diff --git a/src/mhtree/ObjectToNodeDistance.java b/src/mhtree/ObjectToNodeDistance.java index fc1d0704d1fbd4b195c95d046113da88206ec996..f9430f273544340b388207b3d4a2258dd231d556 100644 --- a/src/mhtree/ObjectToNodeDistance.java +++ b/src/mhtree/ObjectToNodeDistance.java @@ -1,7 +1,7 @@ package mhtree; -public enum DistanceMeasure { +public enum ObjectToNodeDistance { NEAREST_HULL_OBJECT, FURTHEST_HULL_OBJECT, - AVERAGE_DISTANCE + AVERAGE_DISTANCE, }