package org.broadinstitute.gatk.tools.walkers.variantrecalibration;

import freemarker.template.Template;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFConstants;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import org.apache.log4j.helpers.DateLayout;
import org.broadinstitute.gatk.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.gatk.tools.walkers.variantrecalibration.VariantRecalibratorArgumentCollection;
import org.broadinstitute.gatk.utils.GenomeLoc;
import org.broadinstitute.gatk.utils.MathUtils;
import org.broadinstitute.gatk.utils.Utils;
import org.broadinstitute.gatk.utils.collections.ExpandingArrayList;
import org.broadinstitute.gatk.utils.exceptions.UserException;
import org.broadinstitute.gatk.utils.refdata.RefMetaDataTracker;
import org.broadinstitute.gatk.utils.variant.GATKVCFConstants;

/* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.class */
public class VariantDataManager {
    private List<VariantDatum> data;
    private double[] meanVector;
    private double[] varianceVector;
    public List<String> annotationKeys;
    private final VariantRecalibratorArgumentCollection VRAC;
    protected static final Logger logger = Logger.getLogger(VariantDataManager.class);
    protected final List<TrainingSet> trainingSets = new ArrayList();
    private static final double SAFETY_OFFSET = 0.01d;
    private static final double PRECISION = 0.01d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager$MyDoubleForSorting.class */
    public class MyDoubleForSorting implements Comparable<MyDoubleForSorting> {
        final Double myData;
        final int originalIndex;

        public MyDoubleForSorting(double d, int i) {
            this.myData = Double.valueOf(d);
            this.originalIndex = i;
        }

        @Override // java.lang.Comparable
        public int compareTo(MyDoubleForSorting myDoubleForSorting) {
            return this.myData.compareTo(myDoubleForSorting.myData);
        }
    }

    public VariantDataManager(List<String> list, VariantRecalibratorArgumentCollection variantRecalibratorArgumentCollection) {
        this.data = Collections.emptyList();
        this.data = Collections.emptyList();
        this.annotationKeys = new ArrayList(list);
        this.VRAC = variantRecalibratorArgumentCollection;
        this.meanVector = new double[this.annotationKeys.size()];
        this.varianceVector = new double[this.annotationKeys.size()];
    }

    public void setData(List<VariantDatum> list) {
        this.data = list;
    }

    public List<VariantDatum> getData() {
        return this.data;
    }

