From a6331c851d8a5c1518105eabab7c8ad9307c7f5e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Wed, 30 Dec 2020 18:21:31 +0100
Subject: [PATCH] ADD: MH-Tree insert method

---
 src/mhtree/BuildTree.java    |  4 +---
 src/mhtree/BuildTreeApp.java | 19 ++++++++----------
 src/mhtree/InternalNode.java | 12 ++++++++++++
 src/mhtree/LeafNode.java     |  6 +-----
 src/mhtree/MHTree.java       | 37 ++++++++++++++++++++----------------
 src/mhtree/Node.java         | 19 ++++++++++++------
 6 files changed, 56 insertions(+), 41 deletions(-)

diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java
index fbfd69f..c0f88c1 100644
--- a/src/mhtree/BuildTree.java
+++ b/src/mhtree/BuildTree.java
@@ -3,7 +3,6 @@ package mhtree;
 import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
 import messif.buckets.BucketDispatcher;
 import messif.buckets.BucketStorageException;
-import messif.buckets.LocalBucket;
 import messif.objects.LocalAbstractObject;
 
 import java.util.*;
@@ -123,8 +122,7 @@ class BuildTree {
         int closestHullObjectIndex = hullObjectIndices[getNNIndex(objectIndex, hullObjectIndices)];
         int nodeIndex = findCorrespondingHullIndex(objectDistances.getObject(closestHullObjectIndex));
 
-        LeafNode node = (LeafNode) nodes[nodeIndex];
-        node.addObject(objectDistances.getObject(objectIndex));
+        nodes[nodeIndex].addObject(objectDistances.getObject(objectIndex));
     }
 
     private int getNNIndex(int centerIndex, int[] dataIndices) {
diff --git a/src/mhtree/BuildTreeApp.java b/src/mhtree/BuildTreeApp.java
index 98c4281..62780a6 100644
--- a/src/mhtree/BuildTreeApp.java
+++ b/src/mhtree/BuildTreeApp.java
@@ -2,11 +2,11 @@ package mhtree;
 
 import messif.buckets.BucketStorageException;
 import messif.buckets.impl.MemoryStorageBucket;
-import messif.objects.AbstractObject;
 import messif.objects.LocalAbstractObject;
 import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
 import messif.objects.util.AbstractObjectList;
 import messif.objects.util.AbstractStreamObjectIterator;
+import messif.objects.util.RankedAbstractObject;
 import messif.objects.util.StreamGenericAbstractObjectIterator;
 import messif.operations.query.ApproxKNNQueryOperation;
 
@@ -29,25 +29,22 @@ public class BuildTreeApp {
                 AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter);
 
                 MHTree tree = new MHTree(objects, 10, 5, MemoryStorageBucket.class, null);
+                int k = 15;
 
                 for (LocalAbstractObject object : objects) {
-
-                    ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, 1);
+                    ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, k);
                     tree.approxKNN(op);
 
-                    Iterator<AbstractObject> answerObjects = op.getAnswerObjects();
-
                     if (op.getAnswerCount() == 0) throw new RuntimeException("no result");
-                    if (op.getAnswerCount() != 1) throw new RuntimeException("too many results");
+                    if (op.getAnswerCount() != k) throw new RuntimeException("too many results");
 
-                    while (answerObjects.hasNext()) {
-                        AbstractObject answerObject = answerObjects.next();
+                    for (Iterator<RankedAbstractObject> answerObjects = op.getAnswer(); answerObjects.hasNext(); ) {
+                        RankedAbstractObject rankedAnswer = answerObjects.next();
 
-                        if (!answerObject.getLocatorURI().equals(object.getLocatorURI()))
-                            throw new RuntimeException("returned different object");
+                        System.out.println(rankedAnswer.getObject().getLocatorURI());
+                        System.out.println(rankedAnswer.getDistance());
                     }
                 }
-
             }
         } catch (IOException | BucketStorageException ex) {
             Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex);
diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java
index 762bee3..b4c70d7 100644
--- a/src/mhtree/InternalNode.java
+++ b/src/mhtree/InternalNode.java
@@ -4,7 +4,9 @@ import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 public class InternalNode extends Node implements Serializable {
@@ -28,4 +30,14 @@ public class InternalNode extends Node implements Serializable {
     public List<Node> getChildren() {
         return children;
     }
+
+    public Node getNearestChild(LocalAbstractObject object) {
+        List<Float> distances = children.stream().map(child -> child.getDistance(object)).collect(Collectors.toList());
+
+        return children.get(distances.indexOf(Collections.min(distances)));
+    }
+
+    public void addObject(LocalAbstractObject object) {
+        rebuildHull(object);
+    }
 }
diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java
index f3dd645..5e47197 100644
--- a/src/mhtree/LeafNode.java
+++ b/src/mhtree/LeafNode.java
@@ -1,13 +1,11 @@
 package mhtree;
 
 import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
-import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3;
 import messif.buckets.BucketStorageException;
 import messif.buckets.LocalBucket;
 import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class LeafNode extends Node implements Serializable {
@@ -34,8 +32,6 @@ public class LeafNode extends Node implements Serializable {
 
         if (isCovered(object)) return;
 
-        List<LocalAbstractObject> objects = new ArrayList<>(getObjects());
-        objects.add(object);
-        hull = new HullOptimizedRepresentationV3(objects);
+        rebuildHull(object);
     }
 }
diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java
index 094410a..213ecf1 100644
--- a/src/mhtree/MHTree.java
+++ b/src/mhtree/MHTree.java
@@ -2,6 +2,7 @@ package mhtree;
 
 import messif.algorithms.Algorithm;
 import messif.buckets.BucketDispatcher;
+import messif.buckets.BucketErrorCode;
 import messif.buckets.BucketStorageException;
 import messif.buckets.LocalBucket;
 import messif.objects.LocalAbstractObject;
@@ -38,14 +39,13 @@ public class MHTree extends Algorithm implements Serializable {
 
         this.leafCapacity = leafCapacity;
         this.numberOfChildren = numberOfChildren;
-        bucketDispatcher =  new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams);
+        bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams);
 
         root = new BuildTree(objects, leafCapacity, numberOfChildren, bucketDispatcher).getRoot();
     }
 
     public void approxKNN(ApproxKNNQueryOperation operation) {
         LocalAbstractObject object = operation.getQueryObject();
-        // int k = operation.getK();
 
         PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>();
         queue.add(new ObjectToNodeDistanceRank(root, object));
@@ -53,28 +53,33 @@ public class MHTree extends Algorithm implements Serializable {
         while (!queue.isEmpty()) {
             Node currentNode = queue.poll().getNode();
 
-            if (currentNode.isLeaf()) {
+            if (currentNode.isLeaf())
                 for (LocalAbstractObject obj : currentNode.getObjects())
-                    if (obj.getLocatorURI().equals(object.getLocatorURI()))
-                        operation.addToAnswer(obj);
-            }
+                    operation.addToAnswer(obj);
 
-            if (!currentNode.isLeaf()) {
-                InternalNode node = (InternalNode) currentNode;
-
-                for (Node child : node.getChildren())
-                    if (child.isCovered(object))
-                        queue.add(new ObjectToNodeDistanceRank(child, object));
-            }
+            if (!currentNode.isLeaf())
+                for (Node child : ((InternalNode) currentNode).getChildren())
+                    queue.add(new ObjectToNodeDistanceRank(child, object));
         }
 
         operation.endOperation();
     }
 
-    public boolean insert(InsertOperation operation) {
+    public boolean insert(InsertOperation operation) throws BucketStorageException {
         LocalAbstractObject object = operation.getInsertedObject();
-        // TODO:
-        operation.endOperation();
+
+        Node currentNode = root;
+
+        while (!currentNode.isLeaf()) {
+            if (!currentNode.isCovered(object))
+                currentNode.addObject(object);
+
+            currentNode = ((InternalNode) currentNode).getNearestChild(object);
+        }
+
+        currentNode.addObject(object);
+
+        operation.endOperation(BucketErrorCode.OBJECT_INSERTED);
         return true;
     }
 
diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java
index 518023e..333055f 100644
--- a/src/mhtree/Node.java
+++ b/src/mhtree/Node.java
@@ -1,10 +1,11 @@
 package mhtree;
 
-import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
 import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3;
+import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -16,8 +17,8 @@ public abstract class Node implements Serializable {
      */
     private static final long serialVersionUID = 420L;
 
-    private Node parent;
     protected HullOptimizedRepresentationV3 hull;
+    private Node parent;
 
     Node(HullOptimizedRepresentationV3 hull) {
         hull.build();
@@ -32,10 +33,6 @@ public abstract class Node implements Serializable {
         this(new HullOptimizedRepresentationV3(objects.collect(Collectors.toList())));
     }
 
-    Node(AbstractRepresentation.PrecomputedDistances distances) {
-        this(new HullOptimizedRepresentationV3(distances));
-    }
-
     public static InternalNode createParent(List<Node> nodes) {
         return new InternalNode(nodes.stream().flatMap(node -> node.getObjects().stream()));
     }
@@ -74,4 +71,14 @@ public abstract class Node implements Serializable {
     public String toString() {
         return "Node{hull=" + hull + '}';
     }
+
+    public abstract void addObject(LocalAbstractObject object) throws BucketStorageException;
+
+    protected void rebuildHull(LocalAbstractObject object) {
+        List<LocalAbstractObject> objects = new ArrayList<>(getObjects());
+        objects.add(object);
+
+        hull = new HullOptimizedRepresentationV3(objects);
+        hull.build();
+    }
 }
-- 
GitLab