package cz.fidentis.analyst.batch;

import cz.fidentis.analyst.Logger;
import cz.fidentis.analyst.core.ProgressDialog;
import cz.fidentis.analyst.face.HumanFace;
import cz.fidentis.analyst.face.HumanFaceFactory;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.stream.Collectors;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;
import org.jocl.cl_queue_properties;

/**
 * A GPU accelerated variant of {@link ApproxHausdorffDistTask}.
 * The computation is accelerated by using both the CPU and GPU parallelism.
 * The OpenCL code is based on 
 * {@see https://github.com/gpu/JOCLSamples/blob/master/src/main/java/org/jocl/samples/JOCLSample.java}.
 * However, the computation is slower than using CPU-based multitasking only
 * (For 500 faces, 40 seconds of the HD finalization phase on GPU compared to 18 second on 8-core CPU)
 * This class is marked as deprecated for this reason. 
 * Use {@link ApproxHausdorffDistTask} instead.
 * <p>
 * The exact computation parameters are taken from the {@code BatchPanel}.
 * </p> 
 * 
 * @author Radek Oslejsek
 */
@Deprecated
public class ApproxHausdorffDistTaskGPU extends SimilarityTask {
    
    /**
     * OpenCL device
     */
    private cl_device_id device;
    
    /**
     * OpenCL context
     */
    private cl_context context;
    
    /**
     * OpenCL command queue
     */
    private cl_command_queue commandQueue;
    
    /**
     * Built OpenCL program
     */
    private cl_program program;
    
    /**
     * GPU memory where the distances of faces to the template face are stored
     */
    private cl_mem[] srcArrays;
    
    private final Stopwatch totalTime = new Stopwatch("Total computation time:\t\t");
    private final Stopwatch hdComputationTime = new Stopwatch("Hausdorff distance preparation time:\t");
    private final Stopwatch loadTime = new Stopwatch("Disk access time:\t\t");
    private final Stopwatch kdTreeConstructionTime = new Stopwatch("KD trees construction time:\t\t");
    private final Stopwatch finalHdComputationTime = new Stopwatch("Hausdorff distance finalization time:\t");
    
    private final int templateFaceIndex;
    
    /**
     * The source code of the OpenCL program to execute
     */
    private static final String PROGRAM_SOURCE =
        "__kernel void "+
        "hdKernel(__global const float *a,"+
        "         __global const float *b,"+
        "         __global float *c)"+
        "{"+
        "    int gid = get_global_id(0);"+
        "    c[gid] = fabs(a[gid] - b[gid]);"+
        "}";

    /**
     * Constructor.
     * 
     * @param progressDialog A window that show the progress of the computation. Must not be {@code null}
     * @param controlPanel A control panel with computation parameters. Must not be {@code null}
     * @param avgFace average face
     */
    public ApproxHausdorffDistTaskGPU(ProgressDialog progressDialog, BatchPanel controlPanel, int templateFaceIndex) {
        super(progressDialog, controlPanel);
        this.templateFaceIndex = templateFaceIndex;
    }
    
