import { jsTorchWasmInstance } from "../wasm";
import { Tensor, TensorDevice } from "../tensor/tensor";
import {
  addWebgl,
  batchNorm2dWebgl,
  convTranspose2dWebgl,
  matmulWebgl,
  reluWebgl,
  sigmoidWebgl,
} from "../webgl/webgl";

const DEBUG_WASM_WEBGL_EQUALITY: boolean =
  false && process.env.NODE_ENV === "development";
if (DEBUG_WASM_WEBGL_EQUALITY) {
  console.warn(
    "ATTENTION: DEBUG_WASM_WEBGL_EQUALITY is enabled. This will slow down the application significantly as it will compare the results of WASM and WebGL implementations of layers."
  );
}
const SHOW_SEQUENTIAL_TIMINGS: boolean =
  true && process.env.NODE_ENV === "development";

function debugWasmWebGLEquality(
  self: Layer,
  input: Tensor,
  parametersNames: string[],
  name: string
) {
  const initialDevice = input.getDevice();

  if (DEBUG_WASM_WEBGL_EQUALITY) {
    // move parameters and input to wasm and get the result
    for (let parameterName of parametersNames) {
      try {
        // @ts-ignore
        self[parameterName] = self[parameterName].toDevice(TensorDevice.WASM);
      } catch (e: any) {
        console.error(
          `Error moving ${parameterName} to WASM at ${name}: ${e.message}`
        );
      }
    }
    input = input.toDevice(TensorDevice.WASM);
    const wasmResult = self.forwardWASM(input).toDevice(TensorDevice.WASM);

    // move parameters and input to webgl and get the result
    for (let parameterName of parametersNames) {
      // @ts-ignore
      self[parameterName] = self[parameterName].toDevice(TensorDevice.WEBGL);
    }
    input = input.toDevice(TensorDevice.WEBGL);
    const webglResult = self.forwardWebGL(input).toDevice(TensorDevice.WASM);

    if (!webglResult.closeTo(wasmResult, 0.001)) {
      console.warn("WebGL and WASM results do not match at " + name);
      console.warn("WebGL result");
      webglResult.print("all");
      console.warn("WASM result");
      wasmResult.print();
      console.warn("-----------------");
    } else {
      console.warn("WebGL and WASM results match at " + name);
    }

    // move parameters and input back to initial device
    for (let parameterName of parametersNames) {
      // @ts-ignore
      self[parameterName] = self[parameterName].toDevice(initialDevice);
    }
  }
}

export abstract class Layer {
  name: string;

  constructor(name: string) {
    this.name = name;
  }

  abstract assertCorrectShapes(): void;
  abstract forward(input: Tensor): Tensor;
  abstract forwardWASM(input: Tensor): Tensor;
  abstract forwardWebGL(input: Tensor): Tensor;
}

export class Sequential extends Layer {
  layers: Layer[];

  constructor(layers: Layer[], name: string) {
    super(name);
    this.layers = layers;
    this.assertCorrectShapes();
  }

  assertCorrectShapes(): void {}

  forward(input: Tensor): Tensor {
    throw new Error("Use forwardWithIntermediates instead");
  }

  forwardWASM(input: Tensor): Tensor {
    throw new Error("Not implemented");
  }

  forwardWebGL(input: Tensor): Tensor {
    throw new Error("Not implemented");
  }

  forwardWithIntermediates(input: Tensor): { [key: string]: Tensor } {
    const intermediates: { [key: string]: Tensor } = {};
    const sst = SHOW_SEQUENTIAL_TIMINGS;
    sst && console.clear();
    for (let layer of this.layers) {
      const start = performance.now();
      input = layer.forward(input);
      intermediates["after_" + layer.name] = input;
      const end = performance.now();
      sst &&
        console.log(
          `Layer ${layer.constructor.name} took ${(end - start).toFixed(2)}ms`
        );
    }
    return intermediates;
  }
}

export class Linear extends Layer {
  inFeatures: number;
  outFeatures: number;
  weights: Tensor;
  bias: Tensor;

