import { MaxPriorityQueue } from '@datastructures-js/priority-queue';
import { log } from '@logtail/next';
import _ from 'lodash';
import { kmeans } from 'ml-kmeans';

export function toFixedNumber(num: number, digits: number, base?: number): number {
    const pow = Math.pow(base || 10, digits);
    return Math.round(num * pow) / pow;
}

export function toFixedFloor(num: number, digits: number, base?: number): number {
    const pow = Math.pow(base || 10, digits);
    return Math.floor(num * pow) / pow;
}

export function toFixedCeil(num: number, digits: number, base?: number): number {
    const pow = Math.pow(base || 10, digits);
    return Math.ceil(num * pow) / pow;
}

export function isValidNumber(toTest: any) {
    return toTest !== undefined
        && toTest !== ""
        // eslint-disable-next-line @typescript-eslint/no-unsafe-argument
        && !isNaN(toTest);
}

export function frameToXAxisType(frame: number, xaxisType: "frames" | "seconds" | "beats", fps: number, bpm: number) {
    switch (xaxisType) {
        case "frames":
            return frame.toFixed(0);
        case "seconds":
            return frameToSec(frame, fps).toFixed(3);
        case "beats":
            return frameToBeat(frame, fps, bpm).toFixed(2);
    }
}

export function xAxisTypeToFrame(position: number, xaxisType: "frames" | "seconds" | "beats", fps: number, bpm: number) {
    switch (xaxisType) {
        case "frames":
            return position;
        case "seconds":
            return secToFrame(position, fps);
        case "beats":
            return beatToFrame(position, fps, bpm);
    }
}

export function frameToBeat(frame: number, fps: number, bpm: number): number {
    return frame / ((fps * 60) / bpm);
}

export function frameToSec(frame: number, fps: number): number {
    return frame / fps;
}

export function secToFrame(sec: number, fps: number): number {
    return sec * fps;
}

export function beatToFrame(beat: number, fps: number, bpm: number): number {
    return beat * ((fps * 60) / bpm);
}

export function beatToSec(beat: number, bpm: number): number {
    return beat / bpm * 60;
}

export function secToBeat(sec: number, bpm: number): number {
    return sec * bpm / 60;
}


export function remapFrameCount(frame: number, keyframeLock: "frames" | "seconds" | "beats", oldFps: number, oldBpm: number, newFps: number, newBpm: number) {
    const lockedPosition = Number(frameToXAxisType(frame, keyframeLock, oldFps, oldBpm));
    const newFramePosition = xAxisTypeToFrame(lockedPosition, keyframeLock, newFps, newBpm);
    return newFramePosition;
}


export function calculateNiceStepSize(roughStepSize: number): number {
    const exponent = Math.floor(Math.log10(roughStepSize));
    const normalizedStepSize = roughStepSize / Math.pow(10, exponent);

    let niceStepSize;
    if (normalizedStepSize < 1.5) {
        niceStepSize = 1;
    } else if (normalizedStepSize < 3) {
        niceStepSize = 2;
    } else if (normalizedStepSize < 7) {
        niceStepSize = 5;
    } else {
        niceStepSize = 10;
    }

    return niceStepSize * Math.pow(10, exponent);
}


export function pickUsingProbabilities(probabilities: { [key: string]: number }): string {
    const totalWeight = Object.values(probabilities).reduce((a, b) => a + b, 0);
    let scaledRandom = Math.random() * totalWeight;

    for (const [key, weight] of Object.entries(probabilities)) {
        scaledRandom -= weight;
        if (scaledRandom < 0) {
            return key;
        }
    }

    throw new Error('Invalid weights configuration: weights must be 0 or above');
}

export function pickRandom<T>(arr: T[]): T {
    return arr[Math.floor(Math.random() * arr.length)];
}


/**
 * Normalize a Float32Array of numbers to the range [0, 1].
 * @param array The Float32Array of numbers to normalize.
 * @returns A new Float32Array with normalized values.
 */
export function normalizeFloat32Array(array: Float32Array): Float32Array {
    let min = Infinity;
    let max = -Infinity;

    // Single pass to find min and max
    for (let i = 0; i < array.length; i++) {
        if (array[i] < min) min = array[i];
        if (array[i] > max) max = array[i];
    }

    // Check for the edge case where all elements are the same
    if (max === min) {
        return new Float32Array(array.length).fill(0.5);
    }

    const range = max - min;
    const normalizedArray = new Float32Array(array.length);

    // Second pass to normalize values
    for (let i = 0; i < array.length; i++) {
        normalizedArray[i] = (array[i] - min) / range;
    }

    return normalizedArray;
}

