Commit dff7af13 authored by Marek Horský's avatar Marek Horský Committed by Radek Ošlejšek
Browse files

Implement OpenCL ICP registration

parent d985cf02
Loading
Loading
Loading
Loading
+127 −0
Original line number Diff line number Diff line
package cz.fidentis.analyst.engines.face;

import com.jogamp.opencl.CLContext;
import cz.fidentis.analyst.data.face.HumanFace;
import cz.fidentis.analyst.data.landmarks.Landmark;
import cz.fidentis.analyst.data.shapes.Plane;
import cz.fidentis.analyst.engines.icp.IcpConfig;
import cz.fidentis.analyst.engines.icp.IcpServicesOpenCL;
import cz.fidentis.analyst.math.Quaternion;
import cz.fidentis.analyst.opencl.memory.CLResources;

import javax.vecmath.Point3d;
import javax.vecmath.Tuple3d;
import javax.vecmath.Vector3d;

/**
 * Handles face registration accelerated using OpenCL. Is responsible for releasing of OpenCL resources.
 *
 * @author Marek Horský
 */
public class FaceRegistrationServicesOpenCL implements CLResources {

    private final IcpServicesOpenCL icpServicesOpenCL;

    public FaceRegistrationServicesOpenCL(CLContext clContext) {
        this.icpServicesOpenCL = new IcpServicesOpenCL(clContext);
    }

    /**
     * Superimpose given face to the face included in the ICP configuration object.
     *
     * @param transformedFace A face to be transformed.
     * @param icpConfig       ICP configuration
     */
    public void alignMeshes(HumanFace transformedFace, IcpConfig icpConfig) {
        // transform mesh:
        var trHistory = icpServicesOpenCL.transform(transformedFace.getMeshModel().getFacets(), icpConfig);

        FaceStateServices.updateKdTree(transformedFace, FaceStateServices.Mode.DELETE);
        FaceStateServices.updateLeftBalancedKdTree(transformedFace, FaceStateServices.Mode.DELETE);
        FaceStateServices.updateOctree(transformedFace, FaceStateServices.Mode.DELETE);
        FaceStateServices.updateBoundingBox(transformedFace, FaceStateServices.Mode.RECOMPUTE_IF_PRESENT);

        // transform feature points:
        if (transformedFace.getLandmarks().hasLandmarks()) {
            trHistory.values()
                    .forEach(trList -> trList.forEach(tr -> {
                        for (int i = 0; i < transformedFace.getLandmarks().getAllLandmarks().size(); i++) {
                            Landmark fp = transformedFace.getLandmarks().getAllLandmarks().get(i);
                            fp.setPosition(tr.transformPoint(fp.getPosition(), icpConfig.scale()));
                        }
                    }));
        }

        // transform symmetry plane:
        if (transformedFace.hasSymmetryPlane()) {
            trHistory.values()
                    .forEach(trList -> trList.forEach(tr -> transformedFace.setSymmetryPlane(
                                    transformPlane(
                                            transformedFace.getSymmetryPlane(),
                                            tr.rotation(),
                                            tr.translation(),
                                            tr.scaleFactor()))
                            )
                    );
        }

    }

    /**
     * Transforms the whole plane, i.e., its normal and position.
     *
     * @param plane       plane to be transformed
     * @param rot         rotation
     * @param translation translation
     * @param scale       scale
     * @return transformed plane
     */
    public Plane transformPlane(Plane plane, Quaternion rot, Vector3d translation, double scale) {
        Point3d point = new Point3d(plane.getNormal());
        transformNormal(point, rot);
        Plane retPlane = new Plane(point, plane.getDistance());

        // ... then translate and scale a point projected on the rotate plane:
        point.scale(retPlane.getDistance()); // point laying on the rotated plane
        transformPoint(point, null, translation, scale); // translate and scale only
        Vector3d normal = retPlane.getNormal();
        double dist = ((normal.x * point.x) + (normal.y * point.y) + (normal.z * point.z))
                / Math.sqrt(normal.dot(normal)); // distance of transformed surface point in the plane's normal direction

        return new Plane(retPlane.getNormal(), dist);
    }

