import { useCallback, useMemo, useState, MouseEvent, Dispatch, SetStateAction, useRef } from "react";
import { IColumn, Icon, IDetailsColumnProps, IDetailsListProps, IGroup, IObjectWithKey, IRenderFunction, IShimmeredDetailsListProps, Selection, SelectionMode, ShimmeredDetailsList, Stack, TooltipHost } from '@fluentui/react';
import React from "react";
import { useConst } from "@fluentui/react-hooks";

type disabledShimmeredDetailsListProps =
    "onRenderRow"
    | "onColumnHeaderClick"
    | "groups";

interface Props<TItem> extends Omit<IShimmeredDetailsListProps, disabledShimmeredDetailsListProps> {
    columns: DataGridColumn<TItem>[];
    items: TItem[];
    getItemGroupName?: (item: TItem) => string;
    groupSortOrder?: string[];
    setSelectedItems?: Dispatch<SetStateAction<TItem[] | undefined>>;
}

type DataGridSortKey = string | number | BigInt | Date | boolean | undefined;

export interface DataGridColumn<TItem> extends IColumn {
    tooltip?: string;
    notSortable?: boolean;
    getSortKey?: (item: TItem) => DataGridSortKey;
    onRender?: (item: TItem, index?: number, column?: DataGridColumn<TItem>) => any;
    getValueKey?: (item?: TItem, index?: number, column?: IColumn) => string;
}

interface SortState {
    columnKey: string;
    isSortedDescending: boolean;
}

function DataGridImpl<TItem>(props: Props<TItem>) {
    const { items, columns, getItemGroupName, groupSortOrder, setSelectedItems, ...restProps } = props;
    const [sortState, setSortState] = useState<SortState>();

    const setSelectedItemsRef = useRef(setSelectedItems);
    setSelectedItemsRef.current = setSelectedItems; // Update every render
    // This function will always stay the same through every render, since it uses a useRef
    const setSelectedItemsRefConst = useCallback((items: TItem[]) => {
        // setSelectedItemsRef.current will always be the latest setSelectedItems function
        setSelectedItemsRef.current?.(items);
    }, []);
    const selection = useConst(() => new DataGridSelection({
        onSelectionChanged: () => {
            const selectedItems = (selection.getSelection() as TItem[])?.filter(x => x);
            setSelectedItemsRefConst(selectedItems);
        }
    }));

    const { mappedItems, mappedColumns, groups } = useMemo(() => {
        const groups = getItemGroupName ? [] as IGroup[] : undefined;
        let mappedItems: TItem[] = items;
        let mappedColumns: DataGridColumn<TItem>[] = columns.map(column => ({
            onRenderHeader,
            ...column
        }));

        if (sortState) {
            const { columnKey, isSortedDescending } = sortState;
            mappedItems = getSortedItems(items, mappedColumns, sortState);
            mappedColumns = mappedColumns.map(item => {
                return item.key === columnKey
                    ? { ...item, isSorted: true, isSortedDescending }
                    : item;
            });
        }

        if (getItemGroupName && mappedItems.length) {
            const itemsByGroup: Record<string, TItem[]> = {};
            const groupedMappedItems: TItem[] = [];
            mappedItems.forEach(item => {
                const groupName = getItemGroupName(item);
                const items = itemsByGroup[groupName] = itemsByGroup[groupName] || [];
                items.push(item);
            });
            let startIndex = 0;
            (groupSortOrder || Object.keys(itemsByGroup)).forEach(groupName => {
                const items = itemsByGroup[groupName];
                const count = items?.length;
                if (count) {
                    groupedMappedItems.push(...items);
                    groups!.push({
                        key: groupName,
                        name: groupName,
                        startIndex,
                        count
                    });
                    startIndex += count;
                }
            });
            mappedItems = groupedMappedItems;
        }

        return { mappedItems, mappedColumns, groups };
    }, [getItemGroupName, items, columns, sortState, groupSortOrder]);

    const onColumnHeaderClick = useCallback((_ev?: MouseEvent<any>, column?: DataGridColumn<TItem>) => {
        setSortState(sortState => {
            if (column && !column.notSortable) {
                if (sortState?.columnKey === column.key) {
                    return {
                        columnKey: column.key,
                        isSortedDescending: !sortState.isSortedDescending
                    };
                }

                return {
                    columnKey: column.key,
                    isSortedDescending: false
                };
            }

            return sortState;
        });
    }, [setSortState]);

    return (
        <ShimmeredDetailsList
            items={mappedItems}
            columns={mappedColumns}
            onRenderRow={onRenderRow}
            onColumnHeaderClick={onColumnHeaderClick}
            groups={restProps.enableShimmer ? undefined : groups}
            selection={restProps.selectionMode === SelectionMode.none ? undefined : selection}
            {...restProps}
        />
    );
};

