import { jsTorchWasmInstance, sizeOfFloat32 } from "../wasm";
import { maxTextureSize } from "../webgl/webgl";
import {
  cloneBuffer,
  createBufferFromArray,
  deleteBuffer,
  deleteTexture,
  padBuffer,
  readBufferToArray,
} from "../webgl/webglLowLevel";
import { Tensor, TensorDevice, nonFreedTensors } from "./tensor";

const memory = {
  free(this: Tensor) {
    if (this.freed) {
      throw new Error("Tensor already freed");
    }
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.free on wasm tensor without jsTorchWasmInstance"
        );
      jsTorchWasmInstance.exports.free(this.wasmData.pointer);
    } else if (this.webglData) {
      this.freeTexture();
      deleteBuffer(this.webglData.buffer);
    } else throw new Error("tensor is neither wasm nor webgl tensor");

    this.freed = true;
    nonFreedTensors.splice(nonFreedTensors.indexOf(this), 1);
  },

  copy(this: Tensor): Tensor {
    if (this.freed) {
      throw new Error("Cannot copy a freed tensor");
    }
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.copy on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.malloc(
        this.size * sizeOfFloat32
      );
      new Float32Array(
        jsTorchWasmInstance.exports.memory.buffer,
        pointer,
        this.size
      ).set(this.wasmData.float32Array);
      return new Tensor({
        shape: this.shape,
        pointer,
      }).setName(this.name);
    } else if (this.webglData) {
      const buffer = cloneBuffer(this.webglData.buffer);
      return new Tensor({ shape: this.shape, buffer }).setName(this.name);
    } else
      throw new Error(
        "Tensor.copy only implemented for wasm and webgl tensors"
      );
  },

  toWebgl(this: Tensor): Tensor {
    if (this.webglData) throw new Error("Tensor already in WebGL format");
    else {
      let buffer = createBufferFromArray(this.getFloat32Array());

      // all buffer sizes must be a multiple of maxTextureSize
      // otherwise, the textures wont be able to be created
      if (this.size > maxTextureSize)
        buffer = padBuffer(
          buffer,
          Math.ceil(this.size / maxTextureSize) * maxTextureSize
        );

      return new Tensor({ shape: this.shape, buffer });
    }
  },

  toWasm(this: Tensor): Tensor {
    if (this.wasmData) throw new Error("Tensor already in WASM format");
    else if (this.webglData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.toWasm on webgl tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.malloc(
        this.size * sizeOfFloat32
      );
      const array = new Float32Array(
        jsTorchWasmInstance.exports.memory.buffer,
        pointer,
        this.size
      );
      readBufferToArray(this.webglData.buffer, array);
      return new Tensor({ shape: this.shape, pointer });
    }
    throw new Error("Tensor.toWasm not implemented for this tensor");
  },

  toDevice(this: Tensor, device: TensorDevice): Tensor {
    if (device === TensorDevice.WASM) {
      // if already in wasm format, return itself
      if (this.wasmData) return this;
      // otherwise convert to wasm
      else return this.toWasm();
    } else if (device === TensorDevice.WEBGL) {
      // if already in webgl format, return itself
      if (this.webglData) return this;
      // otherwise convert to webgl
      else return this.toWebgl();
    } else throw new Error("Invalid device");
  },

  freeTexture(this: Tensor) {
    if (this.webglData) {
      if (this.webglData.texture) {
        deleteTexture(this.webglData.texture);
        this.webglData.texture = null;
        this.webglData.textureUnit = null;
      }
    } else
      throw new Error("Tensor.freeTexture not implemented for WASM tensors");
  },
};

export default memory;