export function firstBeatInSegment(bpm: number, offset: number, segmentStart: number) {
    const secondsPerBeat = 60 / bpm;
    const beatsBeforeSegmentStart = Math.floor((segmentStart - offset) / secondsPerBeat);
    const positionOfFirstBeatInSegment = offset + (beatsBeforeSegmentStart + 1) * secondsPerBeat;
    const beatOffset = positionOfFirstBeatInSegment - segmentStart;

    return { positionOfFirstBeatInSegment, beatOffset };
}


export interface Peak {
    position: number;
    value: number;
}

export function findNHighestPeaks(signal: Float32Array, maxPeakCount: number, minSampleGap: number): Peak[] {
    const peaks = new MaxPriorityQueue<Peak>((peak: Peak) => peak.value);

    // Find all peaks
    const allPeaks = _.reduce(signal, (result: Peak[], value, index) => {
        if ((index === 0 || value > signal[index - 1]) && 
            (index === signal.length - 1 || value > signal[index + 1])) {
            result.push({ position: index, value });
        }
        return result;
    }, []);

    // Add all peaks to the priority queue
    allPeaks.forEach(peak => peaks.enqueue(peak));

    // Extract N highest peaks with minimum gap
    const result: Peak[] = [];
    while (!peaks.isEmpty() && result.length < maxPeakCount) {
        const peak = peaks.dequeue();
        if (result.length === 0 || peak.position - _.last(result)!.position >= minSampleGap) {
            result.push(peak);
        }
    }

    return _.sortBy(result, 'position');
}


export function findAdaptiveThresholdPeaks(signal: Float32Array, targetPeakCount: number, minSampleGap: number): Peak[] {
    // Find all peaks in the signal
    const allPeaks: Peak[] = [];
    for (let i = 0; i < signal.length; i++) {
        if ((i === 0 || signal[i] > signal[i - 1]) && 
            (i === signal.length - 1 || signal[i] > signal[i + 1])) {
            allPeaks.push({ position: i, value: signal[i] });
        }
    }

    // Sort peaks by value in descending order
    allPeaks.sort((a, b) => b.value - a.value);

    // Apply minimum gap between peaks
    const applyMinGap = (peaks: Peak[], minGap: number): Peak[] => {
        const filteredPeaks: Peak[] = [];
        const peaksSortedByPosition = _.sortBy(peaks, 'position');
        const originalLength = peaksSortedByPosition.length;
        for (const peak of peaksSortedByPosition) {
            if (filteredPeaks.length === 0 || peak.position - _.last(filteredPeaks)!.position >= minGap) {
                filteredPeaks.push(peak);
            } else if (peak.value > _.last(filteredPeaks)!.value) {
                filteredPeaks[filteredPeaks.length - 1] = peak;
            }
        }
        log.debug(`Applied min gap ${minGap} to ${originalLength} peaks, got ${filteredPeaks.length} peaks.`);
        return filteredPeaks;
    };

    // Iteratively adjust the number of clusters to get the desired number of peaks
    // Try 1 to 10 clusters until we get within 20% of the target peak count
    let selectedPeaks: Peak[] = [];
    let bestErrorRate : number | undefined;
    const maxBuckets = Math.min(targetPeakCount * 2, allPeaks.length);
    for (let k = 2; k <= maxBuckets; k++) {
        // Perform k-means clustering
        const values = allPeaks.map(peak => [peak.value]);
        const time = performance.now();
        const kmeansResult = kmeans(values, k, { initialization: 'kmeans++' });
        const { clusters: clustersIds } = kmeansResult;

        // Group the peaks into the clusters, and order the cluster list so that the group with the highest peaks comes first.
        const peakClusterList : Peak[][] = [];
        clustersIds.forEach((clusterId, index) => {
            peakClusterList[clusterId] = peakClusterList[clusterId] || [];
            peakClusterList[clusterId].push(allPeaks[index]);
        });
        peakClusterList.sort((a, b) => b[0].value - a[0].value);

        // Clean up and merge clusters until we have selected enough or too many peaks
        let mergedClusters: Peak[] = [];
        for (const cluster of peakClusterList) {
            mergedClusters.push(...cluster);
            mergedClusters = applyMinGap(cluster, minSampleGap);

            // How close are we to the target peak count?
            const errorRate = Math.abs(targetPeakCount - mergedClusters.length) / targetPeakCount;
            if (!bestErrorRate || errorRate < bestErrorRate) {
                bestErrorRate = errorRate;
                selectedPeaks = _.cloneDeep(mergedClusters);
            }

            if (mergedClusters.length >= targetPeakCount) {
                // Error rate will only get worse from here, so stop.
                break;
            }
        }
        
        log.debug(`K-means clustering with k=${k}: ${performance.now() - time} ms. desired=${targetPeakCount}, got=${selectedPeaks.length} (${bestErrorRate?.toFixed(2)})`);

        if (_.isNumber(bestErrorRate) && bestErrorRate<0.2) {
            // No need to try smaller k size
            break;
        }

    }

    // Sort the final peaks by position
    return selectedPeaks.sort((a, b) => a.position - b.position);
}


