import {
  JsTorchWebAssemblyInstance,
  jsTorchWasmInstance,
  sizeOfFloat32,
} from "../wasm";
import { flattenList, listShape } from "../listUtils";
import { Tensor, nonFreedTensors } from "./tensor";

function randomSeed(): number {
  return Math.floor(Math.random() * 1000000);
}

const _static = {
  fromList(data: any[]): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error(
        "cannot call Tensor.fromList without jsTorchWasmInstance"
      );
    const flattenedData = flattenList(data);
    const pointer = jsTorchWasmInstance.exports.malloc(
      flattenedData.length * sizeOfFloat32
    );
    new Float32Array(
      jsTorchWasmInstance.exports.memory.buffer,
      pointer,
      flattenedData.length
    ).set(flattenedData);
    return new Tensor({
      shape: listShape(data),
      pointer,
    });
  },

  freeAll() {
    nonFreedTensors.forEach((tensor) => tensor.free());
  },

  freeAllExcept(exceptions: Tensor[]) {
    nonFreedTensors.forEach((tensor) => {
      if (!exceptions.includes(tensor)) {
        tensor.free();
      }
    });
  },

  zeros(
    shape: number[],
    jsTorchWasmInstance: JsTorchWebAssemblyInstance
  ): Tensor {
    const size = shape.reduce((a, b) => a * b);
    const pointer = jsTorchWasmInstance.exports.zeros(size);
    return new Tensor({
      shape,
      pointer,
    });
  },

  random(
    shape: number[],
    jsTorchWasmInstance: JsTorchWebAssemblyInstance
  ): Tensor {
    const size = shape.reduce((a, b) => a * b);
    const pointer = jsTorchWasmInstance.exports.random_(size, randomSeed());
    return new Tensor({
      shape,
      pointer,
    });
  },

  normal(shape: number[]): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error("cannot call Tensor.normal without jsTorchWasmInstance");
    const size = shape.reduce((a, b) => a * b);
    const pointer = jsTorchWasmInstance.exports.normal(
      size,
      randomSeed()
    );
    return new Tensor({
      shape,
      pointer,
    });
  },
};

export default _static;
