import { useCallback, useMemo } from 'react';
import {
    Fingerprint,
    MixedTargetCandidateGroup,
    UserInputFingerprints,
} from '../../../services/dataService';
import { GoogleCandidateStatus } from '_types';
import { TargetPersonSelectionState } from '_enums';
import { ProcessTarget } from '@indicium/common';
import {
    candidateToReferenceInput,
    groupReferenceInput,
    negativeUserInputToNormalizedReferenceInput,
    NormalizedReferenceInput,
    scoreCandidate,
    userInputToNormalizedReferenceInput,
} from './candidateScoring/scoreCandidate';
import {
    normalizeExactMatchScore,
    normalizeGoogleExactMatchScore,
} from './candidateScoring/utils';
import { scoreGoogleCandidates } from './candidateScoring/scoreGoogleCandidate';
import { isNeitherNullNorUndefined } from '_utils';
import {
    checkIsV2Embedding,
    cosineSimilarityF32,
    embeddingV2Magic,
} from '../../../services/fingerprints/embeddings/v2_text_embeddings';
import { normalize_values_average } from '../../../services/fingerprints/embeddings/utils';
import { useCache } from '../../../hooks/useCache';
import { NEGATIVE_USER_INPUT_ID, POSITIVE_USER_INPUT_ID } from './constants';

interface CandidateSortingInput {
    allCandidates: MixedTargetCandidateGroup[];
    candidateEmbeddingMap: ReadonlyMap<string, Fingerprint>;
    userInputFingerprints?: UserInputFingerprints;
    selectedGoogleCandidates: Record<string, GoogleCandidateStatus[]>;
    selectedCandidates:
        | Record<string, TargetPersonSelectionState | undefined>
        | undefined;
    userInput?: ProcessTarget;
    sortByHardCriteriaScore?: boolean;
    skippedCandidates: MixedTargetCandidateGroup[];
}

export type PendingCandidate = MixedTargetCandidateGroup & {
    hardCriteriaScore: number;
    unifiedScore: number;
    fingerprintScore: number;
};

type SortedCandidates = {
    pending: PendingCandidate[];
    decided: MixedTargetCandidateGroup[];
    included: MixedTargetCandidateGroup[];
    excluded: MixedTargetCandidateGroup[];
};

