import React, { useCallback, useMemo } from 'react';
import { CustomLayer, Point } from '@nivo/line';
import { flatMap } from 'lodash';

const Circle = ({
  x,
  y,
  serieColor
}: Pick<Point, 'serieColor' | 'x' | 'y' | 'id'>) => {
  return (
    <g transform={`translate(${x}, ${y})`} pointerEvents="none">
      <circle
        r={5}
        fill="#fff"
        stroke={serieColor}
        strokeWidth={2}
        pointerEvents="none"
      />
    </g>
  );
};

// Layer that has the line, area and points layer
// Only difference from what we have is that the stacked value is calculated with the condition
// of treating null as 0
const Graphs: CustomLayer = ({
  series,
  lineGenerator,
  yScale,
  innerHeight
}) => {
  const getStackedY = useCallback(
    // @ts-expect-error wrong type from nivo
    (y1: number, y2: number) => yScale(y1 + y2),
    [yScale]
  );

  // Convert null to zeros
  const nullToZeros = useMemo(() => {
    if (!series[0]) return [];

    return series[0].data.map((datum) => ({
      ...datum.position,
      y: datum.data.y === null ? 0 : datum.data.y
    }));
  }, [series]);

  const points = useMemo(() => {
    if (!series.length) return [];

    return series.map((serie, serieIndex) =>
      serie.data.map((datum, index) => ({
        ...datum.position,
        y:
          // Stack only for the second set of data
          serieIndex === 1
            ? datum.data.y !== null
              ? getStackedY(
                  datum.data.y as number,
                  nullToZeros[index].y as number
                )
              : null
            : datum.position.y,
        serieColor: serie.color,
        id: `${serie.id}-${index}`
      }))
    );
  }, [series, getStackedY, nullToZeros]);

  const path = useMemo(() => {
    if (points.length !== 2) return [];

    return lineGenerator(
      points[1].map((datum) => ({
        x: datum.x,
        y: datum.y
      }))
    );
  }, [lineGenerator, points]);

  const areas = useMemo(() => {
    if (points.length !== 2) return [];
    return points.map((pointSet) => {
      const result = [];
      let group = [];

      pointSet.forEach((point, index) => {
        if (point.y !== null) {
          group.push(point);
          if (index !== pointSet.length - 1) return;
        }

        // Push group if there is no more next point to add
        // Then empty it for a new group
        if (group.length > 0) {
          // Add closing points
          result.push([
            { x: group[0].x, y: innerHeight },
            ...group,
            { x: group[group.length - 1].x, y: innerHeight }
          ]);
          group = [];
        }
      });

      return result;
    }, []);
  }, [points, innerHeight]);

  return (
    <>
      <path
        d={path}
        fill="none"
        strokeWidth={2}
        stroke={series[1] ? series[1].color : ''}
        style={{
          strokeDasharray: '4',
          strokeWidth: 2
        }}
      />
      {areas.map((group, serieIndex) =>
        group.map((points, index) => (
          <path
            pointerEvents="none"
            key={`${serieIndex}-${index}`}
            d={lineGenerator(points)}
            fill={series[serieIndex] ? series[serieIndex].color : ''}
            fillOpacity="0.2"
          />
        ))
      )}
      {flatMap(points)
        .filter((point) => point.y !== null)
        .map((point) => (
          <Circle key={point.id} {...point} />
        ))}
    </>
  );
};

export default Graphs;