    @Override
    protected Void doInBackground() throws Exception {
        initOpenCL();

        HumanFaceFactory factory = getControlPanel().getHumanFaceFactory();
        List<Path> faces = getControlPanel().getFacePaths();
        
        totalTime.start();
        
        factory.setReuseDumpFile(true); // it's safe because no changes are made to models 
        factory.setStrategy(HumanFaceFactory.Strategy.MRU); // keep first X faces in the memory
        
        // We don't need to reaload the initFace periodically for two reasons:
        //   - It is never dumped from memory to disk because we use MRU
        //   - Even if dumped, the face keeps in the mempry until we hold the pointer to it
        loadTime.start();
        String templateFaceId = factory.loadFace(faces.get(templateFaceIndex).toFile());
        HumanFace templateFace = factory.getFace(templateFaceId);
        loadTime.stop();
        
        // Cache of distances of individual vertices in the GPU memory
        // srcArrays[i] stores distance of vertices of i-th face to the template face
        srcArrays = new cl_mem[faces.size()];
        int numVertices = (int) templateFace.getMeshModel().getNumVertices();
        
        for (int i = 0; i < faces.size(); i++) {
            
            if (isCancelled()) { // the user canceled the process
                return null;
            }
            
            if (i != templateFaceIndex) {
                
                loadTime.start();
                String faceId = factory.loadFace(faces.get(i).toFile());
                HumanFace face = factory.getFace(faceId);
                loadTime.stop();
                
                kdTreeConstructionTime.start();
                face.computeKdTree(false);
                kdTreeConstructionTime.stop();

                hdComputationTime.start();
                HausdorffDistance hd = new HausdorffDistance(
                        face.getKdTree(),
                        HausdorffDistance.Strategy.POINT_TO_POINT,
                        true, // relative
                        true, // parallel
                        true  // crop
                );
                templateFace.getMeshModel().compute(hd, true);
                
                // Copy commputed distances into the GPU memory:
                float[] auxDist = new float[numVertices];
                int k = 0;
                for (float val: hd.getDistances().values().stream().flatMap(List::stream)
                        .map(x -> x.floatValue()).collect(Collectors.toList())) {
                    auxDist[k++] = val;
                }
                Pointer arrayP = Pointer.to(auxDist);
                srcArrays[i] = CL.clCreateBuffer(context,
                        CL.CL_MEM_READ_ONLY | CL.CL_MEM_COPY_HOST_PTR,
                        Sizeof.cl_float * numVertices, arrayP, null);

                face.removeKdTree(); // TO BE TESTED
                
                hdComputationTime.stop();
            } else { // distance to template face itself is zero
                float[] auxDist = new float[numVertices]; // filled by zeros by default
                Pointer arrayP = Pointer.to(auxDist);
                srcArrays[i] = CL.clCreateBuffer(context,
                        CL.CL_MEM_READ_ONLY | CL.CL_MEM_COPY_HOST_PTR,
                        Sizeof.cl_float * numVertices, arrayP, null);
            }
            
            // update progress bar
            int progress = (int) Math.round(100.0 * (i+1) / faces.size());
            getProgressDialog().setValue(progress);
            
            //Logger.print(factory.toString());
        }
        
        finalHdComputationTime.start();
        finalizeHdConcurrently(faces.size(), numVertices);
        finalHdComputationTime.stop();
        
        totalTime.stop();

        printTimeStats();
        
        return null;
    }
    
    @Override
    protected void done() {
        super.done();
        CL.clReleaseCommandQueue(commandQueue);
        CL.clReleaseContext(context);
    }
    
    protected void printTimeStats() {
        Logger.print(hdComputationTime.toString());
        Logger.print(loadTime.toString());
        Logger.print(finalHdComputationTime.toString());
        Logger.print(kdTreeConstructionTime.toString());
        Logger.print(totalTime.toString());
    }
    
    private void initOpenCL() {
        // The platform, device type and device number that will be used
        final int platformIndex = 0;
        final long deviceType = CL.CL_DEVICE_TYPE_ALL;
        final int deviceIndex = 0;

        // Enable exceptions and subsequently omit error checks in this sample
        CL.setExceptionsEnabled(true);

        // Obtain the number of platforms
        int[] numPlatformsArray = new int[1];
        CL.clGetPlatformIDs(0, null, numPlatformsArray);
        int numPlatforms = numPlatformsArray[0];

        // Obtain a platform ID
        cl_platform_id[] platforms = new cl_platform_id[numPlatforms];
        CL.clGetPlatformIDs(platforms.length, platforms, null);
        cl_platform_id platform = platforms[platformIndex];

        // Initialize the context properties
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(CL.CL_CONTEXT_PLATFORM, platform);
        
        // Obtain the number of devices for the platform
        int[] numDevicesArray= new int[1];
        CL.clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray);
        int numDevices = numDevicesArray[0];
        