  constructor(
    {
      inFeatures,
      outFeatures,
      weights,
      bias,
    }: {
      inFeatures: number;
      outFeatures: number;
      weights: Tensor;
      bias: Tensor;
    },
    name: string
  ) {
    super(name);
    this.inFeatures = inFeatures;
    this.outFeatures = outFeatures;
    this.weights = weights;
    this.bias = bias;

    this.assertCorrectShapes();
  }

  assertCorrectShapes() {
    if (!this.weights.isOfShape([this.outFeatures, this.inFeatures])) {
      throw new Error(
        `Weights shape mismatch: ${this.weights.shape} != [${this.outFeatures}, ${this.inFeatures}]`
      );
    }
    if (!this.bias.isOfShape([this.outFeatures])) {
      throw new Error(
        `Bias shape mismatch: ${this.bias.shape} != [${this.outFeatures}]`
      );
    }
  }

  forward(input: Tensor): Tensor {
    debugWasmWebGLEquality(this, input, ["weights", "bias"], "Linear.forward");

    if (input.wasmData) {
      return this.forwardWASM(input);
    } else if (input.webglData) {
      return this.forwardWebGL(input);
    } else throw new Error("Input tensor not in WASM or WebGL format");
  }

  forwardWASM(input: Tensor): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error(
        "cannot call Linear.forwardWASM without jsTorchWasmInstance"
      );
    if (!input.wasmData)
      throw new Error("Linear.input tensor not in WASM format");
    if (!this.weights.wasmData)
      throw new Error("Linear.weights not in WASM format");
    if (!this.bias.wasmData) throw new Error("Linear.bias not in WASM format");

    if (!input.is1d()) {
      throw new Error("Input tensor must be 1D");
    }
    const afterMatmulPointer = jsTorchWasmInstance.exports.matmul(
      this.weights.wasmData.pointer,
      input.wasmData.pointer,
      this.weights.shape[0],
      this.weights.shape[1],
      input.shape[0],
      1
    );
    if (afterMatmulPointer === 0)
      throw new Error("matmul failed on wasm backend");

    const resultPointer = jsTorchWasmInstance.exports.add(
      afterMatmulPointer,
      this.bias.wasmData.pointer,
      this.outFeatures
    );
    if (resultPointer === 0) throw new Error("add failed on wasm backend");

    jsTorchWasmInstance.exports.free(afterMatmulPointer);
    return new Tensor({
      shape: [this.outFeatures],
      pointer: resultPointer,
    });
  }

  forwardWebGL(input: Tensor): Tensor {
    if (!this.weights.webglData)
      throw new Error("Linear.weights not in WebGL format");
    if (!this.bias.webglData)
      throw new Error("Linear.bias not in WebGL format");

    const afterMatmul = matmulWebgl(this.weights, input, this.outFeatures);
    const afterAdd = addWebgl(afterMatmul, this.bias);

    return afterAdd;
  }
}

export class Unflatten extends Layer {
  shape: number[];

  constructor({ shape }: { shape: number[] }, name: string) {
    super(name);
    this.shape = shape;
    this.assertCorrectShapes();
  }

  assertCorrectShapes() {}

  forward(input: Tensor): Tensor {
    const newTensor = input.reshape(this.shape);
    return newTensor;
  }

  forwardWASM(input: Tensor): Tensor {
    throw new Error("Method not implemented.");
  }

  forwardWebGL(input: Tensor): Tensor {
    throw new Error("Method not implemented.");
  }
}

export class ConvTranspose2d extends Layer {
  inChannels: number;
  outChannels: number;
  kernelSize: number;
  stride: number;
  padding: number;
  outputPadding: number;
  weights: Tensor;
  bias: Tensor;

