/*
 * Decompiled with CFR 0.152.
 */
package me.moros.bending.common.collision;

import java.util.Arrays;
import java.util.Comparator;
import me.moros.bending.api.collision.geometry.AABB;
import me.moros.bending.common.collision.AABBUtil;
import me.moros.bending.common.collision.Boundable;
import me.moros.bending.common.collision.CollisionQuery;
import me.moros.bending.common.collision.CollisionQueryImpl;
import me.moros.bending.common.collision.MortonEncoded;
import org.jspecify.annotations.Nullable;

public class LBVH<E extends Boundable> {
    private final Node<E>[] treeNodes;
    private final Node<E>[] leafNodes;

    private LBVH(Node<E>[] treeNodes, Node<E>[] leafNodes) {
        this.treeNodes = treeNodes;
        this.leafNodes = leafNodes;
    }

    public int size() {
        return this.leafNodes.length;
    }

    private Node<E> root() {
        return this.treeNodes[0];
    }

    public CollisionQuery<E> queryAll() {
        CollisionQueryImpl result = new CollisionQueryImpl();
        Node<E> root = this.root();
        for (Node<E> leaf : this.leafNodes) {
            this.recursiveQuery((Boundable)leaf.element, root, result);
        }
        return result;
    }

    public CollisionQuery<E> query(E element) {
        CollisionQueryImpl result = new CollisionQueryImpl();
        Node<E> root = this.root();
        this.recursiveQuery(element, root, result);
        return result;
    }

    private void recursiveQuery(E toCheck, Node<E> node, CollisionQueryImpl<E> potential) {
        if (node.element == toCheck) {
            return;
        }
        if (toCheck.box().intersects(node.box)) {
            if (node.element != null) {
                potential.add((Boundable)toCheck, (Boundable)node.element);
            } else {
                this.recursiveQuery(toCheck, node.left, potential);
                this.recursiveQuery(toCheck, node.right, potential);
            }
        }
    }

    public static <E extends Boundable & MortonEncoded> LBVH<E> buildTree(E[] elements) {
        int i;
        Arrays.sort(elements, Comparator.comparingInt(rec$ -> ((MortonEncoded)rec$).morton()));
        int length = elements.length;
        int leafLength = length - 1;
        Node[] treeNodes = new Node[leafLength];
        Node[] leafNodes = new Node[length];
        for (i = 0; i < length; ++i) {
            if (i < leafLength) {
                treeNodes[i] = new Node();
            }
            Node node = new Node();
            node.element = elements[i];
            node.box = ((Boundable)node.element).box();
            leafNodes[i] = node;
        }
        for (i = 0; i < treeNodes.length; ++i) {
            LBVH.generateNode((MortonEncoded[])elements, treeNodes, leafNodes, i);
        }
        LBVH.calculateVolumeHierarchy(treeNodes[0]);
        return new LBVH<E>(treeNodes, leafNodes);
    }

    private static void calculateVolumeHierarchy(Node<?> node) {
        if (node.element != null) {
            return;
        }
        LBVH.calculateVolumeHierarchy(node.left);
        LBVH.calculateVolumeHierarchy(node.right);
        node.box = AABBUtil.combine(node.left.box, node.right.box);
    }

    private static <E> void generateNode(MortonEncoded[] sorted, Node<E>[] treeNodes, Node<E>[] leafNodes, int idx) {
        Range range = LBVH.determineRange(sorted, idx);
        int split = LBVH.findSplit(sorted, range.start(), range.end());
        Node<E> left = split == range.start() ? leafNodes[split] : treeNodes[split];
        Node<E> right = split + 1 == range.end() ? leafNodes[split + 1] : treeNodes[split + 1];
        Node<E> node = treeNodes[idx];
        node.left = left;
        node.right = right;
        left.parent = node;
        right.parent = node;
    }

    private static Range determineRange(MortonEncoded[] sorted, int index) {
        int d_min;
        int dir;
        int tempRight;
        int lastIndex = sorted.length - 1;
        if (index == 0) {
            return new Range(0, lastIndex);
        }
        int initialIndex = index;
        int prevMorton = sorted[index - 1].morton();
        int currMorton = sorted[index].morton();
        int nextMorton = sorted[index + 1].morton();
        if (prevMorton == currMorton && nextMorton == currMorton) {
            while (index > 0 && index < lastIndex && ++index < lastIndex && sorted[index].morton() == sorted[index + 1].morton()) {
            }
            return new Range(initialIndex, index);
        }
        int tempLeft = Integer.numberOfLeadingZeros(currMorton ^ prevMorton);
        if (tempLeft > (tempRight = Integer.numberOfLeadingZeros(currMorton ^ nextMorton))) {
            dir = -1;
            d_min = tempRight;
        } else {
            dir = 1;
            d_min = tempLeft;
        }
        int l_max = 2;
        int testIndex = index + l_max * dir;
        while (testIndex <= lastIndex && testIndex >= 0 && Integer.numberOfLeadingZeros(currMorton ^ sorted[testIndex].morton()) > d_min) {
            testIndex = index + (l_max *= 2) * dir;
        }
        int l = 0;
        int div = 2;
        while (l_max / div >= 1) {
            int splitPrefix;
            int t = l_max / div;
            int newTest = index + (l + t) * dir;
            if (newTest <= lastIndex && newTest >= 0 && (splitPrefix = Integer.numberOfLeadingZeros(currMorton ^ sorted[newTest].morton())) > d_min) {
                l += t;
            }
            div *= 2;
        }
        if (dir == 1) {
            return new Range(index, index + l * dir);
        }
        return new Range(index + l * dir, index);
    }

    private static int findSplit(MortonEncoded[] sorted, int first, int last) {
        int lastCode;
        int firstCode = sorted[first].morton();
        if (firstCode == (lastCode = sorted[last].morton())) {
            return first;
        }
        int commonPrefix = Integer.numberOfLeadingZeros(firstCode ^ lastCode);
        int split = first;
        int step = last - first;
        do {
            int splitCode;
            int splitPrefix;
            int newSplit;
            if ((newSplit = split + (step = step + 1 >> 1)) >= last || (splitPrefix = Integer.numberOfLeadingZeros(firstCode ^ (splitCode = sorted[newSplit].morton()))) <= commonPrefix) continue;
            split = newSplit;
        } while (step > 1);
        return split;
    }

    private static final class Node<E> {
        Node<E> left = null;
        Node<E> right = null;
        Node<E> parent = null;
        @Nullable E element = null;
        AABB box;

        private Node() {
        }
    }

    private record Range(int start, int end) {
    }
}