    /**
     * Transform a single 3d point.
     *
     * @param point       point to be transformed
     * @param rotation    rotation, can be {@code null}
     * @param translation translation
     * @param scale       scale
     */
    private void transformPoint(Tuple3d point, Quaternion rotation, Vector3d translation, double scale) {
        Quaternion rotQuat = new Quaternion(point.x, point.y, point.z, 1);

        if (rotation != null) {
            Quaternion rotationCopy = Quaternion.multiply(rotQuat, rotation.getConjugate());
            rotQuat = Quaternion.multiply(rotation, rotationCopy);
        }

        point.set(
                rotQuat.x * scale + translation.x,
                rotQuat.y * scale + translation.y,
                rotQuat.z * scale + translation.z
        );
    }

    private void transformNormal(Tuple3d normal, Quaternion rotation) {
        if (normal != null) {
            transformPoint(normal, rotation, new Vector3d(0, 0, 0), 1.0); // rotate only
        }
    }

    @Override
    public void release() {
        icpServicesOpenCL.release();
    }
}
+2 −1
Original line number Diff line number Diff line
@@ -2,13 +2,14 @@ package cz.fidentis.analyst.engines.face.batch.registration;

import cz.fidentis.analyst.data.face.HumanFace;
import cz.fidentis.analyst.data.mesh.MeshModel;
import cz.fidentis.analyst.opencl.memory.CLResources;

/**
 * N:N registration and/or the computation of the average face.
 *
 * @author Radek Oslejsek
 */
public interface BatchFaceRegistration {
public interface BatchFaceRegistration extends CLResources {

    /**
     * Register and/or update (metamorphose) the template face.
+8 −2
Original line number Diff line number Diff line
@@ -31,7 +31,13 @@ public class BatchFaceRegistrationServices {
         * Use Iterative Closets Point algorithm to mutually align mesh vertices.
         * No landmarks are required.
         */
        ICP
        ICP,

        /**
         * Uses Iterative Closets Point algorithm reimplemented for GPU
         * No landmarks are required.
         */
        ICP_GPU
    }

