import { jsTorchWasmInstance, sizeOfFloat32 } from "../wasm";
import { Tensor } from "./tensor";

const basicOperations = {
  reshape(this: Tensor, shape: number[]): Tensor {
    if (this.freed) {
      throw new Error("Cannot reshape a freed tensor");
    }
    const newTensor = this.copy();
    newTensor._setShape(shape);
    return newTensor;
  },

  unsqueeze(this: Tensor, dim: number): Tensor {
    // check if the dimension is valid
    if (dim < 0 || dim > this.numDims()) {
      throw new Error(`Invalid dimension ${dim}`);
    }

    // add a new dimension at the specified index
    this.shape.splice(dim, 0, 1);
    return this;
  },

  squeeze(this: Tensor): Tensor {
    // remove all dimensions of size 1
    this._setShape(this.shape.filter((v) => v !== 1));
    return this;
  },

  min(this: Tensor): number {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.min on wasm tensor without jsTorchWasmInstance"
        );
      return jsTorchWasmInstance.exports.min(
        this.wasmData.pointer,
        this.size
      );
    } else throw new Error("Tensor.min not implemented for WebGL tensors");
  },

  max(this: Tensor): number {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.max on wasm tensor without jsTorchWasmInstance"
        );
      return jsTorchWasmInstance.exports.max(
        this.wasmData.pointer,
        this.size
      );
    } else throw new Error("Tensor.max not implemented for WebGL tensors");
  },

  add(this: Tensor, tensor: Tensor): Tensor {
    if (this.getDevice() !== tensor.getDevice())
      throw new Error("Cannot add tensors of different devices");
    if (this.wasmData && tensor.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.add on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.isOfShape(tensor.shape))
        throw new Error("Tensor shapes must match");
      const pointer = jsTorchWasmInstance.exports.add(
        this.wasmData.pointer,
        tensor.wasmData.pointer,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else throw new Error("Tensor.add not implemented for WebGL tensors");
  },

  addScalar(this: Tensor, scalar: number): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.addScalar on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.addScalar(
        this.wasmData.pointer,
        scalar,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.addScalar not implemented for WebGL tensors");
  },

  sub(this: Tensor, tensor: Tensor): Tensor {
    if (this.getDevice() !== tensor.getDevice())
      throw new Error("Cannot subtract tensors of different devices");
    if (this.wasmData && tensor.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.sub on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.isOfShape(tensor.shape))
        throw new Error("Tensor shapes must match");
      const pointer = jsTorchWasmInstance.exports.sub(
        this.wasmData.pointer,
        tensor.wasmData.pointer,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else throw new Error("Tensor.sub not implemented for WebGL tensors");
  },

  subScalar(this: Tensor, scalar: number): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.subScalar on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.subScalar(
        this.wasmData.pointer,
        scalar,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.subScalar not implemented for WebGL tensors");
  },

  mul(this: Tensor, tensor: Tensor): Tensor {
    if (this.getDevice() !== tensor.getDevice())
      throw new Error("Cannot multiply tensors of different devices");
    if (this.wasmData && tensor.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.mul on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.isOfShape(tensor.shape))
        throw new Error("Tensor shapes must match");
      const pointer = jsTorchWasmInstance.exports.mul(
        this.wasmData.pointer,
        tensor.wasmData.pointer,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else throw new Error("Tensor.mul not implemented for WebGL tensors");
  },

  mulScalar(this: Tensor, scalar: number): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.mulScalar on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.mulScalar(
        this.wasmData.pointer,
        scalar,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.mulScalar not implemented for WebGL tensors");
  },

  div(this: Tensor, tensor: Tensor): Tensor {
    if (this.getDevice() !== tensor.getDevice())
      throw new Error("Cannot divide tensors of different devices");
    if (this.wasmData && tensor.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.div on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.isOfShape(tensor.shape))
        throw new Error("Tensor shapes must match");

      const pointer = jsTorchWasmInstance.exports.div_(
        this.wasmData.pointer,
        tensor.wasmData.pointer,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else throw new Error("Tensor.div not implemented for WebGL tensors");
  },

  divScalar(this: Tensor, scalar: number): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.divScalar on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.divScalar(
        this.wasmData.pointer,
        scalar,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.divScalar not implemented for WebGL tensors");
  },

  norm(this: Tensor): number {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.norm on wasm tensor without jsTorchWasmInstance"
        );
      return jsTorchWasmInstance.exports.norm(
        this.wasmData.pointer,
        this.size
      );
    } else throw new Error("Tensor.norm not implemented for WebGL tensors");
  },

  threshold(this: Tensor, threshold: number): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.threshold on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.threshold(
        this.wasmData.pointer,
        this.size,
        threshold
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.threshold not implemented for WebGL tensors");
  },

  std(this: Tensor): number {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.std on wasm tensor without jsTorchWasmInstance"
        );
      return jsTorchWasmInstance.exports.std_(
        this.wasmData.pointer,
        this.size
      );
    } else throw new Error("Tensor.std not implemented for WebGL tensors");
  },

  mean(this: Tensor): number {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.mean on wasm tensor without jsTorchWasmInstance"
        );
      return jsTorchWasmInstance.exports.mean(
        this.wasmData.pointer,
        this.size
      );
    } else throw new Error("Tensor.mean not implemented for WebGL tensors");
  },

  normalize(this: Tensor): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.normalize on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.normalize(
        this.wasmData.pointer,
        this.size
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error("Tensor.normalize not implemented for WebGL tensors");
  },

  normalizeExcept(
    this: Tensor,
    exceptIndex: number,
    iterations: number
  ): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.normalizeExcept on wasm tensor without jsTorchWasmInstance"
        );
      const pointer = jsTorchWasmInstance.exports.normalizeExcept(
        this.wasmData.pointer,
        this.size,
        exceptIndex,
        iterations
      );
      return new Tensor({
        shape: this.shape,
        pointer,
      });
    } else
      throw new Error(
        "Tensor.normalizeExcept not implemented for WebGL tensors"
      );
  },

  split(this: Tensor): Tensor[] {
    // split a x-dimensional tensor into x-1 dimensional tensors by the first dimension
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.split on wasm tensor without jsTorchWasmInstance"
        );
      if (this.is1d()) throw new Error("Cannot split a 1D tensor");

      const tensors: Tensor[] = [];
      const tensorSize = this.size / this.shape[0];
      const tensorShape = this.shape.slice(1);
      for (let i = 0; i < this.shape[0]; i++) {
        const originPointer =
          this.wasmData.pointer + i * tensorSize * sizeOfFloat32;
        const pointer = jsTorchWasmInstance.exports.malloc(
          tensorSize * sizeOfFloat32
        );
        new Float32Array(
          jsTorchWasmInstance.exports.memory.buffer,
          pointer,
          tensorSize
        ).set(
          new Float32Array(
            jsTorchWasmInstance.exports.memory.buffer,
            originPointer,
            tensorSize
          )
        );
        tensors.push(
          new Tensor({
            shape: tensorShape,
            pointer,
          })
        );
      }
      return tensors;
    } else throw new Error("Tensor.split not implemented for WebGL tensors");
  },

  closeTo(this: Tensor, tensor: Tensor, threshold: number): boolean {
    if (this.getDevice() !== tensor.getDevice())
      throw new Error(
        `Cannot compare tensors of different devices at Tensor.closeTo`
      );

    if (this.wasmData && tensor.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          `cannot call Tensor.closeTo on wasm tensor without jsTorchWasmInstance`
        );
      if (!this.isOfShape(tensor.shape))
        throw new Error(
          `Tensor shapes must match at Tensor.closeTo, but got ${this.shape} and ${tensor.shape}`
        );

      return jsTorchWasmInstance.exports.closeTo(
        this.wasmData.pointer,
        tensor.wasmData.pointer,
        this.size,
        threshold
      );
    } else throw new Error("Tensor.closeTo not implemented for WebGL tensors");
  },
};

export default basicOperations;