  constructor(
    {
      inChannels,
      outChannels,
      kernelSize,
      stride,
      padding,
      outputPadding,
      weights,
      bias,
    }: {
      inChannels: number;
      outChannels: number;
      kernelSize: number;
      stride: number;
      padding: number;
      outputPadding: number;
      weights: Tensor;
      bias: Tensor;
    },
    name: string
  ) {
    super(name);
    this.inChannels = inChannels;
    this.outChannels = outChannels;
    this.kernelSize = kernelSize;
    this.stride = stride;
    this.padding = padding;
    this.outputPadding = outputPadding;
    this.weights = weights;
    this.bias = bias;

    this.assertCorrectShapes();
  }

  assertCorrectShapes() {
    if (
      !this.weights.isOfShape([
        this.inChannels,
        this.outChannels,
        this.kernelSize,
        this.kernelSize,
      ])
    )
      throw new Error(
        `Weights shape mismatch: ${this.weights.shape} != [${this.inChannels}, ${this.outChannels}, ${this.kernelSize}, ${this.kernelSize}]`
      );

    if (!this.bias.isOfShape([this.outChannels]))
      throw new Error(
        `Bias shape mismatch: ${this.bias.shape} != [${this.outChannels}]`
      );
  }

  forward(input: Tensor): Tensor {
    debugWasmWebGLEquality(
      this,
      input,
      ["weights", "bias"],
      "ConvTranspose2d.forward"
    );

    if (input.wasmData) {
      return this.forwardWASM(input);
    } else if (input.webglData) {
      return this.forwardWebGL(input);
    } else throw new Error("Input tensor not in WASM or WebGL format");
  }

  getOutputShape(input: Tensor): [number, number, number] {
    return [
      this.weights.shape[1],
      (input.shape[1] - 1) * this.stride -
        2 * this.padding +
        this.weights.shape[2] +
        this.outputPadding,
      (input.shape[2] - 1) * this.stride -
        2 * this.padding +
        this.weights.shape[3] +
        this.outputPadding,
    ];
  }

  forwardWASM(input: Tensor): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error(
        "cannot call ConvTranspose2d.forwardWASM without jsTorchWasmInstance"
      );
    if (!input.wasmData)
      throw new Error("ConvTranspose2d.input tensor not in WASM format");
    if (!this.weights.wasmData)
      throw new Error("ConvTranspose2d.weights not in WASM format");
    if (!this.bias.wasmData)
      throw new Error("ConvTranspose2d.bias not in WASM format");

    const outputShape = this.getOutputShape(input);
    let resultPointer = jsTorchWasmInstance.exports.convTranspose2d(
      input.wasmData.pointer,
      this.weights.wasmData.pointer,
      this.bias.wasmData.pointer,
      input.shape[0],
      input.shape[1],
      input.shape[2],
      this.weights.shape[0],
      this.weights.shape[1],
      this.weights.shape[2],
      this.weights.shape[3],
      this.bias.shape[0],
      this.stride,
      this.padding,
      this.outputPadding,
      outputShape[0],
      outputShape[1],
      outputShape[2]
    );

    return new Tensor({
      shape: outputShape,
      pointer: resultPointer,
    });
  }

  forwardWebGL(input: Tensor): Tensor {
    return convTranspose2dWebgl(
      input,
      this.weights,
      this.bias,
      this.stride,
      this.padding,
      this.getOutputShape(input)
    );
  }
}

export class BatchNorm2d extends Layer {
  numFeatures: number;
  weights: Tensor;
  bias: Tensor;
  runningMean: Tensor;
  runningVariance: Tensor;
  epsilon: number = 0.00001;

  constructor(
    {
      numFeatures,
      weights,
      bias,
      runningMean,
      runningVariance,
    }: {
      numFeatures: number;
      weights: Tensor;
      bias: Tensor;
      runningMean: Tensor;
      runningVariance: Tensor;
    },
    name: string
  ) {
    super(name);
    this.numFeatures = numFeatures;
    this.weights = weights;
    this.bias = bias;
    this.runningMean = runningMean;
    this.runningVariance = runningVariance;

    this.assertCorrectShapes();
  }

  assertCorrectShapes() {
    if (!this.weights.isOfShape([this.numFeatures]))
      throw new Error(
        `Weights shape mismatch: ${this.weights.shape} != [${this.numFeatures}]`
      );
    if (!this.bias.isOfShape([this.numFeatures]))
      throw new Error(
        `Bias shape mismatch: ${this.bias.shape} != [${this.numFeatures}]`
      );
  }

