import { useCallback, useEffect, useRef, useState } from "react";
import { Sequential } from "../../jstorch/layers/layers";
import {
  Tensor,
  TensorDevice,
  tensorDeviceFromString,
} from "../../jstorch/tensor/tensor";
import { loadModel } from "./loadModel";
import { useWindowSize } from "@react-hook/window-size";
import Checkbox from "../../components/Checkbox";
import { ScreenSizes } from "../../data";
import { Canvas, ChannelsGrid } from "../../utils/utils";
import Toggle from "../../components/Toggle";
import { useJsTorchWasmLoaded } from "../../jstorch/wasm";
import { CanvasRefs, predict } from "./predict";
import LatentSliders from "../../components/LatentSliders";
import { useStopwatch } from "../../hooks/useStopwatch";

const latentSize = 16;
const initialLatent = [
  -1.2978, -1.7232, -0.8657, 0.3033, 0.2243, -0.966, -0.3652, -0.1457, 1.2122,
  0.7359, 0.9004, -0.8637, 1.4561, -0.6484, 0.7297, 1.3139,
];

export default function VAEExperiment() {
  const [model, setModel] = useState<Sequential | null>(null);
  const [modelWeights, setModelWeights] = useState<Tensor[] | null>(null);
  const [processOutput, setProcessOutput] = useState(true);
  const [tensorDevice, setTensorDevice] = useState<string>(TensorDevice.WEBGL);

  const jsTorchWasmLoaded = useJsTorchWasmLoaded();

  const forwardPassTime = useStopwatch("Forward pass");
  const outputProcessingTime = useStopwatch("Output processing");
  const renderingTime = useStopwatch("Rendering");

  const [windowWidth, windowHeight] = useWindowSize();
  const isMobile = windowWidth < ScreenSizes.SM;

  const canvasRefs = useRef<CanvasRefs>({
    output: null,
    afterLinear: null,
    firstTransposeConv: [],
    secondTransposeConv: [],
    thirdTransposeConv: [],
    fourthTransposeConv: [],
    outputBeforeSigmoid: null,
    outputAfterSigmoid: null,
  });

  // load model and create latent tensor
  useEffect(() => {
    if (!jsTorchWasmLoaded) return;
    loadModel(setModel, setModelWeights, tensorDeviceFromString(tensorDevice));
  }, [jsTorchWasmLoaded, tensorDevice]);

  const onLatentChange = useCallback(
    (newLatent: Tensor) => {
      if (!model || !modelWeights) return;
      predict(
        model,
        modelWeights,
        newLatent,
        forwardPassTime.controls,
        outputProcessingTime.controls,
        renderingTime.controls,
        processOutput,
        tensorDevice,
        canvasRefs
      );
    },
    [
      model,
      modelWeights,
      forwardPassTime.controls,
      outputProcessingTime.controls,
      renderingTime.controls,
      processOutput,
      tensorDevice,
    ]
  );

  const outputCanvasSize = isMobile ? windowWidth - 50 : windowHeight * 0.4;

  if (!model || !modelWeights) return <p>Loading model...</p>;

  return (
    <div className="flex flex-col justify-center gap-4 min-h-screen sm:flex-row">
      <div
        className="flex flex-col items-center"
        style={{ width: outputCanvasSize, marginLeft: isMobile ? 25 : 0 }}
      >
        <div>
          <Canvas
            setRef={(el) => (canvasRefs.current.output = el)}
            width={outputCanvasSize}
            height={outputCanvasSize}
          />
          {forwardPassTime.render()}
          {outputProcessingTime.render()}
          {renderingTime.render()}
        </div>

        <LatentSliders
          latentSize={latentSize}
          initialLatent={initialLatent}
          onLatentChange={onLatentChange}
        />

        <div className="flex flex-row gap-2 mt-2">
          <Checkbox
            label="Process output"
            checked={processOutput}
            onChange={setProcessOutput}
          />

          <Toggle
            options={[TensorDevice.WASM, TensorDevice.WEBGL]}
            selected={tensorDevice}
            setSelected={setTensorDevice}
          />
        </div>
      </div>
      <div
        style={{
          width: isMobile ? outputCanvasSize : "auto",
          marginLeft: isMobile ? 25 : 0,
        }}
        className="overflow-x-auto scrollbar mb-8 sm:mb-0"
      >
        <h2 className="font-bold">
          Intermediate channels (all shapes are in CHW format):
        </h2>
        <h3 className="text-sm text-gray-500">
          Blank channels mean all values are equal (e.g. all zeros after ReLU)
        </h3>
        <hr className="my-1" />

        <Canvas
          setRef={(el) => (canvasRefs.current.afterLinear = el)}
          width={400}
          height={50}
        >
          After linear: <b>512</b>
        </Canvas>

        <ChannelsGrid
          isMobile={isMobile}
          mobileTemplate="repeat(16, 1fr)"
          desktopTemplate="repeat(32, 1fr)"
          length={256}
          setRef={(el, i) => (canvasRefs.current.firstTransposeConv[i] = el)}
          width={15}
          height={15}
        >
          First transpose block: <b>256x2x2</b>
        </ChannelsGrid>

        <ChannelsGrid
          isMobile={isMobile}
          mobileTemplate="repeat(13, 1fr)"
          desktopTemplate="repeat(26, 1fr)"
          length={128}
          setRef={(el, i) => (canvasRefs.current.secondTransposeConv[i] = el)}
          width={18}
          height={18}
        >
          Second transpose block: <b>128x4x4</b>
        </ChannelsGrid>

        <ChannelsGrid
          isMobile={isMobile}
          mobileTemplate="repeat(11, 1fr)"
          desktopTemplate="repeat(22, 1fr)"
          length={64}
          setRef={(el, i) => (canvasRefs.current.thirdTransposeConv[i] = el)}
          width={23}
          height={23}
        >
          Third transpose block: <b>64x8x8</b>
        </ChannelsGrid>

        <ChannelsGrid
          isMobile={isMobile}
          mobileTemplate="repeat(8, 1fr)"
          desktopTemplate="repeat(16, 1fr)"
          length={32}
          setRef={(el, i) => (canvasRefs.current.fourthTransposeConv[i] = el)}
          width={33}
          height={33}
        >
          Fourth transpose block: <b>32x16x16</b>
        </ChannelsGrid>

        <div className="flex flex-row gap-8">
          <Canvas
            setRef={(el) => (canvasRefs.current.outputBeforeSigmoid = el)}
            width={140}
            height={140}
          >
            Output before sigmoid: <b>1x32x32</b>
          </Canvas>

          <Canvas
            setRef={(el) => (canvasRefs.current.outputAfterSigmoid = el)}
            width={140}
            height={140}
          >
            Output after sigmoid: <b>1x32x32</b>
          </Canvas>
        </div>
      </div>
    </div>
  );
}
