import { useCallback, useEffect, useRef, useState } from "react";
import DrawableCanvas from "../../components/DrawableCanvas";
import Matrix from "../../components/Matrix";
import { useWindowSize } from "@react-hook/window-size";
import { ScreenSizes } from "../../data";
import { predict } from "./predict";
import ExperimentExplanation from "./ExperimentExplanation";
import { loadImageAndTransformToPixels } from "../../utils/listUtils";
import { useOnnxModel } from "../../utils/onnxUtils";
import Button from "../../components/Button";

async function loadRandomImage(
  setDrawableCanvasImage: (image: number[][][] | null) => void,
  currentExample: string,
  setCurrentExample: (example: string) => void
): Promise<void> {
  while (true) {
    const randomImageIndex = Math.floor(Math.random() * 10);
    const path = `/experiments/mnist/numbers/${randomImageIndex}.png`;
    if (path === currentExample) {
      continue;
    }
    const image = await loadImageAndTransformToPixels(path);
    setDrawableCanvasImage(image);
    setCurrentExample(path);
    return;
  }
}

const tensorNames = [
  "onnx::MaxPool_31", // 1,32,32,32
  "onnx::MaxPool_35", // 1,64,16,16
  "onnx::MaxPool_39", // 1,128,8,8
  "input.40", // 1,10,4,4
  "45", // 1,10 -> this is not used in the matrices state
];
const columnDemoNames = [
  "1st Block Activations",
  "2nd Block Activations",
  "3rd Block Activations",
  "Last Act.",
];
const lastLayerIndex = 3;
const blockColumnsDesktop: number[] = [3, 5, 8, 1];
const blockRowsMobile: number[] = [4, 6, 10, 1];
const lineGap = 4;

