import pako from "pako";
import { Tensor, TensorDevice } from "../jstorch/tensor/tensor";

const fetchCache: { [key: string]: ArrayBuffer } = {};
async function fetchCached(path: string) {
  if (fetchCache[path]) {
    return fetchCache[path];
  }
  const response = await fetch(path);
  const buffer = await response.arrayBuffer();
  fetchCache[path] = buffer;
  return buffer;
}

export async function loadGzString(path: string) {
  const buffer = await fetchCached(path);
  return pako.inflate(new Uint8Array(buffer), { to: "string" });
}

export async function loadGzJson(path: string) {
  const buffer = await fetchCached(path);
  return JSON.parse(pako.inflate(new Uint8Array(buffer), { to: "string" }));
}

export async function loadGzBytes(path: string) {
  const buffer = await fetchCached(path);
  return pako.inflate(new Uint8Array(buffer));
}

type WeightInfo = {
  shape: number[];
  zero_point: number;
  scale: number;
  num_bits: number;
  first_5_original: number[];
  first_5_quantized: number[];
  first_5_unquantized: number[];
};

export async function loadModelWeights(
  jsonPath: string,
  weightsPath: string,
  tensorDevice: TensorDevice
) {
  const weightsJson: { [key: string]: WeightInfo } = await loadGzJson(jsonPath);
  const weightsBytes = await loadGzBytes(weightsPath);

  // first check if number of parameters match
  let numParamsJson = 0;
  for (const key in weightsJson) {
    numParamsJson += weightsJson[key].shape.reduce(
      (a: number, b: number) => a * b,
      1
    );
  }
  const numParamsBytes = weightsBytes.length;
  if (numParamsJson !== numParamsBytes) {
    console.error(
      "Number of parameters do not match between json and bytes, check for errors in the quantization process"
    );
  }

  // create tensors from json and bytes
  const weightTensors: { [key: string]: Tensor } = {};
  const weightsBytesList = Array.from(weightsBytes);
  let offset = 0;

  for (const key in weightsJson) {
    const weightInfo = weightsJson[key];
    const numParams = weightInfo.shape.reduce((a, b) => a * b, 1);
    const quantizedWeights = weightsBytesList.slice(offset, offset + numParams);
    offset += numParams;
    const quantizedTensor = Tensor.fromList(quantizedWeights);
    const tensor = unquantizeWeights(
      quantizedTensor,
      weightInfo.zero_point,
      weightInfo.scale,
      weightInfo.num_bits,
      weightInfo.shape
    );

    // check if first 5 values match between original reference from experiment ipynb and web unquantized tensor
    const first5 = tensor.getFloat32Array().slice(0, 5);
    const first5Match = first5.every(
      (v, i) => Math.abs(v - weightInfo.first_5_unquantized[i]) < 1e-5
    );
    if (!first5Match) {
      console.error(
        `First 5 values do not match for ${key}, check for errors in the quantization process`
      );
    }

    weightTensors[key] = tensor.toDevice(tensorDevice);
  }

  return weightTensors;
}

function unquantizeWeights(
  quantizedWeights: Tensor,
  zeroPoint: number,
  scale: number,
  numBits: number,
  shape: number[]
) {
  const precision = 2 ** numBits;
  quantizedWeights = quantizedWeights.reshape(shape);
  return quantizedWeights
    .mulScalar(scale)
    .divScalar(precision - 1)
    .addScalar(zeroPoint);
}
