import { extent } from 'd3-array'
import { axisBottom, axisLeft } from 'd3-axis'
import { format } from 'd3-format'
import { scaleBand, scaleLinear } from 'd3-scale'
import { select } from 'd3-selection'
import type { FC } from 'react'
import { useContext, useEffect, useMemo, useRef, useState } from 'react'
import SelectedContext from '../../ToolUsage/SelectedContext'
import useSVGDimensions from '../../useSVGDimensions'
import {
  DEFAULT_PLOT_MARGIN,
  LABEL_NOT_SELECTED_OPACITY,
  LABEL_SELECTED_OPACITY,
  LEGEND_ELEMENT_SIZE,
  NOTHING_SELECTED_OPACITY,
  NOT_SELECTED_OPACITY,
  SELECTED_OPACITY,
} from '../../utilities'
import DownloadSnapshot from '../DownloadSnapshot'
import LegendText from '../LegendText'
import type {
  BarPlotProps,
  PlotDataElement,
  UnresponsivePlotProps,
} from '../types'
import Bar from './Bar'
import BarGroup from './BarGroup'

const BAR_PADDING = 0.3

const BarPlot: FC<UnresponsivePlotProps<BarPlotProps>> = ({
  id,
  width,
  height: componentHeight,
  data,
  selectedGroupName,
  colorScale,
  groupNames,
}) => {
  const svgRef = useRef<SVGSVGElement>(null)
  const xAxisRef = useRef<SVGGElement>(null)
  const yAxisRef = useRef<SVGGElement>(null)
  const legendRef = useRef<SVGGElement>(null)

  const { selectedReducer } = useContext(SelectedContext)
  const [selectedState, selectedDispatch] = selectedReducer

  const getOpacity = (element: PlotDataElement) => {
    if (!selectedState.tool) return NOTHING_SELECTED_OPACITY
    if (selectedState.tool.name === element.name) return SELECTED_OPACITY
    return NOT_SELECTED_OPACITY
  }
  const handleClick = (element: PlotDataElement) => {
    if (
      selectedState.tool &&
      element.tool &&
      selectedState.tool.id === element.tool.id
    ) {
      selectedDispatch({ type: 'selectTool', tool: undefined })
      return
    }

    selectedDispatch({ type: 'selectTool', tool: element.tool })
  }

  const [maxTickLength, setMaxTickLength] = useState(0)
  const [maxTickValue, setMaxTickValue] = useState(0)

  const { height: xAxisHeight } = useSVGDimensions(xAxisRef, [maxTickLength])
  const { width: yAxisWidth } = useSVGDimensions(yAxisRef, [maxTickValue])
  const { width: legendWidth, height: legendHeight } = useSVGDimensions(
    legendRef,
    [groupNames]
  )

  const margin = {
    left: DEFAULT_PLOT_MARGIN + yAxisWidth,
    right: groupNames ? DEFAULT_PLOT_MARGIN + legendWidth : DEFAULT_PLOT_MARGIN,
    top: DEFAULT_PLOT_MARGIN,
    bottom: xAxisHeight,
  }

  const height = Math.max(componentHeight - 30, 0)
  // bounds = area inside the graph axis = calculated by subtracting the margins
  const boundsWidth = Math.max(width - margin.right - margin.left, 0)
  const boundsHeight = Math.max(height - margin.top - margin.bottom, 0)

  const groups = groupNames || data.map(dataElement => dataElement.name)
  const xScale = useMemo(
    () =>
      scaleBand().domain(groups).range([0, boundsWidth]).padding(BAR_PADDING),
    [boundsWidth, groups]
  )

  const yScale = useMemo(() => {
    const [, max] = extent(data.map(d => d.value))
    return scaleLinear()
      .domain([0, max || 10])
      .range([boundsHeight, 0])
  }, [boundsHeight, data])

  useEffect(() => {
    if (!xAxisRef.current || !yAxisRef.current) return

    const xAxisTicks = select(xAxisRef.current)
      .call(axisBottom(xScale))
      // rotate the ticks so that the labels do not overlap
      .selectAll('text')
      .style('text-anchor', 'end')
      .attr('dx', '0.5em')
      .attr('dy', '1em')
      .attr('transform', 'rotate(-30)')

    if (!groupNames) {
      // highlight the label of the selected tool
      xAxisTicks
        .data(data.map(element => element.name) || [])
        .style('opacity', groupName =>
          !selectedGroupName || groupName === selectedGroupName
            ? LABEL_SELECTED_OPACITY
            : LABEL_NOT_SELECTED_OPACITY
        )
    }

    // update the bottom margin
    setMaxTickLength(Math.max(...xScale.domain().map(tick => tick.length), 0))

    const yAxisGenerator = axisLeft(yScale)
      .tickValues(yScale.ticks().filter(tick => Number.isInteger(tick)))
      .tickFormat(format('d'))

    // update the left margin
    const tickValues = (yAxisGenerator.tickValues() || []).map(numberValue =>
      typeof numberValue === 'number' ? numberValue : numberValue.valueOf()
    )
    setMaxTickValue(Math.max(...tickValues, 0))

    select(yAxisRef.current).call(yAxisGenerator)
  }, [data, groupNames, selectedGroupName, xScale, yScale])

  const grid = yScale
    .ticks(5)
    .slice(1)
    .map(value => (
      <g key={value}>
        <line
          x1={0}
          x2={boundsWidth}
          y1={yScale(value)}
          y2={yScale(value)}
          stroke='#808080'
          opacity={LABEL_NOT_SELECTED_OPACITY / 2}
        />
      </g>
    ))

  const bars = groupNames
    ? groupNames.map(g => {
        const x = xScale(g)
        if (!x) return <g key={g} />
        return (
          <BarGroup
            key={g}
            data={data.filter(dataElement => dataElement.groupName === g)}
            colorScale={colorScale}
            range={[x, x + xScale.bandwidth()]}
            onClick={handleClick}
            getOpacity={getOpacity}
            yScale={yScale}
            height={boundsHeight}
          />
        )
      })
    : data.map(d => (
        <Bar
          key={d.name}
          dataElement={d}
          colorScale={colorScale}
          onClick={() => {
            handleClick(d)
          }}
          opacity={getOpacity(d)}
          xScale={xScale}
          yScale={yScale}
          height={boundsHeight}
        />
      ))

  const legendItems = [
    ...Array.from(new Set(data.map(d => d.name))).map(name => (
      <g key={name}>
        <rect
          width={LEGEND_ELEMENT_SIZE}
          height={LEGEND_ELEMENT_SIZE}
          x={-LEGEND_ELEMENT_SIZE / 2}
          y={-LEGEND_ELEMENT_SIZE / 2}
          fill={colorScale(name)}
        />
        <LegendText text={name} />
      </g>
    )),
  ]

  const legend = (
    <g
      ref={legendRef}
      transform={`translate(${
        margin.left + boundsWidth + margin.right - legendWidth
      }, \
        ${margin.top + boundsHeight - legendHeight})`}
    >
      {legendItems.map((item, i) => (
        <g key={i} transform={`translate(0, ${i * 20})`}>
          {item}
        </g>
      ))}
    </g>
  )

  return (
    <>
      <svg
        id={id}
        ref={svgRef}
        style={{ fontFamily: 'sans-serif', display: 'block' }}
        width={width}
        height={height}
      >
        <g
          width={boundsWidth}
          height={boundsHeight}
          transform={`translate(${margin.left}, ${margin.top})`}
        >
          {grid}
          {bars}
          <g ref={xAxisRef} transform={`translate(0,${boundsHeight})`} />
          <g ref={yAxisRef} />
        </g>
        {groupNames ? legend : <g />}
      </svg>

      <div style={{ display: 'flex', justifyContent: 'space-between' }}>
        <div />
        <DownloadSnapshot elementId={id} />
      </div>
    </>
  )
}

export default BarPlot
