import { isInteger } from "../../utils/utils";
import { Tensor } from "../tensor/tensor";
import {
  assignIntUniformLocation,
  createEmptyTexture,
  createFragmentShader,
  createFramebuffer,
  createProgram,
  createQuadAcrossCanvas,
  createVertexShader,
  loadTextureDataToBuffer,
  render,
  glUseProgram,
  assignFloatUniformLocation,
} from "./webglLowLevel";

let _textureIndex = 0;
function getTextureIndex() {
  return _textureIndex++;
}
function resetTextureIndex() {
  _textureIndex = 0;
}

const vertexShaderSource = `#version 300 es
in vec2 a_position;
void main() {
  gl_Position = vec4(a_position, 0, 1);
}`;

const vertexShader = createVertexShader(vertexShaderSource);

export const maxTextureSize = 10000;
const fragmentShaderHeader = `#version 300 es
precision highp float;
int maxTextureSize = ${maxTextureSize};
out vec4 outColor;
`;

const fragmentShaderUtils = `
float get1D(sampler2D x, int i) {
  int index = i;
  int row = index / maxTextureSize;
  int col = index % maxTextureSize;
  return texelFetch(x, ivec2(col, row), 0).r;
}
float get2D(sampler2D x, int xShape1, int i, int j) {
  int index = i * xShape1 + j;
  int row = index / maxTextureSize;
  int col = index % maxTextureSize;
  return texelFetch(x, ivec2(col, row), 0).r;
}
float get3D(sampler2D x, int xShape1, int xShape2, int i, int j, int k) {
  int index = i * xShape1 * xShape2 + j * xShape2 + k;
  int row = index / maxTextureSize;
  int col = index % maxTextureSize;
  return texelFetch(x, ivec2(col, row), 0).r;
}
float get4D(sampler2D x, int xShape1, int xShape2, int xShape3, int i, int j, int k, int l) {
  int index = i * xShape1 * xShape2 * xShape3 + j * xShape2 * xShape3 + k * xShape3 + l;
  int row = index / maxTextureSize;
  int col = index % maxTextureSize;
  return texelFetch(x, ivec2(col, row), 0).r;
}
`;

function createProgramFromShaderSource(fragmentShaderSource: string) {
  const fragmentShader = createFragmentShader(fragmentShaderSource);
  return createProgram(vertexShader, fragmentShader);
}

function prepareProgram(program: WebGLProgram) {
  glUseProgram(program);
  createQuadAcrossCanvas();
}

function setProgramVariables(
  program: WebGLProgram,
  variables: { [key: string]: Tensor | number }
) {
  for (const key in variables) {
    const value = variables[key];
    if (typeof value === "number") {
      if (isInteger(value)) assignIntUniformLocation(program, key, value);
      else assignFloatUniformLocation(program, key, value);
    } else if (value instanceof Tensor) {
      value.toTextureAtUniformLocation(program, key, getTextureIndex());
    } else throw new Error("Invalid variable type");
  }
}

function getResultBuffer(resultSize: number) {
  const width = Math.min(maxTextureSize, resultSize);
  const height = Math.ceil(resultSize / width);
  const resultTex = createEmptyTexture(width, height, getTextureIndex());
  createFramebuffer(resultTex);
  render(width, height);
  return loadTextureDataToBuffer(resultTex, width, height);
}

function callProgramWithVariables(
  program: WebGLProgram,
  variables: { [key: string]: Tensor | number },
  resultSize: number
) {
  resetTextureIndex();
  prepareProgram(program);
  setProgramVariables(program, variables);
  return getResultBuffer(resultSize);
}

const matmulShaderSource = `${fragmentShaderHeader}
uniform sampler2D a;
uniform sampler2D b;
uniform int aShape1;
uniform int bShape1;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  int outputI = index / bShape1;
  int outputJ = index % bShape1;

  float sum = 0.0;
  for (int k = 0; k < aShape1; ++k) {
    float a_value = get2D(a, aShape1, outputI, k);
    float b_value = get2D(b, bShape1, k, outputJ);
    sum += a_value * b_value;
  }
  outColor = vec4(sum);
}`;
const matmulProgram = createProgramFromShaderSource(matmulShaderSource);

export function matmulWebgl(
  weights: Tensor,
  input: Tensor,
  outFeatures: number
) {
  return new Tensor({
    shape: [outFeatures],
    buffer: callProgramWithVariables(
      matmulProgram,
      {
        a: weights,
        b: input,
        aShape1: weights.shape[1],
        bShape1: 1,
      },
      outFeatures
    ),
  });
}

const addShaderSource = `${fragmentShaderHeader}
uniform sampler2D a;
uniform sampler2D b;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  float a_value = get1D(a, index);
  float b_value = get1D(b, index);
  outColor = vec4(a_value + b_value);
}`;
const addProgram = createProgramFromShaderSource(addShaderSource);

export function addWebgl(a: Tensor, b: Tensor) {
  return new Tensor({
    shape: a.shape,
    buffer: callProgramWithVariables(
      addProgram,
      {
        a,
        b,
      },
      a.size
    ),
  });
}

