import { jsTorchWasmInstance } from "../wasm";
import { flattenList, reshapeList } from "../listUtils";
import { Tensor, TensorDevice } from "./tensor";

const visualization = {
  string(this: Tensor): string {
    if (this.wasmData) {
      return `{{Tensor${this.name ? " " + this.name : ""} shape: ${
        this.shape
      } at byte:${this.wasmData.pointer}}}\nmin: ${this.min().toFixed(
        4
      )} max: ${this.max().toFixed(4)}`;
    } else throw new Error("Tensor.string not implemented for WebGL tensors");
  },

  toList(this: Tensor): any[] {
    if (this.wasmData) {
      return reshapeList(
        Array.from(this.wasmData.float32Array),
        this.shape
      ) as any[];
    } else throw new Error("Tensor.toList not implemented for WebGL tensors");
  },

  print(this: Tensor, numValuesShown: number | "all" = 10): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error("globalJsTorchWasmInstance not set");
    const wasmTensor = this.toDevice(TensorDevice.WASM);
    const data = flattenList(wasmTensor.toList());
    console.log(wasmTensor.string());
    numValuesShown = numValuesShown === "all" ? data.length : numValuesShown;
    console.log(
      data.slice(0, numValuesShown).map((v) => Math.round(v * 10000) / 10000)
    );
    return this;
  },

  renderOnCanvas(
    this: Tensor,
    canvas: HTMLCanvasElement | null,
    upsampling: "nearest2d" | "bilinear2d" = "nearest2d"
  ) {
    if (!canvas) return;
    // resize tensor if not matching canvas size
    let resizedTensor: Tensor;
    if (this.shape[0] !== canvas.height || this.shape[1] !== canvas.width) {
      resizedTensor = this[upsampling]([canvas.height, canvas.width]);
    } else {
      resizedTensor = this;
    }

    const [pixels, free] = resizedTensor.toPixels();
    const ctx = canvas.getContext("2d");
    if (ctx) {
      const imageData = new ImageData(pixels, canvas.width, canvas.height);
      ctx.putImageData(imageData, 0, 0);
      free();
    }
  },
};

export default visualization;
