import { SwarmPlotCanvas } from '@nivo/swarmplot';
import classNames from 'classnames';
import { ShapChartData } from 'common/dist/types/reportData/shapChartV1ReportData';
import React, { FC, useCallback, useMemo } from 'react';

import styles from './styles.module.scss';
import { useDimensions, useThemeColor } from '../../../../../../../utils';
import { ReportElementProps } from '../../../types/meta';
import { ShapChartConfig, ShapChartReportData } from '../type';

export type DataType = {
  feature: string;
  featureType: 'numerical' | 'categorical';
  distribution: [number, number][];
};

type FlattenedElement = {
  /** Feature Name _ Incremental ID */
  id: string;
  /** Name of the feature */
  feature: string;
  /** Value of the feature */
  featureVal: number;
  /** Shap value */
  shapVal: number;
};

export type Props = { data: ShapChartReportData } & ShapChartConfig;

function calculateColor(color1: string, color2: string, ratio: number): string {
  const hex = function (x: number) {
    const a = x.toString(16);
    return a.length === 1 ? '0' + a : a;
  };

  const r = Math.ceil(
    parseInt(color1.substring(1, 3), 16) * (1 - ratio) +
      parseInt(color2.substring(1, 3), 16) * ratio
  );
  const g = Math.ceil(
    parseInt(color1.substring(3, 5), 16) * (1 - ratio) +
      parseInt(color2.substring(3, 5), 16) * ratio
  );
  const b = Math.ceil(
    parseInt(color1.substring(5, 7), 16) * (1 - ratio) +
      parseInt(color2.substring(5, 7), 16) * ratio
  );

  return `#${hex(r)}${hex(g)}${hex(b)}`;
}

export const ShapChart: FC<Props> = ({ data }) => {
  const [ref, { height, width }] = useDimensions<HTMLDivElement>(3);
  const lineHeight = 100;
  const color0 = '#dce2eb';
  const color1 = useThemeColor('primary-highlight');

  function flattenData(data: ShapChartData[]): FlattenedElement[] {
    return data.flatMap((singleFeature) =>
      singleFeature.distribution.flatMap((dist, i) => ({
        id: `${singleFeature.feature}_${i}`,
        feature: singleFeature.feature,
        featureVal: dist[0],
        shapVal: dist[1],
      }))
    );
  }
  const flattenedData = useMemo(() => flattenData(data), [data]);
  const minVal = useMemo(
    () => Math.min(...flattenedData.map((d) => d.shapVal)),
    [flattenedData]
  );
  const maxVal = useMemo(
    () => Math.max(...flattenedData.map((d) => d.shapVal)),
    [flattenedData]
  );

  function calcMinMaxFeatureValPerFeature(data: ShapChartData[]): {
    [feature: string]: { min: number; max: number };
  } {
    const minMax: { [feature: string]: { min: number; max: number } } = {};
    data.forEach((singleFeature) => {
      const min = Math.min(...singleFeature.distribution.map((x) => x[0]));
      const max = Math.max(...singleFeature.distribution.map((x) => x[0]));
      minMax[singleFeature.feature] = { min, max };
    });
    return minMax;
  }
  const minMaxPerFeature = useMemo(
    () => calcMinMaxFeatureValPerFeature(data),
    [data]
  );

  const calculateNodeColor = useCallback(
    (node) => {
      const { min, max } = minMaxPerFeature[node.data.feature];
      const ratio = (node.data.featureVal - min) / (max - min);
      return calculateColor(color0, color1, ratio);
    },
    [minMaxPerFeature, color0, color1]
  );

  if (data.length === 0) {
    return null;
  }

  const ordering = data.map((f) => f.feature);

  return (
    <div ref={ref} className={styles.container}>
      <div className={styles.shapChart}>
        <SwarmPlotCanvas
          height={lineHeight * ordering.length}
          // FIXME-CM not sure why but - 8 is necessary to prevent a resizing issue
          width={(width || 0) - 16}
          data={flattenedData}
          groups={ordering}
          groupBy={'feature'}
          value={'shapVal'}
          valueFormat={'.2f'}
          valueScale={{
            type: 'linear',
            min: minVal,
            max: maxVal,
            reverse: false,
          }}
          size={2}
          colors={calculateNodeColor}
          borderWidth={3}
          borderColor={{ from: 'color' }}
          layout={'horizontal'}
          forceStrength={4}
          simulationIterations={60}
          margin={{
            top: 0,
            right: 150,
            bottom: 80,
            left: 5, // Otherwise some points are cut off
          }}
          axisTop={null}
          axisRight={{
            tickSize: 10,
            tickPadding: 5,
            tickRotation: 0,
          }}
          axisBottom={{
            tickSize: 10,
            tickPadding: 5,
            tickRotation: 0,
            legend: 'SHAP Value',
            legendPosition: 'middle',
            legendOffset: 46,
          }}
          axisLeft={null}
          enableGridX={true}
          enableGridY={true} //
          // @ts-ignore
          motionStiffness={50}
          motionDamping={10}
          tooltip={(node) => (
            <div className={styles.tooltip}>
              <p>
                <b>Feature Value:</b> {node.data.featureVal}
              </p>
              <p>
                <b>SHAP Value:</b> {(node.data.shapVal || 0).toFixed(2)}
              </p>
            </div>
          )}
        />
      </div>

      <div className={styles.legend}>
        <div className={classNames(styles.label, styles.left)}>
          <span>Low Feature Value</span>
        </div>
        <div
          className={styles.gradient}
          style={{
            backgroundImage: `linear-gradient(to right, ${color0} , ${color1})`,
          }}
        ></div>
        <div className={styles.label}>
          <span>High Feature Value</span>
        </div>
      </div>
    </div>
  );
};

export const ShapChartSingle: FC<
  ReportElementProps<ShapChartReportData, ShapChartConfig>
> = ({ input: { reportValue }, config, ...rest }) => (
  <ShapChart data={reportValue} {...config} {...rest} />
);
