import { useEffect, useState } from "react";
import { fetchAndInstantiateWasm } from "../utils/wasmUtils";

const wasmPath = "/wasm/matmul.wasm";

function matrixTable(
  matrix: number[][],
  editable: boolean = false,
  setMatrix?: (matrix: number[][]) => void
) {
  if (editable && !setMatrix) {
    throw new Error("setMatrix is required when editable is true");
  }

  const handleInputChange = (i: number, j: number, value: string) => {
    if (!setMatrix) return;
    const newMatrix = [...matrix];
    newMatrix[i][j] = editable ? Number(value) : matrix[i][j];
    setMatrix(newMatrix);
  };

  const addRow = () => {
    if (!setMatrix) return;
    const newMatrix = [...matrix, new Array(matrix[0].length).fill(0)];
    setMatrix(newMatrix);
  };

  const removeRow = () => {
    if (!setMatrix) return;
    const newMatrix = matrix.slice(0, matrix.length - 1);
    setMatrix(newMatrix);
  };

  const addColumn = () => {
    if (!setMatrix) return;
    const newMatrix = matrix.map((row) => [...row, 0]);
    setMatrix(newMatrix);
  };

  const removeColumn = () => {
    if (!setMatrix) return;
    const newMatrix = matrix.map((row) => row.slice(0, row.length - 1));
    setMatrix(newMatrix);
  };

  return (
    <div>
      {editable && (
        <div className="grid grid-cols-2 gap-2 p-2">
          <button onClick={addRow}>Add row</button>
          <button onClick={removeRow}>Remove row</button>
          <button onClick={addColumn}>Add column</button>
          <button onClick={removeColumn}>Remove column</button>
        </div>
      )}
      <table className="border border-gray-300">
        <tbody>
          {matrix.map((row, i) => (
            <tr key={i} className="border-b border-gray-300">
              {row.map((col, j) => (
                <td key={j} className="border-r border-gray-300 p-2">
                  <input
                    type="number"
                    value={Math.round(matrix[i][j] * 100) / 100}
                    onChange={(e) => handleInputChange(i, j, e.target.value)}
                    className="w-16 text-center"
                  />
                </td>
              ))}
            </tr>
          ))}
        </tbody>
      </table>
    </div>
  );
}

type ExtendedWebAssemblyInstance = WebAssembly.Instance & {
  exports: WebAssembly.Exports & {
    malloc: (size: number) => number;
    free: (ptr: number) => void;
    matmul: (
      aPtr: number,
      bPtr: number,
      aHeight: number,
      aWidth: number,
      bHeight: number,
      bWidth: number
    ) => number;
    memory: WebAssembly.Memory;
  };
};

const flatten2DArray = (arr: number[][]) => {
  return arr.reduce((acc, val) => acc.concat(val), []);
};

const matrixShape = (matrix: number[][]) => {
  return [matrix.length, matrix[0].length];
};

const matrixWidth = (matrix: number[][]) => {
  return matrix[0].length;
};

const matrixHeight = (matrix: number[][]) => {
  return matrix.length;
};

const matrixSize = (matrix: number[][]) => {
  return matrix.length * matrix[0].length;
};

export default function SandboxWasm() {
  // define matrices
  const [a, setA] = useState<number[][]>([[1, 2, 3, 4]]);
  const [b, setB] = useState<number[][]>([[5], [6], [7], [8]]);
  const [validMultiplication, setValidMultiplication] = useState<boolean>(true);

  const [result, setResult] = useState<number[][] | null>(null);
  const [matmulInstance, setMatmulInstance] =
    useState<ExtendedWebAssemblyInstance | null>(null);

  // load adder.wasm
  useEffect(() => {
    fetchAndInstantiateWasm(wasmPath, {
      env: {
        emscripten_resize_heap: () => {
          throw new Error("emscripten_resize_heap was called");
        },
      },
    }).then((instance) => {
      setMatmulInstance(instance as ExtendedWebAssemblyInstance);
    });
  }, []);

  const callMatmul = () => {
    if (!matmulInstance) return;

    // check if matrices are valid for multiplication
    if (matrixWidth(a) !== matrixHeight(b)) {
      setValidMultiplication(false);
      setResult(null);
      return;
    }
    setValidMultiplication(true);

    const sizeOfFloat32 = new Float32Array(1).byteLength;

    const aPtr = matmulInstance.exports.malloc(matrixSize(a) * sizeOfFloat32);
    const bPtr = matmulInstance.exports.malloc(matrixSize(b) * sizeOfFloat32);

    new Float32Array(
      matmulInstance.exports.memory.buffer,
      aPtr,
      matrixSize(a)
    ).set(flatten2DArray(a));
    new Float32Array(
      matmulInstance.exports.memory.buffer,
      bPtr,
      matrixSize(b)
    ).set(flatten2DArray(b));

    const resultPtr = matmulInstance.exports.matmul(
      aPtr,
      bPtr,
      matrixWidth(a),
      matrixHeight(a),
      matrixWidth(b),
      matrixHeight(b)
    );
    const resultView = new Float32Array(
      matmulInstance.exports.memory.buffer,
      resultPtr,
      matrixHeight(a) * matrixWidth(b)
    );
    const result = Array.from({ length: matrixHeight(a) }, (_, i) =>
      Array.from(resultView.slice(i * matrixWidth(b), (i + 1) * matrixWidth(b)))
    );
    setResult(result);

    matmulInstance.exports.free(aPtr);
    matmulInstance.exports.free(bPtr);
    matmulInstance.exports.free(resultPtr);
  };

  const randomInt = (min: number, max: number) =>
    Math.floor(Math.random() * (max - min + 1) + min);

  const randomMatrix = (width: number, height: number) => {
    const matrix = [];
    for (let i = 0; i < width; i++) {
      const row = [];
      for (let j = 0; j < height; j++) {
        row.push(randomInt(1, 10));
      }
      matrix.push(row);
    }
    return matrix;
  };

  const setRandomMatrices = () => {
    const [i, j, k] = [randomInt(1, 5), randomInt(1, 5), randomInt(1, 5)];
    setA(randomMatrix(i, j));
    setB(randomMatrix(j, k));
  };

  useEffect(callMatmul, [a, b, matmulInstance]);

  return (
    <div>
      <button onClick={setRandomMatrices}>Set random matrices</button>
      <p>Matrix A (shape: {matrixShape(a).join("x")}):</p>
      {matrixTable(a, true, setA)}
      <p>Matrix B (shape: {matrixShape(b).join("x")}):</p>
      {matrixTable(b, true, setB)}
      {!validMultiplication && (
        <div className="text-red-500">Invalid matrices for multiplication</div>
      )}
      <div className="mt-8">
        {result && <div>Matrix A @ Matrix B: {matrixTable(result)}</div>}
      </div>
    </div>
  );
}
