From 10bdf184913f6343adc72d88a4455c40e24b63b8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Wed, 10 Feb 2021 13:58:51 +0100
Subject: [PATCH] ADD: support for different types of distance measurements in
 Node

---
 src/mhtree/Node.java | 60 ++++++++++++++++++++++++++++++++++++--------
 1 file changed, 50 insertions(+), 10 deletions(-)

diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java
index 6f042bf..05344b4 100644
--- a/src/mhtree/Node.java
+++ b/src/mhtree/Node.java
@@ -6,9 +6,8 @@ import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
+import java.util.*;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 public abstract class Node implements Serializable {
@@ -18,28 +17,37 @@ public abstract class Node implements Serializable {
      */
     private static final long serialVersionUID = 420L;
     protected final InsertType insertType;
+    private final DistanceMeasure distanceMeasure;
     protected HullOptimizedRepresentationV3 hull;
     protected Node parent;
 
-    Node(PrecomputedDistances distances, InsertType insertType) {
+    Node(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) {
         this.hull = new HullOptimizedRepresentationV3(distances);
         this.hull.build();
         this.insertType = insertType;
+        this.distanceMeasure = distanceMeasure;
     }
 
-    public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType) {
+    public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) {
         List<LocalAbstractObject> objects = nodes.stream()
                 .map(Node::getObjects)
                 .flatMap(Collection::stream)
                 .collect(Collectors.toList());
 
-        return new InternalNode(distances.getSubset(objects), insertType);
+        return new InternalNode(distances.getSubset(objects), insertType, distanceMeasure);
     }
 
-    public float getDistance(LocalAbstractObject object) {
-        return hull.getHull().stream()
-                .map(h -> h.getDistance(object))
-                .reduce(Float.MAX_VALUE, Math::min);
+    public double getDistance(LocalAbstractObject object) {
+        switch (distanceMeasure) {
+            case FURTHEST_HULL_OBJECT:
+                return getDistanceToFurthestHullObject(object);
+            case SUM_OF_DISTANCES_TO_HULL_OBJECTS:
+                return getSumOfDistancesToHullObjects(object);
+            case MEDOID:
+                return getDistanceToMedoid(object);
+            default:
+                return getDistanceToNearestHullObject(object);
+        }
     }
 
     public boolean isCovered(LocalAbstractObject object) {
@@ -97,4 +105,36 @@ public abstract class Node implements Serializable {
 
         rebuildHull(object);
     }
+
+    private double getDistanceToNearestHullObject(LocalAbstractObject object) {
+        return hull.getHull().stream()
+                .mapToDouble(object::getDistance)
+                .min()
+                .orElse(Double.MAX_VALUE);
+    }
+
+    private double getDistanceToFurthestHullObject(LocalAbstractObject object) {
+        return hull.getHull().stream()
+                .mapToDouble(object::getDistance)
+                .max()
+                .orElse(Double.MIN_VALUE);
+    }
+
+    private double getSumOfDistancesToHullObjects(LocalAbstractObject object) {
+        return hull.getHull().stream()
+                .mapToDouble(object::getDistance)
+                .sum();
+    }
+
+    private double getDistanceToMedoid(LocalAbstractObject object) {
+        List<LocalAbstractObject> objects = getObjects();
+        Function<LocalAbstractObject, Double> sumOfDistanceToObject = obj -> objects.stream().mapToDouble(obj::getDistance).sum();
+
+        Map<LocalAbstractObject, Double> objectToObjectDistance = objects.stream()
+                .collect(Collectors.toMap(Function.identity(), sumOfDistanceToObject));
+
+        LocalAbstractObject medoid = Collections.min(objectToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey();
+
+        return medoid.getDistance(object);
+    }
 }
-- 
GitLab