/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous;

import de.jstacs.algorithms.graphs.DAG;
import de.jstacs.algorithms.graphs.MST;
import de.jstacs.algorithms.graphs.tensor.SymmetricTensor;
import de.jstacs.algorithms.graphs.tensor.Tensor;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.ConstraintManager;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.CombinationIterator;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.InhCondProb;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.ArrayList;
import java.util.Arrays;

public class StructureLearner {
    private AlphabetContainer con;
    private int length;
    private double ess;
    private int[] alphabetLength;

    public StructureLearner(AlphabetContainer con, int length, double ess) throws IllegalArgumentException {
        if (!con.isDiscrete()) {
            throw new IllegalArgumentException("The instance of AlphabetContainer has to be discrete.");
        }
        int i = con.getPossibleLength();
        if (i != 0 && i != length) {
            throw new IllegalArgumentException("The instance of AlphabetContainer and length are not matching.");
        }
        this.con = con;
        this.length = length;
        this.alphabetLength = new int[length];
        i = 0;
        while (i < length) {
            this.alphabetLength[i] = (int)con.getAlphabetLengthAt(i);
            ++i;
        }
        this.setESS(ess);
    }

    public StructureLearner(AlphabetContainer con, int length) throws IllegalArgumentException {
        this(con, length, 0.0);
    }

    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    public double getEss() {
        return this.ess;
    }

    public void setESS(double ess) throws IllegalArgumentException {
        if (ess < 0.0) {
            throw new IllegalArgumentException("The value for ess has to be non-negative.");
        }
        this.ess = ess;
    }

    public int[][] getStructure(DataSet data, double[] weights, ModelType model, byte order, LearningType method) throws Exception {
        Object dep;
        if (order == 0) {
            dep = new int[this.length][1];
            int counter1 = 0;
            while (counter1 < this.length) {
                dep[counter1][0] = counter1;
                ++counter1;
            }
        } else if (model == ModelType.IMM) {
            int counter2;
            dep = new int[this.length][];
            int counter3 = 0;
            int counter1 = 1;
            while (counter1 <= order) {
                dep[counter3] = new int[counter1];
                counter2 = 0;
                while (counter2 < counter1) {
                    dep[counter3][counter2] = counter2;
                    ++counter2;
                }
                ++counter1;
                ++counter3;
            }
            counter1 = 0;
            while (counter3 < this.length) {
                dep[counter3] = new int[order + 1];
                counter2 = counter1;
                int idx = 0;
                while (counter2 <= counter3) {
                    dep[counter3][idx] = counter2++;
                    ++idx;
                }
                ++counter1;
                ++counter3;
            }
        } else {
            SymmetricTensor t = this.getTensor(data, weights, (byte)order, method);
            dep = StructureLearner.getStructure(t, model, (byte)order);
        }
        return dep;
    }

    public static int[][] getStructure(Tensor t, ModelType model, byte order) throws Exception {
        Object dep;
        int length = t.getNumberOfNodes();
        if (model == ModelType.BN) {
            if (order == 1) {
                int counter3;
                int counter2;
                double[][] w = new double[length][];
                int counter1 = 0;
                while (counter1 < w.length) {
                    w[counter1] = new double[length - 1 - counter1];
                    counter2 = counter1;
                    counter3 = 0;
                    ++counter2;
                    while (counter3 < w[counter1].length) {
                        w[counter1][counter3] = t.getValue(order, counter1, counter2++);
                        ++counter3;
                    }
                    ++counter1;
                }
                int[][] dep2 = MST.kruskal(w);
                dep = new int[length][];
                ArrayList<int[]> edges = new ArrayList<int[]>(dep2.length);
                counter3 = 0;
                while (counter3 < dep2.length) {
                    edges.add(dep2[counter3]);
                    ++counter3;
                }
                boolean[] used = new boolean[length];
                Arrays.fill(used, false);
                dep[0] = new int[1];
                used[0] = true;
                do {
                    counter3 = 0;
                    while (counter3 < edges.size()) {
                        int[] help = (int[])edges.get(counter3);
                        if (used[help[0]] || used[help[1]]) {
                            if (used[help[1]]) {
                                counter2 = help[1];
                                help[1] = help[0];
                                help[0] = counter2;
                            }
                            dep[help[1]] = (int[])edges.remove(counter3);
                            used[help[1]] = true;
                            continue;
                        }
                        ++counter3;
                    }
                } while (edges.size() > 0);
            } else {
                dep = DAG.computeMaximalKDAG(t);
            }
        } else {
            dep = DAG.getStructureFromPath(DAG.computeMaximalHP(t), t);
        }
        return dep;
    }

