import {
  Linear,
  Unflatten,
  ConvTranspose2d,
  BatchNorm2d,
  Relu,
  Sigmoid,
  Sequential,
  Layer,
} from "../../jstorch/layers/layers";
import { loadModelWeights } from "../../utils/modelLoading";
import { Tensor, TensorDevice } from "../../jstorch/tensor/tensor";

export async function loadModel(
  setModel: (model: Sequential) => void,
  setModelWeights: (weights: Tensor[]) => void,
  tensorDevice: TensorDevice
) {
  console.log("Loading model...");

  const weights = await loadModelWeights(
    "./gz-model-weights/vae_decoder.json.gz",
    "./gz-model-weights/vae_decoder.gz",
    tensorDevice
  );

  const convTransposeBlock = (
    blockIndex: number,
    inChannels: number,
    outChannels: number,
    activation: "relu" | "sigmoid" = "relu"
  ): Layer[] => {
    const blockName = "conv_t_block" + blockIndex;
    return [
      new ConvTranspose2d(
        {
          inChannels,
          outChannels,
          kernelSize: 3,
          stride: 2,
          padding: 1,
          outputPadding: 1,
          weights: weights[`${blockName}.conv_t.weight`],
          bias: weights[`${blockName}.conv_t.bias`],
        },
        `${blockName}.conv_t`
      ),
      new BatchNorm2d(
        {
          numFeatures: outChannels,
          weights: weights[`${blockName}.bn.weight`],
          bias: weights[`${blockName}.bn.bias`],
          runningMean: weights[`${blockName}.bn.running_mean`],
          runningVariance: weights[`${blockName}.bn.running_var`],
        },
        `${blockName}.bn`
      ),
      activation === "relu"
        ? new Relu(`${blockName}.relu`)
        : new Sigmoid(`${blockName}.sigmoid`),
    ];
  };

  const [c1, c2, c3, c4, c5] = [32, 64, 128, 256, 512];
  const decoder = new Sequential(
    [
      new Linear(
        {
          inFeatures: 16,
          outFeatures: c5,
          weights: weights["fc.weight"],
          bias: weights["fc.bias"],
        },
        "fc"
      ),
      new Unflatten({ shape: [c5, 1, 1] }, "unflatten"),
      ...convTransposeBlock(0, c5, c4),
      ...convTransposeBlock(1, c4, c3),
      ...convTransposeBlock(2, c3, c2),
      ...convTransposeBlock(3, c2, c1),
      ...convTransposeBlock(4, c1, 1, "sigmoid"),
    ],
    "decoder"
  );

  console.log("Model loaded");
  setModel(decoder);
  setModelWeights(Object.values(weights));
}
