import React, { JSX, useContext, useState } from 'react';
import { Bar } from '@visx/shape';
import { Group } from '@visx/group';
import {
  BAR_MIN_HEIGHT,
  BAR_OPACITY_BLUR,
  BAR_RADIUS,
  BAR_SELECTED_STROKE_WIDTH,
  BarStackGroupProps,
  SPACE_BETWEEN_BARS
} from '../types';
import { groupByDate } from '../helpers';
import { XYChartContext } from '../XYChart/context';

export function BarStackGroup<T>(props: BarStackGroupProps<T>): JSX.Element | null {
  const {
    data,
    getLegendKey,
    getXValue,
    getYValue,
    getSelectedColor,
    onMouseEnter,
    onMouseLeave,
    onClick,
    maxBarWidth,
    minBarHeight: externalMinBarHeight,
    barStackSorter
  } = props;
  const [focus, setFocus] = useState<string | null>(null);
  const [selection, setSelection] = useState<string | null>(null);

  const context = useContext(XYChartContext);
  if (!context) return null;

  const { xScale, yScale, colorScale, margin } = context;

  const groupData = groupByDate(data, getXValue);
  const minBarHeight = externalMinBarHeight ?? BAR_MIN_HEIGHT;

  return (
    <Group>
      {groupData.map((data, i) => {
        let cumulativeValue = 0;
        let yOffset = SPACE_BETWEEN_BARS;
        const barStackData = barStackSorter ? barStackSorter(data.data) : data.data;
        return (
          <Group key={`bar-stack-${i}`}>
            {barStackData.map((d, j) => {
              const value = getYValue(d);
              const key = getLegendKey(d);
              const selectedKey = `${key}_${i}`;
              const isSelected = selection === selectedKey;
              const fillColor = colorScale(key);
              const strokeColor = isSelected ? getSelectedColor?.(key) : undefined;
              const strokeWidth = isSelected ? BAR_SELECTED_STROKE_WIDTH : undefined;
              const bandwidth = xScale.bandwidth();
              const barWidth = maxBarWidth !== undefined ? Math.min(bandwidth, maxBarWidth) : bandwidth;
              const xOffset = maxBarWidth !== undefined && bandwidth > maxBarWidth ? (bandwidth - maxBarWidth) / 2 : 0;
              const barHeightActual = yScale(0) - yScale(value);
              const barHeight = Math.max(barHeightActual, minBarHeight);

              yOffset += Math.max(barHeight - barHeightActual, 0);

              const barY = yScale(cumulativeValue + value) - yOffset;
              const barX = (xScale(data.date) ?? 0) + xOffset;

              cumulativeValue += value;
              yOffset += SPACE_BETWEEN_BARS;

              return (
                <Bar
                  style={onClick ? { cursor: 'pointer' } : undefined}
                  key={`bar-${i}-${j}`}
                  x={barX}
                  y={barY}
                  rx={BAR_RADIUS}
                  ry={BAR_RADIUS}
                  width={barWidth}
                  height={barHeight}
                  fill={fillColor}
                  strokeWidth={strokeWidth}
                  stroke={strokeColor}
                  opacity={focus && focus !== key ? BAR_OPACITY_BLUR : '100%'}
                  onMouseLeave={() => {
                    setFocus(null);
                    onMouseLeave?.();
                  }}
                  onMouseEnter={() => {
                    setFocus(key);

                    onMouseEnter?.({
                      tooltipData: d,
                      tooltipTop: barY - margin.top + margin.bottom,
                      tooltipLeft: barX + margin.left + barWidth
                    });
                  }}
                  onClick={() => {
                    setSelection(isSelected ? null : selectedKey);
                    onClick?.(d, !isSelected);
                  }}
                />
              );
            })}
          </Group>
        );
      })}
    </Group>
  );
}

export default BarStackGroup;
