/*!
Copyright 2018 Propel http://propel.site/.  All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// copied with minor modifications from the following source:
// https://github.com/MaximeKjaer/tfjs-npy-node/blob/master/src/npy.ts

import ndarray from 'ndarray';

const MAGIC_STRING: string = '\x93NUMPY' as const;

export type TypedArray = Float32Array | Int32Array | Uint8Array;

// FIXME: int64, float64 and int8 look broken
export type DType = 'float32' | 'int32' | 'bool';

/** Information about how to read and write a Numpy descr */
interface DescrInfo {
  /** Number of bytes needed for a single element. */
  bytes: number;
  /** TensorFlow dtype corresponding to this dtype */
  dtype: DType;
  /** Function for creating a typed array. */
  createArray: (buf: ArrayBuffer) => TypedArray;
  /**
   * Function for writing into a view. Undefined if serialization is not
   * supported.
   */
  write?: (view: DataView, pos: number, byte: number) => void;
}

/**
 * Union type of the Numpy descr that the library currently can read or write.
 */
type SupportedDescr = '<f8' | '<f4' | '<i8' | '<i4' | '|u1' | '|b1';

const numpyDescrInfo: Readonly<Record<SupportedDescr, DescrInfo>> = {
  '<f8': {
    bytes: 8,
    dtype: 'float32', // downcast to float32
    createArray: (buf) => new Float32Array(new Float64Array(buf)),
  },
  '<f4': {
    bytes: 4,
    dtype: 'float32',
    createArray: (buf) => new Float32Array(buf),
    write: (view, pos, byte) => view.setFloat32(pos, byte, true),
  },
  '<i8': {
    bytes: 8,
    dtype: 'int32', // downcast to int32
    createArray: (buf) => new Int32Array(buf).filter((val, i) => i % 2 === 0),
  },
  '<i4': {
    bytes: 4,
    dtype: 'int32',
    createArray: (buf) => new Int32Array(buf),
    write: (view, pos, byte) => view.setInt32(pos, byte, true),
  },
  '|b1': {
    bytes: 1,
    dtype: 'bool',
    createArray: (buf) => new Uint8Array(buf),
    write: (view, pos, byte) => view.setUint8(pos, byte),
  },
  '|u1': {
    bytes: 1,
    dtype: 'int32', // FIXME: should be uint8
    createArray: (buf) => new Uint8Array(buf),
    write: (view, pos, byte) => view.setUint8(pos, byte),
  },
};

/**
 * Get a view of the buffer. If specified, byte offset and byte lengths are in
 * relative terms. If not specified, returns a view of the whole buffer.
 */
function getView(
  buf: ArrayBuffer,
  byteOffset?: number,
  byteLength?: number,
): DataView {
  return new DataView(buf, byteOffset, byteLength);
}

function dataViewToAscii(dv: DataView): string {
  let out = '';
  for (let i = 0; i < dv.byteLength; i++) {
    const val = dv.getUint8(i);
    if (val === 0) {
      break;
    }
    out += String.fromCharCode(val);
  }
  return out;
}

function numEls(shape: number[]): number {
  return shape.reduce((a: number, b: number) => a * b, 1);
}

/**
 * Get a slice of the buffer. Start and end positions are relative to the byte
 * offset, if any.
 */
function getSlice(buf: ArrayBuffer, start: number, end: number): ArrayBuffer {
  // assert(start <= end);
  return buf.slice(start, end);
}

export function parseToNdArray(buf: ArrayBuffer): ndarray.NdArray {
  // assert(buf.byteLength > MAGIC_STRING.length);
  const view = getView(buf);
  let pos = 0;

  // First parse the magic string.
  const magicStr = dataViewToAscii(getView(buf, pos, MAGIC_STRING.length));
  if (magicStr !== MAGIC_STRING) {
    throw Error(`Not a numpy file.`);
  }
  pos += MAGIC_STRING.length;

  // Parse the version
  const version = [view.getUint8(pos++), view.getUint8(pos++)].join('.');
  if (version !== '1.0') {
    throw Error(`Unsupported npy version ${version}.`);
  }

  // Parse the header length.
  const headerLen = view.getUint16(pos, true);
  pos += 2;

  // Parse the header.
  // Header is almost json, so we just manipulated it until it is.
  // Example: {'descr': '<f8', 'fortran_order': False, 'shape': (1, 2), }
  const headerPy = dataViewToAscii(getView(buf, pos, headerLen));
  pos += headerLen;
  const bytesLeft = view.byteLength - pos;
  const headerJson = headerPy
    .replace('True', 'true')
    .replace('False', 'false')
    .replace(/'/g, `"`)
    .replace(/,\s*}/, ' }')
    .replace(/,?\)/, ']')
    .replace('(', '[');
  const header = JSON.parse(headerJson);
  if (header.fortran_order) {
    throw Error(`NPY parse error. Implement me. ${header.fortran_order}`);
  }

  // Parse shape
  const shape = header.shape;
  // assert(Array.isArray(shape));
  // assert(shape.every((el: any) => typeof el === 'number'));
  const size = numEls(shape);

  // Parse descr
  const descr = header.descr;
  // assert(typeof descr === 'string');
  const info = numpyDescrInfo[descr as SupportedDescr];
  // assert(info !== undefined, `Unknown dtype "${descr}". Implement me.`);
  const bytesPerElement = info.bytes;
  // assert(
  //   bytesLeft === size * bytesPerElement,
  //   `Expected there to be ${
  //     size * bytesPerElement
  //   } bytes left for npy file of dtype descr ${descr}, but there were ${bytesLeft} bytes left`,
  // );

  // Finally parse the actual data.
  const slice = getSlice(buf, pos, pos + size * bytesPerElement);

  return ndarray(info.createArray(slice), shape);
}
