import { AnyD3Scale, createScale } from "@visx/scale";
import { castArray, map, zip } from "lodash";
import { createContext, useContext, useMemo } from "react";

import { useSvgSize } from "../common/SvgSize";
import {
  domainForScaleType,
  useDeepCompareMemo,
  useFlatMemo,
  useId,
} from "../common/util";
import type { ScaleConfig } from "../types";
import { AnySeries, Series, useThemedSeries } from "./Series";

/* eslint-disable @typescript-eslint/no-explicit-any */

export interface XYData {
  /** All x scales used in this chart. */
  xScales: AnyD3Scale[];
  /** All y scales used in this chart. */
  yScales: AnyD3Scale[];
  /** All data series used in this chart. */
  series: Record<string, Series>;
  /** Are all series empty? */
  hasNoData: boolean;
}

const XYDataContext = createContext<XYData | undefined>(undefined);

/** Hook to retrieve the current chart's series and scales. */
export const useXYData = () => {
  const value = useContext(XYDataContext);
  if (!value) {
    throw new Error("useXYData must be called inside XYDataProvider");
  }
  return value;
};

export interface XYDataProviderProps {
  xScale: ScaleConfig | readonly ScaleConfig[];
  yScale: ScaleConfig | readonly ScaleConfig[];
  series: Series[];
  children: React.ReactNode;
}

/** Creates the domain for a scale from all series that will use the scale. */
const makeCombinedDomain = <T extends ScaleConfig>(
  scale: T,
  series: Series[],
  accessorFactory: <S extends AnySeries>(s: S) => (d: S["data"][number]) => any,
) => {
  // Computing a domain for all series looks like computing domains for each
  // series separately, and then computing the domain of those domains (e.g.
  // extent(series.map(extent)).
  const seriesDomains = series.map((s) =>
    domainForScaleType(scale.type, s.data, accessorFactory(s)),
  );
  return domainForScaleType(scale.type, seriesDomains.flat().sort(), (x) => x);
};

const fixTimeDomain = (
  type: ScaleConfig["type"],
  domain: NonNullable<ScaleConfig["domain"]>,
) => {
  if (type === "time" && domain && typeof domain[0] === "string") {
    return (domain as [string, string]).map((date) => new Date(date));
  } else {
    return domain;
  }
};

interface UseScalesParams {
  xScales: ScaleConfig[];
  yScales: ScaleConfig[];
  series: Series[];
}

/**
 * Constructs an array of memoized d3 scales given scale configs and the
 * chart's data.
 */
const useScales = ({ xScales, yScales, series }: UseScalesParams) => {
  /* eslint-disable arrow-body-style */
  /* eslint-disable react-hooks/exhaustive-deps */

  // X domain is based on either a hard-coded domain, or the full extent of the
  // data that appears on the given scale.
  const xDomains = useFlatMemo(() => {
    return xScales.map((scale, scaleIdx) => {
      if (scale.domain) return fixTimeDomain(scale.type, scale.domain);
      const seriesForScale = series.filter((s) => s.xScaleIdx === scaleIdx);
      return makeCombinedDomain(scale, seriesForScale, (s) => s.x);
    });
  }, [series, map(xScales, "type"), map(xScales, "domain")]);
  // Y domain is based on either a hard-coded domain, or the full extent of the
  // data that appears on the given scale -- but we also need to exclude points
  // that fall outside the _X_ domain, in case we have a hard-coded x domain
  // (e.g. adjusting a date range for zooming).
  const yDomains = useFlatMemo(() => {
    return yScales.map((scale, scaleIdx) => {
      if (scale.domain) return fixTimeDomain(scale.type, scale.domain);
      const seriesForScale = series.filter((s) => s.yScaleIdx === scaleIdx);
      return makeCombinedDomain(scale, seriesForScale, (s) => {
        const xDomain = xDomains[s.xScaleIdx];
        return (d) => (s.definedInXDomain(d, xDomain) ? s.y(d) : null);
      });
    });
  }, [series, map(yScales, "type"), map(yScales, "domain"), xDomains]);
  // The domain calculation is the tricky (and expensive) part, hence the
  // separate memoized calculation above. This memo hook captures any other
  // scale config changes.
  return useDeepCompareMemo(() => {
    return {
      xScales: zip(xScales, xDomains).map(([config, domain]: any) =>
        createScale({ ...config, domain }),
      ),
      yScales: zip(yScales, yDomains).map(([config, domain]: any) => {
        const scale = createScale({ ...config, domain });
        if ("interpolate" in scale && domain.every((x: unknown) => x === 0)) {
          // If domain start = end this means we have a "collapsed" domain. D3
          // used to scale these to the range's start value, but it changed it
          // to use the midpoint of the range in 2.2. We want to preserve the
          // old behavior in certain situations: y axes where the only value is
          // 0, so that the 0 tick remains at the bottom of the axis instead of
          // the middle.
          // See https://github.com/d3/d3-scale/issues/117
          scale.interpolate((rangeStart, _rangeEnd) => (_val) => rangeStart);
        }
        return scale;
      }),
    };
  }, [xScales, yScales, xDomains, yDomains]);

  /* eslint-enable arrow-body-style */
  /* eslint-enable react-hooks/exhaustive-deps */
};

/** Transforms series into an object, keyed by seriesKey. */
const useSeriesMap = (series: Series[]): Record<string, Series> => {
  const themedSeries = useThemedSeries(series);
  return useMemo(
    () => Object.fromEntries(themedSeries.map((s) => [s.seriesKey, s])),
    [themedSeries],
  );
};

const seriesIsEmpty = (series: Series, domain: any) =>
  series.definedData.every((d) => !series.definedInXDomain(d, domain));

/**
 * Provider for chart series and scales. Must be an ancestor of
 * SvgSizeProvider.
 */
export const XYDataProvider = ({
  xScale,
  yScale,
  series,
  children,
}: XYDataProviderProps) => {
  const { width, height } = useSvgSize();
  const seriesMap = useSeriesMap(series);
  const { xScales, yScales } = useScales({
    xScales: castArray(xScale).map((config) => ({
      range: [0, width],
      ...config,
    })),
    yScales: castArray(yScale).map((config) => ({
      range: [height, 0],
      ...config,
    })),
    series,
  });
  const idPrefix = useId();
  const value = useMemo(
    () => ({
      series: seriesMap,
      xScales,
      yScales,
      idPrefix,
      hasNoData: Object.values(seriesMap).every((s) =>
        seriesIsEmpty(s, xScales[s.xScaleIdx].domain()),
      ),
    }),
    [seriesMap, xScales, yScales, idPrefix],
  );
  return (
    <XYDataContext.Provider value={value}>{children}</XYDataContext.Provider>
  );
};
