import { useState, useEffect } from "react";
import { fetchAndInstantiateWasm } from "../../utils/wasmUtils";

export type JsTorchWebAssemblyInstance = WebAssembly.Instance & {
  exports: WebAssembly.Exports & {
    malloc: (size: number) => number;
    free: (ptr: number) => void;
    memory: WebAssembly.Memory;
    matmul: (
      aPtr: number, // 2D tensor [row, column]
      bPtr: number, // 2D tensor [row, column]
      aShape0: number,
      aShape1: number,
      bShape0: number,
      bShape1: number
    ) => number; // pointer to 2D tensor [row, column]
    convTranspose2d: (
      inputPtr: number, // 3D tensor [inputChannel, row, column]
      weightPtr: number, // 4D tensor [inputChannel, outputChannel, row, column]
      biasPtr: number, // 1D tensor [outputChannel]
      inputShape0: number,
      inputShape1: number,
      inputShape2: number,
      weightShape0: number,
      weightShape1: number,
      weightShape2: number,
      weightShape3: number,
      biasShape0: number,
      stride: number,
      padding: number,
      outputPadding: number,
      outputShape0: number,
      outputShape1: number,
      outputShape2: number
    ) => number; // pointer to 3D tensor [outputChannel, row, column]
    batchNorm2d: (
      inputPtr: number, // 3D tensor [channel, row, column]
      weightPtr: number, // 1D tensor [channel]
      biasPtr: number, // 1D tensor [channel]
      runningMeanPtr: number, // 1D tensor [channel]
      runningVariancePtr: number, // 1D tensor [channel]
      epsilon: number,
      inputShape0: number,
      inputShape1: number,
      inputShape2: number
    ) => number; // pointer to 3D tensor [channel, row, column
    relu: (
      inputPtr: number, // any dimensional tensor of size x
      inputSize: number // x
    ) => number; // pointer to any dimensional tensor of size x
    sigmoid: (
      inputPtr: number, // any dimensional tensor of size x
      inputSize: number // x
    ) => number; // pointer to any dimensional tensor of size x
    zeros: (size: number) => number; // pointer to any dimensional tensor
    random_: (size: number, seed: number) => number; // pointer to any dimensional tensor
    normal: (size: number, seed: number) => number; // pointer to any dimensional tensor
    min: (ptr: number, size: number) => number;
    max: (ptr: number, size: number) => number;
    add: (
      aPtr: number, // any dimensional tensor of size x
      bPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number; // pointer to any dimensional tensor of size x
    addScalar: (
      aPtr: number, // any dimensional tensor of size x
      size: number, // x
      scalar: number
    ) => number; // pointer to any dimensional tensor of size x
    sub: (
      aPtr: number, // any dimensional tensor of size x
      bPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number; // pointer to any dimensional tensor of size x
    subScalar: (
      aPtr: number, // any dimensional tensor of size x
      size: number, // x
      scalar: number
    ) => number; // pointer to any dimensional tensor of size x
    mul: (
      aPtr: number, // any dimensional tensor of size x
      bPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number; // pointer to any dimensional tensor of size x
    mulScalar: (
      aPtr: number, // any dimensional tensor of size x
      size: number, // x
      scalar: number
    ) => number; // pointer to any dimensional tensor of size x
    div_: (
      aPtr: number, // any dimensional tensor of size x
      bPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number; // pointer to any dimensional tensor of size x
    divScalar: (
      aPtr: number, // any dimensional tensor of size x
      size: number, // x
      scalar: number
    ) => number; // pointer to any dimensional tensor of size x
    norm: (
      aPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number;
    bilinear2d: (
      inputPtr: number, // 2D tensor [row, column]
      inputShape0: number,
      inputShape1: number,
      outputShape0: number,
      outputShape1: number
    ) => number; // pointer to 2D tensor [outputShape0, outputShape1]
    bicubic2d: (
      inputPtr: number, // 2D tensor [row, column]
      inputShape0: number,
      inputShape1: number,
      outputShape0: number,
      outputShape1: number
    ) => number; // pointer to 2D tensor [outputShape0, outputShape1]
    nearest2d: (
      inputPtr: number, // 2D tensor [row, column]
      inputShape0: number,
      inputShape1: number,
      outputShape0: number,
      outputShape1: number
    ) => number; // pointer to 2D tensor [outputShape0, outputShape1]
    threshold: (
      inputPtr: number, // any dimensional tensor of size x
      size: number, // x
      threshold: number
    ) => number; // pointer to any dimensional tensor of size x
    tensor2dToPixels: (
      inputPtr: number, // 2D tensor [row, column]
      inputShape0: number,
      inputShape1: number
    ) => number; // pointer to Uint8ClampedArray
    std_: (
      inputPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number;
    mean: (
      inputPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number;
    normalize: (
      inputPtr: number, // any dimensional tensor of size x
      size: number // x
    ) => number; // pointer to any dimensional tensor of size x
    normalizeExcept: (
      inputPtr: number, // any dimensional tensor of size x
      size: number, // x
      exceptIndex: number,
      iterations: number
    ) => number; // pointer to any dimensional tensor of size x
    closeTo: (
      aPtr: number, // any dimensional tensor of size x
      bPtr: number, // any dimensional tensor of size x
      size: number, // x
      threshold: number
    ) => boolean;
  };
};

const jsTorchWasmPath = "/wasm/jstorch.wasm";
export const sizeOfFloat32 = Float32Array.BYTES_PER_ELEMENT;
export let jsTorchWasmInstance: JsTorchWebAssemblyInstance | null = null;

if (jsTorchWasmInstance == null) {
  fetchAndInstantiateWasm(jsTorchWasmPath, {
    env: {
      emscripten_resize_heap: (requestedSize: number) => {
        if (
          !(
            jsTorchWasmInstance &&
            jsTorchWasmInstance.exports.memory instanceof WebAssembly.Memory
          )
        )
          throw new Error("Memory not yet initialized");

        const memory = jsTorchWasmInstance.exports.memory;
        const currentPages = memory.buffer.byteLength / 65536;
        const requestedPages = Math.ceil(requestedSize / 65536);
        try {
          memory.grow(requestedPages - currentPages);
          return true;
        } catch (e) {
          throw new Error("Failed to grow memory: " + e);
        }
      },
      emscripten_date_now: () => {
        return Date.now();
      },
    },
  }).then((instance) => {
    jsTorchWasmInstance = instance as JsTorchWebAssemblyInstance;
    console.log("jsTorchWasmInstance loaded");
  });
}

export function useJsTorchWasmLoaded() {
  const [jsTorchWasLoaded, setJsTorchWasLoaded] = useState(false);

  // create interval that keeps checking if jsTorchWasmInstance is loaded
  useEffect(() => {
    const interval = setInterval(() => {
      if (jsTorchWasmInstance) {
        setJsTorchWasLoaded(true);
        clearInterval(interval);
      }
    }, 100);
    return () => clearInterval(interval);
  }, []);

  return jsTorchWasLoaded;
}
