/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.Tranche;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantDatum;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantRecalibratorArgumentCollection;
import org.broadinstitute.sting.utils.exceptions.UserException;

public class TrancheManager {
    protected static final Logger logger = Logger.getLogger(TrancheManager.class);

    public static List<Tranche> findTranches(ArrayList<VariantDatum> data, double[] tranches, SelectionMetric metric, VariantRecalibratorArgumentCollection.Mode model) {
        return TrancheManager.findTranches(data, tranches, metric, model, null);
    }

    public static List<Tranche> findTranches(ArrayList<VariantDatum> data, double[] trancheThresholds, SelectionMetric metric, VariantRecalibratorArgumentCollection.Mode model, File debugFile) {
        logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.length, data.size()));
        Collections.sort(data);
        metric.calculateRunningMetric(data);
        if (debugFile != null) {
            TrancheManager.writeTranchesDebuggingInfo(debugFile, data, metric);
        }
        ArrayList<Tranche> tranches = new ArrayList<Tranche>();
        for (double trancheThreshold : trancheThresholds) {
            Tranche t = TrancheManager.findTranche(data, metric, trancheThreshold, model);
            if (t == null) {
                if (tranches.size() != 0) break;
                throw new UserException(String.format("Couldn't find any tranche containing variants with a %s > %.2f. Are you sure the truth files contain unfiltered variants which overlap the input data?", metric.getName(), metric.getThreshold(trancheThreshold)));
            }
            tranches.add(t);
        }
        return tranches;
    }

    private static void writeTranchesDebuggingInfo(File f, List<VariantDatum> tranchesData, SelectionMetric metric) {
        try {
            PrintStream out = new PrintStream(f);
            out.println("Qual metricValue runningValue");
            for (int i = 0; i < tranchesData.size(); ++i) {
                VariantDatum d = tranchesData.get(i);
                int score = metric.datumValue(d);
                double runningValue = metric.getRunningMetric(i);
                out.printf("%.4f %d %.4f%n", d.lod, score, runningValue);
            }
            out.close();
        }
        catch (FileNotFoundException e) {
            throw new UserException.CouldNotCreateOutputFile(f, (Exception)e);
        }
    }

    public static Tranche findTranche(List<VariantDatum> data, SelectionMetric metric, double trancheThreshold, VariantRecalibratorArgumentCollection.Mode model) {
        logger.info(String.format("  Tranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold)));
        double metricThreshold = metric.getThreshold(trancheThreshold);
        int n = data.size();
        for (int i = 0; i < n; ++i) {
            if (!(metric.getRunningMetric(i) >= metricThreshold)) continue;
            Tranche t = TrancheManager.trancheOfVariants(data, i, trancheThreshold, model);
            logger.info(String.format("  Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ", trancheThreshold, metricThreshold, i, metric.getRunningMetric(i)));
            logger.info(String.format("  Tranche is %s", t));
            return t;
        }
        return null;
    }

    public static Tranche trancheOfVariants(List<VariantDatum> data, int minI, double ts, VariantRecalibratorArgumentCollection.Mode model) {
        int numKnown = 0;
        int numNovel = 0;
        int knownTi = 0;
        int knownTv = 0;
        int novelTi = 0;
        int novelTv = 0;
        double minLod = data.get((int)minI).lod;
        for (VariantDatum datum : data) {
            if (!(datum.lod >= minLod)) continue;
            if (datum.isKnown) {
                ++numKnown;
                if (!datum.isSNP) continue;
                if (datum.isTransition) {
                    ++knownTi;
                    continue;
                }
                ++knownTv;
                continue;
            }
            ++numNovel;
            if (!datum.isSNP) continue;
            if (datum.isTransition) {
                ++novelTi;
                continue;
            }
            ++novelTv;
        }
        double knownTiTv = (double)knownTi / Math.max(1.0 * (double)knownTv, 1.0);
        double novelTiTv = (double)novelTi / Math.max(1.0 * (double)novelTv, 1.0);
        int accessibleTruthSites = TrancheManager.countCallsAtTruth(data, Double.NEGATIVE_INFINITY);
        int nCallsAtTruth = TrancheManager.countCallsAtTruth(data, minLod);
        return new Tranche(ts, minLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, nCallsAtTruth, model);
    }

    public static double fdrToTiTv(double desiredFDR, double targetTiTv) {
        return (1.0 - desiredFDR / 100.0) * (targetTiTv - 0.5) + 0.5;
    }

    public static int countCallsAtTruth(List<VariantDatum> data, double minLOD) {
        int n = 0;
        for (VariantDatum d : data) {
            n += d.atTruthSite && d.lod >= minLOD ? 1 : 0;
        }
        return n;
    }

    public static class TruthSensitivityMetric
    extends SelectionMetric {
        double[] runningSensitivity;
        int nTrueSites = 0;

        public TruthSensitivityMetric(int nTrueSites) {
            super("TruthSensitivity");
            this.nTrueSites = nTrueSites;
        }

        @Override
        public double getThreshold(double tranche) {
            return 1.0 - tranche / 100.0;
        }

        @Override
        public double getTarget() {
            return 1.0;
        }

        @Override
        public void calculateRunningMetric(List<VariantDatum> data) {
            int nCalledAtTruth = 0;
            this.runningSensitivity = new double[data.size()];
            for (int i = data.size() - 1; i >= 0; --i) {
                VariantDatum datum = data.get(i);
                this.runningSensitivity[i] = 1.0 - (double)(nCalledAtTruth += datum.atTruthSite ? 1 : 0) / (1.0 * (double)this.nTrueSites);
            }
        }

        @Override
        public double getRunningMetric(int i) {
            return this.runningSensitivity[i];
        }

        @Override
        public int datumValue(VariantDatum d) {
            return d.atTruthSite ? 1 : 0;
        }
    }

    public static class NovelTiTvMetric
    extends SelectionMetric {
        double[] runningTiTv;
        double targetTiTv = 0.0;

        public NovelTiTvMetric(double target) {
            super("NovelTiTv");
            this.targetTiTv = target;
        }

        @Override
        public double getThreshold(double tranche) {
            return TrancheManager.fdrToTiTv(tranche, this.targetTiTv);
        }

        @Override
        public double getTarget() {
            return this.targetTiTv;
        }

        @Override
        public void calculateRunningMetric(List<VariantDatum> data) {
            int ti = 0;
            int tv = 0;
            this.runningTiTv = new double[data.size()];
            for (int i = data.size() - 1; i >= 0; --i) {
                VariantDatum datum = data.get(i);
                if (datum.isKnown) continue;
                if (datum.isTransition) {
                    ++ti;
                } else {
                    ++tv;
                }
                this.runningTiTv[i] = (double)ti / Math.max(1.0 * (double)tv, 1.0);
            }
        }

        @Override
        public double getRunningMetric(int i) {
            return this.runningTiTv[i];
        }

        @Override
        public int datumValue(VariantDatum d) {
            return d.isTransition ? 1 : 0;
        }
    }

    public static abstract class SelectionMetric {
        String name = null;

        public SelectionMetric(String name) {
            this.name = name;
        }

        public String getName() {
            return this.name;
        }

        public abstract double getThreshold(double var1);

        public abstract double getTarget();

        public abstract void calculateRunningMetric(List<VariantDatum> var1);

        public abstract double getRunningMetric(int var1);

        public abstract int datumValue(VariantDatum var1);
    }
}

