package cz.fidentis.analyst.registration;

import cz.fidentis.analyst.Logger;
import cz.fidentis.analyst.canvas.Canvas;
import cz.fidentis.analyst.core.ControlPanelAction;
import cz.fidentis.analyst.face.HumanFace;
import cz.fidentis.analyst.face.HumanFaceFactory;
import cz.fidentis.analyst.feature.FeaturePoint;
import cz.fidentis.analyst.icp.IcpTransformation;
import cz.fidentis.analyst.icp.IcpTransformer;
import cz.fidentis.analyst.icp.NoUndersampling;
import cz.fidentis.analyst.icp.RandomStrategy;
import cz.fidentis.analyst.icp.UndersamplingStrategy;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.procrustes.ProcrustesAnalysis;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;

import java.awt.Color;
import java.awt.event.ActionEvent;
import java.util.List;
import java.util.zip.DataFormatException;
import javax.swing.JCheckBox;
import javax.swing.JComboBox;
import javax.swing.JFormattedTextField;
import javax.swing.JSlider;
import javax.swing.JTabbedPane;
import javax.swing.JToggleButton;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * Action listener for the curvature computation.
 *
 * @author Richard Pajersky
 * @author Radek Oslejsek
 * @author Daniel Schramm
 */
public class RegistrationAction extends ControlPanelAction {

    /*
     * Attributes handling the state
     */
    private boolean scale = true;
    private int maxIterations = 10;
    private double error = 0.3;
    private UndersamplingStrategy undersampling = new RandomStrategy(200);
    private HausdorffDistance hdVisitor = null;
    private String strategy = RegistrationPanel.STRATEGY_POINT_TO_POINT;
    private boolean relativeDist = false;
    private boolean heatmapRender = false;
    private boolean procrustesScalingEnabled = false;


    /*
     * Coloring threshold and statistical values of feature point distances:
     */
    private double fpThreshold = 5.0;

    private final RegistrationPanel controlPanel;

    /**
     * Constructor.
     *
     * @param canvas OpenGL canvas
     * @param topControlPanel Top component for placing control panels
     */
    public RegistrationAction(Canvas canvas, JTabbedPane topControlPanel) {
        super(canvas, topControlPanel);
        this.controlPanel = new RegistrationPanel(this);

        // Place control panel to the topControlPanel
        topControlPanel.addTab(controlPanel.getName(), controlPanel.getIcon(), controlPanel);
        topControlPanel.addChangeListener(e -> {
            // If the registration panel is focused...
            if (((JTabbedPane) e.getSource()).getSelectedComponent() instanceof RegistrationPanel) {
                // ... display heatmap and feature points relevant to the registration
                getCanvas().getScene().setDefaultColors();
                calculateFeaturePoints();
                setHeatmap();
                getSecondaryDrawableFace().setRenderHeatmap(heatmapRender);
            }
        });
        topControlPanel.setSelectedComponent(controlPanel); // Focus registration panel
        calculateHausdorffDistance();
    }