export default function MnistExperiment() {
  const onnxSession = useOnnxModel("./onnx/mnist_model.onnx");
  const [image, setImage] = useState<number[][] | null>(null); // image height, image width
  const [drawableCanvasImage, setDrawableCanvasImage] = useState<
    number[][][] | null
  >(null); // image height, image width -> variable to pass to the drawable canvas to change its image
  const [matrices, setMatrices] = useState<number[][][][] | null>(null); // column, matrix index, image height, image width
  const [prediction, setPrediction] = useState<number[] | null>(null);
  const [lastLayerMinMax, setLastLayerMinMax] = useState<
    [number, number] | null
  >(null);
  const [autoPredict, setAutoPredict] = useState<boolean>(true);
  const [windowWidth, windowHeight] = useWindowSize();
  const [currentExample, setCurrentExample] = useState<string>(""); // path of the current shown example digit, we keep track of this to avoid loading the same digit again

  const isMobile = windowWidth < ScreenSizes.SM;

  const drawableCanvasSize = isMobile ? windowWidth - 24 : 300;
  const drawableCanvasPixelSize = Math.floor(drawableCanvasSize / 50);
  const drawableCanvasStrokeSize = Math.floor(drawableCanvasSize / 20);

  // calculate number of rows per block
  const blockRowsDesktop: number[] = [];
  const blockColumnsMobile: number[] = [];
  if (matrices !== null) {
    for (let i = 0; i < blockColumnsDesktop.length; i++) {
      if (!isMobile) {
        blockRowsDesktop.push(
          Math.ceil(matrices[i].length / blockColumnsDesktop[i]) + 1
        ); // add 1 for the title
      } else {
        blockColumnsMobile.push(
          Math.ceil(matrices[i].length / blockRowsMobile[i])
        ); // add 1 for the title
      }
    }
  }

  // calculate matrix canvas width and height per block
  const blockMatrixSize: number[] = [];
  if (matrices !== null) {
    for (let i = 0; i < blockColumnsDesktop.length; i++) {
      if (!isMobile) {
        const availableHeight = windowHeight - blockRowsDesktop[i] * lineGap;
        blockMatrixSize.push(availableHeight / blockRowsDesktop[i]);
      } else {
        const availableWidth =
          windowWidth - blockColumnsMobile[i] * lineGap - 24; // 24 is the p-3 padding on both sides
        blockMatrixSize.push(availableWidth / blockColumnsMobile[i]);
      }
    }
  }

  const predictWrapper = useCallback(() => {
    if (onnxSession !== null && image !== null) {
      predict(
        onnxSession,
        image,
        tensorNames,
        setMatrices,
        setPrediction,
        setLastLayerMinMax,
        lastLayerIndex
      );
    }
  }, [onnxSession, image]);

  useEffect(() => {
    if (autoPredict === true) predictWrapper();
  }, [autoPredict, predictWrapper]);

  // we need a current example ref for the random image selection below,
  // if we just use the state, the state change inside loadRandomImage will trigger a new render causing a loop
  const currentExampleRef = useRef(currentExample);
  useEffect(() => {
    currentExampleRef.current = currentExample;
  }, [currentExample]);

  // everytime image is set to null, choose a random predrawn image
  useEffect(() => {
    if (image === null) {
      loadRandomImage(
        setDrawableCanvasImage,
        currentExampleRef.current,
        setCurrentExample
      );
    }
  }, [image]);

  const columnTitle = (text: string) => (
    <p className="text-sm font-bold col-span-full text-center mt-3 sm:mt-0">
      {text}
    </p>
  );

  return (
    <div>
      <div
        className="w-screen sm:h-screen p-3 bg-sky-100"
        style={{
          display: !isMobile ? "grid" : "flex",
          gridTemplateColumns: !isMobile ? "1fr auto 1fr" : undefined,
          gridTemplateRows: !isMobile ? "1fr" : undefined,
          gap: !isMobile ? "1rem" : undefined,
          flexDirection: isMobile ? "column" : undefined,
        }}
      >
        {/* Column with drawable canvas and buttons */}
        <div className="flex flex-col items-left space-y-2">
          <h2
            style={{ maxWidth: drawableCanvasSize }}
            className="font-bold text-2xl"
          >
            Draw a number from 0 to 9 on the canvas below:
          </h2>
          <DrawableCanvas
            setImage={setImage}
            outputSize={[32, 32]}
            canvasSize={[drawableCanvasSize, drawableCanvasSize]}
            pixelSize={drawableCanvasPixelSize}
            strokeSize={drawableCanvasStrokeSize}
            drawableCanvasImage={drawableCanvasImage}
            setDrawableCanvasImage={setDrawableCanvasImage}
          />

          <p
            style={{ maxWidth: drawableCanvasSize }}
            className="text-sm italic"
          >
            Disable auto predict if you are having performance issues
          </p>
          <div className="flex items-center space-x-2">
            <input
              type="checkbox"
              id="autoPredict"
              name="autoPredict"
              checked={autoPredict}
              onChange={(e) => setAutoPredict(e.target.checked)}
            />
            <label htmlFor="autoPredict">Auto Predict</label>
          </div>
          {!autoPredict && <button onClick={predictWrapper}>Predict</button>}
          <Button onClick={() => setImage(null)}>Load random example</Button>
        </div>

        {/* Column with activations */}
        <div className="flex flex-col sm:flex-row sm:space-x-6 overflow-auto">
          {matrices?.map((column, i) => (
            <div
              key={i}
              className="grid"
              style={{
                gridTemplateColumns: `repeat(${
                  isMobile ? blockColumnsMobile[i] : blockColumnsDesktop[i]
                }, 1fr)`,
                gridTemplateRows: `repeat(${
                  isMobile ? blockRowsMobile[i] : blockRowsDesktop[i]
                }, 1fr)`,
                gap: `${lineGap}px`,
              }}
            >
              {columnTitle(columnDemoNames[i])}

              {column.map((matrix, j) => (
                <Matrix
                  key={j}
                  matrix={matrix}
                  canvasSize={[blockMatrixSize[i], blockMatrixSize[i]]}
                  minMax={
                    i === lastLayerIndex
                      ? lastLayerMinMax
                        ? lastLayerMinMax
                        : [0, 1]
                      : undefined
                  }
                />
              ))}
            </div>
          ))}
        </div>

        {/* Column with predictions */}
        {prediction !== null && (
          <div className="grid grid-row-3 sm:grid-row-11 w-full sm:w-32 grid-cols-5 sm:grid-cols-1">
            {columnTitle("Predictions")}
            {prediction.map((value, i) => (
              <div
                key={i}
                className="relative"
                style={{ height: blockMatrixSize[lastLayerIndex] }}
              >
                <div
                  className="absolute left-0 top-0 h-full bg-red-300 z-10"
                  style={{
                    width: `${value * 100}%`,
                  }}
                />
                <div className="absolute inset-0 flex items-center justify-center z-20">
                  <p className="font-bold text-sm sm:text-xl">
                    {i} -&gt; {Math.floor(value * 100)}%
                  </p>
                </div>
              </div>
            ))}
          </div>
        )}
      </div>
      <ExperimentExplanation />
    </div>
  );
}