export const useCandidateSorting = ({
    allCandidates,
    candidateEmbeddingMap,
    selectedGoogleCandidates,
    selectedCandidates,
    userInput,
    sortByHardCriteriaScore,
    skippedCandidates,
}: CandidateSortingInput): SortedCandidates => {
    const {
        pendingCandidates,
        includedCandidates,
        excludedCandidates,
        decidedCandidates,
        positiveCandidateIds,
        negativeCandidateIds,
    } = useMemo(() => {
        const positiveCandidateIds: string[] = [POSITIVE_USER_INPUT_ID];
        const negativeCandidateIds: string[] = [NEGATIVE_USER_INPUT_ID];

        const pendingCandidates: MixedTargetCandidateGroup[] = [];
        const includedCandidates: MixedTargetCandidateGroup[] = [];
        const excludedCandidates: MixedTargetCandidateGroup[] = [];
        const decidedCandidates: MixedTargetCandidateGroup[] = [];

        allCandidates.forEach((candidateGroup) => {
            if ('isGoogleCandidate' in candidateGroup) {
                let includedCandidateCount = 0;
                let excludedCandidateCount = 0;

                const googleCandidates =
                    selectedGoogleCandidates[candidateGroup.groupId] ?? [];

                googleCandidates.forEach((googleCandidate) => {
                    if (
                        googleCandidate.status ===
                        TargetPersonSelectionState.Confirmed
                    ) {
                        includedCandidateCount += 1;
                        positiveCandidateIds.push(googleCandidate.id);
                    }

                    if (
                        googleCandidate.status ===
                        TargetPersonSelectionState.Ignored
                    ) {
                        excludedCandidateCount += 1;
                        negativeCandidateIds.push(googleCandidate.id);
                    }
                });

                if (includedCandidateCount > 0 || excludedCandidateCount > 0) {
                    decidedCandidates.push(candidateGroup);
                } else {
                    pendingCandidates.push(candidateGroup);
                }

                if (includedCandidateCount === googleCandidates.length) {
                    includedCandidates.push(candidateGroup);
                }

                if (excludedCandidateCount === googleCandidates.length) {
                    excludedCandidates.push(candidateGroup);
                }
            } else {
                if (
                    selectedCandidates?.[candidateGroup.groupId] ===
                    TargetPersonSelectionState.Confirmed
                ) {
                    positiveCandidateIds.push(...candidateGroup.candidateIds);

                    includedCandidates.push(candidateGroup);
                    decidedCandidates.push(candidateGroup);
                } else if (
                    selectedCandidates?.[candidateGroup.groupId] ===
                    TargetPersonSelectionState.Ignored
                ) {
                    negativeCandidateIds.push(...candidateGroup.candidateIds);

                    excludedCandidates.push(candidateGroup);
                    decidedCandidates.push(candidateGroup);
                } else {
                    pendingCandidates.push(candidateGroup);
                }
            }
        });

        const filteredPendingCandidates = pendingCandidates.filter(
            (candidate) => {
                return !skippedCandidates.some(
                    (skippedCandidate) =>
                        skippedCandidate.groupId === candidate.groupId,
                );
            },
        );

        return {
            positiveCandidateIds,
            negativeCandidateIds,
            pendingCandidates: filteredPendingCandidates,
            includedCandidates,
            excludedCandidates,
            decidedCandidates,
        };
    }, [
        allCandidates,
        selectedGoogleCandidates,
        selectedCandidates,
        skippedCandidates,
    ]);

    const cache = useCache<string, number>();

    const calculateSimilarity = useCallback(
        (leftId: string, rightId: string): number | null => {
            const key =
                leftId > rightId
                    ? `${leftId}:${rightId}`
                    : `${rightId}:${leftId}`;

            const oldValue = cache.get(key) ?? null;

            if (oldValue !== null) {
                return oldValue;
            }

            const leftEmbedding = candidateEmbeddingMap.get(leftId);
            const rightEmbedding = candidateEmbeddingMap.get(rightId);

            if (!leftEmbedding || !rightEmbedding) {
                return null;
            }

            if (
                !checkIsV2Embedding(leftEmbedding) ||
                !checkIsV2Embedding(rightEmbedding)
            ) {
                return null;
            }

            const newValue = cosineSimilarityF32(
                leftEmbedding,
                rightEmbedding,
                embeddingV2Magic.length,
            );

            cache.set(key, newValue);

            return newValue;
        },
        [cache, candidateEmbeddingMap],
    );

    const scoredFingerprints = useMemo(
        function buildCandidateScoreMap() {
            const map = new Map<string, number>();

            for (const candidateId of candidateEmbeddingMap.keys()) {
                const positiveSimilarities: number[] = [];

                for (const id of positiveCandidateIds) {
                    const similarity = calculateSimilarity(candidateId, id);

                    if (similarity !== null) {
                        positiveSimilarities.push(similarity);
                    }
                }

                const negativeSimilarities: number[] = [];

                for (const id of negativeCandidateIds) {
                    const similarity = calculateSimilarity(candidateId, id);

                    if (similarity !== null) {
                        negativeSimilarities.push(similarity);
                    }
                }

                const positive_score =
                    normalize_values_average(positiveSimilarities);
                const negative_score =
                    normalize_values_average(negativeSimilarities);

                const score = (positive_score - negative_score) / 2 + 0.5;

                map.set(candidateId, score);
            }

            return map;
        },
        [
            candidateEmbeddingMap,
            positiveCandidateIds,
            negativeCandidateIds,
            calculateSimilarity,
        ],
    );

    const positiveReferenceInput = useMemo(() => {
        const referenceInput: NormalizedReferenceInput[] = [
            userInputToNormalizedReferenceInput(userInput),
            ...includedCandidates
                .map((candidate) =>
                    'isGoogleCandidate' in candidate
                        ? null
                        : candidateToReferenceInput(candidate),
                )
                .filter(isNeitherNullNorUndefined),
        ];

        return groupReferenceInput(referenceInput);
    }, [userInput, includedCandidates]);

    const negativeReferenceInput = useMemo(() => {
        const referenceInput: NormalizedReferenceInput[] = [
            negativeUserInputToNormalizedReferenceInput(userInput),
            ...excludedCandidates
                .map((candidate) =>
                    'isGoogleCandidate' in candidate
                        ? null
                        : candidateToReferenceInput(candidate),
                )
                .filter(isNeitherNullNorUndefined),
        ];

        return groupReferenceInput(referenceInput);
    }, [userInput, excludedCandidates]);

    const sortedPendingCandidates = useMemo(
        () =>
            pendingCandidates
                .map((candidateGroup) => {
                    if ('isGoogleCandidate' in candidateGroup) {
                        const totalFingerprintScore =
                            candidateGroup.candidates.reduce(
                                (total, googleCandidate) =>
                                    total +
                                    (scoredFingerprints.get(
                                        googleCandidate.id,
                                    ) ?? 0),
                                0,
                            );

                        const fingerprintScore =
                            totalFingerprintScore /
                            candidateGroup.candidates.length;

                        const hardCriteriaScore =
                            normalizeGoogleExactMatchScore(
                                scoreGoogleCandidates(
                                    candidateGroup.candidates,
                                    positiveReferenceInput,
                                    negativeReferenceInput,
                                ),
                            );

                        return {
                            ...candidateGroup,
                            hardCriteriaScore,
                            fingerprintScore,
                            unifiedScore:
                                (hardCriteriaScore + fingerprintScore) / 2,
                        };
                    }
                    const totalFingerprintScore =
                        candidateGroup.candidateIds.reduce(
                            (total, id) =>
                                total + (scoredFingerprints.get(id) ?? 0),
                            0,
                        );

                    const fingerprintScore =
                        totalFingerprintScore /
                        candidateGroup.candidateIds.length;

                    const hardCriteriaScore = normalizeExactMatchScore(
                        scoreCandidate(
                            candidateGroup,
                            positiveReferenceInput,
                            negativeReferenceInput,
                        ),
                    );

                    return {
                        ...candidateGroup,
                        hardCriteriaScore,
                        fingerprintScore,
                        unifiedScore:
                            (hardCriteriaScore + fingerprintScore) / 2,
                    };
                })
                .sort((a, b) =>
                    sortByHardCriteriaScore
                        ? b.hardCriteriaScore - a.hardCriteriaScore
                        : b.unifiedScore - a.unifiedScore,
                ),
        [
            pendingCandidates,
            scoredFingerprints,
            positiveReferenceInput,
            negativeReferenceInput,
            sortByHardCriteriaScore,
        ],
    );

    return {
        pending: sortedPendingCandidates,
        included: includedCandidates, // likely correctly sorted
        excluded: excludedCandidates, // likely correctly sorted
        decided: decidedCandidates,
    };
};
