/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers;

import de.jstacs.DataType;
import de.jstacs.NotTrainedException;
import de.jstacs.classifiers.AbstractClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.classifiers.performanceMeasures.AbstractPerformanceMeasure;
import de.jstacs.classifiers.performanceMeasures.AbstractPerformanceMeasureParameterSet;
import de.jstacs.classifiers.performanceMeasures.PerformanceMeasure;
import de.jstacs.classifiers.utils.PValueComputation;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.ImageResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.utils.REnvironment;
import de.jstacs.utils.ToolBox;
import java.util.AbstractList;
import java.util.Arrays;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;

public abstract class AbstractScoreBasedClassifier
extends AbstractClassifier {
    private double[] classWeights;

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int classes) {
        this(abc, 0, classes, 0.0);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int classes, double classWeight) {
        this(abc, 0, classes, classWeight);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int length, int classes) {
        this(abc, length, classes, 0.0);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int length, int classes, double classWeight) throws IllegalArgumentException {
        super(abc, length);
        if (classes < 2) {
            throw new IllegalArgumentException("You should have at least 2 classes.");
        }
        this.createDefaultClassWeights(classes, classWeight);
    }

    public AbstractScoreBasedClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public AbstractScoreBasedClassifier clone() throws CloneNotSupportedException {
        AbstractScoreBasedClassifier erg = (AbstractScoreBasedClassifier)super.clone();
        erg.classWeights = (double[])this.classWeights.clone();
        return erg;
    }

    @Override
    public byte classify(Sequence seq) throws Exception {
        return this.classify(seq, true);
    }

    @Override
    protected double[][][] getMultiClassScores(DataSet[] s) throws Exception {
        int d = 0;
        while (d < s.length) {
            this.check(s[d]);
            ++d;
        }
        double[][][] scores = new double[this.getNumberOfClasses()][][];
        int d2 = 0;
        while (d2 < s.length) {
            scores[d2] = new double[s[d2].getNumberOfElements()][scores.length];
            int n = 0;
            while (n < scores[d2].length) {
                int c = 0;
                while (c < s.length) {
                    scores[d2][n][c] = this.getScore(s[d2].getElementAt(n), c, false);
                    ++c;
                }
                ++n;
            }
            ++d2;
        }
        return scores;
    }

    @Override
    protected boolean getResults(LinkedList list, DataSet[] s, double[][] weights, AbstractPerformanceMeasureParameterSet<? extends PerformanceMeasure> params, boolean exceptionIfNotComputeable) throws Exception {
        AbstractPerformanceMeasure[] m;
        if (s.length != 2) {
            return super.getResults(list, s, weights, params, exceptionIfNotComputeable);
        }
        if (s.length != this.getNumberOfClasses()) {
            throw new ClassDimensionException();
        }
        double[][] scores = new double[2][];
        double[][] w = new double[2][];
        int i = 0;
        while (i < s.length) {
            w[i] = (double[])(weights != null && weights[i] != null ? (double[])weights[i].clone() : null);
            scores[i] = this.getScores(s[i]);
            ToolBox.sortAlongWith(scores[i], new double[][]{w[i]});
            ++i;
        }
        boolean isNumeric = true;
        AbstractPerformanceMeasure[] abstractPerformanceMeasureArray = m = params.getAllMeasures();
        int n = m.length;
        int n2 = 0;
        while (n2 < n) {
            ResultSet r;
            AbstractPerformanceMeasure current;
            block10: {
                current = abstractPerformanceMeasureArray[n2];
                r = null;
                try {
                    r = current.compute(scores[0], w[0], scores[1], w[1]);
                }
                catch (Exception e) {
                    if (!exceptionIfNotComputeable) break block10;
                    throw e;
                }
            }
            if (r == null) {
                if (exceptionIfNotComputeable) {
                    throw new IllegalArgumentException("The measure \"" + current.getName() + "\" could not be evaluate with this classifier (" + this.getClass() + ").");
                }
            } else {
                isNumeric &= r instanceof NumericalResultSet;
                int j = 0;
                while (j < r.getNumberOfResults()) {
                    list.add(r.getResultAt(j));
                    ++j;
                }
            }
            ++n2;
        }
        return isNumeric;
    }

    public double[] getClassWeights() {
        return (double[])this.classWeights.clone();
    }

    @Override
    public int getNumberOfClasses() {
        return this.classWeights.length;
    }

    public double getScore(Sequence seq, int i) throws Exception {
        return this.getScore(seq, i, true);
    }

    public final void setClassWeights(boolean add, double ... weights) throws ClassDimensionException {
        int c = this.getNumberOfClasses();
        if (weights == null || c != weights.length) {
            throw new ClassDimensionException();
        }
        this.setClassWeights(add, weights, 0);
    }

    protected final void setClassWeights(boolean add, double[] weights, int start) {
        if (add) {
            int i = 0;
            while (i < this.classWeights.length) {
                int n = i;
                this.classWeights[n] = this.classWeights[n] + weights[start + i];
                ++i;
            }
        } else {
            int i = 0;
            while (i < this.classWeights.length) {
                this.classWeights[i] = weights[start + i];
                ++i;
            }
        }
    }

    public final void setThresholdClassWeights(boolean add, double t) throws OperationNotSupportedException {
        int c = this.getNumberOfClasses();
        if (c != 2) {
            throw new OperationNotSupportedException();
        }
        if (this.classWeights == null) {
            this.classWeights = new double[2];
        }
        double logP = -Math.log1p(Math.exp(t));
        if (add) {
            this.classWeights[0] = this.classWeights[0] + logP;
            this.classWeights[1] = this.classWeights[1] + (t + logP);
        } else {
            this.classWeights[0] = logP;
            this.classWeights[1] = t + logP;
        }
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = new StringBuffer(300);
        XMLParser.appendObjectWithTags(xml, this.classWeights, "classWeights");
        return xml;
    }

    protected void check(DataSet s) throws NotTrainedException, IllegalArgumentException {
        if (!this.isInitialized()) {
            throw new NotTrainedException("The classifier is not trained yet.");
        }
        int length = this.getLength();
        if (length != 0 && s.getElementLength() != length) {
            throw new IllegalArgumentException("The sequences have not the correct length.");
        }
        if (!this.getAlphabetContainer().checkConsistency(s.getAlphabetContainer())) {
            throw new IllegalArgumentException("The sequences are not defined over the correct alphabets.");
        }
    }

    protected void check(Sequence seq) throws NotTrainedException, IllegalArgumentException {
        if (!this.isInitialized()) {
            throw new NotTrainedException("The classifier is not trained yet.");
        }
        int length = this.getLength();
        if (length != 0 && seq.getLength() != length) {
            throw new IllegalArgumentException("The sequence has not the correct length.");
        }
        if (!this.getAlphabetContainer().checkConsistency(seq.getAlphabetContainer())) {
            throw new IllegalArgumentException("The sequence is not defined over the correct alphabets.");
        }
    }

    protected byte classify(Sequence seq, boolean check) throws Exception {
        if (check) {
            this.check(seq);
        }
        int clazz = 0;
        double max = this.getScore(seq, clazz, false);
        int i = 1;
        while (i < this.getNumberOfClasses()) {
            double current = this.getScore(seq, i, false);
            if (current > max) {
                max = current;
                clazz = i;
            }
            i = (byte)(i + 1);
        }
        return (byte)clazz;
    }

    protected void createDefaultClassWeights(int classes, double val) throws IllegalArgumentException {
        if (classes < 2) {
            throw new IllegalArgumentException();
        }
        this.classWeights = new double[classes];
        Arrays.fill(this.classWeights, val);
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        this.classWeights = XMLParser.extractObjectForTags(xml, "classWeights", double[].class);
    }

    protected double getClassWeight(int index) {
        return this.classWeights[index];
    }

    protected abstract double getScore(Sequence var1, int var2, boolean var3) throws IllegalArgumentException, NotTrainedException, Exception;

    public double[] getScores(DataSet s) throws Exception {
        if (this.classWeights.length != 2) {
            throw new OperationNotSupportedException("This method is only for 2-class-classifiers.");
        }
        if (s == null) {
            return new double[0];
        }
        this.check(s);
        double[] score = new double[s.getNumberOfElements()];
        DataSet.ElementEnumerator ei = new DataSet.ElementEnumerator(s);
        int i = 0;
        while (i < score.length) {
            Sequence seq = ei.nextElement();
            score[i] = this.getScore(seq, 0, false) - this.getScore(seq, 1, false);
            if (Double.isNaN(score[i])) {
                throw new IllegalArgumentException("Could not classify sequence " + i + ": " + seq + "\nfg: " + this.getScore(seq, 0, false) + "\nbg: " + this.getScore(seq, 1, false));
            }
            ++i;
        }
        return score;
    }

    public double getPValue(Sequence candidate, DataSet bg) throws Exception {
        double[] scores = this.createStatistic(bg);
        return PValueComputation.getPValue(scores, this.getScore(candidate, 0) - this.getScore(candidate, 1));
    }

    public double[] getPValue(DataSet candidates, DataSet bg) throws Exception {
        double[] scores = this.createStatistic(bg);
        double[] pVal = new double[candidates.getNumberOfElements()];
        int i = 0;
        while (i < pVal.length) {
            Sequence candidate = candidates.getElementAt(i);
            pVal[i] = PValueComputation.getPValue(scores, this.getScore(candidate, 0) - this.getScore(candidate, 1));
            ++i;
        }
        return pVal;
    }

    private double[] createStatistic(DataSet bg) throws Exception {
        double[] scores = this.getScores(bg);
        Arrays.sort(scores);
        return scores;
    }

    public static class DoubleTableResult
    extends Result {
        private double[][] content;

        @Override
        public String getXMLTag() {
            return "DoubleTableResult";
        }

        public DoubleTableResult(String name, String comment, AbstractList<double[]> list) {
            super(name, comment, DataType.LIST);
            this.content = new double[list.size()][];
            int i = 0;
            while (i < this.content.length) {
                this.content[i] = (double[])list.get(i).clone();
                ++i;
            }
        }

        public DoubleTableResult(StringBuffer representation) throws NonParsableException {
            super(representation);
        }

        @Override
        protected void extractFurtherInfos(StringBuffer xml) throws NonParsableException {
            this.content = XMLParser.extractObjectForTags(xml, "content", double[][].class);
        }

        public double[] getLine(int index) {
            return (double[])this.content[index].clone();
        }

        public int getNumberOfLines() {
            return this.content.length;
        }

        public String toString() {
            return "[table] \t " + this.name + " \t(" + this.comment + ")";
        }

        public double[][] getValue() {
            double[][] res = new double[this.content.length][];
            int i = 0;
            while (i < res.length) {
                res[i] = (double[])this.content[i].clone();
                ++i;
            }
            return res;
        }

        @Override
        protected void appendFurtherInfos(StringBuffer xml) {
            XMLParser.appendObjectWithTags(xml, this.content, "content");
        }

        public static final ImageResult plot(REnvironment e, DoubleTableResult ... dtr) throws Exception {
            String opt = dtr[0].name;
            int i = 1;
            while (i < dtr.length && dtr[i].name.equalsIgnoreCase(opt)) {
                ++i;
            }
            if (i != dtr.length) {
                opt = null;
            }
            return new ImageResult(opt, "This plot shows the " + opt + ".", e.plot(DoubleTableResult.getPlotCommands(e, opt, dtr).toString()));
        }

        public static final StringBuffer getPlotCommands(REnvironment e, String plotOptions, DoubleTableResult ... dtr) throws Exception {
            return DoubleTableResult.getPlotCommands(e, plotOptions, null, dtr);
        }

        public static final StringBuffer getPlotCommands(REnvironment e, String plotOptions, int[] colors, DoubleTableResult ... dtr) throws Exception {
            String[] col = new String[colors.length];
            int i = 0;
            while (i < col.length) {
                col[i] = "" + colors[i];
                ++i;
            }
            return DoubleTableResult.getPlotCommands(e, plotOptions, col, dtr);
        }

        public static final StringBuffer getPlotCommands(REnvironment e, String plotOptions, String[] colors, DoubleTableResult ... dtr) throws Exception {
            int i = 0;
            while (i < dtr.length) {
                e.createMatrix("dtr" + i, dtr[i].content);
                ++i;
            }
            if (plotOptions == null) {
                String string = plotOptions = dtr[0].name == null ? "" : dtr[0].name;
            }
            if (plotOptions.equals("Receiver Operating Characteristic curve")) {
                plotOptions = ", xlim=c(0, 1), ylim=c(0, 1), xlab=\"false positive rate\", ylab=\"sensitivity\", main=\"ROC curve\", lwd=3";
            } else if (plotOptions.equals("Precision-Recall curve")) {
                plotOptions = ", xlim=c(0, 1), ylim=c(0, 1), xlab=\"recall\", ylab=\"precision\", main=\"PR curve\", lwd=3";
            } else if ((plotOptions = plotOptions.trim()).charAt(0) != ',') {
                plotOptions = ", " + plotOptions;
            }
            StringBuffer p = new StringBuffer(dtr.length * 200);
            p.append("plot( 0:1,0:1, col=0, " + plotOptions + " );");
            i = 0;
            while (i < dtr.length) {
                p.append("\nlines( dtr" + i + "[,1], dtr" + i + "[,2], col=" + (colors == null || colors.length == 0 ? Integer.valueOf(++i) : "\"" + colors[i++] + "\"") + ", lwd=3 );");
            }
            return p;
        }
    }
}

