import { jsTorchWasmInstance } from "../wasm";
import visualization from "./visualization";
import memory from "./memory";
import _static from "./static";
import gettersSetters from "./gettersSetters";
import basicOperations from "./basicOperations";
import image from "./image";
import webgl from "./webgl";

export let nonFreedTensors: Tensor[] = [];

export enum TensorDevice {
  WASM = "Wasm",
  WEBGL = "WebGL",
}

export function tensorDeviceFromString(device: string): TensorDevice {
  if (device === TensorDevice.WASM) return TensorDevice.WASM;
  else if (device === TensorDevice.WEBGL) return TensorDevice.WEBGL;
  else throw new Error(`Invalid tensor device: "${device}"`);
}

export class Tensor {
  shape: number[] = [];
  size: number = 0;
  name: string;
  freed: boolean = false;

  wasmData: {
    pointer: number;
    float32Array: Float32Array;
  } | null = null;
  webglData: {
    buffer: WebGLBuffer;
    texture: WebGLTexture | null;
    textureUnit: number | null;
  } | null = null;

  constructor({
    shape,
    pointer,
    buffer,
  }: {
    shape: number[];
    pointer?: number;
    buffer?: WebGLBuffer;
  }) {
    this.shape = shape;
    this.size = shape.reduce((a, b) => a * b, 1);
    nonFreedTensors.push(this);
    this.name = "";

    // wasm tensor
    if (!!pointer && !buffer) {
      if (!jsTorchWasmInstance)
        throw new Error("cannot create Tensor without jsTorchWasmInstance");
      this.wasmData = {
        pointer,
        float32Array: new Float32Array(
          jsTorchWasmInstance.exports.memory.buffer,
          pointer,
          this.size
        ),
      };
    } else if (!pointer && !!buffer) {
      this.webglData = {
        buffer,
        texture: null,
        textureUnit: null,
      };
    } else {
      throw new Error(`Invalid tensor initialization, check your arguments`);
    }

    return this;
  }

  free = memory.free;
  copy = memory.copy;
  toWebgl = memory.toWebgl;
  toWasm = memory.toWasm;
  toDevice = memory.toDevice;
  freeTexture = memory.freeTexture;

  static fromList = _static.fromList;
  static freeAll = _static.freeAll;
  static freeAllExcept = _static.freeAllExcept;
  static zeros = _static.zeros;
  static random = _static.random;
  static normal = _static.normal;

  _setShape = gettersSetters._setShape;
  setName = gettersSetters.setName;
  get1d = gettersSetters.get1d;
  get2d = gettersSetters.get2d;
  get3d = gettersSetters.get3d;
  get4d = gettersSetters.get4d;
  is1d = gettersSetters.is1d;
  is2d = gettersSetters.is2d;
  is3d = gettersSetters.is3d;
  is4d = gettersSetters.is4d;
  set1d = gettersSetters.set1d;
  set2d = gettersSetters.set2d;
  set3d = gettersSetters.set3d;
  set4d = gettersSetters.set4d;
  isOfShape = gettersSetters.isOfShape;
  numDims = gettersSetters.numDims;
  getFloat32Array = gettersSetters.getFloat32Array;
  setFloat32Array = gettersSetters.setFloat32Array;
  getDevice = gettersSetters.getDevice;

  string = visualization.string;
  toList = visualization.toList;
  print = visualization.print;
  renderOnCanvas = visualization.renderOnCanvas;

  reshape = basicOperations.reshape;
  unsqueeze = basicOperations.unsqueeze;
  squeeze = basicOperations.squeeze;
  min = basicOperations.min;
  max = basicOperations.max;
  add = basicOperations.add;
  addScalar = basicOperations.addScalar;
  sub = basicOperations.sub;
  subScalar = basicOperations.subScalar;
  mul = basicOperations.mul;
  mulScalar = basicOperations.mulScalar;
  div = basicOperations.div;
  divScalar = basicOperations.divScalar;
  norm = basicOperations.norm;
  threshold = basicOperations.threshold;
  std = basicOperations.std;
  mean = basicOperations.mean;
  normalize = basicOperations.normalize;
  normalizeExcept = basicOperations.normalizeExcept;
  split = basicOperations.split;
  closeTo = basicOperations.closeTo;

  bilinear2d = image.bilinear2d;
  bicubic2d = image.bicubic2d;
  nearest2d = image.nearest2d;
  toPixels = image.toPixels;

  toTextureAtUniformLocation = webgl.toTextureAtUniformLocation;
}
