import {
  normalize2d,
  flatten,
  reshapeArray,
  findMinAndMax1d,
} from "../../utils/listUtils";
import * as onnx from "onnxjs";

export async function predict(
  session: onnx.InferenceSession,
  image: number[][],
  tensorNames: string[],
  setMatrices: (matrices: number[][][][]) => void, // column, matrix index, image height, image width
  setPrediction: (prediction: number[]) => void,
  setLastLayerMinMax: (minMax: [number, number]) => void,
  lastLayerIndex: number
) {
  const normalizedImage = normalize2d(image);

  let outputMap: onnx.InferenceSession.TensorsMapType;
  try {
    const onnxImage = new onnx.Tensor(
      flatten(normalizedImage),
      "float32",
      [1, 1, 32, 32]
    );
    outputMap = await session.run([onnxImage]);
  } catch (error) {
    console.error(error);
    return;
  }

  // get intermediate activations
  const newMatrices = [];
  for (const tensorName of tensorNames) {
    const tensorMatrices = [];
    const onnxTensor = outputMap.get(tensorName);
    if (onnxTensor?.dims.length !== 4) {
      continue;
    }
    const numberList = reshapeArray(
      onnxTensor.data as Float32Array,
      onnxTensor.dims as number[]
    ) as number[][][][];

    for (let i = 0; i < numberList[0].length; i++) {
      tensorMatrices.push(numberList[0][i]);
    }
    newMatrices.push(tensorMatrices);
  }
  setMatrices(newMatrices);

  // get last layer min and max
  const lastLayerTensor = outputMap.get(tensorNames[lastLayerIndex]);
  if (lastLayerTensor?.dims.length === 4) {
    const lastLayerList = Array.from(lastLayerTensor.data as Float32Array);
    const [min, max] = findMinAndMax1d(lastLayerList);
    setLastLayerMinMax([min, max]);
  }

  // get prediction values
  const predictionTensor = outputMap.get(tensorNames[tensorNames.length - 1]);
  if (predictionTensor?.dims.length === 2) {
    const predictionList = Array.from(predictionTensor.data as Float32Array);
    setPrediction(predictionList);
  }
}
