import { useState, useEffect, useCallback } from "react";
import { useJsTorchWasmLoaded } from "../jstorch/wasm";
import Button from "./Button";
import Checkbox from "./Checkbox";
import { Tensor } from "../jstorch/tensor/tensor";

function sliderTextColor(value: number) {
  if (value === 0) return "text-black";
  if (value < 0) return "text-red-500";
  return "text-green-500";
}

function meanValueTextColor(value: number) {
  if (Math.abs(value) > 0.05) return "text-red-500";
  return "text-gray-500";
}

function stdValueTextColor(value: number) {
  if (Math.abs(value - 1) > 0.05) return "text-red-500";
  return "text-gray-500";
}

export default function LatentSliders({
  latentSize,
  initialLatent,
  onLatentChange,
}: {
  latentSize: number;
  initialLatent: number[];
  onLatentChange: (latent: Tensor) => void;
}) {
  const jsTorchWasmLoaded = useJsTorchWasmLoaded();
  const [autoNormalize, setAutoNormalize] = useState(true);
  const [latent, setLatent] = useState<Tensor | null>(null);
  const [latentMean, setLatentMean] = useState<number | null>(null);
  const [latentStd, setLatentStd] = useState<number | null>(null);

  const changeLatent = useCallback(
    (newLatent: Tensor) => {
      onLatentChange(newLatent);
      setLatentMean(newLatent.mean());
      setLatentStd(newLatent.std());
      setLatent(newLatent);
    },
    [onLatentChange]
  );

  useEffect(() => {
    if (!jsTorchWasmLoaded) return;
    const latent = Tensor.fromList(initialLatent).setName("latent");
    changeLatent(latent);
  }, [jsTorchWasmLoaded, initialLatent, changeLatent]);

  if (latent === null || latentMean === null || latentStd === null)
    return <p>Loading...</p>;

  return (
    <div>
      <h2 className="my-2">
        Modify the latent vector below to generate different variations of
        digits from 0 to 9
      </h2>
      <div className="grid grid-cols-4 gap-3">
        {Array.from({ length: latentSize }).map((_, i) => (
          <div
            className="flex flex-col items-center bg-gray-200 p-1 rounded shadow-md"
            key={i}
          >
            <input
              key={i}
              type="range"
              min={-3}
              max={3}
              step={0.01}
              value={latent.get1d(i)}
              className="w-full"
              onChange={(e) => {
                if (autoNormalize) {
                  changeLatent(
                    latent
                      .set1d(i, parseFloat(e.target.value))
                      .normalizeExcept(i, 10)
                      .copy()
                  );
                } else {
                  changeLatent(
                    latent.set1d(i, parseFloat(e.target.value)).copy()
                  );
                }
              }}
            />
            <p className={sliderTextColor(latent.get1d(i)) + " text-sm"}>
              {latent.get1d(i).toFixed(2)}
            </p>
          </div>
        ))}
      </div>
      <p>
        Latent Mean:{" "}
        <b className={meanValueTextColor(latentMean)}>
          {Math.abs(latentMean) < 0.01 ? "0.00" : latentMean.toFixed(2)}
        </b>{" "}
        - Latent Std:{" "}
        <b className={stdValueTextColor(latentStd)}>{latentStd.toFixed(2)}</b>
      </p>
      {(Math.abs(latentMean) > 0.05 || Math.abs(latentStd - 1) > 0.05) && (
        <p className="text-red-500 text-sm">
          Latent vector is not normalized. Normalized latent vectors generate
          better results
        </p>
      )}

      <div className="flex flex-row gap-2 justify-center mt-2">
        <Button
          onClick={() => {
            changeLatent(Tensor.normal([latentSize]).setName("latent"));
          }}
        >
          Random normal
        </Button>
        <Checkbox
          label="Auto normalize"
          checked={autoNormalize}
          onChange={(checked) => {
            if (checked) changeLatent(latent.normalize());
            setAutoNormalize(checked);
          }}
        />
      </div>
    </div>
  );
}
