From 24024c78eb60bb88fd87219bacafeb6a3394f8cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Thu, 28 Jan 2021 10:19:18 +0100 Subject: [PATCH] ADD: subset of precomputed distances is used to speed up creation of a new node --- src/mhtree/BuildTree.java | 16 ++--- src/mhtree/InternalNode.java | 5 +- src/mhtree/LeafNode.java | 12 ++-- src/mhtree/MHTree.java | 4 +- src/mhtree/Node.java | 22 ++++--- src/mhtree/benchmarking/BenchmarkConfig.java | 6 +- src/mhtree/benchmarking/MHTreeConfig.java | 3 +- src/mhtree/benchmarking/RunBenchmark.java | 62 ++++++++++++++++---- 8 files changed, 84 insertions(+), 46 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index fc36b71..ddab38a 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -1,11 +1,14 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.buckets.BucketDispatcher; import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; @@ -17,7 +20,7 @@ class BuildTree { private final Node[] nodes; private final BitSet validNodeIndices; - private final AbstractRepresentation.PrecomputedDistances objectDistances; + private final PrecomputedDistances objectDistances; private final float[][] nodeDistances; private final BucketDispatcher bucketDispatcher; @@ -33,7 +36,7 @@ class BuildTree { validNodeIndices = new BitSet(nodes.length); validNodeIndices.set(0, nodes.length); - objectDistances = new AbstractRepresentation.PrecomputedDistances(objects); + objectDistances = new PrecomputedDistances(objects); nodeDistances = new float[nodes.length][nodes.length]; this.bucketDispatcher = bucketDispatcher; @@ -110,7 +113,7 @@ class BuildTree { findClosestObjectIndices(furthestIndex, leafCapacity - 1, notProcessedObjectIndices) .forEach(i -> objects.add(objectDistances.getObject(i))); - nodes[nodeIndex] = new LeafNode(objects, bucketDispatcher.createBucket(), insertType); + nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType); } return true; @@ -127,7 +130,6 @@ class BuildTree { .map(getMinHullObjectDistance) .collect(Collectors.toList()); - Node node = nodes[Utils.getIndexOfMinElement(hullObjectDistances)]; node.addObject(object); } @@ -206,7 +208,7 @@ class BuildTree { List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList()); - InternalNode parent = Node.createParent(nodes, insertType); + InternalNode parent = Node.createParent(nodes, objectDistances, insertType); nodes.forEach(node -> node.setParent(parent)); parent.addChildren(nodes); diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java index 2d6dd3d..65195d6 100644 --- a/src/mhtree/InternalNode.java +++ b/src/mhtree/InternalNode.java @@ -1,5 +1,6 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.objects.LocalAbstractObject; import java.io.Serializable; @@ -17,8 +18,8 @@ public class InternalNode extends Node implements Serializable { private final List<Node> children; - InternalNode(List<LocalAbstractObject> objects, InsertType insertType) { - super(objects, insertType); + InternalNode(PrecomputedDistances distances, InsertType insertType) { + super(distances, insertType); children = new ArrayList<>(); } diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java index 6f94ab0..0378200 100644 --- a/src/mhtree/LeafNode.java +++ b/src/mhtree/LeafNode.java @@ -1,6 +1,6 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; import messif.objects.LocalAbstractObject; @@ -18,15 +18,11 @@ public class LeafNode extends Node implements Serializable { private static final long serialVersionUID = 1L; private LocalBucket bucket; - LeafNode(List<LocalAbstractObject> objects, LocalBucket bucket, InsertType insertType) throws BucketStorageException { - super(objects, insertType); + LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType) throws BucketStorageException { + super(distances, insertType); this.bucket = bucket; - this.bucket.addObjects(objects); - } - - LeafNode(AbstractRepresentation.PrecomputedDistances distances, LocalBucket bucket, InsertType insertType) throws BucketStorageException { - this(distances.getObjects(), bucket, insertType); + this.bucket.addObjects(distances.getObjects()); } public void addObject(LocalAbstractObject object) throws BucketStorageException { diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 99fd3e0..ba9beff 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -1,6 +1,6 @@ package mhtree; -import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import messif.algorithms.Algorithm; import messif.buckets.BucketDispatcher; import messif.buckets.BucketErrorCode; @@ -46,7 +46,7 @@ public class MHTree extends Algorithm implements Serializable { this.insertType = insertType; bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = numberOfThreads; + PrecomputedDistances.COMPUTATION_THREADS = numberOfThreads; root = new BuildTree(objects, leafCapacity, numberOfChildren, insertType, bucketDispatcher).getRoot(); } diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index 995d10e..c8e2b55 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -1,5 +1,6 @@ package mhtree; +import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances; import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3; import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; @@ -15,25 +16,22 @@ public abstract class Node implements Serializable { * Serialization ID */ private static final long serialVersionUID = 420L; - - protected HullOptimizedRepresentationV3 hull; protected final InsertType insertType; + protected HullOptimizedRepresentationV3 hull; private Node parent; - Node(HullOptimizedRepresentationV3 hull, InsertType insertType) { - hull.build(); - this.hull = hull; + Node(PrecomputedDistances distances, InsertType insertType) { + this.hull = new HullOptimizedRepresentationV3(distances); + this.hull.build(); this.insertType = insertType; } - Node(List<LocalAbstractObject> objects, InsertType insertType) { - this(new HullOptimizedRepresentationV3(objects), insertType); - } - - public static InternalNode createParent(List<Node> nodes, InsertType insertType) { - return new InternalNode(nodes.stream() + public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType) { + List<LocalAbstractObject> objects = nodes.stream() .flatMap(node -> node.getObjects().stream()) - .collect(Collectors.toList()), insertType); + .collect(Collectors.toList()); + + return new InternalNode(distances.getSubset(objects), insertType); } public float getDistance(LocalAbstractObject object) { diff --git a/src/mhtree/benchmarking/BenchmarkConfig.java b/src/mhtree/benchmarking/BenchmarkConfig.java index 17bdfd9..647e6f7 100644 --- a/src/mhtree/benchmarking/BenchmarkConfig.java +++ b/src/mhtree/benchmarking/BenchmarkConfig.java @@ -39,8 +39,10 @@ public class BenchmarkConfig { for (MHTreeConfig mhTreeConfig : mhTreeConfigs) { mhTreeConfig.setObjects(objects); - for (Consumer<MHTreeConfig> benchmarkFunction : benchmarkFunctions) - benchmarkFunction.accept(mhTreeConfig); + for (int i = 0; i < benchmarkFunctions.size(); i++) { + System.out.println("~Running" + mhTreeConfig + " on " + i + "-th function"); + benchmarkFunctions.get(i).accept(mhTreeConfig); + } } } } diff --git a/src/mhtree/benchmarking/MHTreeConfig.java b/src/mhtree/benchmarking/MHTreeConfig.java index 5a90644..7ff5f04 100644 --- a/src/mhtree/benchmarking/MHTreeConfig.java +++ b/src/mhtree/benchmarking/MHTreeConfig.java @@ -46,8 +46,7 @@ public class MHTreeConfig { @Override public String toString() { return "MHTreeConfig{" + - "objects=" + objects + - ", leafCapacity=" + leafCapacity + + "leafCapacity=" + leafCapacity + ", numberOfChildren=" + numberOfChildren + ", numberOfThreads=" + numberOfThreads + ", insertType=" + insertType + diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java index 9b121fe..9a0994a 100644 --- a/src/mhtree/benchmarking/RunBenchmark.java +++ b/src/mhtree/benchmarking/RunBenchmark.java @@ -3,6 +3,7 @@ package mhtree.benchmarking; import messif.buckets.BucketStorageException; import messif.buckets.impl.MemoryStorageBucket; import messif.objects.LocalAbstractObject; +import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import mhtree.InsertType; import mhtree.MHTree; @@ -31,11 +32,53 @@ public class RunBenchmark { datasets, Arrays.asList( RunBenchmark::benchBuildTree, - (MHTreeConfig config) -> benchApproxKNN(config, 10), - (MHTreeConfig config) -> benchApproxKNN(config, 30) + (MHTreeConfig config) -> benchApproxKNN(config, 30), + RunBenchmark::benchInsert )).execute(); } + public static void benchInsert(MHTreeConfig config) { + try { + List<LocalAbstractObject> objects = config.getObjects(); + List<LocalAbstractObject> initialObjects = objects.subList(0, objects.size() / 2); + List<LocalAbstractObject> restOfTheObjects = objects.subList(objects.size() / 2, objects.size()); + + long timeStart = System.currentTimeMillis(); + + MHTree tree = new MHTree( + initialObjects, + config.getLeafCapacity(), + config.getNumberOfChildren(), + config.getNumberOfThreads(), + config.getInsertType(), + MemoryStorageBucket.class, + null); + + long treeBuilt = System.currentTimeMillis(); + + long tookMax = Long.MIN_VALUE; + long tookMin = Long.MAX_VALUE; + + for (LocalAbstractObject object : restOfTheObjects) { + long insertStart = System.currentTimeMillis(); + tree.insert(new InsertOperation(object)); + long took = System.currentTimeMillis() - insertStart; + + tookMax = Math.max(tookMax, took); + tookMin = Math.min(tookMin, took); + } + + long timeStop = System.currentTimeMillis(); + + System.out.println("~Building took: " + (treeBuilt - timeStart) + " ms"); + System.out.println("~max: " + tookMax + " ms"); + System.out.println("~min: " + tookMin + " ms"); + System.out.println("~Took: " + (timeStop - timeStart) + " ms"); + } catch (BucketStorageException ex) { + Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); + } + } + public static void benchApproxKNN(MHTreeConfig config, int k) { try { MHTree tree = new MHTree(config.getObjects(), @@ -46,8 +89,6 @@ public class RunBenchmark { MemoryStorageBucket.class, null); - System.out.println("Tree built"); - long timeStart = System.currentTimeMillis(); long tookMax = Long.MIN_VALUE; @@ -56,8 +97,7 @@ public class RunBenchmark { for (LocalAbstractObject object : config.getObjects()) { long nnStart = System.currentTimeMillis(); - ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, k); - tree.approxKNN(op); + tree.approxKNN(new ApproxKNNQueryOperation(object, k)); long took = System.currentTimeMillis() - nnStart; @@ -67,11 +107,11 @@ public class RunBenchmark { long timeStop = System.currentTimeMillis(); - System.out.println("nn took: " + (timeStop - timeStart) + " ms"); - System.out.println("max: " + tookMax + " ms"); - System.out.println("min: " + tookMin + " ms"); + System.out.println("~nn took: " + (timeStop - timeStart) + " ms"); + System.out.println("~max: " + tookMax + " ms"); + System.out.println("~min: " + tookMin + " ms"); - System.out.println(config + " k: " + k); + System.out.println("~" + config + " k: " + k); } catch (BucketStorageException ex) { Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); @@ -91,7 +131,7 @@ public class RunBenchmark { null); long timeStop = System.currentTimeMillis(); - System.out.println("Building took: " + (timeStop - timeStart) + " ms"); + System.out.println("~Building took: " + (timeStop - timeStart) + " ms"); } catch (BucketStorageException ex) { Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex); } -- GitLab