export function zeroOutNonPeaks(signal: Float32Array, peaks: Peak[]): Float32Array {
    const modifiedSignal = new Float32Array(signal.length);
    modifiedSignal.fill(0);
    for (const peak of peaks) {
        if (peak.position >= 0 && peak.position < signal.length) {
            modifiedSignal[peak.position] = signal[peak.position];
        }
    }
    return modifiedSignal;
}


export function createPeakDecaySignal(
    signal: Float32Array, 
    selectedPeaks: Peak[],
    decayFactor: number = 0.9,
    originalSignalFactor: number = 0.5,
    originalSignalDecayFactor: number = 1,
): Float32Array {
    const modifiedSignal = new Float32Array(signal.length); 

    let currentPeakIndex = 0;
    for (let i = 0; i < signal.length; i++) {
        if (currentPeakIndex < selectedPeaks.length && i === selectedPeaks[currentPeakIndex].position) {
            // We're on a peak.
            modifiedSignal[i] = selectedPeaks[currentPeakIndex].value;
            currentPeakIndex++;
        } else if (currentPeakIndex > 0) {
            // We're not on a peak, so apply decay.
            const prevPeak = selectedPeaks[currentPeakIndex - 1];
            const nextPeak = selectedPeaks[currentPeakIndex] || { position: signal.length };
            const distanceFromPeak = i - prevPeak.position;
            const decayValue = prevPeak.value * Math.pow(decayFactor, distanceFromPeak);
            
            // Combine decay with reduced & decaying original signal. Pick the max.
            modifiedSignal[i] = Math.max(signal[i] * originalSignalFactor * Math.pow(originalSignalDecayFactor, distanceFromPeak), decayValue);

            // Stop decay if we've reached the next peak or zero
            if (i === nextPeak.position - 1 || modifiedSignal[i] <= 0) {
                currentPeakIndex++;
            }
        } else {
            // Before first peak, stay flat.
            modifiedSignal[i] = 0;
        }
    }

    // Normalize the signal to [-1, 1] range
    const maxAbsValue = Math.max(...modifiedSignal.map(Math.abs));
    if (maxAbsValue > 0) {
        for (let i = 0; i < modifiedSignal.length; i++) {
            modifiedSignal[i] = modifiedSignal[i] / maxAbsValue;
        }
    }

    return modifiedSignal;
}


export function closestPowerOf2Below(n: number): number {
    if (n <= 0) return 0;
    if (n === 1) return 1;

    // Convert to integer if it's not already
    n = Math.floor(n);
    // Subtract 1 to handle the case where n is already a power of 2
    n = n - 1;
    // Set all bits after the most significant bit to 1
    n |= n >> 1;
    n |= n >> 2;
    n |= n >> 4;
    n |= n >> 8;
    n |= n >> 16;
    // Add 1 to get the next power of 2, then divide by 2
    return (n + 1) >> 1;
}


export interface GateConfig {
    sustain: number;
    hysteresisPc: number;
    targetEvents: number;
    tolerance: number;
    minGap: number;
}

interface GateEvent {
    start: number;
    end: number;
    peakValue: number;
    peakPosition: number;
}

function findGateEventsAtSpecificThreshold(signal: Float32Array, threshold: number, config: GateConfig): GateEvent[] {
    const events: GateEvent[] = [];
    let isGateOpen = false;
    let sustainedCount = 0;
    let eventStart = 0;
    let peakValue = 0;
    let peakPosition = 0;
    const { sustain: sustainedSamples, hysteresisPc } = config;
    const hysteresis = threshold * (hysteresisPc / 100);

    for (let i = 0; i < signal.length; i++) {
        const absSample = Math.abs(signal[i]);
        if (isGateOpen) {
            if (absSample > peakValue) {
                peakValue = absSample;
                peakPosition = i;
            }
            if (absSample < threshold - hysteresis) {
                sustainedCount++;
                if (sustainedCount >= sustainedSamples) {
                    events.push({ start: eventStart, end: i, peakValue, peakPosition });
                    isGateOpen = false;
                    sustainedCount = 0;
                    peakValue = 0;
                    peakPosition = 0;
                }
            } else {
                sustainedCount = 0;
            }
        } else {
            if (absSample > threshold + hysteresis) {
                sustainedCount++;
                if (sustainedCount >= sustainedSamples) {
                    isGateOpen = true;
                    eventStart = i - sustainedSamples + 1;
                    peakValue = absSample;
                    peakPosition = i;
                    sustainedCount = 0;
                }
            } else {
                sustainedCount = 0;
            }
        }
    }

    if (isGateOpen) {
        events.push({ start: eventStart, end: signal.length - 1, peakValue, peakPosition });
    }

    return events;
}