        // Obtain a device ID 
        cl_device_id[] devices= new cl_device_id[numDevices];
        CL.clGetDeviceIDs(platform, deviceType, numDevices, devices, null);
        this.device = devices[deviceIndex];

        // Create a context for the selected device
        this.context = CL.clCreateContext(
            contextProperties, 1, new cl_device_id[]{device}, 
            null, null, null);        
        
        // Create a command-queue for the selected device
        cl_queue_properties properties = new cl_queue_properties();
        commandQueue = CL.clCreateCommandQueueWithProperties(
            context, device, properties, null);

        // Create the program from the source code
        program = CL.clCreateProgramWithSource(context,
            1, new String[]{ PROGRAM_SOURCE }, null, null);
        
        // Build the program
        CL.clBuildProgram(program, 0, null, null, null, null);
    }

    protected void finalizeHdConcurrently(int numFaces, int numVertices) {
        // Set the work-item dimensions
        long[] globalWorkSize = new long[]{numVertices};
        
        ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        final List<Future<Void>> results = new ArrayList<>();
        
        for (int i = 0; i < numFaces; i++) {
            for (int j = i; j < numFaces; j++) {

                if (i == j) {
                    continue;
                }

                final int finalI = i;
                final int finalJ = j;

                results.add(executor.submit(new Callable<Void>() {
                    @Override
                    public Void call() throws Exception {
                        float[] dstArray = new float[numVertices];

                        cl_mem dstMem = CL.clCreateBuffer(context,
                                CL.CL_MEM_READ_WRITE,
                                Sizeof.cl_float * numVertices, null, null);
                        Pointer dstP = Pointer.to(dstArray);

                        // Create the kernel
                        cl_kernel kernel = CL.clCreateKernel(program, "hdKernel", null);

                        // Set the arguments for the kernel
                        int a = 0;
                        CL.clSetKernelArg(kernel, a++, Sizeof.cl_mem, Pointer.to(srcArrays[finalI]));
                        CL.clSetKernelArg(kernel, a++, Sizeof.cl_mem, Pointer.to(srcArrays[finalJ]));
                        CL.clSetKernelArg(kernel, a++, Sizeof.cl_mem, Pointer.to(dstMem));

                        // Execute the kernel
                        CL.clEnqueueNDRangeKernel(commandQueue, kernel, 1, null,
                                globalWorkSize, null, 0, null, null);

                        // Read the output data
                        CL.clEnqueueReadBuffer(commandQueue, dstMem, CL.CL_TRUE, 0,
                                numVertices * Sizeof.cl_float, dstP, 0, null, null);

                        // Release kernel
                        CL.clReleaseKernel(kernel);
                        CL.clReleaseMemObject(dstMem);

                        // Compute average HD of the face pair
                        double sum = 0.0;
                        int counter = 0;
                        for (int k = 0; k < numVertices; k++) {
                            double d = dstArray[k];
                            if (Double.isFinite(d)) {
                                sum += d;
                                counter++;
                            }
                        }
                        setDistSimilarity(finalI, finalJ, sum / counter);
                        setDistSimilarity(finalJ, finalI, sum / counter);

                        return null;
                    }
                }));
            }
        }
        
        executor.shutdown();
        while (!executor.isTerminated()){}
        try {
            int i = 0;
            for (Future<Void> res: results) {
                res.get(); // waits until all computations are finished
                int progress = (int) Math.round(100.0 * (i+1) / results.size());
                getProgressDialog().setValue(progress);
            }
        } catch (final InterruptedException | ExecutionException ex) {
            java.util.logging.Logger.getLogger(ApproxHausdorffDistTask.class.getName()).log(Level.SEVERE, null, ex);
        }
        
        // Release Program, and memory objects
        for (int i = 0; i < numFaces; i++) {
            CL.clReleaseMemObject(srcArrays[i]);
        }
        CL.clReleaseProgram(program);
    }    
}
