import { Row, RowData, Table as ReactTable } from '@tanstack/react-table';
import { isEmpty, min, size, uniq } from 'lodash-es';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useBulkEditHubContext } from 'src/contexts/BulkEditHubContext';
import { useGenericDialogContext } from 'src/contexts/GenericDialogContext/GenericDialogContext';

import { PosTableData } from './Table';

export const useTableHotkeys = <T extends PosTableData>(
  enabled: boolean,
  table: ReactTable<T>,
  onRowSelection?: (row: Row<T>) => void
) => {
  const { mainDialogOpened } = useBulkEditHubContext();
  const { isDialogOpen } = useGenericDialogContext();
  const selectRow = useCallback(
    (id: number) => {
      table.setRowSelection((prev) => {
        for (const key in prev) {
          prev[key] = false;
        }
        return {
          ...prev,
          [id]: true,
        };
      });
    },
    [table]
  );
  const onArrowUp = useCallback(() => {
    if (!enabled || mainDialogOpened || isDialogOpen) {
      return;
    }
    const { rowsById } = table.getSelectedRowModel();
    const { rows } = table.getRowModel();
    const selectedRows = Object.values(rowsById);

    // Select the last row when no selection
    if (isEmpty(selectedRows) && size(rows) > 0) {
      selectRow(parseInt(rows[rows.length - 1].id));
      return;
    }

    const parent = getParentRow(selectedRows);
    if (parent) {
      const sanitizedId = parseInt(parent.id);
      selectRow(sanitizedId);
      return;
    }

    const selectedIds = selectedRows.map((row) => row.id);
    let minIndex = Number.MAX_VALUE;
    rows.forEach((row, index) => {
      if (selectedIds.includes(row.id)) {
        minIndex = Math.min(index, minIndex);
      }
    });

    if (minIndex > 0) {
      const sanitizedId = parseInt(rows[minIndex - 1].id);
      selectRow(sanitizedId);
    }
  }, [enabled, isDialogOpen, mainDialogOpened, selectRow, table]);

  const onArrowDown = useCallback(
    (event: KeyboardEvent) => {
      event.preventDefault();
      if (!enabled || mainDialogOpened || isDialogOpen) {
        return;
      }
      const { rowsById } = table.getSelectedRowModel();
      const { rows } = table.getRowModel();
      const selectedRows = Object.values(rowsById);

      // Select the frist row when no selection
      if (isEmpty(selectedRows) && size(rows) > 0) {
        selectRow(parseInt(rows[0].id));
        return;
      }

      const parent = getParentRow(selectedRows);
      if (parent) {
        const sanitizedId = parseInt(parent.id);
        selectRow(sanitizedId);
        return;
      }

      const selectedIds = selectedRows.map((row) => row.id);
      let maxIndex = Number.MIN_VALUE;
      rows.forEach((row, index) => {
        if (selectedIds.includes(row.id)) {
          maxIndex = Math.max(index, maxIndex);
        }
      });

      if (maxIndex + 1 < rows.length) {
        const sanitizedId = parseInt(rows[maxIndex + 1].id);
        selectRow(sanitizedId);
      }
    },
    [enabled, isDialogOpen, mainDialogOpened, selectRow, table]
  );

  const onEnter = useCallback(
    (event: KeyboardEvent) => {
      event.preventDefault();
      if (!enabled || mainDialogOpened || isDialogOpen) {
        return;
      }
      const { rowsById } = table.getSelectedRowModel();
      const selectedRows = Object.values(rowsById);

      if (isEmpty(selectedRows)) {
        return;
      }
      const parent = getParentRow(selectedRows);
      if (parent) {
        parent.toggleExpanded(!parent.getIsExpanded());
        return;
      }

      const row = selectedRows[0];
      if (!isEmpty(row.subRows)) {
        row.toggleExpanded(!row.getIsExpanded());
      } else {
        onRowSelection && onRowSelection(row);
      }
    },
    [enabled, isDialogOpen, mainDialogOpened, onRowSelection, table]
  );

  useHotkeys('down', onArrowDown, { keydown: true }, [onArrowDown]);
  useHotkeys('up', onArrowUp, { keydown: true }, [onArrowUp]);
  useHotkeys('enter', onEnter, { keydown: true }, [onEnter]);
};

const getParentRow = <T extends RowData>(rows: Row<T>[]) => {
  const depths = uniq(rows.map(({ depth }) => depth));
  if (depths.length === 1) {
    return null;
  }
  const parentDepth = min(depths);
  return rows.find((row) => row.depth === parentDepth);
};
