import { Stopwatch } from "../../hooks/useStopwatch";
import { Sequential } from "../../jstorch/layers/layers";
import {
  tensorDeviceFromString,
  TensorDevice,
  Tensor,
} from "../../jstorch/tensor/tensor";

export type CanvasRefs = {
  output: HTMLCanvasElement | null;
  afterLinear: HTMLCanvasElement | null;
  firstTransposeConv: (HTMLCanvasElement | null)[];
  secondTransposeConv: (HTMLCanvasElement | null)[];
  thirdTransposeConv: (HTMLCanvasElement | null)[];
  fourthTransposeConv: (HTMLCanvasElement | null)[];
  outputBeforeSigmoid: HTMLCanvasElement | null;
  outputAfterSigmoid: HTMLCanvasElement | null;
};

export function predict(
  model: Sequential,
  modelWeights: Tensor[],
  latent: Tensor,
  forwardPassTimeControls: Stopwatch["controls"],
  outputProcessingTimeControls: Stopwatch["controls"],
  renderingTimeControls: Stopwatch["controls"],
  processOutput: boolean,
  tensorDevice: string,
  canvasRefs: React.MutableRefObject<CanvasRefs>
) {
  // cancel prediction if model weights are being loaded on new device
  if (modelWeights[0].getDevice() !== tensorDeviceFromString(tensorDevice)) {
    return;
  }

  // forward pass
  forwardPassTimeControls.start();
  const latentOnDevice = latent.toDevice(tensorDeviceFromString(tensorDevice));
  const deviceIntermediates = model.forwardWithIntermediates(latentOnDevice);
  const intermediates = Object.fromEntries(
    Object.entries(deviceIntermediates).map(([name, tensor]) => [
      name,
      tensor.toDevice(TensorDevice.WASM),
    ])
  );
  let output = intermediates["after_conv_t_block4.sigmoid"]
    .squeeze()
    .setName("output");
  forwardPassTimeControls.stop();

  // process output if needed
  if (processOutput) {
    outputProcessingTimeControls.start();
    output = output.bilinear2d([500, 500]).threshold(0.4);
    outputProcessingTimeControls.stop();
  } else {
    outputProcessingTimeControls.reset();
  }

  // render tensors on canvas references
  renderingTimeControls.start();
  output.renderOnCanvas(
    canvasRefs.current.output,
    processOutput ? "bilinear2d" : "nearest2d"
  );
  intermediates["after_fc"]
    .reshape([8, 64])
    .renderOnCanvas(canvasRefs.current.afterLinear);
  intermediates["after_conv_t_block0.relu"].split().forEach((tensor, i) => {
    tensor.renderOnCanvas(canvasRefs.current.firstTransposeConv[i]);
  });
  intermediates["after_conv_t_block1.relu"].split().forEach((tensor, i) => {
    tensor.renderOnCanvas(canvasRefs.current.secondTransposeConv[i]);
  });
  intermediates["after_conv_t_block2.relu"].split().forEach((tensor, i) => {
    tensor.renderOnCanvas(canvasRefs.current.thirdTransposeConv[i]);
  });
  intermediates["after_conv_t_block3.relu"].split().forEach((tensor, i) => {
    tensor.renderOnCanvas(canvasRefs.current.fourthTransposeConv[i]);
  });
  intermediates["after_conv_t_block4.conv_t"]
    .squeeze()
    .renderOnCanvas(canvasRefs.current.outputBeforeSigmoid);
  intermediates["after_conv_t_block4.sigmoid"]
    .squeeze()
    .renderOnCanvas(canvasRefs.current.outputAfterSigmoid);
  renderingTimeControls.stop();

  // free useless tensors
  Tensor.freeAllExcept([latent, ...modelWeights]);
}
