import { Tensor, TensorDevice } from "./tensor";

const gettersSetters = {
  _setShape(this: Tensor, shape: number[]): Tensor {
    const newSize = shape.reduce((a, b) => a * b);
    if (newSize !== this.size)
      throw new Error(
        `original size ${this.size} must match new size ${newSize} at Tensor._setShape`
      );
    this.shape = shape;
    return this;
  },

  setName(this: Tensor, name: string): Tensor {
    this.name = name;
    return this;
  },

  get1d(this: Tensor, i: number): number {
    if (this.wasmData) {
      return this.wasmData.float32Array[i];
    } else throw new Error("Tensor.get1d not implemented for WebGL tensors");
  },

  get2d(this: Tensor, i: number, j: number): number {
    if (this.wasmData) {
      return this.wasmData.float32Array[i * this.shape[1] + j];
    } else throw new Error("Tensor.get2d not implemented for WebGL tensors");
  },

  get3d(this: Tensor, i: number, j: number, k: number): number {
    if (this.wasmData) {
      return this.wasmData.float32Array[
        i * this.shape[1] * this.shape[2] + j * this.shape[2] + k
      ];
    } else throw new Error("Tensor.get3d not implemented for WebGL tensors");
  },

  get4d(this: Tensor, i: number, j: number, k: number, l: number): number {
    if (this.wasmData) {
      return this.wasmData.float32Array[
        i * this.shape[1] * this.shape[2] * this.shape[3] +
          j * this.shape[2] * this.shape[3] +
          k * this.shape[3] +
          l
      ];
    } else throw new Error("Tensor.get4d not implemented for WebGL tensors");
  },

  is1d(this: Tensor): boolean {
    return this.numDims() === 1;
  },

  is2d(this: Tensor): boolean {
    return this.numDims() === 2;
  },

  is3d(this: Tensor): boolean {
    return this.numDims() === 3;
  },

  is4d(this: Tensor): boolean {
    return this.numDims() === 4;
  },

  set1d(this: Tensor, i: number, value: number): Tensor {
    if (this.wasmData) {
      if (!this.is1d()) throw new Error("Tensor must be 1D");
      if (i < 0 || i >= this.shape[0])
        throw new Error(`Index ${i} out of bounds`);
      this.wasmData.float32Array[i] = value;
      return this;
    } else throw new Error("Tensor.set1d not implemented for WebGL tensors");
  },

  set2d(this: Tensor, i: number, j: number, value: number): Tensor {
    if (this.wasmData) {
      if (!this.is2d()) throw new Error("Tensor must be 2D");
      if (i < 0 || i >= this.shape[0] || j < 0 || j >= this.shape[1])
        throw new Error(`Index (${i}, ${j}) out of bounds`);
      this.wasmData.float32Array[i * this.shape[1] + j] = value;
      return this;
    } else throw new Error("Tensor.set2d not implemented for WebGL tensors");
  },

  set3d(this: Tensor, i: number, j: number, k: number, value: number): Tensor {
    if (this.wasmData) {
      if (!this.is3d()) throw new Error("Tensor must be 3D");
      if (
        i < 0 ||
        i >= this.shape[0] ||
        j < 0 ||
        j >= this.shape[1] ||
        k < 0 ||
        k >= this.shape[2]
      )
        throw new Error(`Index (${i}, ${j}, ${k}) out of bounds`);
      this.wasmData.float32Array[
        i * this.shape[1] * this.shape[2] + j * this.shape[2] + k
      ] = value;
      return this;
    } else throw new Error("Tensor.set3d not implemented for WebGL tensors");
  },

  set4d(
    this: Tensor,
    i: number,
    j: number,
    k: number,
    l: number,
    value: number
  ): Tensor {
    if (this.wasmData) {
      if (!this.is4d()) throw new Error("Tensor must be 4D");
      if (
        i < 0 ||
        i >= this.shape[0] ||
        j < 0 ||
        j >= this.shape[1] ||
        k < 0 ||
        k >= this.shape[2] ||
        l < 0 ||
        l >= this.shape[3]
      )
        throw new Error(`Index (${i}, ${j}, ${k}, ${l}) out of bounds`);
      this.wasmData.float32Array[
        i * this.shape[1] * this.shape[2] * this.shape[3] +
          j * this.shape[2] * this.shape[3] +
          k * this.shape[3] +
          l
      ] = value;
      return this;
    } else throw new Error("Tensor.set4d not implemented for WebGL tensors");
  },

  isOfShape(this: Tensor, shape: number[]): boolean {
    return (
      this.shape.length === shape.length &&
      this.shape.every((v, i) => v === shape[i])
    );
  },

  numDims(this: Tensor): number {
    return this.shape.length;
  },

  getFloat32Array(this: Tensor): Float32Array {
    if (this.wasmData) return this.wasmData.float32Array;
    else
      throw new Error(
        "Tensor.getFloat32Array not implemented for WebGL tensors"
      );
  },

  setFloat32Array(this: Tensor, data: Float32Array): Tensor {
    if (this.wasmData) this.wasmData.float32Array = data;
    else
      throw new Error(
        "Tensor.setFloat32Array not implemented for WebGL tensors"
      );
    return this;
  },

  getDevice(this: Tensor): TensorDevice {
    return this.wasmData ? TensorDevice.WASM : TensorDevice.WEBGL;
  },
};

export default gettersSetters;