const convTranspose2dShaderSource = `${fragmentShaderHeader}
uniform sampler2D _input;
uniform sampler2D weight;
uniform sampler2D bias;
uniform int inputShape0;
uniform int inputShape1;
uniform int inputShape2;
uniform int weightShape1;
uniform int weightShape2;
uniform int weightShape3;
uniform int stride;
uniform int padding;
uniform int outputShape0;
uniform int outputShape1;
uniform int outputShape2;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  if (index >= outputShape0 * outputShape1 * outputShape2) {
    outColor = vec4(0.0);
    return;
  }

  int oc = index / (outputShape1 * outputShape2);
  int oh = (index / outputShape2) % outputShape1;
  int ow = index % outputShape2;

  float sum = get1D(bias, oc);

  for (int ic = 0; ic < inputShape0; ++ic) {
      for (int kh = 0; kh < weightShape2; ++kh) {
          for (int kw = 0; kw < weightShape3; ++kw) {
              int ih = (oh + padding - kh) / stride;
              int iw = (ow + padding - kw) / stride;
              
              if (ih >= 0 && ih < inputShape1 && iw >= 0 && iw < inputShape2 &&
                  (oh + padding - kh) % stride == 0 && (ow + padding - kw) % stride == 0) {
                  float inputVal = get3D(_input, inputShape1, inputShape2, ic, ih, iw);
                  float weightVal = get4D(weight, weightShape1, weightShape2, weightShape3, ic, oc, kh, kw);
                  sum += inputVal * weightVal;
              }
          }
      }
  }

  outColor = vec4(sum);
}`;

const convTranspose2dProgram = createProgramFromShaderSource(
  convTranspose2dShaderSource
);

export function convTranspose2dWebgl(
  input: Tensor,
  weight: Tensor,
  bias: Tensor,
  stride: number,
  padding: number,
  outputShape: [number, number, number]
) {
  return new Tensor({
    shape: outputShape,
    buffer: callProgramWithVariables(
      convTranspose2dProgram,
      {
        _input: input,
        weight: weight,
        bias,
        inputShape0: input.shape[0],
        inputShape1: input.shape[1],
        inputShape2: input.shape[2],
        weightShape1: weight.shape[1],
        weightShape2: weight.shape[2],
        weightShape3: weight.shape[3],
        stride,
        padding,
        outputShape0: outputShape[0],
        outputShape1: outputShape[1],
        outputShape2: outputShape[2],
      },
      outputShape[0] * outputShape[1] * outputShape[2]
    ),
  });
}

const batchNorm2dShaderSource = `${fragmentShaderHeader}
uniform sampler2D _input;
uniform sampler2D weight;
uniform sampler2D bias;
uniform sampler2D runningMean;
uniform sampler2D runningVariance;
uniform float epsilon;
uniform int inputShape1;
uniform int inputShape2;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  int ic = index / (inputShape1 * inputShape2);
  int ih = (index / inputShape2) % inputShape1;
  int iw = index % inputShape2;

  float x = get3D(_input, inputShape1, inputShape2, ic, ih, iw);
  float gamma = get1D(weight, ic);
  float beta = get1D(bias, ic);
  float _mean = get1D(runningMean, ic);
  float _variance = get1D(runningVariance, ic);
  float normalized = (x - _mean) / sqrt(_variance + epsilon);
  float outputValue = gamma * normalized + beta;
  outColor = vec4(outputValue);
}`;
const batchNorm2dProgram = createProgramFromShaderSource(
  batchNorm2dShaderSource
);

export function batchNorm2dWebgl(
  input: Tensor,
  weight: Tensor,
  bias: Tensor,
  runningMean: Tensor,
  runningVariance: Tensor,
  epsilon: number
) {
  return new Tensor({
    shape: input.shape,
    buffer: callProgramWithVariables(
      batchNorm2dProgram,
      {
        _input: input,
        weight,
        bias,
        runningMean,
        runningVariance,
        epsilon,
        inputShape1: input.shape[1],
        inputShape2: input.shape[2],
      },
      input.size
    ),
  });
}

const reluShaderSource = `${fragmentShaderHeader}
uniform sampler2D _input;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  float x = get1D(_input, index);
  outColor = vec4(max(0.0, x));
}`;
const reluProgram = createProgramFromShaderSource(reluShaderSource);

export function reluWebgl(input: Tensor) {
  return new Tensor({
    shape: input.shape,
    buffer: callProgramWithVariables(
      reluProgram,
      {
        _input: input,
      },
      input.size
    ),
  });
}

const sigmoidShaderSource = `${fragmentShaderHeader}
uniform sampler2D _input;
${fragmentShaderUtils}
void main() {
  int index = int(gl_FragCoord.x) + int(gl_FragCoord.y) * maxTextureSize;
  float x = get1D(_input, index);
  outColor = vec4(1.0 / (1.0 + exp(-x)));
}`;
const sigmoidProgram = createProgramFromShaderSource(sigmoidShaderSource);

export function sigmoidWebgl(input: Tensor) {
  return new Tensor({
    shape: input.shape,
    buffer: callProgramWithVariables(
      sigmoidProgram,
      {
        _input: input,
      },
      input.size
    ),
  });
}
