diff --git a/jars/similarityoperators.jar b/jars/similarityoperators.jar index a5338eb5aeed572b39aea82e811317957c03d858..fde259d77bad5c2e611aa6b4b0c6d12664b2fa37 100644 Binary files a/jars/similarityoperators.jar and b/jars/similarityoperators.jar differ diff --git a/src/mhtree/ApproxState.java b/src/mhtree/ApproxState.java new file mode 100644 index 0000000000000000000000000000000000000000..8c1b08415a582559dad42220a40950e35b5922e6 --- /dev/null +++ b/src/mhtree/ApproxState.java @@ -0,0 +1,60 @@ +package mhtree; + +import messif.operations.Approximate; + +public class ApproxState { + protected int limit; + protected int objectsChecked; + protected int bucketsVisited; + + protected ApproxState() { + } + + protected ApproxState(int limit) { + this.limit = limit; + } + + public static ApproxState create(Approximate limits, MHTree mhTree) { + switch (limits.getLocalSearchType()) { + case PERCENTAGE: + return new ApproxStateObjects(Math.round((float) mhTree.getObjectCount() * (float) limits.getLocalSearchParam() / 100f)); + case ABS_OBJ_COUNT: + return new ApproxStateObjects(limits.getLocalSearchParam()); + case DATA_PARTITIONS: + return new ApproxStateBuckets(limits.getLocalSearchParam()); + default: + return new ApproxState(); + } + } + + public void update(LeafNode node) { + objectsChecked += node.getObjectCount(); + bucketsVisited++; + } + + public boolean stop() { + return false; + } + + private static class ApproxStateBuckets extends ApproxState { + private ApproxStateBuckets(int limit) { + super(limit); + } + + @Override + public boolean stop() { + return bucketsVisited >= limit; + } + } + + private static class ApproxStateObjects extends ApproxState { + private ApproxStateObjects(int limit) { + super(limit); + } + + @Override + public boolean stop() { + return objectsChecked >= limit; + } + } +} diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 22929b507edc84887a0944fb805d3072e03e37b8..8f2fee7befa44bd543ecfdafa42106a78d4b28fe 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -20,7 +20,7 @@ class InternalNode extends Node implements Serializable { private final List<Node> children; - InternalNode(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, List<Node> children) { + protected InternalNode(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, List<Node> children) { super(distances, insertType, objectToNodeDistance); this.children = children; @@ -31,7 +31,7 @@ class InternalNode extends Node implements Serializable { * * @return the list of child nodes */ - List<Node> getChildren() { + protected List<Node> getChildren() { return children; } @@ -41,16 +41,16 @@ class InternalNode extends Node implements Serializable { * @param object object to which the distance is measured * @return the nearest child to the {@code object} */ - Node getNearestChild(LocalAbstractObject object) { + protected Node getNearestChild(LocalAbstractObject object) { Node nearestChild = children.get(0); double minDistance = nearestChild.getDistance(object); - for (Node child : children) { - double distance = child.getDistance(object); + for (int i = 1; i < children.size(); i++) { + double distance = children.get(i).getDistance(object); if (distance < minDistance) { minDistance = distance; - nearestChild = child; + nearestChild = children.get(i); } } @@ -62,7 +62,7 @@ class InternalNode extends Node implements Serializable { * * @param object object to be added */ - void addObject(LocalAbstractObject object) { + protected void addObject(LocalAbstractObject object) { addObjectIntoHull(object); } @@ -71,7 +71,7 @@ class InternalNode extends Node implements Serializable { * * @return the list of objects stored in node's descendants */ - List<LocalAbstractObject> getObjects() { + protected List<LocalAbstractObject> getObjects() { return children .stream() .map(Node::getObjects) @@ -84,7 +84,7 @@ class InternalNode extends Node implements Serializable { * * @return the height of this node */ - int getHeight() { + protected int getHeight() { return children .stream() .mapToInt(Node::getHeight) @@ -97,7 +97,7 @@ class InternalNode extends Node implements Serializable { * * @param nodes list of nodes */ - void gatherNodes(List<Node> nodes) { + protected void gatherNodes(List<Node> nodes) { nodes.add(this); nodes.addAll(children); } @@ -107,7 +107,7 @@ class InternalNode extends Node implements Serializable { * * @param leafNodes list of leaf nodes */ - void gatherLeafNodes(List<LeafNode> leafNodes) { + protected void gatherLeafNodes(List<LeafNode> leafNodes) { children.forEach(child -> child.gatherLeafNodes(leafNodes)); } @@ -116,7 +116,7 @@ class InternalNode extends Node implements Serializable { * * @return the number of internal nodes in this subtree */ - int getInternalNodesCount() { + protected int getInternalNodesCount() { return children .stream() .filter(Node::isInternal) diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index 9d30a59ddf285dd9c8a553b225f925dfab719067..66d41e1bcba9644313bb4b570133df958617e2d1 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -4,7 +4,6 @@ import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.Precompu import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; -import messif.objects.util.AbstractObjectIterator; import java.io.Serializable; import java.util.ArrayList; @@ -25,7 +24,7 @@ class LeafNode extends Node implements Serializable { */ private final LocalBucket bucket; - LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) throws BucketStorageException { + protected LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) throws BucketStorageException { super(distances, insertType, objectToNodeDistance); this.bucket = bucket; @@ -38,10 +37,14 @@ class LeafNode extends Node implements Serializable { * @param object object to be added * @throws BucketStorageException addition of object into bucket exception */ - void addObject(LocalAbstractObject object) throws BucketStorageException { + protected void addObject(LocalAbstractObject object) throws BucketStorageException { + addObject(object, null); + } + + protected void addObject(LocalAbstractObject object, PrecomputedDistances distances) throws BucketStorageException { bucket.addObject(object); - addObjectIntoHull(object); + addObjectIntoHull(object, distances); } /** @@ -49,12 +52,12 @@ class LeafNode extends Node implements Serializable { * * @return a list of objects in node's bucket */ - List<LocalAbstractObject> getObjects() { + protected List<LocalAbstractObject> getObjects() { List<LocalAbstractObject> objects = new ArrayList<>(bucket.getObjectCount()); - for (AbstractObjectIterator<LocalAbstractObject> it = bucket.getAllObjects(); it.hasNext(); ) { - objects.add(it.next()); - } + bucket + .getAllObjects() + .forEachRemaining(objects::add); return objects; } @@ -64,7 +67,7 @@ class LeafNode extends Node implements Serializable { * * @return the number of objects stored in node's bucket */ - int getObjectCount() { + protected int getObjectCount() { return bucket.getObjectCount(); } @@ -73,7 +76,7 @@ class LeafNode extends Node implements Serializable { * * @return the height of this node */ - int getHeight() { + protected int getHeight() { return 0; } @@ -82,7 +85,7 @@ class LeafNode extends Node implements Serializable { * * @param nodes list of nodes */ - void gatherNodes(List<Node> nodes) { + protected void gatherNodes(List<Node> nodes) { nodes.add(this); } @@ -91,7 +94,7 @@ class LeafNode extends Node implements Serializable { * * @param leafNodes list of leaf nodes */ - void gatherLeafNodes(List<LeafNode> leafNodes) { + protected void gatherLeafNodes(List<LeafNode> leafNodes) { leafNodes.add(this); } } diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 80db3a7f994f5cffc52504067fa3d9ff4ad90900..9ec88c077f9aa85b1a097c7c122707c3d9008092 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -3,13 +3,15 @@ package mhtree; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; -import messif.buckets.BucketErrorCode; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.buckets.impl.MemoryStorageBucket; import messif.objects.LocalAbstractObject; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; +import messif.operations.query.KNNQueryOperation; +import messif.statistics.OperationStatistics; +import messif.statistics.StatisticCounter; import java.io.Serializable; import java.util.ArrayList; @@ -27,6 +29,8 @@ public class MHTree extends Algorithm implements Serializable { * Serialization ID */ private static final long serialVersionUID = 42L; + private static final String STAT_NAME_LEAVES_VISITED = "Node.Leaf.Visited"; + private static final StatisticCounter statVisitedLeaves = StatisticCounter.getStatistics(STAT_NAME_LEAVES_VISITED); /** * Minimal number of objects in leaf node's bucket. @@ -37,7 +41,6 @@ public class MHTree extends Algorithm implements Serializable { * Maximal degree of internal node. */ private final int NODE_DEGREE; - private final Node root; private final InsertType insertType; private final ObjectToNodeDistance objectToNodeDistance; @@ -60,21 +63,38 @@ public class MHTree extends Algorithm implements Serializable { } public void approxKNN(ApproxKNNQueryOperation operation) { + approxKNNSearch(operation, ApproxState.create(operation, this)); + } + + private void approxKNNSearch(ApproxKNNQueryOperation operation, ApproxState approxState) { LocalAbstractObject queryObject = operation.getQueryObject(); PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>(); queue.add(new ObjectToNodeDistanceRank(queryObject, root)); + statVisitedLeaves.reset(); + StatisticCounter counter = OperationStatistics.getOpStatisticCounter(STAT_NAME_LEAVES_VISITED); + counter.bindTo(statVisitedLeaves); + while (!queue.isEmpty()) { - Node node = queue.poll().getNode(); + if (approxState.stop()) + break; + + Node node = queue.remove().getNode(); if (operation.isAnswerFull() && isPrunable(node, queryObject, operation)) continue; if (node.isLeaf()) { - for (LocalAbstractObject object : node.getObjects()) - if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) + statVisitedLeaves.add(); + + for (LocalAbstractObject object : node.getObjects()) { + if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) { operation.addToAnswer(object); + } + } + + approxState.update((LeafNode) node); } else { for (Node child : ((InternalNode) node).getChildren()) if (!operation.isAnswerFull() || !isPrunable(child, queryObject, operation)) @@ -89,6 +109,15 @@ public class MHTree extends Algorithm implements Serializable { return operation.getAnswerDistance() < child.getDistanceToNearest(queryObject); } + public void kNN(KNNQueryOperation knnQueryOperation) { + root.getObjects().forEach(knnQueryOperation::addToAnswer); + knnQueryOperation.endOperation(); + } + + public int getObjectCount() { + return bucketDispatcher.getObjectCount(); + } + public void insert(InsertOperation operation) throws BucketStorageException { LocalAbstractObject object = operation.getInsertedObject(); @@ -102,7 +131,7 @@ public class MHTree extends Algorithm implements Serializable { node.addObject(object); - operation.endOperation(BucketErrorCode.OBJECT_INSERTED); + operation.endOperation(); } private List<Node> getNodes() { @@ -315,8 +344,7 @@ public class MHTree extends Algorithm implements Serializable { if (notProcessedObjectIndices.cardinality() < leafCapacity) { for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) { LocalAbstractObject object = objectDistances.getObject(i); - - nodes[getClosestNodeIndex(object)].addObject(object); + ((LeafNode) nodes[getClosestNodeIndex(object)]).addObject(object, objectDistances); } return; @@ -346,7 +374,7 @@ public class MHTree extends Algorithm implements Serializable { int closestNodeIndex = -1; for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) { - double distance = nodes[candidateIndex].getDistance(object); + double distance = nodes[candidateIndex].getDistance(object, objectDistances); if (distance < minDistance) { minDistance = distance; diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index b89d9bc8a414f17e1e2006c9f9e68482bd426230..2bfcdaaf8595a8704a4ab6987ba52de37491a319 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -23,7 +23,7 @@ abstract class Node implements Serializable { private HullOptimizedRepresentationV3 hull; - Node(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { + protected Node(PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { this.hull = new HullOptimizedRepresentationV3(distances); this.hull.build(); @@ -31,12 +31,7 @@ abstract class Node implements Serializable { this.OBJECT_TO_NODE_DISTANCE = objectToNodeDistance; } - @Override - public String toString() { - return "Node{hull=" + hull + '}'; - } - - static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { + protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance) { List<LocalAbstractObject> objects = nodes .stream() .map(Node::getObjects) @@ -46,60 +41,82 @@ abstract class Node implements Serializable { return new InternalNode(distances.getSubset(objects), insertType, objectToNodeDistance, nodes); } - double getDistance(LocalAbstractObject object) { + @Override + public String toString() { + return "Node{hull=" + hull + '}'; + } + + protected double getDistance(LocalAbstractObject object) { return OBJECT_TO_NODE_DISTANCE.getDistance(object, this); } - double getDistanceToNearest(LocalAbstractObject object) { + protected double getDistance(LocalAbstractObject object, PrecomputedDistances distances) { + return OBJECT_TO_NODE_DISTANCE.getDistance(object, this, distances); + } + + protected double getDistanceToNearest(LocalAbstractObject object) { return ObjectToNodeDistance.NEAREST.getDistance(object, this); } - boolean isLeaf() { + protected boolean isLeaf() { return (this instanceof LeafNode); } - boolean isInternal() { + protected boolean isInternal() { return !isLeaf(); } - List<LocalAbstractObject> getHullObjects() { + protected List<LocalAbstractObject> getHullObjects() { return hull.getHull(); } - int getHullObjectCount() { + protected int getHullObjectCount() { return hull.getRepresentativesCount(); } - abstract void addObject(LocalAbstractObject object) throws BucketStorageException; + protected void addObjectIntoHull(LocalAbstractObject object) { + addObjectIntoHull(object, null); + } + + protected void addObjectIntoHull(LocalAbstractObject object, PrecomputedDistances distances) { + if (isCovered(object, distances)) return; + + if (INSERT_TYPE == InsertType.INCREMENTAL) { + hull.addHullObject(object); + return; + } + + rebuildHull(object, distances); + } + + protected abstract void addObject(LocalAbstractObject object) throws BucketStorageException; - abstract List<LocalAbstractObject> getObjects(); + protected abstract List<LocalAbstractObject> getObjects(); - abstract int getHeight(); + protected abstract int getHeight(); - abstract void gatherNodes(List<Node> nodes); + protected abstract void gatherNodes(List<Node> nodes); - abstract void gatherLeafNodes(List<LeafNode> leafNodes); + protected abstract void gatherLeafNodes(List<LeafNode> leafNodes); - private void rebuildHull(LocalAbstractObject object) { + private void rebuildHull(LocalAbstractObject object, PrecomputedDistances distances) { List<LocalAbstractObject> objects = new ArrayList<>(getObjects()); objects.add(object); - hull = new HullOptimizedRepresentationV3(objects); + if (distances == null) { + hull = new HullOptimizedRepresentationV3(objects); + } else { + hull = new HullOptimizedRepresentationV3(distances.getSubset(objects)); + } + hull.build(); } - void addObjectIntoHull(LocalAbstractObject object) { - if (isCovered(object)) return; - - if (INSERT_TYPE == InsertType.INCREMENTAL) { - hull.addHullObject(object); - return; + private boolean isCovered(LocalAbstractObject object, PrecomputedDistances distances) { + if (distances == null) { + return hull.isExternalCovered(object); } - rebuildHull(object); - } - - private boolean isCovered(LocalAbstractObject object) { - return hull.isExternalCovered(object); + return hull.isExternalCovered(object, distances); } } diff --git a/src/mhtree/ObjectToNodeDistance.java b/src/mhtree/ObjectToNodeDistance.java index 410d904685ed50bbfca880a59ead92b4eeeae3ef..8d11bfde979436e9b2d2003cd00004752e08dee3 100644 --- a/src/mhtree/ObjectToNodeDistance.java +++ b/src/mhtree/ObjectToNodeDistance.java @@ -1,21 +1,25 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.objects.LocalAbstractObject; +import java.util.function.ToDoubleFunction; + /** * Specifies possible distance measurements between an object and a node. */ public enum ObjectToNodeDistance { + /** * Average distance between {@code object} and every hull object in {@code node}. */ AVERAGE { @Override - public double getDistance(LocalAbstractObject object, Node node) { + protected double getDistance(Node node, ToDoubleFunction<? super LocalAbstractObject> distanceFunction) { return node .getHullObjects() .stream() - .mapToDouble(object::getDistance) + .mapToDouble(distanceFunction) .sum() / node.getHullObjects().size(); } }, @@ -25,11 +29,11 @@ public enum ObjectToNodeDistance { */ FURTHEST { @Override - public double getDistance(LocalAbstractObject object, Node node) { + protected double getDistance(Node node, ToDoubleFunction<? super LocalAbstractObject> distanceFunction) { return node .getHullObjects() .stream() - .mapToDouble(object::getDistance) + .mapToDouble(distanceFunction) .max() .orElse(Double.MIN_VALUE); } @@ -40,22 +44,47 @@ public enum ObjectToNodeDistance { */ NEAREST { @Override - public double getDistance(LocalAbstractObject object, Node node) { + protected double getDistance(Node node, ToDoubleFunction<? super LocalAbstractObject> distanceFunction) { return node .getHullObjects() .stream() - .mapToDouble(object::getDistance) + .mapToDouble(distanceFunction) .min() .orElse(Double.MAX_VALUE); } + }; /** * Returns the distance between {@code object} and {@code node}. * * @param object an object - * @param node a node + * @param node a node * @return the distance between {@code object} and {@code node} */ - public abstract double getDistance(LocalAbstractObject object, Node node); + public double getDistance(LocalAbstractObject object, Node node) { + return this.getDistance(node, getDistanceFunction(object, null)); + } + + /** + * Returns the distance between {@code object} and {@code node}. + * + * @param object an object + * @param node a node + * @param distances precomputed object distances + * @return the distance between {@code object} and {@code node} + */ + public double getDistance(LocalAbstractObject object, Node node, PrecomputedDistances distances) { + return this.getDistance(node, getDistanceFunction(object, distances)); + } + + protected abstract double getDistance(Node node, ToDoubleFunction<? super LocalAbstractObject> distanceFunction); + + private ToDoubleFunction<? super LocalAbstractObject> getDistanceFunction(LocalAbstractObject object, PrecomputedDistances distances) { + if (distances == null) { + return object::getDistance; + } + + return (o) -> distances.getDistance(object, o); + } } diff --git a/src/mhtree/benchmarking/Tree.java b/src/mhtree/benchmarking/Tree.java deleted file mode 100644 index fb6d637a499db47e7861cc231aa7456480f91803..0000000000000000000000000000000000000000 --- a/src/mhtree/benchmarking/Tree.java +++ /dev/null @@ -1,15 +0,0 @@ -package mhtree.benchmarking; - -import messif.operations.query.KNNQueryOperation; -import messif.utility.reflection.NoSuchInstantiatorException; - -import java.lang.reflect.InvocationTargetException; - -/** - * Abstraction over MH-Tree and M-Tree for the benchmark purposes. - */ -public interface Tree { - void kNN(KNNQueryOperation knnQueryOperation) throws ClassNotFoundException, NoSuchInstantiatorException, InvocationTargetException; - - int getObjectCount(); -}