    @Override
    public void actionPerformed(ActionEvent ae) {
        double value;
        String action = ae.getActionCommand();

//        OutputWindow.print(ae.getActionCommand());
        switch (action) {
            case RegistrationPanel.ACTION_COMMAND_APPLY_ICP:
                applyICP();
                calculateFeaturePoints();
                calculateHausdorffDistance();
                this.hdVisitor = null; // recompute
                setHeatmap();
                break;
            case RegistrationPanel.ACTION_COMMAND_SHIFT_X:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getTranslation().x = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getTranslation().x = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_SHIFT_Y:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getTranslation().y = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getTranslation().y = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_SHIFT_Z:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getTranslation().z = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getTranslation().z = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_ROTATE_X:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getRotation().x = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getRotation().x = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_ROTATE_Y:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getRotation().y = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getRotation().y = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_ROTATE_Z:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getRotation().z = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getRotation().z = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_SCALE:
                value = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                getSecondaryDrawableFace().getScale().x = value;
                getSecondaryDrawableFace().getScale().y = value;
                getSecondaryDrawableFace().getScale().z = value;
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().getScale().x = value;
                    getSecondaryFeaturePoints().getScale().y = value;
                    getSecondaryFeaturePoints().getScale().z = value;
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_FRONT_VIEW:
                getCanvas().getCamera().initLocation();
                break;
            case RegistrationPanel.ACTION_COMMAND_SIDE_VIEW:
                getCanvas().getCamera().initLocation();
                getCanvas().getCamera().rotate(0, 90);
                break;
            case RegistrationPanel.ACTION_COMMAND_RESET_TRANSLATION:
                getSecondaryDrawableFace().setTranslation(new Vector3d(0, 0, 0));
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().setTranslation(new Vector3d(0, 0, 0));
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_RESET_ROTATION:
                getSecondaryDrawableFace().setRotation(new Vector3d(0, 0, 0));
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().setRotation(new Vector3d(0, 0, 0));
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_RESET_SCALE:
                getSecondaryDrawableFace().setScale(new Vector3d(0, 0, 0));
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().setScale(new Vector3d(0, 0, 0));
                }
                calculateFeaturePoints();
                //calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_APPLY_TRANSFORMATIONS:
                transformFace();
                getSecondaryDrawableFace().setTranslation(new Vector3d(0, 0, 0));
                getSecondaryDrawableFace().setRotation(new Vector3d(0, 0, 0));
                getSecondaryDrawableFace().setScale(new Vector3d(0, 0, 0));
                if (getSecondaryFeaturePoints() != null) {
                    getSecondaryFeaturePoints().setTranslation(new Vector3d(0, 0, 0));
                    getSecondaryFeaturePoints().setRotation(new Vector3d(0, 0, 0));
                    getSecondaryFeaturePoints().setScale(new Vector3d(0, 0, 0));
                }
                calculateHausdorffDistance();
                break;
            case RegistrationPanel.ACTION_COMMAND_TRANSPARENCY:
                int transparency = ((JSlider) ae.getSource()).getValue();
                setTransparency(transparency);
                break;
            case RegistrationPanel.ACTION_COMMAND_FP_CLOSENESS_THRESHOLD:
                fpThreshold = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                calculateFeaturePoints();
                break;
            case RegistrationPanel.ACTION_COMMAND_ICP_SCALE:
                this.scale = ((JCheckBox) ae.getSource()).isSelected();
                break;
            case RegistrationPanel.ACTION_COMMAND_ICP_MAX_ITERATIONS:
                maxIterations = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).intValue();
                break;
            case RegistrationPanel.ACTION_COMMAND_ICP_ERROR:
                error = ((Number) (((JFormattedTextField) ae.getSource()).getValue())).doubleValue();
                break;
            case RegistrationPanel.ACTION_COMMAND_ICP_UNDERSAMPLING:
                String item = (String) ((JComboBox) ae.getSource()).getSelectedItem();
                switch (item) {
                    case "None":
                        this.undersampling = new NoUndersampling();
                        break;
                    case "Random 200":
                        this.undersampling = new RandomStrategy(200);
                        break;
                    default:
                        throw new UnsupportedOperationException(item);
                }
                break;
            case RegistrationPanel.ACTION_COMMAND_HD_HEATMAP:
                heatmapRender = ((JToggleButton) ae.getSource()).isSelected();
                if (heatmapRender) {
                    calculateHausdorffDistance();
                    setHeatmap();
                }
                getSecondaryDrawableFace().setRenderHeatmap(heatmapRender);
                break;
            case RegistrationPanel.ACTION_COMMAND_HD_STRATEGY:
                strategy = (String) ((JComboBox) ae.getSource()).getSelectedItem();
                this.hdVisitor = null; // recompute
                setHeatmap();
                break;
            case RegistrationPanel.ACTION_COMMAND_HD_RELATIVE_DIST:
                this.relativeDist = ((JToggleButton) ae.getSource()).isSelected();
                this.hdVisitor = null; // recompute
                setHeatmap();
                break;
            case RegistrationPanel.ACTION_COMMAND_PROCRUSTES:
                initiateProcrustesAnalysis();
                break;
            case RegistrationPanel.ACTION_COMMAND_PROCRUSTES_RESET:
                resetProcrustesFace();
                break;
            case RegistrationPanel.ACTION_COMMAND_PROCRUSTES_SCALING:
                this.procrustesScalingEnabled = !this.procrustesScalingEnabled;
                Logger.print("Procrustes: scaling enabled is " + this.procrustesScalingEnabled);
                break;
            default:
            // to nothing
        }

        renderScene();
    }

    /**
     * Calculates Procrustes analysis.
     * 
     * First it creates object of type ProcrusteAnalasys containing two face models converted
     * to ProcrustesAnalysisFaceModel objects which are required for next analysis step.
     * 
     * In analysis step faces are superimposed and rotated.
     * 
     * If {@see this.procrustesScalingEnabled} is set to true then one of them is scaled as well.
     */
    protected void initiateProcrustesAnalysis() {
        Logger.print("Procrustes Analysis");
        Logger out = Logger.measureTime();

        HumanFace primaryFace = getScene().getHumanFace(0);
        HumanFace secondaryFace = getScene().getHumanFace(1);

        try {
            ProcrustesAnalysis procrustesAnalysisInit = new ProcrustesAnalysis(primaryFace, secondaryFace, this.procrustesScalingEnabled);
            procrustesAnalysisInit.analyze();
            calculateFeaturePoints();
            calculateHausdorffDistance();
        } catch (DataFormatException e) {
            Logger.print("Procrustes Analysis experienced exception");
        }

        out.printDuration("Computation of Procrustes for models with "
                + getPrimaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + "/"
                + getSecondaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + " vertices and "
                + getPrimaryDrawableFace().getHumanFace().getFeaturePoints().size()
                + "feature points"
        );
    }
    
    private void resetProcrustesFace() {
        //TODO: create deep copy of second face and reset it on button press
        Logger.print("Reset faces after procrustes is not implemented yet.");
    }

    protected void applyICP() {
        Logger out = Logger.measureTime();

        IcpTransformer visitor = new IcpTransformer(getPrimaryDrawableFace().getModel(), maxIterations, scale, error, undersampling);
        getSecondaryDrawableFace().getModel().compute(visitor); // NOTE: the secondary face is physically transformed
        for (List<IcpTransformation> trList : visitor.getTransformations().values()) {
            for (IcpTransformation tr : trList) {
                for (int i = 0; i < getSecondaryFeaturePoints().getFeaturePoints().size(); i++) {
                    FeaturePoint fp = getSecondaryFeaturePoints().getFeaturePoints().get(i);
                    Point3d trPoint = tr.transformPoint(fp.getPosition(), scale);
                    getSecondaryFeaturePoints().getFeaturePoints().set(i, new FeaturePoint(trPoint.x, trPoint.y, trPoint.z, fp.getFeaturePointType()));
                }
            }
        }

        out.printDuration("Computation of ICP for models with "
                + getPrimaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + "/"
                + getSecondaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + " vertices"
        );
    }

    /**
     * Sets the transparency of {@link #getPrimaryDrawableFace} or
     * {@link #getSecondaryDrawableFace()} based on the inputted value and
     * {@link @TRANSPARENCY_RANGE}
     *
     * @param value Value
     */
    private void setTransparency(int value) {
        if (value == RegistrationPanel.TRANSPARENCY_RANGE) {
            getPrimaryDrawableFace().setTransparency(1);
            getSecondaryDrawableFace().setTransparency(1);
        }
        if (value < RegistrationPanel.TRANSPARENCY_RANGE) {
            getPrimaryDrawableFace().setTransparency(value / 10f);
            getSecondaryDrawableFace().setTransparency(1);
        }
        if (value > RegistrationPanel.TRANSPARENCY_RANGE) {
            getPrimaryDrawableFace().setTransparency(1);
            getSecondaryDrawableFace().setTransparency((2 * RegistrationPanel.TRANSPARENCY_RANGE - value) / 10f);
        }
    }

    /**
     * Calculates feature points which are too far away and changes their color
     * to green otherwise set color to default
     */
    private void calculateFeaturePoints() {
        if (getPrimaryDrawableFace() == null) { // scene not yet initiated
            return;
        }

        if (getPrimaryFeaturePoints() == null
                || getSecondaryFeaturePoints() == null
                || getPrimaryFeaturePoints().getFeaturePoints().size() != getSecondaryFeaturePoints().getFeaturePoints().size()) {
            return;
        }

        double fpMaxDist = Double.NEGATIVE_INFINITY;
        double fpMinDist = Double.POSITIVE_INFINITY;
        double distSum = 0.0;
        for (int i = 0; i < getPrimaryFeaturePoints().getFeaturePoints().size(); i++) {
            FeaturePoint primary = getPrimaryFeaturePoints().getFeaturePoints().get(i);
            FeaturePoint secondary = getSecondaryFeaturePoints().getFeaturePoints().get(i);
            Point3d transformed = new Point3d(secondary.getX(), secondary.getY(), secondary.getZ());
            transformPoint(transformed);
            double distance = Math.sqrt(
                    Math.pow(transformed.x - primary.getX(), 2)
                    + Math.pow(transformed.y - primary.getY(), 2)
                    + Math.pow(transformed.z - primary.getZ(), 2));
            if (distance > fpThreshold) {
                getPrimaryFeaturePoints().resetColorToDefault(i);
                getSecondaryFeaturePoints().resetColorToDefault(i);
            } else {
                getPrimaryFeaturePoints().setColor(i, Color.GREEN);
                getSecondaryFeaturePoints().setColor(i, Color.GREEN);
            }
            fpMaxDist = Math.max(fpMaxDist, distance);
            fpMinDist = Math.min(fpMinDist, distance);
            distSum += distance;
        }
        double fpAvgDist = distSum / getPrimaryFeaturePoints().getFeaturePoints().size();
        this.controlPanel.updateFPStats(fpAvgDist, fpMaxDist, fpMinDist);
    }

    private void calculateHausdorffDistance() {
        Logger out = Logger.measureTime();

        HumanFace primFace = getScene().getHumanFace(0);
        primFace.computeKdTree(false);
        HausdorffDistance hd = new HausdorffDistance(primFace.getKdTree(), HausdorffDistance.Strategy.POINT_TO_POINT_DISTANCE_ONLY, false, true);
        getScene().getHumanFace(1).getMeshModel().compute(hd);

        out.printDuration("Computation of Hausdorff distance for models with "
                + getPrimaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + "/"
                + getSecondaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                + " vertices"
        );

        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        double sum = 0;
        int count = 0;
        for (List<Double> distList : hd.getDistances().values()) {
            for (double dist : distList) {
                min = Math.min(min, dist);
                max = Math.max(max, dist);
                sum += dist;
                count++;
            }
        }

        this.controlPanel.updateHDStats(sum / count, max, min);
    }

    /**
     * Applies carried out transformations first on
     * {@link #getSecondaryDrawableFace} and then on
     * {@link #getSecondaryDrawableFace} feature points
     */
    private void transformFace() {
        for (MeshFacet transformedFacet : getSecondaryDrawableFace().getFacets()) {
            for (MeshPoint comparedPoint : transformedFacet.getVertices()) {
                transformPoint(comparedPoint.getPosition());
            }
        }
        for (int i = 0; i < getSecondaryFeaturePoints().getFeaturePoints().size(); i++) {
            FeaturePoint point = getSecondaryFeaturePoints().getFeaturePoints().get(i);
            Point3d transformed = new Point3d(point.getX(), point.getY(), point.getZ());
            transformPoint(transformed);
            point = new FeaturePoint(transformed.x, transformed.y, transformed.z, point.getFeaturePointType());
            getSecondaryFeaturePoints().getFeaturePoints().set(i, point);
        }
    }

    /**
     * Transforms point based on transformation info from
     * {@link #getSecondaryDrawableFace}
     *
     * @param point Point to transform
     */
    private void transformPoint(Point3d point) {
        if (point == null) {
            throw new IllegalArgumentException("point is null");
        }

        Point3d newPoint = new Point3d(0, 0, 0);
        double quotient;

        // rotate around X
        quotient = Math.toRadians(getSecondaryDrawableFace().getRotation().x);
        if (!Double.isNaN(quotient)) {
            double cos = Math.cos(quotient);
            double sin = Math.sin(quotient);
            newPoint.y = point.y * cos - point.z * sin;
            newPoint.z = point.z * cos + point.y * sin;
            point.y = newPoint.y;
            point.z = newPoint.z;
        }

        // rotate around Y
        quotient = Math.toRadians(getSecondaryDrawableFace().getRotation().y);
        if (!Double.isNaN(quotient)) {
            double cos = Math.cos(quotient);
            double sin = Math.sin(quotient);
            newPoint.x = point.x * cos + point.z * sin;
            newPoint.z = point.z * cos - point.x * sin;
            point.x = newPoint.x;
            point.z = newPoint.z;
        }

        // rotate around Z
        quotient = Math.toRadians(getSecondaryDrawableFace().getRotation().z);
        if (!Double.isNaN(quotient)) {
            double cos = Math.cos(quotient);
            double sin = Math.sin(quotient);
            newPoint.x = point.x * cos - point.y * sin;
            newPoint.y = point.y * cos + point.x * sin;
            point.x = newPoint.x;
            point.y = newPoint.y;
        }

        // translate
        point.x += getSecondaryDrawableFace().getTranslation().x;
        point.y += getSecondaryDrawableFace().getTranslation().y;
        point.z += getSecondaryDrawableFace().getTranslation().z;

        // scale
        point.x *= 1 + getSecondaryDrawableFace().getScale().x;
        point.y *= 1 + getSecondaryDrawableFace().getScale().y;
        point.z *= 1 + getSecondaryDrawableFace().getScale().z;
    }

    protected void setHeatmap() {
        HausdorffDistance.Strategy useStrategy;
        switch (strategy) {
            case RegistrationPanel.STRATEGY_POINT_TO_POINT:
                useStrategy = HausdorffDistance.Strategy.POINT_TO_POINT;
                break;
            case RegistrationPanel.STRATEGY_POINT_TO_TRIANGLE:
                useStrategy = HausdorffDistance.Strategy.POINT_TO_TRIANGLE_APPROXIMATE;
                break;
            default:
                throw new UnsupportedOperationException(strategy);
        }

        if (hdVisitor == null) {
            Logger out = Logger.measureTime();

            this.hdVisitor = new HausdorffDistance(getPrimaryDrawableFace().getModel(), useStrategy, relativeDist, true);
            getSecondaryDrawableFace().getModel().compute(hdVisitor);

            out.printDuration("Computation of Hausdorff distance for models with "
                    + getPrimaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                    + "/"
                    + getSecondaryDrawableFace().getHumanFace().getMeshModel().getNumVertices()
                    + " vertices"
            );
        }

        getSecondaryDrawableFace().setHeatMap(hdVisitor.getDistances());
    }

}