    /**
+45 −16
Original line number Diff line number Diff line
package cz.fidentis.analyst.engines.face.batch.registration.impl;

import cz.fidentis.analyst.data.face.HumanFace;
import cz.fidentis.analyst.data.kdtree.LeftBalancedKdTree;
import cz.fidentis.analyst.data.mesh.MeshModel;
import cz.fidentis.analyst.engines.avgmesh.AvgMeshConfig;
import cz.fidentis.analyst.engines.avgmesh.AvgMeshVisitor;
import cz.fidentis.analyst.engines.face.FaceRegistrationServices;
import cz.fidentis.analyst.engines.face.FaceRegistrationServicesOpenCL;
import cz.fidentis.analyst.engines.face.FaceStateServices;
import cz.fidentis.analyst.engines.face.batch.registration.BatchFaceRegistration;
import cz.fidentis.analyst.engines.face.batch.registration.BatchFaceRegistrationConfig;
import cz.fidentis.analyst.engines.icp.IcpConfig;
import cz.fidentis.analyst.engines.icp.IcpServicesOpenCL;
import cz.fidentis.analyst.engines.sampling.PointSamplingConfig;
import cz.fidentis.analyst.opencl.OpenCLServices;

import java.util.Objects;

import static cz.fidentis.analyst.engines.face.batch.registration.BatchFaceRegistrationServices.RegistrationStrategy.ICP;
import static cz.fidentis.analyst.engines.face.batch.registration.BatchFaceRegistrationServices.RegistrationStrategy.NONE;

/**
@@ -27,6 +30,7 @@ public class BatchFaceRegistrationImpl implements BatchFaceRegistration {
    private final IcpConfig icpConfig;
    private final HumanFace templateFace;

    private FaceRegistrationServicesOpenCL faceRegistrationServicesOpenCL;
    private AvgMeshVisitor avgFaceVisitor = null;

    /**
@@ -42,7 +46,8 @@ public class BatchFaceRegistrationImpl implements BatchFaceRegistration {
        this.templateFace = Objects.requireNonNull(templateFace);
        this.config = config;

        if (config.regStrategy() == ICP) {
        switch (config.regStrategy()) {
            case ICP -> {
                PointSamplingConfig sampling = (config.icpSubsampling() == 0)
                        ? new PointSamplingConfig(PointSamplingConfig.Method.NO_SAMPLING, config.icpSubsampling())
                        : new PointSamplingConfig(PointSamplingConfig.Method.RANDOM, config.icpSubsampling());
@@ -53,17 +58,33 @@ public class BatchFaceRegistrationImpl implements BatchFaceRegistration {
                        config.icpError(),
                        sampling,
                        config.icpAutoCropSince());
        } else {
            this.icpConfig = null;
            }
            case ICP_GPU -> {
                this.faceRegistrationServicesOpenCL = new FaceRegistrationServicesOpenCL(OpenCLServices.createContext());
                PointSamplingConfig sampling = (config.icpSubsampling() == 0)
                        ? new PointSamplingConfig(PointSamplingConfig.Method.NO_SAMPLING, config.icpSubsampling())
                        : new PointSamplingConfig(PointSamplingConfig.Method.RANDOM, config.icpSubsampling());
                FaceStateServices.updateLeftBalancedKdTree(templateFace, FaceStateServices.Mode.COMPUTE_IF_ABSENT);
                this.icpConfig = new IcpConfig(
                        templateFace.getLeftBalancedKdTree(),
                        config.icpIterations(),
                        config.scale(),
                        config.icpError(),
                        sampling,
                        config.icpAutoCropSince());
            }
            default -> this.icpConfig = null;
        }
    }

    @Override
    public void register(HumanFace face) {
        switch (config.regStrategy()) {
            case NONE -> {}
            case NONE -> {
            }
            case ICP -> FaceRegistrationServices.alignMeshes(face, icpConfig);
            case GPA -> FaceRegistrationServices.alignFeaturePoints(templateFace, face, config.scale());
            case ICP_GPU -> faceRegistrationServicesOpenCL.alignMeshes(face, icpConfig);
            default -> throw new IllegalStateException("Unexpected value: " + config.regStrategy());
        }

@@ -80,6 +101,14 @@ public class BatchFaceRegistrationImpl implements BatchFaceRegistration {
        return (avgFaceVisitor == null) ? templateFace.getMeshModel() : avgFaceVisitor.getAveragedMeshModel();
    }

    @Override
    public void release() {
        if (faceRegistrationServicesOpenCL != null) {
            faceRegistrationServicesOpenCL.release();
            //TODO Release context as well?
        }
    }

    protected AvgMeshVisitor computeAvgFaceNN(HumanFace initFace, HumanFace superimposedFace, AvgMeshVisitor avgFaceVisitor) {
        // If the face was moved by registration, then the spatial ordering structure was removed => create it
        // If no transformation was made, then you can use old structure, if exists
+96 −0
Original line number Diff line number Diff line
package cz.fidentis.analyst.engines.face;

import com.jogamp.opencl.CLContext;
import cz.fidentis.analyst.data.face.HumanFace;
import cz.fidentis.analyst.data.face.impl.HumanFaceImpl;
import cz.fidentis.analyst.engines.icp.IcpConfig;
import cz.fidentis.analyst.engines.sampling.PointSamplingConfig;
import cz.fidentis.analyst.opencl.OpenCLServices;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIf;

import java.io.File;
import java.io.IOException;

import static org.junit.jupiter.api.Assertions.assertTrue;

class FaceRegistrationServicesOpenCLTest {
    private static final int MAX_ITERATION = 100;
    private static final double ERROR = 0.05;
    private static final int DISABLED_CROP = Integer.MAX_VALUE;
    private static final String FIRST_FACE_PATH = "src/test/resources/cz/fidentis/analyst/basic-model-01.obj";
    private static final String SECOND_FACE_PATH = "src/test/resources/cz/fidentis/analyst/basic-model-02.obj";
    private static final PointSamplingConfig NO_SAMPLING = new PointSamplingConfig(PointSamplingConfig.Method.NO_SAMPLING, 100);
    private static final PointSamplingConfig SAMPLING = new PointSamplingConfig(PointSamplingConfig.Method.UNIFORM_SPACE, 100);

    private static CLContext clContext;

    @BeforeAll
    @DisabledIf("OpenCLisNotAvailable")
    static void setUpFaces() {
        clContext = OpenCLServices.createContext();
    }

    @Test
    @DisabledIf("OpenCLisNotAvailable")
    void transformSingleFaceWithScale() throws IOException {
        assertTrue(areFaceTransformationEqual(FIRST_FACE_PATH,
                SECOND_FACE_PATH,
                new IcpConfigTemplate(MAX_ITERATION, true, ERROR, NO_SAMPLING, DISABLED_CROP))
        );
    }

    @Test
    @DisabledIf("OpenCLisNotAvailable")
    void transformSingleFace() throws IOException {
        assertTrue(areFaceTransformationEqual(FIRST_FACE_PATH,
                SECOND_FACE_PATH,
                new IcpConfigTemplate(MAX_ITERATION, false, ERROR, NO_SAMPLING, DISABLED_CROP))
        );
    }

    @Test
    @DisabledIf("OpenCLisNotAvailable")
    void transformSingleFaceWithSubsampling() throws IOException {
        assertTrue(areFaceTransformationEqual(FIRST_FACE_PATH,
                SECOND_FACE_PATH,
                new IcpConfigTemplate(MAX_ITERATION, true, ERROR, SAMPLING, DISABLED_CROP))
        );
    }

    private boolean areFaceTransformationEqual(String targetFacePath, String transformedFacePath, IcpConfigTemplate icpConfigTemplate) throws IOException {
        FaceRegistrationServicesOpenCL faceRegistrationServicesOpenCL = new FaceRegistrationServicesOpenCL(clContext);
        HumanFace targetFace = loadHumanFace(targetFacePath);
        FaceStateServices.updateKdTree(targetFace, FaceStateServices.Mode.COMPUTE_IF_ABSENT);
        FaceStateServices.updateLeftBalancedKdTree(targetFace, FaceStateServices.Mode.COMPUTE_IF_ABSENT);
        HumanFace faceToBeTransformedOnGPU = loadHumanFace(transformedFacePath);
        HumanFace faceToBeTransformedOnCPU = loadHumanFace(transformedFacePath);
        IcpConfig cpuConfig = new IcpConfig(targetFace.getKdTree(), icpConfigTemplate.maxIteration, icpConfigTemplate.scale, icpConfigTemplate.error, icpConfigTemplate.strategy, icpConfigTemplate.crop);
        IcpConfig gpuConfig = new IcpConfig(targetFace.getLeftBalancedKdTree(), icpConfigTemplate.maxIteration, icpConfigTemplate.scale, icpConfigTemplate.error, icpConfigTemplate.strategy, icpConfigTemplate.crop);
        faceRegistrationServicesOpenCL.alignMeshes(faceToBeTransformedOnGPU, gpuConfig);
        FaceRegistrationServices.alignMeshes(faceToBeTransformedOnCPU, cpuConfig);
        faceRegistrationServicesOpenCL.release();
        return FaceTester.areFacesEqual(faceToBeTransformedOnGPU, faceToBeTransformedOnCPU);
    }

    @AfterAll
    @DisabledIf("OpenCLisNotAvailable")
    static void cleanUp(){
        if(clContext != null) {
            clContext.release();
        }
    }

    private static HumanFace loadHumanFace(String path) throws IOException {
        return new HumanFaceImpl(new File(path), true);
    }

    private record IcpConfigTemplate(int maxIteration, boolean scale, double error, PointSamplingConfig strategy, int crop) {
    }

    private static boolean OpenCLisNotAvailable() {
        return !OpenCLServices.isOpenCLAvailable();
    }
}
Loading