    public void normalizeData() {
        boolean z = false;
        for (int i = 0; i < this.meanVector.length; i++) {
            double mean = mean(i, true);
            double standardDeviation = standardDeviation(mean, i, true);
            logger.info(this.annotationKeys.get(i) + String.format(": \t mean = %.2f\t standard deviation = %.2f", Double.valueOf(mean), Double.valueOf(standardDeviation)));
            if (Double.isNaN(mean)) {
                throw new UserException.BadInput("Values for " + this.annotationKeys.get(i) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
            }
            z = z || standardDeviation < 1.0E-5d;
            this.meanVector[i] = mean;
            this.varianceVector[i] = standardDeviation;
            for (VariantDatum variantDatum : this.data) {
                variantDatum.annotations[i] = variantDatum.isNull[i] ? 0.1d * Utils.getRandomGenerator().nextGaussian() : (variantDatum.annotations[i] - mean) / standardDeviation;
            }
        }
        if (z) {
            throw new UserException.BadInput("Found annotations with zero variance. They must be excluded before proceeding.");
        }
        for (VariantDatum variantDatum2 : this.data) {
            boolean z2 = false;
            for (double d : variantDatum2.annotations) {
                z2 = z2 || Math.abs(d) > this.VRAC.STD_THRESHOLD;
            }
            variantDatum2.failingSTDThreshold = z2;
        }
        List<Integer> calculateSortOrder = calculateSortOrder(this.meanVector);
        this.annotationKeys = reorderList(this.annotationKeys, calculateSortOrder);
        this.varianceVector = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(this.varianceVector), calculateSortOrder));
        this.meanVector = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(this.meanVector), calculateSortOrder));
        for (VariantDatum variantDatum3 : this.data) {
            variantDatum3.annotations = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(variantDatum3.annotations), calculateSortOrder));
            variantDatum3.isNull = ArrayUtils.toPrimitive((Boolean[]) reorderArray(ArrayUtils.toObject(variantDatum3.isNull), calculateSortOrder));
        }
        logger.info("Annotations are now ordered by their information content: " + this.annotationKeys.toString());
    }

    public double[] getMeanVector() {
        return this.meanVector;
    }

    public double[] getVarianceVector() {
        return this.varianceVector;
    }

    protected List<Integer> calculateSortOrder(double[] dArr) {
        ArrayList arrayList = new ArrayList(dArr.length);
        ArrayList arrayList2 = new ArrayList(dArr.length);
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i;
            i++;
            arrayList2.add(new MyDoubleForSorting((-1.0d) * Math.abs(dArr[i2] - mean(i2, false)), i3));
        }
        Collections.sort(arrayList2);
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            arrayList.add(Integer.valueOf(((MyDoubleForSorting) it2.next()).originalIndex));
        }
        return arrayList;
    }

    private <T> T[] reorderArray(T[] tArr, List<Integer> list) {
        return (T[]) reorderList(Arrays.asList(tArr), list).toArray(tArr);
    }

    private <T> List<T> reorderList(List<T> list, List<Integer> list2) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Integer> it2 = list2.iterator();
        while (it2.hasNext()) {
            arrayList.add(list.get(it2.next().intValue()));
        }
        return arrayList;
    }

    public double denormalizeDatum(double d, int i) {
        return (d * this.varianceVector[i]) + this.meanVector[i];
    }

    public void addTrainingSet(TrainingSet trainingSet) {
        this.trainingSets.add(trainingSet);
    }

    public List<String> getAnnotationKeys() {
        return this.annotationKeys;
    }

    public boolean checkHasTrainingSet() {
        Iterator<TrainingSet> it2 = this.trainingSets.iterator();
        while (it2.hasNext()) {
            if (it2.next().isTraining) {
                return true;
            }
        }
        return false;
    }

    public boolean checkHasTruthSet() {
        Iterator<TrainingSet> it2 = this.trainingSets.iterator();
        while (it2.hasNext()) {
            if (it2.next().isTruth) {
                return true;
            }
        }
        return false;
    }

    public List<VariantDatum> getTrainingData() {
        ExpandingArrayList expandingArrayList = new ExpandingArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum.atTrainingSite && !variantDatum.failingSTDThreshold) {
                expandingArrayList.add(variantDatum);
            }
        }
        logger.info("Training with " + expandingArrayList.size() + " variants after standard deviation thresholding.");
        if (expandingArrayList.size() < this.VRAC.MIN_NUM_BAD_VARIANTS) {
            logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable.");
        } else if (expandingArrayList.size() > this.VRAC.MAX_NUM_TRAINING_DATA) {
            logger.warn("WARNING: Very large training set detected. Downsampling to " + this.VRAC.MAX_NUM_TRAINING_DATA + " training variants.");
            Collections.shuffle(expandingArrayList, Utils.getRandomGenerator());
            return expandingArrayList.subList(0, this.VRAC.MAX_NUM_TRAINING_DATA);
        }
        return expandingArrayList;
    }

    public List<VariantDatum> selectWorstVariants() {
        ExpandingArrayList expandingArrayList = new ExpandingArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum != null && !variantDatum.failingSTDThreshold && !Double.isInfinite(variantDatum.lod) && variantDatum.lod < this.VRAC.BAD_LOD_CUTOFF) {
                variantDatum.atAntiTrainingSite = true;
                expandingArrayList.add(variantDatum);
            }
        }
        logger.info("Training with worst " + expandingArrayList.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", Double.valueOf(this.VRAC.BAD_LOD_CUTOFF)) + ".");
        return expandingArrayList;
    }

    public List<VariantDatum> getEvaluationData() {
        ExpandingArrayList expandingArrayList = new ExpandingArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum != null && !variantDatum.failingSTDThreshold && !variantDatum.atTrainingSite && !variantDatum.atAntiTrainingSite) {
                expandingArrayList.add(variantDatum);
            }
        }
        return expandingArrayList;
    }

    public void dropAggregateData() {
        Iterator<VariantDatum> it2 = this.data.iterator();
        while (it2.hasNext()) {
            if (it2.next().isAggregate) {
                it2.remove();
            }
        }
    }

    public List<VariantDatum> getRandomDataForPlotting(int i, List<VariantDatum> list, List<VariantDatum> list2, List<VariantDatum> list3) {
        ExpandingArrayList expandingArrayList = new ExpandingArrayList();
        Collections.shuffle(list, Utils.getRandomGenerator());
        Collections.shuffle(list2, Utils.getRandomGenerator());
        Collections.shuffle(list3, Utils.getRandomGenerator());
        expandingArrayList.addAll(list.subList(0, Math.min(i, list.size())));
        expandingArrayList.addAll(list2.subList(0, Math.min(i, list2.size())));
        expandingArrayList.addAll(list3.subList(0, Math.min(i, list3.size())));
        Collections.shuffle(expandingArrayList, Utils.getRandomGenerator());
        return expandingArrayList;
    }

    protected double mean(int i, boolean z) {
        double d = 0.0d;
        int i2 = 0;
        for (VariantDatum variantDatum : this.data) {
            if (z == variantDatum.atTrainingSite && !variantDatum.isNull[i]) {
                d += variantDatum.annotations[i];
                i2++;
            }
        }
        return d / i2;
    }

    protected double standardDeviation(double d, int i, boolean z) {
        double d2 = 0.0d;
        int i2 = 0;
        for (VariantDatum variantDatum : this.data) {
            if (z == variantDatum.atTrainingSite && !variantDatum.isNull[i]) {
                d2 += (variantDatum.annotations[i] - d) * (variantDatum.annotations[i] - d);
                i2++;
            }
        }
        return Math.sqrt(d2 / i2);
    }

    public void decodeAnnotations(VariantDatum variantDatum, VariantContext variantContext, boolean z) {
        double[] dArr = new double[this.annotationKeys.size()];
        boolean[] zArr = new boolean[this.annotationKeys.size()];
        int i = 0;
        for (String str : this.annotationKeys) {
            zArr[i] = false;
            dArr[i] = decodeAnnotation(str, variantContext, z, this.VRAC, variantDatum);
            if (Double.isNaN(dArr[i])) {
                zArr[i] = true;
            }
            i++;
        }
        variantDatum.annotations = dArr;
        variantDatum.isNull = zArr;
    }

    private static double logitTransform(double d, double d2, double d3) {
        return Math.log((d - d2) / (d3 - d));
    }

    private static double decodeAnnotation(String str, VariantContext variantContext, boolean z, VariantRecalibratorArgumentCollection variantRecalibratorArgumentCollection, VariantDatum variantDatum) {
        double d;
        try {
            if (variantRecalibratorArgumentCollection.useASannotations && str.startsWith(GATKVCFConstants.ALLELE_SPECIFIC_PREFIX)) {
                List<Object> attributeAsList = variantContext.getAttributeAsList(str);
                if (!variantContext.hasAllele(variantDatum.alternateAllele)) {
                    throw new IllegalStateException("VariantDatum allele " + variantDatum.alternateAllele + " is not contained in the input VariantContext.");
                }
                d = Double.parseDouble((String) attributeAsList.get(variantContext.getAlleleIndex(variantDatum.alternateAllele) - 1));
            } else {
                d = variantContext.getAttributeAsDouble(str, Double.NaN);
            }
            if (Double.isInfinite(d)) {
                d = Double.NaN;
            }
            if (z && str.equalsIgnoreCase(GATKVCFConstants.HAPLOTYPE_SCORE_KEY) && MathUtils.compareDoubles(d, StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION, 0.01d) == 0) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && ((str.equalsIgnoreCase(GATKVCFConstants.FISHER_STRAND_KEY) || str.equalsIgnoreCase(GATKVCFConstants.AS_FILTER_STATUS_KEY)) && MathUtils.compareDoubles(d, StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION, 0.01d) == 0)) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && str.equalsIgnoreCase(GATKVCFConstants.INBREEDING_COEFFICIENT_KEY) && MathUtils.compareDoubles(d, StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION, 0.01d) == 0) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && ((str.equalsIgnoreCase(GATKVCFConstants.STRAND_ODDS_RATIO_KEY) || str.equalsIgnoreCase(GATKVCFConstants.AS_STRAND_ODDS_RATIO_KEY)) && MathUtils.compareDoubles(d, 0.6931472d, 0.01d) == 0)) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && (str.equalsIgnoreCase(VCFConstants.RMS_MAPPING_QUALITY_KEY) || str.equalsIgnoreCase(GATKVCFConstants.AS_RMS_MAPPING_QUALITY_KEY))) {
                if (variantRecalibratorArgumentCollection.MQ_CAP > 0) {
                    d = logitTransform(d, -0.01d, variantRecalibratorArgumentCollection.MQ_CAP + 0.01d);
                    if (MathUtils.compareDoubles(d, logitTransform(variantRecalibratorArgumentCollection.MQ_CAP, -0.01d, variantRecalibratorArgumentCollection.MQ_CAP + 0.01d), 0.01d) == 0) {
                        d += variantRecalibratorArgumentCollection.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                    }
                } else if (MathUtils.compareDoubles(d, variantRecalibratorArgumentCollection.MQ_CAP, 0.01d) == 0) {
                    d += variantRecalibratorArgumentCollection.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                }
            }
        } catch (Exception e) {
            d = Double.NaN;
        }
        return d;
    }

    public void parseTrainingSets(RefMetaDataTracker refMetaDataTracker, GenomeLoc genomeLoc, VariantContext variantContext, VariantDatum variantDatum, boolean z) {
        variantDatum.isKnown = false;
        variantDatum.atTruthSite = false;
        variantDatum.atTrainingSite = false;
        variantDatum.atAntiTrainingSite = false;
        variantDatum.prior = 2.0d;
        for (TrainingSet trainingSet : this.trainingSets) {
            for (VariantContext variantContext2 : refMetaDataTracker.getValues(trainingSet.rodBinding, genomeLoc)) {
                if (!this.VRAC.useASannotations || doAllelesMatch(variantContext2, variantDatum)) {
                    if (isValidVariant(variantContext, variantContext2, z)) {
                        variantDatum.isKnown = variantDatum.isKnown || trainingSet.isKnown;
                        variantDatum.atTruthSite = variantDatum.atTruthSite || trainingSet.isTruth;
                        variantDatum.atTrainingSite = variantDatum.atTrainingSite || trainingSet.isTraining;
                        variantDatum.prior = Math.max(variantDatum.prior, trainingSet.prior);
                        variantDatum.consensusCount += trainingSet.isConsensus ? 1 : 0;
                    }
                    if (variantContext2 != null) {
                        variantDatum.atAntiTrainingSite = variantDatum.atAntiTrainingSite || trainingSet.isAntiTraining;
                    }
                }
            }
        }
    }

    private boolean isValidVariant(VariantContext variantContext, VariantContext variantContext2, boolean z) {
        return variantContext2 != null && variantContext2.isNotFiltered() && variantContext2.isVariant() && checkVariationClass(variantContext, variantContext2) && (z || !variantContext2.hasGenotypes() || variantContext2.isPolymorphicInSamples());
    }

    private boolean doAllelesMatch(VariantContext variantContext, VariantDatum variantDatum) {
        return variantDatum.alternateAllele == null || variantContext.getAlternateAlleles().contains(variantDatum.alternateAllele);
    }

    protected static boolean checkVariationClass(VariantContext variantContext, VariantContext variantContext2) {
        switch (variantContext2.getType()) {
            case SNP:
            case MNP:
                return checkVariationClass(variantContext, VariantRecalibratorArgumentCollection.Mode.SNP);
            case INDEL:
            case MIXED:
            case SYMBOLIC:
                return checkVariationClass(variantContext, VariantRecalibratorArgumentCollection.Mode.INDEL);
            default:
                return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean checkVariationClass(VariantContext variantContext, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP:
                return variantContext.isSNP() || variantContext.isMNP();
            case INDEL:
                return variantContext.isStructuralIndel() || variantContext.isIndel() || variantContext.isMixed() || variantContext.isSymbolic();
            case BOTH:
                return true;
            default:
                throw new IllegalStateException("Encountered unknown recal mode: " + mode);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean checkVariationClass(VariantContext variantContext, Allele allele, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP:
                return variantContext.getReference().length() == allele.length();
            case INDEL:
                return variantContext.getReference().length() != allele.length() || allele.isSymbolic();
            case BOTH:
                return true;
            default:
                throw new IllegalStateException("Encountered unknown recal mode: " + mode);
        }
    }

    public void writeOutRecalibrationTable(VariantContextWriter variantContextWriter) {
        Collections.sort(this.data, new Comparator<VariantDatum>() { // from class: org.broadinstitute.gatk.tools.walkers.variantrecalibration.VariantDataManager.1
            @Override // java.util.Comparator
            public int compare(VariantDatum variantDatum, VariantDatum variantDatum2) {
                return variantDatum.loc.compareTo(variantDatum2.loc);
            }
        });
        List asList = Arrays.asList(Allele.create(Template.NO_NS_PREFIX, true), Allele.create("<VQSR>", false));
        for (VariantDatum variantDatum : this.data) {
            if (this.VRAC.useASannotations) {
                asList = Arrays.asList(variantDatum.referenceAllele, variantDatum.alternateAllele);
            }
            VariantContextBuilder variantContextBuilder = new VariantContextBuilder("VQSR", variantDatum.loc.getContig(), variantDatum.loc.getStart(), variantDatum.loc.getStop(), asList);
            variantContextBuilder.attribute(VCFConstants.END_KEY, Integer.valueOf(variantDatum.loc.getStop()));
            variantContextBuilder.attribute(GATKVCFConstants.VQS_LOD_KEY, String.format("%.4f", Double.valueOf(variantDatum.lod)));
            variantContextBuilder.attribute(GATKVCFConstants.CULPRIT_KEY, variantDatum.worstAnnotation != -1 ? this.annotationKeys.get(variantDatum.worstAnnotation) : DateLayout.NULL_DATE_FORMAT);
            if (variantDatum.atTrainingSite) {
                variantContextBuilder.attribute(GATKVCFConstants.POSITIVE_LABEL_KEY, true);
            }
            if (variantDatum.atAntiTrainingSite) {
                variantContextBuilder.attribute(GATKVCFConstants.NEGATIVE_LABEL_KEY, true);
            }
            variantContextWriter.add(variantContextBuilder.make());
        }
    }
}
