package pal.eval;

import pal.alignment.SitePattern;
import pal.datatype.DataType;
import pal.substmodel.RateMatrix;
import pal.tree.Node;
import pal.tree.NodeUtils;
import pal.tree.Tree;
import pal.tree.TreeUtils;

/* loaded from: input_file:pal/eval/SimpleLikelihoodCalculator.class */
public class SimpleLikelihoodCalculator implements LikelihoodCalculator {
    SitePattern sitePattern_;
    Tree tree_;
    RateMatrix model_;
    DataType patternDatatype_;
    boolean modelChanged_ = false;
    private double[][][] partials_;
    private int numberOfStates_;
    private int numberOfPatterns_;
    private double[] frequency_;
    private double[] siteLogL_;

    public SimpleLikelihoodCalculator(SitePattern sitePattern) {
        setPattern(sitePattern);
    }

    private void setPattern(SitePattern sitePattern) {
        this.sitePattern_ = sitePattern;
        this.patternDatatype_ = this.sitePattern_.getDataType();
        this.numberOfPatterns_ = this.sitePattern_.numPatterns;
        this.siteLogL_ = new double[this.numberOfPatterns_];
    }

    public SimpleLikelihoodCalculator(SitePattern sitePattern, Tree tree, RateMatrix rateMatrix) {
        setPattern(sitePattern);
        setTree(tree);
        setRateMatrix(rateMatrix);
    }

    @Override // pal.eval.LikelihoodCalculator
    public void release() {
    }

    @Override // pal.eval.LikelihoodCalculator
    public double calculateLogLikelihood() {
        return treeLikelihood();
    }

    public SitePattern getSitePattern() {
        return this.sitePattern_;
    }

    public Tree getTree() {
        return this.tree_;
    }

    public void setRateMatrix(RateMatrix rateMatrix) {
        if (rateMatrix == null) {
            throw new RuntimeException("Assertion error : SetModel called with null model!");
        }
        this.model_ = rateMatrix;
        this.frequency_ = this.model_.getEquilibriumFrequencies();
        this.numberOfStates_ = this.model_.getDataType().getNumStates();
        allocatePartialMemory((2 * this.sitePattern_.getSequenceCount()) - 2);
    }

    public void setTree(Tree tree) {
        this.tree_ = tree;
        if (tree == null) {
            throw new RuntimeException("Assertion error : SetTree called with null tree!");
        }
        int[] mapExternalIdentifiers = TreeUtils.mapExternalIdentifiers(this.sitePattern_, this.tree_);
        for (int i = 0; i < this.tree_.getExternalNodeCount(); i++) {
            this.tree_.getExternalNode(i).setSequence(this.sitePattern_.pattern[mapExternalIdentifiers[i]]);
        }
    }

    public final void modelUpdated() {
        setRateMatrix(this.model_);
    }

    public final void treeUpdated() {
        setTree(this.tree_);
    }

    private void allocatePartialMemory(int i) {
        if (this.partials_ != null && i == this.partials_.length && this.numberOfPatterns_ == this.partials_[0].length && this.numberOfStates_ == this.partials_[0][0].length) {
            return;
        }
        this.partials_ = new double[i][this.numberOfPatterns_][this.numberOfStates_];
    }

    private int getKey(Node node) {
        return node.isLeaf() ? node.getNumber() : node.getNumber() + this.tree_.getExternalNodeCount();
    }

    protected double[][] getPartial(Node node) {
        return this.partials_[getKey(node)];
    }

    private Node getNextBranchOrRoot(Node node, Node node2) {
        int childCount = node2.getChildCount();
        int i = 0;
        while (i < childCount && node2.getChild(i) != node) {
            i++;
        }
        int i2 = i + 1;
        if (i2 > childCount) {
            i2 = 0;
        }
        return i2 == childCount ? node2 : node2.getChild(i2);
    }

    protected Node getNextBranch(Node node, Node node2) {
        Node nextBranchOrRoot = getNextBranchOrRoot(node, node2);
        if (nextBranchOrRoot.isRoot()) {
            nextBranchOrRoot = nextBranchOrRoot.getChild(0);
        }
        return nextBranchOrRoot;
    }

    protected void productPartials(Node node) {
        NodeUtils.getUnrootedBranchCount(node);
        double[][] partial = getPartial(node.getChild(0));
        for (int i = 1; i < node.getChildCount(); i++) {
            double[][] partial2 = getPartial(node.getChild(i));
            for (int i2 = 0; i2 < this.numberOfPatterns_; i2++) {
                double[] dArr = partial[i2];
                double[] dArr2 = partial2[i2];
                for (int i3 = 0; i3 < this.numberOfStates_; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] * dArr2[i3];
                }
            }
        }
    }

    protected void partialsInternal(Node node) {
        double[][] partial = getPartial(node);
        double[][] partial2 = getPartial(node.getChild(0));
        this.model_.setDistance(node.getBranchLength());
        for (int i = 0; i < this.numberOfPatterns_; i++) {
            double[] dArr = partial[i];
            double[] dArr2 = partial2[i];
            for (int i2 = 0; i2 < this.numberOfStates_; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < this.numberOfStates_; i3++) {
                    d += this.model_.getTransitionProbability(i2, i3) * dArr2[i3];
                }
                dArr[i2] = d;
            }
        }
    }

    protected void partialsExternal(Node node) {
        double[][] partial = getPartial(node);
        byte[] sequence = node.getSequence();
        this.model_.setDistance(node.getBranchLength());
        for (int i = 0; i < this.numberOfPatterns_; i++) {
            double[] dArr = partial[i];
            byte b = sequence[i];
            if (this.patternDatatype_.isUnknownState(b)) {
                for (int i2 = 0; i2 < this.numberOfStates_; i2++) {
                    dArr[i2] = 1.0d;
                }
            } else {
                for (int i3 = 0; i3 < this.numberOfStates_; i3++) {
                    dArr[i3] = this.model_.getTransitionProbability(i3, b);
                }
            }
        }
    }

    private void traverseTree(Node node) {
        if (node.isLeaf()) {
            partialsExternal(node);
            return;
        }
        for (int i = 0; i < node.getChildCount(); i++) {
            traverseTree(node.getChild(i));
        }
        if (node.isRoot()) {
            return;
        }
        productPartials(node);
        partialsInternal(node);
    }

    private int getBranchCount(Node node) {
        return node.isRoot() ? node.getChildCount() : node.getChildCount() + 1;
    }

    private double treeLikelihood() {
        Node root = this.tree_.getRoot();
        traverseTree(root);
        Node child = root.getChild(0);
        root.getChild(root.getChildCount() - 1);
        double[][] partial = getPartial(child);
        productPartials(root);
        double d = 0.0d;
        for (int i = 0; i < this.numberOfPatterns_; i++) {
            double d2 = 0.0d;
            double[] dArr = partial[i];
            for (int i2 = 0; i2 < this.numberOfStates_; i2++) {
                d2 += this.frequency_[i2] * dArr[i2];
            }
            this.siteLogL_[i] = Math.log(d2);
            d += this.siteLogL_[i] * this.sitePattern_.weight[i];
        }
        return d;
    }
}
