import { isEqual, sampleSize } from 'lodash-es';

function getRandomCentroids<T>(dataset: VectorWithData<T>[], k: number) {
  return sampleSize(dataset, k).map((v) => v.vector);
}

function squaredEuclideanDistance(a: Vector, b: Vector) {
  return a.map((_, i) => a[i] - b[i]).reduce((r, e) => r + e * e, 0);
}

export type Vector = number[];

export type VectorWithData<T> = {
  vector: Vector;
  data: T;
  value?: string | null;
};

export type KmeansClusteringResults<T> = {
  clusters: Cluster<T>[];
  centroids: Vector[];
  iterations: number;
  converged: boolean;
};

/**
 * @param dataset
 * @param k Number of clusters
 * @param options
 * @returns
 */
export function kmeans<T>(
  dataset: VectorWithData<T>[],
  k: number,
  options: {
    maxIterations?: number;
    initializeCentroids?: (dataset: VectorWithData<T>[], k: number) => Vector[];
  } = {}
): KmeansClusteringResults<T> {
  if (dataset.length < k) {
    throw new Error(`Invalid dataset length ${dataset.length} < k = ${k}`);
  }

  const { maxIterations = 50, initializeCentroids = getRandomCentroids } =
    options;

  let prevCentroids: Vector[] | undefined,
    centroids = initializeCentroids(dataset, k);
  let clusters: Cluster<T>[] = [];

  let iterations = 0;
  for (
    ;
    iterations < maxIterations && !isEqual(prevCentroids, centroids);
    iterations++
  ) {
    prevCentroids = centroids.slice(0);
    clusters = clusterVectors(dataset, centroids);
    centroids = updateCentroids(dataset, clusters);
  }

  return {
    clusters,
    centroids,
    iterations,
    converged: iterations < maxIterations,
  };
}

type Cluster<T> = {
  centroid: Vector;
  vectors: VectorWithData<T>[];
  value?: string | null;
};

/**
 * Assign each vector to a centroid.
 * @param dataset
 * @param centroids
 * @returns
 */
function clusterVectors<T>(dataset: VectorWithData<T>[], centroids: Vector[]) {
  const clusters = centroids.map<Cluster<T>>((centroid) => ({
    centroid,
    vectors: [],
  }));
  for (const vectorWithData of dataset) {
    let closestCentroidIndex = 0;
    let closestCentroid = centroids[closestCentroidIndex];
    let prevDistance = squaredEuclideanDistance(
      vectorWithData.vector,
      closestCentroid
    );
    for (let i = 1; i < centroids.length; i++) {
      const centroid = centroids[i];
      const distance = squaredEuclideanDistance(
        vectorWithData.vector,
        centroid
      );
      if (distance < prevDistance) {
        prevDistance = distance;
        closestCentroid = centroid;
        closestCentroidIndex = i;
      }
    }
    clusters[closestCentroidIndex].vectors.push(vectorWithData);
    if (vectorWithData.value !== undefined) {
      clusters[closestCentroidIndex].value = vectorWithData.value;
    }
  }
  return clusters;
}

/**
 * Update centroids to the mean vector of the clusters.
 * @param dataset
 * @param clusters
 * @returns
 */
function updateCentroids<T>(
  dataset: VectorWithData<T>[],
  clusters: Cluster<T>[]
) {
  return clusters.map((cluster) =>
    cluster.vectors.length > 0
      ? getMeanVector(cluster.vectors)
      : getRandomCentroids(dataset, 1)[0]
  );
}

function getMeanVector<T>(vectors: VectorWithData<T>[]) {
  const meanVector = new Array<number>(vectors[0].vector.length).fill(0);
  for (const vectorWithData of vectors) {
    const { vector } = vectorWithData;
    for (let i = 0; i < vector.length; i++) {
      meanVector[i] = meanVector[i] + vector[i] / vectors.length;
    }
  }
  return meanVector;
}