    private double[][] getSummands(DataSet data, double[] weights, byte order, LearningType method, double[] extra) throws IllegalArgumentException, WrongAlphabetException {
        int counter2;
        if (method == LearningType.BMA && this.ess == 0.0) {
            throw new IllegalArgumentException("The ESS has to be strict positive for BMA.");
        }
        InhCondProb[][] constr = new InhCondProb[order + 1][];
        ArrayList<InhCondProb> list = new ArrayList<InhCondProb>();
        CombinationIterator com = new CombinationIterator(this.length, (byte)(order + 1));
        byte counter1 = 1;
        byte counter3 = 0;
        while (counter3 <= order) {
            com.setCurrentLength(counter1);
            long l = com.getNumberOfCombinations(counter1);
            if (l > Integer.MAX_VALUE) {
                throw new IllegalArgumentException();
            }
            counter2 = (int)l;
            constr[counter3] = new InhCondProb[counter2];
            --counter2;
            while (counter2 >= 0) {
                constr[counter3][counter2] = new InhCondProb(com.getCombination(), this.alphabetLength, false);
                list.add(constr[counter3][counter2]);
                com.next();
                --counter2;
            }
            counter1 = (byte)(counter1 + 1);
            counter3 = (byte)(counter3 + 1);
        }
        double sum = ConstraintManager.countInhomogeneous(this.con, this.length, data, weights, true, list.toArray(new InhCondProb[0]));
        double all = sum + this.ess;
        double[][] h = new double[order + 1][];
        counter1 = 0;
        while (counter1 <= order) {
            h[counter1] = new double[constr[counter1].length];
            counter2 = 0;
            while (counter2 < h[counter1].length) {
                if (method == LearningType.ML_OR_MAP) {
                    constr[counter1][counter2].estimateUnConditional(this.ess, sum);
                    h[counter1][counter2] = all * ConstraintManager.getEntropy(constr[counter1][counter2]);
                    if (this.ess > 0.0) {
                        double help = constr[counter1][counter2].getNumberOfSpecificConstraints();
                        double[] dArray = h[counter1];
                        int n = counter2;
                        dArray[n] = dArray[n] - help * Gamma.logOfGamma(this.ess / help);
                    }
                } else {
                    h[counter1][counter2] = ConstraintManager.getLogGammaSum(constr[counter1][counter2], this.ess);
                }
                ++counter2;
            }
            counter1 = (byte)(counter1 + 1);
        }
        extra[0] = this.ess > 0.0 ? Gamma.logOfGamma(this.ess) : 0.0;
        if (method == LearningType.BMA) {
            extra[0] = extra[0] - Gamma.logOfGamma(all);
        }
        return h;
    }

    public SymmetricTensor getTensor(DataSet data, double[] weights, byte order, LearningType method) throws IllegalArgumentException, WrongAlphabetException {
        double[] extra = new double[1];
        return this.fillTensor(this.getSummands(data, weights, order, method, extra), order, extra[0]);
    }

    private SymmetricTensor fillTensor(double[][] summands, byte order, double extra) {
        SymmetricTensor t = new SymmetricTensor(this.length, (byte)order);
        byte counter3 = 0;
        CombinationIterator com = new CombinationIterator(this.length, (byte)(order + 1));
        boolean[] used = new boolean[this.length];
        counter3 = 1;
        while (counter3 <= order) {
            com.setCurrentLength(counter3);
            int[] parents = new int[counter3];
            long l = com.getNumberOfCombinations(counter3);
            if (l > Integer.MAX_VALUE) {
                throw new IllegalArgumentException();
            }
            int counter2 = (int)l;
            --counter2;
            int[] comb2 = new int[counter3 + 1];
            while (counter2 >= 0) {
                int[] comb = com.getCombination();
                Arrays.fill(used, false);
                int counter1 = 0;
                while (counter1 < counter3) {
                    parents[counter1] = comb[counter1];
                    used[parents[counter1]] = true;
                    ++counter1;
                }
                l = com.getIndex(comb);
                if (l > Integer.MAX_VALUE) {
                    throw new IllegalArgumentException();
                }
                int idx = (int)l;
                System.arraycopy(comb, 0, comb2, 1, counter3);
                byte swap = 0;
                counter1 = 0;
                while (counter1 < this.length) {
                    if (!used[counter1]) {
                        comb2[swap] = counter1;
                        while (swap < counter3 && comb2[swap] > comb2[swap + 1]) {
                            int help = comb2[swap];
                            comb2[swap++] = comb2[swap];
                            comb2[swap] = help;
                        }
                        l = com.getIndex(comb2);
                        if (l > Integer.MAX_VALUE) {
                            throw new IllegalArgumentException();
                        }
                        int ind = (int)l;
                        t.setValue(counter3, summands[0][this.length - 1 - counter1] + summands[counter3 - 1][idx] - summands[counter3][ind] - extra, counter1, parents);
                    }
                    ++counter1;
                }
                com.next();
                --counter2;
            }
            counter3 = (byte)(counter3 + 1);
        }
        return t;
    }

    public static enum LearningType {
        ML_OR_MAP,
        BMA;

    }

    public static enum ModelType {
        IMM,
        PMM,
        BN;

    }
}

