import { localPoint } from "@visx/event";
import type { AnyD3Scale } from "@visx/scale";
import { every, isNil, mapValues, minBy, omitBy } from "lodash";
import { RefObject, useCallback, useRef } from "react";

import HoverDetector, { HoverHandler } from "../common/HoverDetector";
import { useSvgSize } from "../common/SvgSize";
import {
  TooltipData,
  TooltipDatumBase,
  useTooltip,
  useTooltipSetter,
} from "../common/Tooltip";
import { invertScale } from "../common/util";
import type { AnySeries, Series } from "./Series";
import { useXYData } from "./XYData";

export interface XYTooltipDatum<T extends AnySeries = AnySeries>
  extends TooltipDatumBase<T["definedData"][number]> {
  seriesKey: T["seriesKey"];
  color?: T["color"];
  // we're guaranteeing that the data is defined at this point, so it should
  // not have null x or y values
  x: NonNullable<ReturnType<T["x"]>>;
  y: NonNullable<ReturnType<T["y"]>>;
}

export type TooltipProps<T extends AnySeries | AnySeries[]> = TooltipData<
  XYTooltipDatum<T extends AnySeries[] ? T[number] : T>[]
>;

/** Type for a Component that renders tooltips. */
export type TooltipComponent<T extends AnySeries | AnySeries[]> = (
  props: TooltipProps<T>,
) => JSX.Element | null;

export const useXYTooltip = useTooltip<XYTooltipDatum[]>;

// The already-inverted x, y point for the primary axes.
interface TooltipPoint {
  x0?: number | Date;
  y0?: number | Date;
}

const nearestPoint = ({
  series,
  xScale,
  yScale,
  x,
  tolerancePx,
}: {
  series: Series;
  xScale: AnyD3Scale;
  yScale: AnyD3Scale;
  x: number;
  y: number;
  tolerancePx: number;
}) => {
  const seriesX = invertScale(xScale, x);
  if (seriesX == null) {
    return undefined;
  }
  const nearestDatum = series.xNearest(seriesX);
  if (nearestDatum == null) {
    return undefined;
  }
  const datumX = series.x(nearestDatum);
  const datumY = series.y(nearestDatum);
  // convert to screen coords, centering if it's a band scale (e.g. bars)
  const width = "bandwidth" in xScale ? xScale.bandwidth() : 0;
  const height = "bandwidth" in yScale ? yScale.bandwidth() : 0;
  const screenX = xScale(datumX) + width / 2;
  const screenY = yScale(datumY) + height / 2;
  // Make sure this point is within the scale's domain, and not too far away.
  if (
    series.definedInXDomain(nearestDatum, xScale.domain()) &&
    Math.abs(x - screenX) <= tolerancePx
  ) {
    return {
      seriesKey: series.seriesKey,
      label: series.label,
      color: series.color,
      datum: nearestDatum,
      x: datumX,
      y: datumY,
      screenX,
      screenY,
    };
  } else {
    return undefined;
  }
};

const useXYTooltipCallbacks = (
  thisRef: RefObject<SVGElement | null>,
  { tolerance }: { tolerance: number },
) => {
  const setTooltip = useTooltipSetter<XYTooltipDatum[]>();
  const { margin, width } = useSvgSize();
  const { xScales, yScales, series } = useXYData();
  const tolerancePx = Math.round(tolerance * width);

  const updateTooltip = useCallback(
    (point?: TooltipPoint) => {
      if (point?.x0 == null || point?.y0 == null) {
        setTooltip({ isOpen: false });
        return;
      }
      // Convert x0 and y0 to cursor positions
      const x = xScales[0](point.x0);
      const y = yScales[0](point.y0);
      // Find the closest data point to the cursor for each series
      const seriesData = mapValues(series, (s) =>
        nearestPoint({
          series: s,
          xScale: xScales[s.xScaleIdx],
          yScale: yScales[s.yScaleIdx],
          x,
          y,
          tolerancePx,
        }),
      );
      if (every(seriesData, isNil) || x == null || y == null) {
        // Close the tooltip if we are hovering over the chart but are nowhere
        // near any data points.
        setTooltip({ isOpen: false });
      } else {
        const nearest = minBy(Object.values(seriesData), (p) =>
          p ? Math.hypot(x - p.screenX, y - p.screenY) : Number.MAX_VALUE,
        );
        setTooltip({
          isOpen: true,
          data: {
            cursor: { x, y },
            nearest: nearest!,
            series: omitBy(seriesData, isNil) as unknown as Record<
              string,
              NonNullable<(typeof seriesData)[string]>
            >,
          },
          left: x,
          top: y,
        });
      }
    },
    [series, xScales, yScales, setTooltip, tolerancePx],
  );

  const onHover = useCallback<HoverHandler>(
    (event) => {
      if (!thisRef.current) {
        setTooltip({ isOpen: false });
        return;
      }
      // We need to be explicit about the element this handler is for,
      // otherwise we could end up with a point that is slightly off if this is
      // a hover event that targets a text element!
      const cursor = localPoint(thisRef.current, event);
      if (cursor) {
        const x0 = invertScale(xScales[0], cursor.x - margin.left);
        const y0 = invertScale(yScales[0], cursor.y - margin.top);
        updateTooltip({ x0, y0 });
      } else {
        updateTooltip(undefined);
      }
    },
    [
      thisRef,
      setTooltip,
      xScales,
      yScales,
      margin.left,
      margin.top,
      updateTooltip,
    ],
  );

  const onLeave = useCallback<HoverHandler>(() => {
    updateTooltip(undefined);
  }, [updateTooltip]);

  return { onHover, onLeave };
};

/** An svg `rect` that captures cursor events and updates TooltipContext. */
export const XYTooltipHoverDetector = ({
  width,
  height,
  tolerance = 0.2, // 20% of the chart width
}: XYTooltipHoverDetectorProps) => {
  const ref = useRef<SVGRectElement>(null);
  const { onHover, onLeave } = useXYTooltipCallbacks(ref, { tolerance });
  return (
    <HoverDetector
      ref={ref}
      width={width}
      height={height}
      onHover={onHover}
      onLeave={onLeave}
    />
  );
};

export interface XYTooltipHoverDetectorProps {
  width: number;
  height: number;
  tolerance?: number;
}
