import * as d3 from 'd3';
import { Calibration, BasicCalibration } from '../../types';
import * as amp from '../../amp';
import { UnsafeSharedValue, assert, randomInt } from '../../util';
import { addScrubber } from '.';
import * as _ from 'remeda';
import { stackedSeries, type StackedSeries } from './data';

const transform_to_stacks = (calibrations: Calibration[], end: Date) => {
  const cal_classes = amp.basic_package_calibration_classes;
  const latest_calibration: Record<string, number> = Object.fromEntries(
    _.zip(cal_classes, Array(cal_classes.length).fill(0)),
  );
  const basic_calibrations = amp.filterBasicPackageCalibrations(calibrations);

  const calibration_points = basic_calibrations.flatMap((c) => {
    // We essentially turn each calibration into a rectangle that spans
    // from when it was created until it changes again.
    const end = cal_classes.map((cls) => ({
      timestamp: c.created.getTime() - 1,
      class: cls,
      amount: latest_calibration[cls],
    }));
    // const end: any = [];
    latest_calibration[amp.calibrationClass(c)] = c.amount;
    const start = cal_classes.map((cls) => ({
      timestamp: c.created.getTime(),
      class: cls,
      amount: latest_calibration[cls],
    }));
    return [...end, ...start];
  });

  // Add the final points for the end of the period
  cal_classes.forEach((cls) => {
    calibration_points.push({
      timestamp: end.getTime(),
      class: cls,
      amount: latest_calibration[cls],
    });
  });

  const pos_stack = stackedSeries<number>(
    'class',
    'timestamp',
    'amount',
    amp.basic_package_calibration_classes,
  )(
    calibration_points.map((c) => ({
      ...c,
      amount: Math.max(c.amount, 0),
    })),
  );
  const neg_stack = stackedSeries<number>(
    'class',
    'timestamp',
    'amount',
    amp.basic_package_calibration_classes,
  )(
    calibration_points.map((c) => ({
      ...c,
      amount: Math.min(c.amount, 0),
    })),
  );

  return {
    pos: pos_stack,
    neg: neg_stack,
  };
};

// We calculate the totals from the stacks because the stacks
// have all the points we need to create the flat spans of time corresponding
// to the duratoin that a calibration was set for.
type TotalPoint = {
  timestamp: number;
  amount: number;
};
const total = (
  pos_stack: StackedSeries<number>[],
  neg_stack: StackedSeries<number>[],
) => {
  assert(
    pos_stack.length === neg_stack.length,
    'Stacks must be of the same length',
  );

  const paths = [
    ...pos_stack.map(({ values }) => values),
    ...neg_stack.map(({ values }) => values),
  ];

  const points: TotalPoint[] = [];
  for (let i = 0; i < paths[0].length; i++) {
    const timestamp = paths[0][i].index;
    const amount =
      paths.reduce((acc, path) => {
        const amount = path[i].span[1] - path[i].span[0];
        return acc * (1 + amount);
      }, 1) - 1;
    points.push({ timestamp, amount });
  }

  return points;
};

const packageCalibrationHistorySvg =
  ({ width, height }: { width: number; height: number }) =>
  (
    {
      calibrations,
      from,
      to,
    }: {
      calibrations: Calibration[];
      from: Date;
      to: Date;
    },
    selectedDate: UnsafeSharedValue<Date>,
  ) => {
    // Set up the SVG canvas
    const x_axis_height = 20;
    const y_axis_width = 40;
    const cheight = height - x_axis_height;

    const svg = d3.create('svg');
    svg.attr('width', width).attr('height', height);

    // Create x and y scales
    const x = d3.scaleTime().domain([from, to]).range([y_axis_width, width]);

    const stacks = transform_to_stacks(calibrations, to);
    const total_calibration_path = total(stacks.pos, stacks.neg);

    const y = d3
      .scaleLinear()
      .domain([
        (d3.min(stacks.neg, (d) => d.limit) as number) - 0.1,
        (d3.max(stacks.pos, (d) => d.limit) as number) + 0.1,
      ])
      .range([cheight, 0]);

    const clip_id = 'clip-package-' + randomInt();

    // Define the clipping area
    svg
      .append('defs')
      .append('svg:clipPath')
      .attr('id', clip_id)
      .append('rect')
      .attr('x', y_axis_width)
      .attr('y', 0)
      .attr('width', width - y_axis_width)
      .attr('height', height - x_axis_height);

    const append_area = (
      data: StackedSeries<number>[],
      opts: { opacity: number },
    ) => {
      svg
        .append('g')
        .attr('clip-path', `url(#${clip_id})`)
        .selectAll('.area')
        .data(data)
        .enter()
        .append('path')
        .attr('class', 'area')
        .attr('opacity', opts.opacity)
        .attr('fill', (d) => amp.calibrationColor(d.name) as string)
        .attr('d', ({ values }) =>
          d3
            .area<typeof values[0]>()
            .x((d) => x(d.index))
            .y0((d) => y(d.span[0]))
            .y1((d) => y(d.span[1]))(values),
        );
    };

    append_area(stacks.neg, { opacity: 0.5 });
    append_area(stacks.pos, { opacity: 0.5 });

    svg
      .append('line')
      .attr('stroke', 'white')
      .attr('stroke-width', 1.5)
      .style('opacity', 1)
      .attr('x1', y_axis_width)
      .attr('x2', width)
      .attr('y1', y(0))
      .attr('y2', y(0));

    svg
      .append('path')
      .attr('clip-path', `url(#${clip_id})`)
      .datum(total_calibration_path)
      .attr('fill', 'none')
      .attr('stroke', amp.amp_adjustment_style.color)
      .attr('stroke-width', amp.amp_adjustment_style.stroke_width)
      .attr('stroke-opacity', amp.amp_adjustment_style.opacity)
      .attr(
        'd',
        d3
          .line<TotalPoint>()
          .x((d) => {
            return x(d.timestamp);
          })
          .y((d) => {
            return y(d.amount);
          }),
      );

    // Append x-axis
    svg
      .append('g')
      .attr('transform', `translate(0,${cheight})`)
      .call(
        d3
          .axisBottom(x)
          .ticks(width / 80)
          .tickSizeOuter(0),
      );

    // Append y-axis
    svg
      .append('g')
      .attr('transform', 'translate(' + y_axis_width + ',0)')
      .call(
        d3
          .axisLeft(y)
          .ticks(height / 30)
          .tickFormat(d3.format('.0%')),
      );

    const getClosestCalibration = (date: Date) => {
      return calibrations.reduce((prev, curr) =>
        Math.abs(curr.created.getTime() - date.getTime()) <
        Math.abs(prev.created.getTime() - date.getTime())
          ? curr
          : prev,
      );
    };

    addScrubber(
      svg,
      (x_coord) => getClosestCalibration(x.invert(x_coord)).created,
      selectedDate,
      {
        vertical_rule: {
          y1: 0,
          y2: height - x_axis_height,
        },
        region: {
          x: y_axis_width,
          y: 0,
          width: width - y_axis_width,
          height: height - x_axis_height,
        },
        invert: (date) => ({ x: x(date) }),
      },
    );

    return svg.node() as SVGSVGElement;
  };

export { packageCalibrationHistorySvg };
