From a6331c851d8a5c1518105eabab7c8ad9307c7f5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 30 Dec 2020 18:21:31 +0100 Subject: [PATCH] ADD: MH-Tree insert method --- src/mhtree/BuildTree.java | 4 +--- src/mhtree/BuildTreeApp.java | 19 ++++++++---------- src/mhtree/InternalNode.java | 12 ++++++++++++ src/mhtree/LeafNode.java | 6 +----- src/mhtree/MHTree.java | 37 ++++++++++++++++++++---------------- src/mhtree/Node.java | 19 ++++++++++++------ 6 files changed, 56 insertions(+), 41 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index fbfd69f..c0f88c1 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -3,7 +3,6 @@ package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.buckets.BucketDispatcher; import messif.buckets.BucketStorageException; -import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; import java.util.*; @@ -123,8 +122,7 @@ class BuildTree { int closestHullObjectIndex = hullObjectIndices[getNNIndex(objectIndex, hullObjectIndices)]; int nodeIndex = findCorrespondingHullIndex(objectDistances.getObject(closestHullObjectIndex)); - LeafNode node = (LeafNode) nodes[nodeIndex]; - node.addObject(objectDistances.getObject(objectIndex)); + nodes[nodeIndex].addObject(objectDistances.getObject(objectIndex)); } private int getNNIndex(int centerIndex, int[] dataIndices) { diff --git a/src/mhtree/BuildTreeApp.java b/src/mhtree/BuildTreeApp.java index 98c4281..62780a6 100644 --- a/src/mhtree/BuildTreeApp.java +++ b/src/mhtree/BuildTreeApp.java @@ -2,11 +2,11 @@ package mhtree; import messif.buckets.BucketStorageException; import messif.buckets.impl.MemoryStorageBucket; -import messif.objects.AbstractObject; import messif.objects.LocalAbstractObject; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; import messif.objects.util.AbstractObjectList; import messif.objects.util.AbstractStreamObjectIterator; +import messif.objects.util.RankedAbstractObject; import messif.objects.util.StreamGenericAbstractObjectIterator; import messif.operations.query.ApproxKNNQueryOperation; @@ -29,25 +29,22 @@ public class BuildTreeApp { AbstractObjectList<LocalAbstractObject> objects = new AbstractObjectList<>(iter); MHTree tree = new MHTree(objects, 10, 5, MemoryStorageBucket.class, null); + int k = 15; for (LocalAbstractObject object : objects) { - - ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, 1); + ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, k); tree.approxKNN(op); - Iterator<AbstractObject> answerObjects = op.getAnswerObjects(); - if (op.getAnswerCount() == 0) throw new RuntimeException("no result"); - if (op.getAnswerCount() != 1) throw new RuntimeException("too many results"); + if (op.getAnswerCount() != k) throw new RuntimeException("too many results"); - while (answerObjects.hasNext()) { - AbstractObject answerObject = answerObjects.next(); + for (Iterator<RankedAbstractObject> answerObjects = op.getAnswer(); answerObjects.hasNext(); ) { + RankedAbstractObject rankedAnswer = answerObjects.next(); - if (!answerObject.getLocatorURI().equals(object.getLocatorURI())) - throw new RuntimeException("returned different object"); + System.out.println(rankedAnswer.getObject().getLocatorURI()); + System.out.println(rankedAnswer.getDistance()); } } - } } catch (IOException | BucketStorageException ex) { Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 762bee3..b4c70d7 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -4,7 +4,9 @@ import messif.objects.LocalAbstractObject; import java.io.Serializable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; public class InternalNode extends Node implements Serializable { @@ -28,4 +30,14 @@ public class InternalNode extends Node implements Serializable { public List<Node> getChildren() { return children; } + + public Node getNearestChild(LocalAbstractObject object) { + List<Float> distances = children.stream().map(child -> child.getDistance(object)).collect(Collectors.toList()); + + return children.get(distances.indexOf(Collections.min(distances))); + } + + public void addObject(LocalAbstractObject object) { + rebuildHull(object); + } } diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index f3dd645..5e47197 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -1,13 +1,11 @@ package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; -import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; import java.io.Serializable; -import java.util.ArrayList; import java.util.List; public class LeafNode extends Node implements Serializable { @@ -34,8 +32,6 @@ public class LeafNode extends Node implements Serializable { if (isCovered(object)) return; - List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); - objects.add(object); - hull = new HullOptimizedRepresentationV3(objects); + rebuildHull(object); } } diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 094410a..213ecf1 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -2,6 +2,7 @@ package mhtree; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; +import messif.buckets.BucketErrorCode; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; @@ -38,14 +39,13 @@ public class MHTree extends Algorithm implements Serializable { this.leafCapacity = leafCapacity; this.numberOfChildren = numberOfChildren; - bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); + bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); root = new BuildTree(objects, leafCapacity, numberOfChildren, bucketDispatcher).getRoot(); } public void approxKNN(ApproxKNNQueryOperation operation) { LocalAbstractObject object = operation.getQueryObject(); - // int k = operation.getK(); PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>(); queue.add(new ObjectToNodeDistanceRank(root, object)); @@ -53,28 +53,33 @@ public class MHTree extends Algorithm implements Serializable { while (!queue.isEmpty()) { Node currentNode = queue.poll().getNode(); - if (currentNode.isLeaf()) { + if (currentNode.isLeaf()) for (LocalAbstractObject obj : currentNode.getObjects()) - if (obj.getLocatorURI().equals(object.getLocatorURI())) - operation.addToAnswer(obj); - } + operation.addToAnswer(obj); - if (!currentNode.isLeaf()) { - InternalNode node = (InternalNode) currentNode; - - for (Node child : node.getChildren()) - if (child.isCovered(object)) - queue.add(new ObjectToNodeDistanceRank(child, object)); - } + if (!currentNode.isLeaf()) + for (Node child : ((InternalNode) currentNode).getChildren()) + queue.add(new ObjectToNodeDistanceRank(child, object)); } operation.endOperation(); } - public boolean insert(InsertOperation operation) { + public boolean insert(InsertOperation operation) throws BucketStorageException { LocalAbstractObject object = operation.getInsertedObject(); - // TODO: - operation.endOperation(); + + Node currentNode = root; + + while (!currentNode.isLeaf()) { + if (!currentNode.isCovered(object)) + currentNode.addObject(object); + + currentNode = ((InternalNode) currentNode).getNearestChild(object); + } + + currentNode.addObject(object); + + operation.endOperation(BucketErrorCode.OBJECT_INSERTED); return true; } diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 518023e..333055f 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -1,10 +1,11 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; +import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -16,8 +17,8 @@ public abstract class Node implements Serializable { */ private static final long serialVersionUID = 420L; - private Node parent; protected HullOptimizedRepresentationV3 hull; + private Node parent; Node(HullOptimizedRepresentationV3 hull) { hull.build(); @@ -32,10 +33,6 @@ public abstract class Node implements Serializable { this(new HullOptimizedRepresentationV3(objects.collect(Collectors.toList()))); } - Node(AbstractRepresentation.PrecomputedDistances distances) { - this(new HullOptimizedRepresentationV3(distances)); - } - public static InternalNode createParent(List<Node> nodes) { return new InternalNode(nodes.stream().flatMap(node -> node.getObjects().stream())); } @@ -74,4 +71,14 @@ public abstract class Node implements Serializable { public String toString() { return "Node{hull=" + hull + '}'; } + + public abstract void addObject(LocalAbstractObject object) throws BucketStorageException; + + protected void rebuildHull(LocalAbstractObject object) { + List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); + objects.add(object); + + hull = new HullOptimizedRepresentationV3(objects); + hull.build(); + } } -- GitLab