function assert2dArray(input: number[][]): void {
  const shape = getListShape(input);
  if (shape.length !== 2) {
    throw new Error("Expected a 2D array, but got shape: " + shape.join("x"));
  }
}

export function downsampleArray2d(
  input: number[][],
  targetWidth: number,
  targetHeight: number
): number[][] {
  assert2dArray(input);
  // assert targetWidth > 0 && targetHeight > 0
  if (targetWidth <= 0 || targetHeight <= 0) {
    throw new Error(
      "Target width and height must be greater than 0, but got: targetWidth=" +
        targetWidth +
        ", targetHeight=" +
        targetHeight
    );
  }

  // Create an empty 2D array for the output
  const output = Array.from({ length: targetHeight }, () =>
    Array(targetWidth).fill(0)
  );
  const counts = Array.from({ length: targetHeight }, () =>
    Array(targetWidth).fill(0)
  );

  const xRatio = input[0].length / targetWidth;
  const yRatio = input.length / targetHeight;

  // Iterate over the input array
  for (let i = 0; i < input.length; i++) {
    for (let j = 0; j < input[i].length; j++) {
      // Calculate the corresponding position in the output array
      const x = Math.floor(j / xRatio);
      const y = Math.floor(i / yRatio);

      // Add the value to the output array and increment the count
      output[y][x] += input[i][j];
      counts[y][x]++;
    }
  }

  // Divide the output array by the counts to get the average
  for (let i = 0; i < targetHeight; i++) {
    for (let j = 0; j < targetWidth; j++) {
      output[i][j] /= counts[i][j];
    }
  }

  return output;
}

export function findMinAndMax1d(input: number[]): [number, number] {
  let min = Infinity;
  let max = -Infinity;

  for (const value of input) {
    min = Math.min(min, value);
    max = Math.max(max, value);
  }

  return [min, max];
}

export function findMinAndMax2d(input: number[][]): [number, number] {
  let min = Infinity;
  let max = -Infinity;

  for (const row of input) {
    for (const value of row) {
      min = Math.min(min, value);
      max = Math.max(max, value);
    }
  }

  return [min, max];
}

function findMeanAndStd2d(input: number[][]): [number, number] {
  let sum = 0;
  let sumOfSquares = 0;
  let count = 0;

  for (const row of input) {
    for (const value of row) {
      sum += value;
      sumOfSquares += value * value;
      count++;
    }
  }

  const mean = sum / count;
  const variance = sumOfSquares / count - mean * mean;
  const std = Math.sqrt(variance);

  return [mean, std];
}

export function normalize2d(input: number[][]): number[][] {
  // Calculate the mean and standard deviation
  const [mean, std] = findMeanAndStd2d(input);

  // Create a new 2D array for the output
  const output = Array.from({ length: input.length }, () =>
    Array(input[0].length).fill(0)
  );

  // Normalize the input array
  for (let i = 0; i < input.length; i++) {
    for (let j = 0; j < input[i].length; j++) {
      output[i][j] = (input[i][j] - mean) / std;
    }
  }

  return output;
}

export function flatten(input: number[][]): number[] {
  const output: number[] = [];

  for (const row of input) {
    for (const value of row) {
      output.push(value);
    }
  }

  return output;
}

export function reshapeArray(data: Float32Array, dims: number[]): any {
  function _reshape(data: Float32Array, dims: number[]): any {
    if (dims.length === 1) {
      return Array.from(data.slice(0, dims[0]));
    } else {
      const [dim, ...restDims] = dims;
      const subArraySize = restDims.reduce((a, b) => a * b, 1);
      const result = [];
      for (let i = 0; i < dim; i++) {
        result.push(
          _reshape(
            data.slice(i * subArraySize, (i + 1) * subArraySize),
            restDims
          )
        );
      }
      return result;
    }
  }

  return _reshape(data, dims);
}

export async function loadImageAndTransformToPixels(
  src: string
): Promise<number[][][]> {
  const img = new Image();
  img.src = src;
  await new Promise((resolve, reject) => {
    img.onload = resolve;
    img.onerror = reject;
  });

  const canvas = document.createElement("canvas");
  const ctx = canvas.getContext("2d");
  if (!ctx) {
    throw new Error("Could not get canvas context");
  }
  canvas.width = img.width;
  canvas.height = img.height;
  ctx.drawImage(img, 0, 0, img.width, img.height);
  const imageData = ctx.getImageData(0, 0, img.width, img.height);
  const data = imageData.data;
  const pixels = [];
  for (let y = 0; y < img.height; y++) {
    const row = [];
    for (let x = 0; x < img.width; x++) {
      const idx = (y * img.width + x) * 4;
      const pixel = [data[idx], data[idx + 1], data[idx + 2], data[idx + 3]];
      row.push(pixel);
    }
    pixels.push(row);
  }
  return pixels;
}

function getListShape(list: any[]): number[] {
  if (!Array.isArray(list)) {
    return [];
  }
  return [list.length, ...getListShape(list[0])];
}

// ts-unused-exports:disable-next-line
export function bicubicInterpolation3d(
  image: number[][][],
  targetResolution: [number, number]
): number[][][] {
  const [targetWidth, targetHeight] = targetResolution;
  const sourceWidth = image[0].length;
  const sourceHeight = image.length;

  if (sourceWidth === targetWidth && sourceHeight === targetHeight) {
    return image;
  }

  const scaleX = sourceWidth / targetWidth;
  const scaleY = sourceHeight / targetHeight;
  const output: number[][][] = [];

  for (let y = 0; y < targetHeight; y++) {
    const outputRow: number[][] = [];
    for (let x = 0; x < targetWidth; x++) {
      const srcX = x * scaleX;
      const srcY = y * scaleY;
      const intX = Math.floor(srcX);
      const intY = Math.floor(srcY);

      const pixels: number[][] = [];
      for (let dy = -1; dy <= 2; dy++) {
        for (let dx = -1; dx <= 2; dx++) {
          const pixelX = Math.min(Math.max(intX + dx, 0), sourceWidth - 1);
          const pixelY = Math.min(Math.max(intY + dy, 0), sourceHeight - 1);
          pixels.push(image[pixelY][pixelX]);
        }
      }

      const weightsX = [1 - (srcX - intX), srcX - intX];
      const weightsY = [1 - (srcY - intY), srcY - intY];

      const newPixel = [0, 0, 0, 0];
      for (let i = 0; i < 4; i++) {
        const weight = weightsX[i % 2] * weightsY[Math.floor(i / 2)];
        for (let j = 0; j < 4; j++) {
          newPixel[j] += pixels[i][j] * weight;
        }
      }

      outputRow.push(newPixel);
    }
    output.push(outputRow);
  }

  return output;
}

export function toGrayscale(image: number[][][]): number[][] {
  const output: number[][] = [];

  for (let y = 0; y < image.length; y++) {
    const outputRow: number[] = [];
    for (let x = 0; x < image[0].length; x++) {
      const [r, g, b] = image[y][x];
      const grayscale = (r + g + b) / 3;
      outputRow.push(grayscale);
    }
    output.push(outputRow);
  }

  return output;
}