const onRenderHeader: IRenderFunction<IDetailsColumnProps> = (props, defaultRender) => {
    const tooltip = (props?.column as DataGridColumn<unknown>).tooltip;
    return (
        <Stack horizontal>
            {defaultRender!(props)}
            {tooltip && (
                <TooltipHost
                    content={tooltip}
                    id={`${props?.column.key}-header-tooltip`}
                >
                    <Icon iconName="Info" styles={{ root: { marginLeft: 4 } }} />
                </TooltipHost>
            )}
        </Stack>
    );
};

const onRenderRow: IDetailsListProps["onRenderRow"] = (props, defaultRender) => {
    return defaultRender!({
        ...props!,
        styles: {
            root: {
                userSelect: "any"
            }
        }
    });
};

function getSortedItems<TItem>(items: TItem[], columns: DataGridColumn<TItem>[], sortState: SortState) {
    const { columnKey, isSortedDescending } = sortState;
    const sortColumn = columns.filter(item => item.key === columnKey)[0];
    if (sortColumn) {
        const getSortKey = sortColumn.getSortKey ||
            (sortColumn.fieldName && ((item: TItem) => (item as any)[sortColumn.fieldName!] as DataGridSortKey)) ||
            sortColumn.onRender;
        const sign = isSortedDescending ? -1 : 1;
        const sortedItems = items.map(item => {
            const sortKey = getSortKey!(item);
            return {
                item,
                sortKey,
                keyType: getSortKeyType(sortKey)
            };
        }).sort((a, b) => sign * compare(a.sortKey, b.sortKey, a.keyType, b.keyType)!);
        return sortedItems.map(x => x.item);
    }

    return items;
}

const enum SortKeyType {
    text = 1,
    numeric = 2,
    bigint = 3,
    date = 4,
    bool = 5,
    unsorted = 6,
}

function getSortKeyType(a: DataGridSortKey): SortKeyType {
    if (typeof a === "string") {
        return SortKeyType.text;
    } else if (typeof a === "number") {
        return SortKeyType.numeric;
    } else if (typeof a === "bigint") {
        return SortKeyType.bigint;
    } else if (a instanceof Date) {
        return SortKeyType.date;
    } else if (typeof a === "boolean") {
        return SortKeyType.bool;
    } else {
        return SortKeyType.unsorted;
    }
}

function compare(a: DataGridSortKey, b: DataGridSortKey, typeA: SortKeyType, typeB: SortKeyType) {
    if (typeA === typeB) {
        if (typeA === SortKeyType.text) {
            return (a as string).localeCompare(b as string);
        } else if (typeA === SortKeyType.numeric) {
            const numA = a as number;
            const numB = b as number;
            return isNaN(numA)
                ? 1
                : isNaN(numB) ? -1 : numA - numB;
        } else if (typeA === SortKeyType.bigint) {
            const diff = (a as bigint) - (b as bigint);
            return diff > 0
                ? 1
                : diff < 0 ? -1 : 0;
        } else if (typeA === SortKeyType.date) {
            return (a as Date).getTime() - (b as Date).getTime();
        } else if (typeA === SortKeyType.bool) {
            return a
                ? -1
                : b ? 1 : 0;
        }
    } else {
        return typeA - typeB;
    }
}

const typedMemo: <T>(c: T) => T = React.memo;
export const DataGrid = typedMemo(DataGridImpl);

class DataGridSelection<TItem = IObjectWithKey> extends Selection<TItem> {
    private _selectedKeysBeforeShimmer: string[] | undefined;

    public setItems(items: TItem[], shouldClear?: boolean) {
        // Re-select keys after items update due to navigation, filter or load more.
        //
        // Edge case: When shimmer appears, items temporarily update to a dummy array.
        // We need to remember the selection until setItems is called again with real items.
        const selectedKeys = (this.getSelection() || []).map(item => this.getKey(item));
        const reselectKeys = this._selectedKeysBeforeShimmer || selectedKeys;
        const isShimmer = items.length && items.every(item => !item);
        this._selectedKeysBeforeShimmer = isShimmer ? selectedKeys : undefined;
        super.setItems(items, shouldClear);
        this.setChangeEvents(false);
        reselectKeys.forEach(key => this.setKeySelected(key, true, false));
        this.setChangeEvents(true);
    }
}