// Finds the optimal threshold for a given signal and gate configuration.
export function findGateEventsWithProgressivelyReducingThreshold(signal: Float32Array, config: GateConfig): [number, GateEvent[], number] {
    const maxThreshold = Math.max(...signal.map((s) => _.isNaN(s) ? 0 : Math.abs(s)));
    let currentThreshold = maxThreshold;
    let allEvents: GateEvent[] = [];
    let attempts = 0;

    let bestThreshold = currentThreshold;
    let bestEvents: GateEvent[] = [];
    let stagnationCount = 0;

    const MAX_ATTEMPTS = 200;
    const NOISE_FLOOR = 0.000001;
    const STAGNATION_THRESHOLD = 5; // number of iterations with no improvement before reducing threshold more aggressively

    while (attempts < MAX_ATTEMPTS && currentThreshold>NOISE_FLOOR) {
        const newEvents = findGateEventsAtSpecificThreshold(signal, currentThreshold, config);
        allEvents = mergeEvents(allEvents, newEvents, config.minGap);
        if (allEvents.length > bestEvents.length) {
            bestThreshold = currentThreshold;
            bestEvents = [...allEvents];
        }

        if (allEvents.length === bestEvents.length) {
            stagnationCount++;
        }

        if (bestEvents.length > config.targetEvents) {
            return [currentThreshold, bestEvents, attempts + 1];
        }

        if (stagnationCount >= STAGNATION_THRESHOLD) {
            currentThreshold *= 0.9; // Reduce by 10% if no improvement in the last few iterations
        } else {
            currentThreshold *= 0.99; // Reduce by 1% each iteration
        }
        
        attempts++;

        log.debug("Gate iteration " + attempts, {attempts, targetEvents: config.targetEvents, bestEvents: bestEvents.length, currentThreshold});
    }

    return [bestThreshold, bestEvents, attempts];
}


// Merges two lists of gate events, combining events that are too close together.
function mergeEvents(existingEvents: GateEvent[], newEvents: GateEvent[], minGap:number): GateEvent[] {
    const mergedEvents = [...existingEvents, ...newEvents].sort((a, b) => a.peakPosition - b.peakPosition).reduce((result: GateEvent[], event) => {
        if (result.length === 0) {
            result.push(event);
        } else {
            const lastEvent = result[result.length - 1];
            if (Math.abs(lastEvent.peakPosition - event.peakPosition)<minGap) {
                // If incoming event is close to the previous event, merge them.
                // The resulting event is considered to start at the position where it was first detected,
                // but the peak value is the maximum of the two.
                lastEvent.peakValue = Math.max(lastEvent.peakValue, event.peakValue);
            } else {
                result.push(event);
            }
        }
        return result;
    }, []);
    return mergedEvents;
}

// Takes a signal and a list of gate events, and applies the gate events to the signal.
function applyGateEvents(signal: Float32Array, events: GateEvent[]): Float32Array {
    const gatedSignal = new Float32Array(signal.length).fill(0);
    for (const event of events) {
        gatedSignal[event.peakPosition] = event.peakValue;
    }
    return gatedSignal;
}

// Applies a gate to a signal, returning the gated signal and some statistics.
export function adaptiveGate(signal: Float32Array, config: GateConfig): [Float32Array, number, number, number, GateEvent[]] {
    const [optimalThreshold, events, attempts] = findGateEventsWithProgressivelyReducingThreshold(signal, config);
    const gatedSignal = applyGateEvents(signal, events);
    return [gatedSignal, optimalThreshold, events.length, attempts, events];
}

export function cleanNumberArrayFrom(arr: Float32Array | number[]): number[] {
    let lastValidValue = 0;
    const cleanedArr: number[] = new Array<number>(arr.length); 

    for (let i = 0; i < arr.length; i++) {
        const value = arr[i];
        if (value === null || isNaN(value) || !isFinite(value)) {
            cleanedArr[i] = lastValidValue;
        } else {
            lastValidValue = value;
            cleanedArr[i] = value;
        }
    }

    return cleanedArr;
}