import { jsTorchWasmInstance } from "../wasm";
import { Tensor } from "./tensor";

const image = {
  bilinear2d(this: Tensor, outputShape: [number, number]): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.bilinear2d on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.is2d())
        throw new Error("Tensor must be 2D but got " + this.shape);
      const pointer = jsTorchWasmInstance.exports.bilinear2d(
        this.wasmData.pointer,
        this.shape[0],
        this.shape[1],
        outputShape[0],
        outputShape[1]
      );
      return new Tensor({
        shape: outputShape,
        pointer,
      });
    } else
      throw new Error("Tensor.bilinear2d not implemented for WebGL tensors");
  },

  bicubic2d(this: Tensor, outputShape: [number, number]): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.bicubic2d on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.is2d())
        throw new Error("Tensor must be 2D but got " + this.shape);
      const pointer = jsTorchWasmInstance.exports.bicubic2d(
        this.wasmData.pointer,
        this.shape[0],
        this.shape[1],
        outputShape[0],
        outputShape[1]
      );
      return new Tensor({
        shape: outputShape,
        pointer,
      });
    } else
      throw new Error("Tensor.bicubic2d not implemented for WebGL tensors");
  },

  nearest2d(this: Tensor, outputShape: [number, number]): Tensor {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.nearest2d on wasm tensor without jsTorchWasmInstance"
        );
      if (!this.is2d()) {
        throw new Error("Tensor must be 2D but got " + this.shape);
      }
      const pointer = jsTorchWasmInstance.exports.nearest2d(
        this.wasmData.pointer,
        this.shape[0],
        this.shape[1],
        outputShape[0],
        outputShape[1]
      );
      return new Tensor({
        shape: outputShape,
        pointer,
      });
    } else
      throw new Error("Tensor.nearest2d not implemented for WebGL tensors");
  },

  toPixels(this: Tensor): [Uint8ClampedArray, () => void] {
    if (this.wasmData) {
      if (!jsTorchWasmInstance)
        throw new Error(
          "cannot call Tensor.toPixels on wasm tensor without jsTorchWasmInstance"
        );
      if (this.is2d()) {
        const pointer = jsTorchWasmInstance.exports.tensor2dToPixels(
          this.wasmData.pointer,
          this.shape[0],
          this.shape[1]
        );
        const free = () => {
          if (!jsTorchWasmInstance)
            throw new Error(
              "Error freeing tensor pixels, jsTorchWasmInstance is null"
            );
          jsTorchWasmInstance.exports.free(pointer);
        };
        return [
          new Uint8ClampedArray(
            jsTorchWasmInstance.exports.memory.buffer,
            pointer,
            this.size * 4
          ),
          free,
        ];
      }
      if (this.is3d()) throw new Error("3D tensor to pixels not implemented");

      throw new Error("Tensor must be 2D or 3D");
    } else throw new Error("Tensor.toPixels not implemented for WebGL tensors");
  },
};

export default image;
