From 10bdf184913f6343adc72d88a4455c40e24b63b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 10 Feb 2021 13:58:51 +0100 Subject: [PATCH] ADD: support for different types of distance measurements in Node --- src/mhtree/Node.java | 60 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 6f042bf..05344b4 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -6,9 +6,8 @@ import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; +import java.util.*; +import java.util.function.Function; import java.util.stream.Collectors; public abstract class Node implements Serializable { @@ -18,28 +17,37 @@ public abstract class Node implements Serializable { */ private static final long serialVersionUID = 420L; protected final InsertType insertType; + private final DistanceMeasure distanceMeasure; protected HullOptimizedRepresentationV3 hull; protected Node parent; - Node(PrecomputedDistances distances, InsertType insertType) { + Node(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { this.hull = new HullOptimizedRepresentationV3(distances); this.hull.build(); this.insertType = insertType; + this.distanceMeasure = distanceMeasure; } - public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType) { + public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { List<LocalAbstractObject> objects = nodes.stream() .map(Node::getObjects) .flatMap(Collection::stream) .collect(Collectors.toList()); - return new InternalNode(distances.getSubset(objects), insertType); + return new InternalNode(distances.getSubset(objects), insertType, distanceMeasure); } - public float getDistance(LocalAbstractObject object) { - return hull.getHull().stream() - .map(h -> h.getDistance(object)) - .reduce(Float.MAX_VALUE, Math::min); + public double getDistance(LocalAbstractObject object) { + switch (distanceMeasure) { + case FURTHEST_HULL_OBJECT: + return getDistanceToFurthestHullObject(object); + case SUM_OF_DISTANCES_TO_HULL_OBJECTS: + return getSumOfDistancesToHullObjects(object); + case MEDOID: + return getDistanceToMedoid(object); + default: + return getDistanceToNearestHullObject(object); + } } public boolean isCovered(LocalAbstractObject object) { @@ -97,4 +105,36 @@ public abstract class Node implements Serializable { rebuildHull(object); } + + private double getDistanceToNearestHullObject(LocalAbstractObject object) { + return hull.getHull().stream() + .mapToDouble(object::getDistance) + .min() + .orElse(Double.MAX_VALUE); + } + + private double getDistanceToFurthestHullObject(LocalAbstractObject object) { + return hull.getHull().stream() + .mapToDouble(object::getDistance) + .max() + .orElse(Double.MIN_VALUE); + } + + private double getSumOfDistancesToHullObjects(LocalAbstractObject object) { + return hull.getHull().stream() + .mapToDouble(object::getDistance) + .sum(); + } + + private double getDistanceToMedoid(LocalAbstractObject object) { + List<LocalAbstractObject> objects = getObjects(); + Function<LocalAbstractObject, Double> sumOfDistanceToObject = obj -> objects.stream().mapToDouble(obj::getDistance).sum(); + + Map<LocalAbstractObject, Double> objectToObjectDistance = objects.stream() + .collect(Collectors.toMap(Function.identity(), sumOfDistanceToObject)); + + LocalAbstractObject medoid = Collections.min(objectToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey(); + + return medoid.getDistance(object); + } } -- GitLab