From ff598df0e9b4a9a1ec8ae535ef29e2800ef1148b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 10 Feb 2021 14:07:50 +0100 Subject: [PATCH] ADD: support for distance measure, number of threads is now set externally --- src/mhtree/InternalNode.java | 6 +++--- src/mhtree/LeafNode.java | 4 ++-- src/mhtree/MHTree.java | 9 ++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index c786940..4b961df 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -17,8 +17,8 @@ public class InternalNode extends Node implements Serializable { private final List<Node> children; - InternalNode(PrecomputedDistances distances, InsertType insertType) { - super(distances, insertType); + InternalNode(PrecomputedDistances distances, InsertType insertType, DistanceMeasure distanceMeasure) { + super(distances, insertType, distanceMeasure); children = new ArrayList<>(); } @@ -31,7 +31,7 @@ public class InternalNode extends Node implements Serializable { } public Node getNearestChild(LocalAbstractObject object) { - Map<Node, Float> nodeToObjectDistance = children.stream() + Map<Node, Double> nodeToObjectDistance = children.stream() .collect(Collectors.toMap(Function.identity(), node -> node.getDistance(object))); return Collections.min(nodeToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey(); diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index 49ea709..d720c98 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -19,8 +19,8 @@ public class LeafNode extends Node implements Serializable { private static final long serialVersionUID = 1L; private LocalBucket bucket; - LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType) throws BucketStorageException { - super(distances, insertType); + LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType, DistanceMeasure distanceMeasure) throws BucketStorageException { + super(distances, insertType, distanceMeasure); this.bucket = bucket; this.bucket.addObjects(distances.getObjects()); diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index f7290cd..27fe6e7 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -1,6 +1,5 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; import messif.buckets.BucketErrorCode; @@ -36,20 +35,20 @@ public class MHTree extends Algorithm implements Serializable { "arity", "number of threads used in precomputing distances", "insert type", + "distance measure", "storage class for buckets", "storage class parameters" }) - public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, int numberOfThreads, InsertType insertType, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { + public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { super("MH-Tree"); this.leafCapacity = leafCapacity; this.arity = arity; this.insertType = insertType; - bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - PrecomputedDistances.COMPUTATION_THREADS = numberOfThreads; + bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - root = new BuildTree(objects, leafCapacity, arity, insertType, bucketDispatcher).getRoot(); + root = new BuildTree(objects, leafCapacity, arity, insertType, distanceMeasure, bucketDispatcher).getRoot(); } public void approxKNN(ApproxKNNQueryOperation operation) { -- GitLab