package pal.distance;

import java.io.Serializable;
import pal.alignment.SitePattern;
import pal.datatype.AmbiguousDataType;
import pal.datatype.DataType;
import pal.math.UnivariateFunction;
import pal.substmodel.SubstitutionModel;

/* loaded from: input_file:pal/distance/SequencePairLikelihood.class */
public class SequencePairLikelihood implements UnivariateFunction, Serializable {
    private SubstitutionModel model;
    private SitePattern sitePattern;
    private DataType patternDataType_;
    private int numPatterns;
    private int numSites;
    private int numStates;
    private int numberOfTransitionCategories_;
    private int[] weight;
    private double[] logEquilibriumFrequencies_;
    private double[][][] transitionProbabilityStores_;
    private double[] transitionCategoryProbabilities_;
    private int[][] fastMatchCount_;
    private byte[] seqPat1;
    private byte[] seqPat2;
    boolean isAmbiguous_;

    public SequencePairLikelihood(SitePattern sitePattern, SubstitutionModel substitutionModel) {
        updateSitePattern(sitePattern);
        updateModel(substitutionModel);
    }

    public void updateModel(SubstitutionModel substitutionModel) {
        this.model = substitutionModel;
        double[] equilibriumFrequencies = substitutionModel.getEquilibriumFrequencies();
        this.numberOfTransitionCategories_ = this.model.getNumberOfTransitionCategories();
        this.transitionCategoryProbabilities_ = this.model.getTransitionCategoryProbabilities();
        int numStates = this.model.getDataType().getNumStates();
        this.transitionProbabilityStores_ = new double[this.numberOfTransitionCategories_][numStates][numStates];
        this.fastMatchCount_ = new int[numStates][numStates];
        this.logEquilibriumFrequencies_ = new double[equilibriumFrequencies.length];
        for (int i = 0; i < equilibriumFrequencies.length; i++) {
            this.logEquilibriumFrequencies_[i] = Math.log(equilibriumFrequencies[i]);
        }
    }

    public void updateSitePattern(SitePattern sitePattern) {
        this.sitePattern = sitePattern;
        this.numPatterns = sitePattern.numPatterns;
        this.numSites = sitePattern.getSiteCount();
        this.patternDataType_ = sitePattern.getDataType();
        this.isAmbiguous_ = this.patternDataType_.isAmbiguous();
        this.numStates = this.patternDataType_.getNumStates();
        this.weight = sitePattern.weight;
    }

    public void setSequences(int i, int i2) {
        setSequences(this.sitePattern.pattern[i], this.sitePattern.pattern[i2]);
    }

    public void setSequences(byte[] bArr, byte[] bArr2) {
        this.seqPat1 = bArr;
        this.seqPat2 = bArr2;
    }

    private final void clearFastMatchCount(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                this.fastMatchCount_[i2][i3] = 0;
            }
        }
    }

    private final double evaluateAmbiguous(double d) {
        AmbiguousDataType ambiguousVersion = this.patternDataType_.getAmbiguousVersion();
        int numStates = ambiguousVersion.getSpecificDataType().getNumStates();
        this.model.getTransitionProbabilities(d, this.transitionProbabilityStores_);
        double d2 = 0.0d;
        clearFastMatchCount(numStates);
        for (int i = 0; i < this.numPatterns; i++) {
            byte b = this.seqPat1[i];
            byte b2 = this.seqPat2[i];
            boolean isUnknownState = this.patternDataType_.isUnknownState(b);
            boolean isUnknownState2 = this.patternDataType_.isUnknownState(b2);
            if (!isUnknownState && !isUnknownState2) {
                int[] specificStates = ambiguousVersion.getSpecificStates(b);
                int[] specificStates2 = ambiguousVersion.getSpecificStates(b2);
                for (int i2 : specificStates) {
                    for (int i3 : specificStates2) {
                        int[] iArr = this.fastMatchCount_[i2];
                        iArr[i3] = iArr[i3] + this.weight[i];
                    }
                }
            }
        }
        for (int i4 = 0; i4 < numStates; i4++) {
            for (int i5 = 0; i5 < numStates; i5++) {
                int i6 = this.fastMatchCount_[i4][i5];
                if (i6 > 0) {
                    double d3 = 0.0d;
                    for (int i7 = 0; i7 < this.numberOfTransitionCategories_; i7++) {
                        d3 += this.transitionProbabilityStores_[i7][i4][i5] * this.transitionCategoryProbabilities_[i7];
                    }
                    d2 += i6 * (this.logEquilibriumFrequencies_[i4] + Math.log(d3));
                }
            }
        }
        return -d2;
    }

    @Override // pal.math.UnivariateFunction
    public final double evaluate(double d) {
        if (this.isAmbiguous_) {
            return evaluateAmbiguous(d);
        }
        this.model.getTransitionProbabilities(d, this.transitionProbabilityStores_);
        double d2 = 0.0d;
        clearFastMatchCount(this.numStates);
        for (int i = 0; i < this.numPatterns; i++) {
            byte b = this.seqPat1[i];
            byte b2 = this.seqPat2[i];
            boolean isUnknownState = this.patternDataType_.isUnknownState(b);
            boolean isUnknownState2 = this.patternDataType_.isUnknownState(b2);
            if (!isUnknownState && !isUnknownState2) {
                int[] iArr = this.fastMatchCount_[b];
                iArr[b2] = iArr[b2] + this.weight[i];
            }
        }
        for (int i2 = 0; i2 < this.numStates; i2++) {
            for (int i3 = 0; i3 < this.numStates; i3++) {
                int i4 = this.fastMatchCount_[i2][i3];
                if (i4 > 0) {
                    double d3 = 0.0d;
                    for (int i5 = 0; i5 < this.numberOfTransitionCategories_; i5++) {
                        d3 += this.transitionProbabilityStores_[i5][i2][i3] * this.transitionCategoryProbabilities_[i5];
                    }
                    d2 += i4 * (this.logEquilibriumFrequencies_[i2] + Math.log(d3));
                }
            }
        }
        return -d2;
    }

    @Override // pal.math.UnivariateFunction
    public double getLowerBound() {
        return 1.0E-9d;
    }

    @Override // pal.math.UnivariateFunction
    public double getUpperBound() {
        return 100.0d;
    }
}