  forward(input: Tensor): Tensor {
    debugWasmWebGLEquality(
      this,
      input,
      ["weights", "bias"],
      "BatchNorm2d.forward"
    );

    if (input.wasmData) {
      return this.forwardWASM(input);
    } else if (input.webglData) {
      return this.forwardWebGL(input);
    } else throw new Error("Input tensor not in WASM or WebGL format");
  }

  forwardWASM(input: Tensor): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error(
        "cannot call BatchNorm2d.forward without jsTorchWasmInstance"
      );
    if (!input.wasmData)
      throw new Error("BatchNorm2d.input tensor not in WASM format");
    if (!this.weights.wasmData)
      throw new Error("BatchNorm2d.weights not in WASM format");
    if (!this.bias.wasmData)
      throw new Error("BatchNorm2d.bias not in WASM format");
    if (!this.runningMean.wasmData)
      throw new Error("BatchNorm2d.runningMean not in WASM format");
    if (!this.runningVariance.wasmData)
      throw new Error("BatchNorm2d.runningVariance not in WASM format");

    let resultPointer = jsTorchWasmInstance.exports.batchNorm2d(
      input.wasmData.pointer,
      this.weights.wasmData.pointer,
      this.bias.wasmData.pointer,
      this.runningMean.wasmData.pointer,
      this.runningVariance.wasmData.pointer,
      this.epsilon,
      input.shape[0],
      input.shape[1],
      input.shape[2]
    );

    return new Tensor({
      shape: input.shape,
      pointer: resultPointer,
    });
  }

  forwardWebGL(input: Tensor): Tensor {
    return batchNorm2dWebgl(
      input,
      this.weights,
      this.bias,
      this.runningMean,
      this.runningVariance,
      this.epsilon
    );
  }
}

export class Relu extends Layer {
  constructor(name: string) {
    super(name);

    this.assertCorrectShapes();
  }

  assertCorrectShapes() {}

  forward(input: Tensor): Tensor {
    debugWasmWebGLEquality(this, input, [], "Relu.forward");

    if (input.wasmData) {
      return this.forwardWASM(input);
    } else if (input.webglData) {
      return this.forwardWebGL(input);
    } else throw new Error("Input tensor not in WASM or WebGL format");
  }

  forwardWASM(input: Tensor): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error("cannot call Relu.forward without jsTorchWasmInstance");
    if (!input.wasmData)
      throw new Error("Relu.input tensor not in WASM format");
    let resultPointer = jsTorchWasmInstance.exports.relu(
      input.wasmData.pointer,
      input.size
    );
    return new Tensor({
      shape: input.shape,
      pointer: resultPointer,
    });
  }

  forwardWebGL(input: Tensor): Tensor {
    return reluWebgl(input);
  }
}

export class Sigmoid extends Layer {
  constructor(name: string) {
    super(name);

    this.assertCorrectShapes();
  }

  assertCorrectShapes() {}

  forward(input: Tensor): Tensor {
    debugWasmWebGLEquality(this, input, [], "Sigmoid.forward");

    if (input.wasmData) {
      return this.forwardWASM(input);
    } else if (input.webglData) {
      return this.forwardWebGL(input);
    } else throw new Error("Input tensor not in WASM or WebGL format");
  }

  forwardWASM(input: Tensor): Tensor {
    if (!jsTorchWasmInstance)
      throw new Error(
        "cannot call Sigmoid.forward without jsTorchWasmInstance"
      );
    if (!input.wasmData)
      throw new Error("Sigmoid.input tensor not in WASM format");
    let resultPointer = jsTorchWasmInstance.exports.sigmoid(
      input.wasmData.pointer,
      input.size
    );

    return new Tensor({
      shape: input.shape,
      pointer: resultPointer,
    });
  }

  forwardWebGL(input: Tensor): Tensor {
    return sigmoidWebgl(input);
  }
}
