package weka.classifiers.trees;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.ContingencyTables;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TestInstances;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/trees/LADTree.class */
public class LADTree extends Classifier implements Drawable, AdditionalMeasureProducer, TechnicalInformationHandler {
    private static final long serialVersionUID = -4940716114518300302L;
    protected int m_numOfClasses;
    protected ReferenceInstances m_trainInstances;
    protected int[] m_numericAttIndices;
    protected double m_search_smallestLeastSquares;
    protected PredictionNode m_search_bestInsertionNode;
    protected Splitter m_search_bestSplitter;
    protected Instances m_search_bestPathInstances;
    protected FastVector m_staticPotentialSplitters2way;
    protected double Z_MAX = 4.0d;
    protected PredictionNode m_root = null;
    protected int m_lastAddedSplitNum = 0;
    protected int m_nodesExpanded = 0;
    protected int m_examplesCounted = 0;
    protected int m_boostingIterations = 10;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/LADTree$LADInstance.class */
    public class LADInstance extends Instance {
        public double[] fVector;
        public double[] wVector;
        public double[] pVector;
        public double[] zVector;

        public LADInstance(Instance instance) {
            super(instance);
            setDataset(instance.dataset());
            this.fVector = new double[LADTree.this.m_numOfClasses];
            this.wVector = new double[LADTree.this.m_numOfClasses];
            this.pVector = new double[LADTree.this.m_numOfClasses];
            this.zVector = new double[LADTree.this.m_numOfClasses];
            double d = 1.0d / LADTree.this.m_numOfClasses;
            for (int i = 0; i < LADTree.this.m_numOfClasses; i++) {
                this.pVector[i] = d;
            }
            updateZVector();
            updateWVector();
        }

        public void updateWeights(double[] dArr) {
            for (int i = 0; i < this.fVector.length; i++) {
                double[] dArr2 = this.fVector;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr[i];
            }
            updateVectors(this.fVector);
        }

        public void updateVectors(double[] dArr) {
            updatePVector(dArr);
            updateZVector();
            updateWVector();
        }

        public void updatePVector(double[] dArr) {
            double d = dArr[Utils.maxIndex(dArr)];
            for (int i = 0; i < this.pVector.length; i++) {
                this.pVector[i] = Math.exp(dArr[i] - d);
            }
            Utils.normalize(this.pVector);
        }

        public void updateWVector() {
            for (int i = 0; i < this.wVector.length; i++) {
                this.wVector[i] = (yVector(i) - this.pVector[i]) / this.zVector[i];
            }
        }

        public void updateZVector() {
            for (int i = 0; i < this.zVector.length; i++) {
                if (yVector(i) == 1.0d) {
                    this.zVector[i] = 1.0d / this.pVector[i];
                    if (this.zVector[i] > LADTree.this.Z_MAX) {
                        this.zVector[i] = LADTree.this.Z_MAX;
                    }
                } else {
                    this.zVector[i] = (-1.0d) / (1.0d - this.pVector[i]);
                    if (this.zVector[i] < (-LADTree.this.Z_MAX)) {
                        this.zVector[i] = -LADTree.this.Z_MAX;
                    }
                }
            }
        }

        public double yVector(int i) {
            if (i == ((int) classValue())) {
                return 1.0d;
            }
            return KStarConstants.FLOOR;
        }

        @Override // weka.core.Instance, weka.core.Copyable
        public Object copy() {
            LADInstance lADInstance = new LADInstance((Instance) super.copy());
            System.arraycopy(this.fVector, 0, lADInstance.fVector, 0, this.fVector.length);
            System.arraycopy(this.wVector, 0, lADInstance.wVector, 0, this.wVector.length);
            System.arraycopy(this.pVector, 0, lADInstance.pVector, 0, this.pVector.length);
            System.arraycopy(this.zVector, 0, lADInstance.zVector, 0, this.zVector.length);
            return lADInstance;
        }

        @Override // weka.core.Instance
        public String toString() {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(" * F(");
            for (int i = 0; i < this.fVector.length; i++) {
                stringBuffer.append(Utils.doubleToString(this.fVector[i], 3));
                if (i < this.fVector.length - 1) {
                    stringBuffer.append(",");
                }
            }
            stringBuffer.append(") P(");
            for (int i2 = 0; i2 < this.pVector.length; i2++) {
                stringBuffer.append(Utils.doubleToString(this.pVector[i2], 3));
                if (i2 < this.pVector.length - 1) {
                    stringBuffer.append(",");
                }
            }
            stringBuffer.append(") W(");
            for (int i3 = 0; i3 < this.wVector.length; i3++) {
                stringBuffer.append(Utils.doubleToString(this.wVector[i3], 3));
                if (i3 < this.wVector.length - 1) {
                    stringBuffer.append(",");
                }
            }
            stringBuffer.append(")");
            return super.toString() + stringBuffer.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/LADTree$PredictionNode.class */
    public class PredictionNode implements Serializable, Cloneable {
        private double[] values;
        private FastVector children;

        public PredictionNode(double[] dArr) {
            this.values = new double[LADTree.this.m_numOfClasses];
            setValues(dArr);
            this.children = new FastVector();
        }

        public void setValues(double[] dArr) {
            System.arraycopy(dArr, 0, this.values, 0, LADTree.this.m_numOfClasses);
        }

        public double[] getValues() {
            return this.values;
        }

        public FastVector getChildren() {
            return this.children;
        }

        public Enumeration children() {
            return this.children.elements();
        }

        public void addChild(Splitter splitter) {
            Splitter splitter2 = null;
            Enumeration children = children();
            while (true) {
                if (!children.hasMoreElements()) {
                    break;
                }
                Splitter splitter3 = (Splitter) children.nextElement();
                if (splitter.equalTo(splitter3)) {
                    splitter2 = splitter3;
                    break;
                }
            }
            if (splitter2 == null) {
                Splitter splitter4 = (Splitter) splitter.clone();
                LADTree lADTree = LADTree.this;
                int i = lADTree.m_lastAddedSplitNum + 1;
                lADTree.m_lastAddedSplitNum = i;
                splitter4.orderAdded = i;
                this.children.addElement(splitter4);
                return;
            }
            for (int i2 = 0; i2 < splitter.getNumOfBranches(); i2++) {
                PredictionNode childForBranch = splitter2.getChildForBranch(i2);
                PredictionNode childForBranch2 = splitter.getChildForBranch(i2);
                if (childForBranch != null && childForBranch2 != null) {
                    childForBranch.merge(childForBranch2);
                }
            }
        }

        public Object clone() {
            PredictionNode predictionNode = new PredictionNode(this.values);
            Enumeration elements = this.children.elements();
            while (elements.hasMoreElements()) {
                predictionNode.children.addElement((Splitter) ((Splitter) elements.nextElement()).clone());
            }
            return predictionNode;
        }

        public void merge(PredictionNode predictionNode) {
            for (int i = 0; i < LADTree.this.m_numOfClasses; i++) {
                double[] dArr = this.values;
                int i2 = i;
                dArr[i2] = dArr[i2] + predictionNode.values[i];
            }
            Enumeration children = predictionNode.children();
            while (children.hasMoreElements()) {
                addChild((Splitter) children.nextElement());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/LADTree$Splitter.class */
    public abstract class Splitter implements Serializable, Cloneable {
        protected int attIndex;
        public int orderAdded;

        protected Splitter() {
        }

        public abstract int getNumOfBranches();

        public abstract int branchInstanceGoesDown(Instance instance);

        public abstract Instances instancesDownBranch(int i, Instances instances);

        public abstract String attributeString();

        public abstract String comparisonString(int i);

        public abstract boolean equalTo(Splitter splitter);

        public abstract void setChildForBranch(int i, PredictionNode predictionNode);

        public abstract PredictionNode getChildForBranch(int i);

        public abstract Object clone();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/LADTree$TwoWayNominalSplit.class */
    public class TwoWayNominalSplit extends Splitter {
        private int trueSplitValue;
        private PredictionNode[] children;

        public TwoWayNominalSplit(int i, int i2) {
            super();
            this.attIndex = i;
            this.trueSplitValue = i2;
            this.children = new PredictionNode[2];
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public int getNumOfBranches() {
            return 2;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public int branchInstanceGoesDown(Instance instance) {
            if (instance.isMissing(this.attIndex)) {
                return -1;
            }
            return instance.value(this.attIndex) == ((double) this.trueSplitValue) ? 0 : 1;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public Instances instancesDownBranch(int i, Instances instances) {
            ReferenceInstances referenceInstances = new ReferenceInstances(instances, 1);
            if (i == -1) {
                Enumeration enumerateInstances = instances.enumerateInstances();
                while (enumerateInstances.hasMoreElements()) {
                    Instance instance = (Instance) enumerateInstances.nextElement();
                    if (instance.isMissing(this.attIndex)) {
                        referenceInstances.addReference(instance);
                    }
                }
            } else if (i == 0) {
                Enumeration enumerateInstances2 = instances.enumerateInstances();
                while (enumerateInstances2.hasMoreElements()) {
                    Instance instance2 = (Instance) enumerateInstances2.nextElement();
                    if (!instance2.isMissing(this.attIndex) && instance2.value(this.attIndex) == this.trueSplitValue) {
                        referenceInstances.addReference(instance2);
                    }
                }
            } else {
                Enumeration enumerateInstances3 = instances.enumerateInstances();
                while (enumerateInstances3.hasMoreElements()) {
                    Instance instance3 = (Instance) enumerateInstances3.nextElement();
                    if (!instance3.isMissing(this.attIndex) && instance3.value(this.attIndex) != this.trueSplitValue) {
                        referenceInstances.addReference(instance3);
                    }
                }
            }
            return referenceInstances;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public String attributeString() {
            return LADTree.this.m_trainInstances.attribute(this.attIndex).name();
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public String comparisonString(int i) {
            String value;
            Attribute attribute = LADTree.this.m_trainInstances.attribute(this.attIndex);
            if (attribute.numValues() != 2) {
                return (i == 0 ? "= " : "!= ") + attribute.value(this.trueSplitValue);
            }
            StringBuilder append = new StringBuilder().append("= ");
            if (i == 0) {
                value = attribute.value(this.trueSplitValue);
            } else {
                value = attribute.value(this.trueSplitValue == 0 ? 1 : 0);
            }
            return append.append(value).toString();
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public boolean equalTo(Splitter splitter) {
            if (!(splitter instanceof TwoWayNominalSplit)) {
                return false;
            }
            TwoWayNominalSplit twoWayNominalSplit = (TwoWayNominalSplit) splitter;
            return this.attIndex == twoWayNominalSplit.attIndex && this.trueSplitValue == twoWayNominalSplit.trueSplitValue;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public void setChildForBranch(int i, PredictionNode predictionNode) {
            this.children[i] = predictionNode;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public PredictionNode getChildForBranch(int i) {
            return this.children[i];
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public Object clone() {
            TwoWayNominalSplit twoWayNominalSplit = new TwoWayNominalSplit(this.attIndex, this.trueSplitValue);
            if (this.children[0] != null) {
                twoWayNominalSplit.setChildForBranch(0, (PredictionNode) this.children[0].clone());
            }
            if (this.children[1] != null) {
                twoWayNominalSplit.setChildForBranch(1, (PredictionNode) this.children[1].clone());
            }
            return twoWayNominalSplit;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/LADTree$TwoWayNumericSplit.class */
    public class TwoWayNumericSplit extends Splitter implements Cloneable {
        private double splitPoint;
        private PredictionNode[] children;

        public TwoWayNumericSplit(int i, double d) {
            super();
            this.attIndex = i;
            this.splitPoint = d;
            this.children = new PredictionNode[2];
        }

        public TwoWayNumericSplit(int i, Instances instances) throws Exception {
            super();
            this.attIndex = i;
            this.splitPoint = findSplit(instances, this.attIndex);
            this.children = new PredictionNode[2];
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public int getNumOfBranches() {
            return 2;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public int branchInstanceGoesDown(Instance instance) {
            if (instance.isMissing(this.attIndex)) {
                return -1;
            }
            return instance.value(this.attIndex) < this.splitPoint ? 0 : 1;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public Instances instancesDownBranch(int i, Instances instances) {
            ReferenceInstances referenceInstances = new ReferenceInstances(instances, 1);
            if (i == -1) {
                Enumeration enumerateInstances = instances.enumerateInstances();
                while (enumerateInstances.hasMoreElements()) {
                    Instance instance = (Instance) enumerateInstances.nextElement();
                    if (instance.isMissing(this.attIndex)) {
                        referenceInstances.addReference(instance);
                    }
                }
            } else if (i == 0) {
                Enumeration enumerateInstances2 = instances.enumerateInstances();
                while (enumerateInstances2.hasMoreElements()) {
                    Instance instance2 = (Instance) enumerateInstances2.nextElement();
                    if (!instance2.isMissing(this.attIndex) && instance2.value(this.attIndex) < this.splitPoint) {
                        referenceInstances.addReference(instance2);
                    }
                }
            } else {
                Enumeration enumerateInstances3 = instances.enumerateInstances();
                while (enumerateInstances3.hasMoreElements()) {
                    Instance instance3 = (Instance) enumerateInstances3.nextElement();
                    if (!instance3.isMissing(this.attIndex) && instance3.value(this.attIndex) >= this.splitPoint) {
                        referenceInstances.addReference(instance3);
                    }
                }
            }
            return referenceInstances;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public String attributeString() {
            return LADTree.this.m_trainInstances.attribute(this.attIndex).name();
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public String comparisonString(int i) {
            return (i == 0 ? "< " : ">= ") + Utils.doubleToString(this.splitPoint, 3);
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public boolean equalTo(Splitter splitter) {
            if (!(splitter instanceof TwoWayNumericSplit)) {
                return false;
            }
            TwoWayNumericSplit twoWayNumericSplit = (TwoWayNumericSplit) splitter;
            return this.attIndex == twoWayNumericSplit.attIndex && this.splitPoint == twoWayNumericSplit.splitPoint;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public void setChildForBranch(int i, PredictionNode predictionNode) {
            this.children[i] = predictionNode;
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public PredictionNode getChildForBranch(int i) {
            return this.children[i];
        }

        @Override // weka.classifiers.trees.LADTree.Splitter
        public Object clone() {
            TwoWayNumericSplit twoWayNumericSplit = new TwoWayNumericSplit(this.attIndex, this.splitPoint);
            if (this.children[0] != null) {
                twoWayNumericSplit.setChildForBranch(0, (PredictionNode) this.children[0].clone());
            }
            if (this.children[1] != null) {
                twoWayNumericSplit.setChildForBranch(1, (PredictionNode) this.children[1].clone());
            }
            return twoWayNumericSplit;
        }

        private double findSplit(Instances instances, int i) throws Exception {
            double d = 0.0d;
            double d2 = Double.MAX_VALUE;
            int i2 = 0;
            double[][] dArr = new double[3][instances.numClasses()];
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                Instance instance = instances.instance(i3);
                if (instance.isMissing(i)) {
                    double[] dArr2 = dArr[2];
                    int classValue = (int) instance.classValue();
                    dArr2[classValue] = dArr2[classValue] + 1.0d;
                    i2++;
                } else {
                    double[] dArr3 = dArr[1];
                    int classValue2 = (int) instance.classValue();
                    dArr3[classValue2] = dArr3[classValue2] + 1.0d;
                }
            }
            instances.sort(i);
            for (int i4 = 0; i4 < instances.numInstances() - (i2 + 1); i4++) {
                Instance instance2 = instances.instance(i4);
                Instance instance3 = instances.instance(i4 + 1);
                double[] dArr4 = dArr[0];
                int classValue3 = (int) instance2.classValue();
                dArr4[classValue3] = dArr4[classValue3] + instance2.weight();
                double[] dArr5 = dArr[1];
                int classValue4 = (int) instance2.classValue();
                dArr5[classValue4] = dArr5[classValue4] - instance2.weight();
                if (Utils.sm(instance2.value(i), instance3.value(i))) {
                    double value = (instance2.value(i) + instance3.value(i)) / 2.0d;
                    double entropyConditionedOnRows = ContingencyTables.entropyConditionedOnRows(dArr);
                    if (Utils.sm(entropyConditionedOnRows, d2)) {
                        d = value;
                        d2 = entropyConditionedOnRows;
                    }
                }
            }
            return d;
        }
    }

    public String globalInfo() {
        return "Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Multiclass alternating decision trees");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "ECML");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "161-172");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Springer");
        return technicalInformation;
    }

    public void initClassifier(Instances instances) throws Exception {
        this.m_nodesExpanded = 0;
        this.m_examplesCounted = 0;
        this.m_lastAddedSplitNum = 0;
        this.m_numOfClasses = instances.numClasses();
        if (instances.checkForStringAttributes()) {
            throw new Exception("Can't handle string attributes!");
        }
        if (!instances.classAttribute().isNominal()) {
            throw new Exception("Class must be nominal!");
        }
        this.m_trainInstances = new ReferenceInstances(instances, instances.numInstances());
        Enumeration enumerateInstances = instances.enumerateInstances();
        while (enumerateInstances.hasMoreElements()) {
            Instance instance = (Instance) enumerateInstances.nextElement();
            if (!instance.classIsMissing()) {
                LADInstance lADInstance = new LADInstance(instance);
                this.m_trainInstances.addReference(lADInstance);
                lADInstance.setDataset(this.m_trainInstances);
            }
        }
        this.m_root = new PredictionNode(new double[this.m_numOfClasses]);
        generateStaticPotentialSplittersAndNumericIndices();
    }

    public void next(int i) throws Exception {
        boost();
    }

    public void done() throws Exception {
    }

    private void boost() throws Exception {
        if (this.m_trainInstances == null) {
            throw new Exception("Trying to boost with no training data");
        }
        searchForBestTest();
        if (this.m_Debug) {
            System.out.println("Best split found: " + this.m_search_bestSplitter.getNumOfBranches() + "-way split on " + this.m_search_bestSplitter.attributeString() + "\nBestGain = " + this.m_search_smallestLeastSquares);
        }
        if (this.m_search_bestSplitter == null) {
            return;
        }
        for (int i = 0; i < this.m_search_bestSplitter.getNumOfBranches(); i++) {
            Instances instancesDownBranch = this.m_search_bestSplitter.instancesDownBranch(i, this.m_search_bestPathInstances);
            double[] calcPredictionValues = calcPredictionValues(instancesDownBranch);
            PredictionNode predictionNode = new PredictionNode(calcPredictionValues);
            updateWeights(instancesDownBranch, calcPredictionValues);
            this.m_search_bestSplitter.setChildForBranch(i, predictionNode);
        }
        this.m_search_bestInsertionNode.addChild(this.m_search_bestSplitter);
        if (this.m_Debug) {
            System.out.println("Tree is now:\n" + toString(this.m_root, 1) + "\n");
        }
        this.m_search_bestPathInstances = null;
    }

    private void updateWeights(Instances instances, double[] dArr) {
        for (int i = 0; i < instances.numInstances(); i++) {
            ((LADInstance) instances.instance(i)).updateWeights(dArr);
        }
    }

    private void generateStaticPotentialSplittersAndNumericIndices() {
        this.m_staticPotentialSplitters2way = new FastVector();
        FastVector fastVector = new FastVector();
        for (int i = 0; i < this.m_trainInstances.numAttributes(); i++) {
            if (i != this.m_trainInstances.classIndex()) {
                if (this.m_trainInstances.attribute(i).isNumeric()) {
                    fastVector.addElement(new Integer(i));
                } else {
                    int numValues = this.m_trainInstances.attribute(i).numValues();
                    if (numValues == 2) {
                        this.m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, 0));
                    } else {
                        for (int i2 = 0; i2 < numValues; i2++) {
                            this.m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, i2));
                        }
                    }
                }
            }
        }
        this.m_numericAttIndices = new int[fastVector.size()];
        for (int i3 = 0; i3 < fastVector.size(); i3++) {
            this.m_numericAttIndices[i3] = ((Integer) fastVector.elementAt(i3)).intValue();
        }
    }

    private void searchForBestTest() throws Exception {
        if (this.m_Debug) {
            System.out.println("Searching for best split...");
        }
        this.m_search_smallestLeastSquares = KStarConstants.FLOOR;
        searchForBestTest(this.m_root, this.m_trainInstances);
    }

    private void searchForBestTest(PredictionNode predictionNode, Instances instances) throws Exception {
        this.m_nodesExpanded++;
        this.m_examplesCounted += instances.numInstances();
        Enumeration elements = this.m_staticPotentialSplitters2way.elements();
        while (elements.hasMoreElements()) {
            evaluateSplitter((Splitter) elements.nextElement(), predictionNode, instances);
        }
        if (this.m_Debug) {
        }
        for (int i = 0; i < this.m_numericAttIndices.length; i++) {
            evaluateNumericSplit(predictionNode, instances, this.m_numericAttIndices[i]);
        }
        if (predictionNode.getChildren().size() == 0) {
            return;
        }
        goDownAllPaths(predictionNode, instances);
    }

    private void goDownAllPaths(PredictionNode predictionNode, Instances instances) throws Exception {
        Enumeration children = predictionNode.children();
        while (children.hasMoreElements()) {
            Splitter splitter = (Splitter) children.nextElement();
            for (int i = 0; i < splitter.getNumOfBranches(); i++) {
                searchForBestTest(splitter.getChildForBranch(i), splitter.instancesDownBranch(i, instances));
            }
        }
    }

    private void evaluateSplitter(Splitter splitter, PredictionNode predictionNode, Instances instances) throws Exception {
        double leastSquaresNonMissing = leastSquaresNonMissing(instances, splitter.attIndex);
        for (int i = 0; i < splitter.getNumOfBranches(); i++) {
            leastSquaresNonMissing -= leastSquares(splitter.instancesDownBranch(i, instances));
        }
        if (this.m_Debug) {
            System.out.print(splitter.getNumOfBranches() + "-way split on " + splitter.attributeString() + " has leastSquares value of " + Utils.doubleToString(leastSquaresNonMissing, 3));
        }
        if (leastSquaresNonMissing > this.m_search_smallestLeastSquares) {
            if (this.m_Debug) {
                System.out.print(" (best so far)");
            }
            this.m_search_smallestLeastSquares = leastSquaresNonMissing;
            this.m_search_bestInsertionNode = predictionNode;
            this.m_search_bestSplitter = splitter;
            this.m_search_bestPathInstances = instances;
        }
        if (this.m_Debug) {
            System.out.print("\n");
        }
    }

    private void evaluateNumericSplit(PredictionNode predictionNode, Instances instances, int i) {
        double[] findNumericSplitpointAndLS = findNumericSplitpointAndLS(instances, i);
        double leastSquaresNonMissing = leastSquaresNonMissing(instances, i) - findNumericSplitpointAndLS[1];
        if (this.m_Debug) {
            System.out.print("Numeric split on " + instances.attribute(i).name() + " has leastSquares value of " + Utils.doubleToString(leastSquaresNonMissing, 3));
        }
        if (leastSquaresNonMissing > this.m_search_smallestLeastSquares) {
            if (this.m_Debug) {
                System.out.print(" (best so far)");
            }
            this.m_search_smallestLeastSquares = leastSquaresNonMissing;
            this.m_search_bestInsertionNode = predictionNode;
            this.m_search_bestSplitter = new TwoWayNumericSplit(i, findNumericSplitpointAndLS[0]);
            this.m_search_bestPathInstances = instances;
        }
        if (this.m_Debug) {
            System.out.print("\n");
        }
    }

    private double[] findNumericSplitpointAndLS(Instances instances, int i) {
        double leastSquares = leastSquares(instances);
        double[] dArr = new double[this.m_numOfClasses];
        double[] dArr2 = new double[this.m_numOfClasses];
        double[] dArr3 = new double[this.m_numOfClasses];
        double[] dArr4 = new double[this.m_numOfClasses];
        double[] dArr5 = new double[this.m_numOfClasses];
        double[] dArr6 = new double[this.m_numOfClasses];
        double[] dArr7 = new double[this.m_numOfClasses];
        double[] dArr8 = new double[this.m_numOfClasses];
        double[] dArr9 = new double[this.m_numOfClasses];
        double[] dArr10 = new double[this.m_numOfClasses];
        double[] dArr11 = new double[this.m_numOfClasses];
        double[] dArr12 = new double[this.m_numOfClasses];
        for (int i2 = 0; i2 < this.m_numOfClasses; i2++) {
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                LADInstance lADInstance = (LADInstance) instances.instance(i3);
                double d = lADInstance.wVector[i2] * lADInstance.zVector[i2];
                int i4 = i2;
                dArr6[i4] = dArr6[i4] + (d * lADInstance.zVector[i2]);
                int i5 = i2;
                dArr7[i5] = dArr7[i5] + d;
                int i6 = i2;
                dArr8[i6] = dArr8[i6] + lADInstance.wVector[i2];
                int i7 = i2;
                dArr9[i7] = dArr9[i7] + (lADInstance.wVector[i2] * lADInstance.zVector[i2]);
            }
        }
        double d2 = Double.POSITIVE_INFINITY;
        double d3 = 0.0d;
        instances.sort(i);
        for (int i8 = 0; i8 < instances.numInstances() - 1 && !instances.instance(i8 + 1).isMissing(i); i8++) {
            boolean z = instances.instance(i8 + 1).value(i) > instances.instance(i8).value(i);
            LADInstance lADInstance2 = (LADInstance) instances.instance(i8);
            double d4 = 0.0d;
            for (int i9 = 0; i9 < this.m_numOfClasses; i9++) {
                double d5 = lADInstance2.wVector[i9] * lADInstance2.zVector[i9];
                double d6 = d5 * lADInstance2.zVector[i9];
                double d7 = lADInstance2.wVector[i9] * lADInstance2.zVector[i9];
                int i10 = i9;
                dArr[i10] = dArr[i10] + d6;
                int i11 = i9;
                dArr2[i11] = dArr2[i11] + d5;
                int i12 = i9;
                dArr3[i12] = dArr3[i12] + lADInstance2.wVector[i9];
                int i13 = i9;
                dArr6[i13] = dArr6[i13] - d6;
                int i14 = i9;
                dArr7[i14] = dArr7[i14] - d5;
                int i15 = i9;
                dArr8[i15] = dArr8[i15] - lADInstance2.wVector[i9];
                int i16 = i9;
                dArr4[i16] = dArr4[i16] + d7;
                int i17 = i9;
                dArr9[i17] = dArr9[i17] - d7;
                if (z) {
                    double d8 = dArr4[i9] / dArr3[i9];
                    double d9 = dArr9[i9] / dArr8[i9];
                    d4 = d4 + (dArr[i9] - ((2.0d * d8) * dArr2[i9])) + (d8 * d8 * dArr3[i9]) + (dArr6[i9] - ((2.0d * d9) * dArr7[i9])) + (d9 * d9 * dArr8[i9]);
                }
            }
            if (this.m_Debug && z) {
                System.out.println(i + "/" + ((instances.instance(i8).value(i) + instances.instance(i8 + 1).value(i)) / 2.0d) + " = " + (leastSquares - d4));
            }
            if (z && d4 < d2) {
                d3 = (instances.instance(i8).value(i) + instances.instance(i8 + 1).value(i)) / 2.0d;
                d2 = d4;
            }
        }
        double[] dArr13 = new double[2];
        dArr13[0] = d3;
        dArr13[1] = d2 > KStarConstants.FLOOR ? d2 : KStarConstants.FLOOR;
        return dArr13;
    }

    private double leastSquares(Instances instances) {
        double d = 0.0d;
        double d2 = 0.0d;
        double[] dArr = new double[this.m_numOfClasses];
        double[] dArr2 = new double[this.m_numOfClasses];
        for (int i = 0; i < instances.numInstances(); i++) {
            LADInstance lADInstance = (LADInstance) instances.instance(i);
            for (int i2 = 0; i2 < this.m_numOfClasses; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (lADInstance.zVector[i2] * lADInstance.wVector[i2]);
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + lADInstance.wVector[i2];
            }
        }
        instances.numInstances();
        for (int i5 = 0; i5 < this.m_numOfClasses; i5++) {
            if (dArr2[i5] != KStarConstants.FLOOR) {
                int i6 = i5;
                dArr[i6] = dArr[i6] / dArr2[i5];
            }
        }
        for (int i7 = 0; i7 < instances.numInstances(); i7++) {
            for (int i8 = 0; i8 < this.m_numOfClasses; i8++) {
                LADInstance lADInstance2 = (LADInstance) instances.instance(i7);
                double d3 = lADInstance2.wVector[i8];
                double d4 = lADInstance2.zVector[i8] - dArr[i8];
                d += d3 * d4 * d4;
                d2 += d3;
            }
        }
        return d > KStarConstants.FLOOR ? d : KStarConstants.FLOOR;
    }

    private double leastSquaresNonMissing(Instances instances, int i) {
        double d = 0.0d;
        double d2 = 0.0d;
        double[] dArr = new double[this.m_numOfClasses];
        double[] dArr2 = new double[this.m_numOfClasses];
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            LADInstance lADInstance = (LADInstance) instances.instance(i2);
            for (int i3 = 0; i3 < this.m_numOfClasses; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (lADInstance.zVector[i3] * lADInstance.wVector[i3]);
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + lADInstance.wVector[i3];
            }
        }
        instances.numInstances();
        for (int i6 = 0; i6 < this.m_numOfClasses; i6++) {
            if (dArr2[i6] != KStarConstants.FLOOR) {
                int i7 = i6;
                dArr[i7] = dArr[i7] / dArr2[i6];
            }
        }
        for (int i8 = 0; i8 < instances.numInstances(); i8++) {
            for (int i9 = 0; i9 < this.m_numOfClasses; i9++) {
                LADInstance lADInstance2 = (LADInstance) instances.instance(i8);
                if (!lADInstance2.isMissing(i)) {
                    double d3 = lADInstance2.wVector[i9];
                    double d4 = lADInstance2.zVector[i9] - dArr[i9];
                    d += d3 * d4 * d4;
                    d2 += d3;
                }
            }
        }
        return d > KStarConstants.FLOOR ? d : KStarConstants.FLOOR;
    }

    private double[] calcPredictionValues(Instances instances) {
        double[] dArr = new double[this.m_numOfClasses];
        double d = 0.0d;
        double d2 = (this.m_numOfClasses - 1) / this.m_numOfClasses;
        double[] dArr2 = new double[this.m_numOfClasses];
        for (int i = 0; i < instances.numInstances(); i++) {
            LADInstance lADInstance = (LADInstance) instances.instance(i);
            for (int i2 = 0; i2 < this.m_numOfClasses; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (lADInstance.zVector[i2] * lADInstance.wVector[i2]);
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + lADInstance.wVector[i2];
            }
        }
        instances.numInstances();
        for (int i5 = 0; i5 < this.m_numOfClasses; i5++) {
            if (dArr2[i5] != KStarConstants.FLOOR) {
                int i6 = i5;
                dArr[i6] = dArr[i6] / dArr2[i5];
            }
            d += dArr[i5];
        }
        double d3 = d / this.m_numOfClasses;
        for (int i7 = 0; i7 < this.m_numOfClasses; i7++) {
            dArr[i7] = d2 * (dArr[i7] - d3);
        }
        return dArr;
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) {
        double[] dArr = new double[this.m_numOfClasses];
        for (int i = 0; i < this.m_numOfClasses; i++) {
            dArr[i] = 0.0d;
        }
        double[] predictionValuesForInstance = predictionValuesForInstance(instance, this.m_root, dArr);
        double d = predictionValuesForInstance[Utils.maxIndex(predictionValuesForInstance)];
        for (int i2 = 0; i2 < this.m_numOfClasses; i2++) {
            predictionValuesForInstance[i2] = Math.exp(predictionValuesForInstance[i2] - d);
        }
        double sum = Utils.sum(predictionValuesForInstance);
        if (sum > KStarConstants.FLOOR) {
            Utils.normalize(predictionValuesForInstance, sum);
        }
        return predictionValuesForInstance;
    }

    private double[] predictionValuesForInstance(Instance instance, PredictionNode predictionNode, double[] dArr) {
        double[] values = predictionNode.getValues();
        for (int i = 0; i < this.m_numOfClasses; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + values[i];
        }
        Enumeration children = predictionNode.children();
        while (children.hasMoreElements()) {
            Splitter splitter = (Splitter) children.nextElement();
            int branchInstanceGoesDown = splitter.branchInstanceGoesDown(instance);
            if (branchInstanceGoesDown >= 0) {
                dArr = predictionValuesForInstance(instance, splitter.getChildForBranch(branchInstanceGoesDown), dArr);
            }
        }
        return dArr;
    }

    public String toString() {
        String name = getClass().getName();
        return this.m_root == null ? name + " not built yet" : name + ":\n\n" + toString(this.m_root, 1) + "\nLegend: " + legend() + "\n#Tree size (total): " + numOfAllNodes(this.m_root) + "\n#Tree size (number of predictor nodes): " + numOfPredictionNodes(this.m_root) + "\n#Leaves (number of predictor nodes): " + numOfLeafNodes(this.m_root) + "\n#Expanded nodes: " + this.m_nodesExpanded + "\n#Processed examples: " + this.m_examplesCounted + "\n#Ratio e/n: " + (this.m_examplesCounted / this.m_nodesExpanded);
    }

    private String toString(PredictionNode predictionNode, int i) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(": ");
        double[] values = predictionNode.getValues();
        for (int i2 = 0; i2 < this.m_numOfClasses; i2++) {
            stringBuffer.append(Utils.doubleToString(values[i2], 3));
            if (i2 < this.m_numOfClasses - 1) {
                stringBuffer.append(",");
            }
        }
        Enumeration children = predictionNode.children();
        while (children.hasMoreElements()) {
            Splitter splitter = (Splitter) children.nextElement();
            for (int i3 = 0; i3 < splitter.getNumOfBranches(); i3++) {
                PredictionNode childForBranch = splitter.getChildForBranch(i3);
                if (childForBranch != null) {
                    stringBuffer.append("\n");
                    for (int i4 = 0; i4 < i; i4++) {
                        stringBuffer.append("|  ");
                    }
                    stringBuffer.append("(" + splitter.orderAdded + ")");
                    stringBuffer.append(splitter.attributeString() + TestInstances.DEFAULT_SEPARATORS + splitter.comparisonString(i3));
                    stringBuffer.append(toString(childForBranch, i + 1));
                }
            }
        }
        return stringBuffer.toString();
    }

    @Override // weka.core.Drawable
    public String graph() throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("digraph ADTree {\n");
        graphTraverse(this.m_root, stringBuffer, 0, 0);
        return stringBuffer.toString() + "}\n";
    }

    protected void graphTraverse(PredictionNode predictionNode, StringBuffer stringBuffer, int i, int i2) throws Exception {
        stringBuffer.append("S" + i + "P" + i2 + " [label=\"");
        double[] values = predictionNode.getValues();
        for (int i3 = 0; i3 < this.m_numOfClasses; i3++) {
            stringBuffer.append(Utils.doubleToString(values[i3], 3));
            if (i3 < this.m_numOfClasses - 1) {
                stringBuffer.append(",");
            }
        }
        if (i == 0) {
            stringBuffer.append(" (" + legend() + ")");
        }
        stringBuffer.append("\" shape=box style=filled]\n");
        Enumeration children = predictionNode.children();
        while (children.hasMoreElements()) {
            Splitter splitter = (Splitter) children.nextElement();
            stringBuffer.append("S" + i + "P" + i2 + "->S" + splitter.orderAdded + " [style=dotted]\n");
            stringBuffer.append("S" + splitter.orderAdded + " [label=\"" + splitter.orderAdded + ": " + splitter.attributeString() + "\"]\n");
            for (int i4 = 0; i4 < splitter.getNumOfBranches(); i4++) {
                PredictionNode childForBranch = splitter.getChildForBranch(i4);
                if (childForBranch != null) {
                    stringBuffer.append("S" + splitter.orderAdded + "->S" + splitter.orderAdded + "P" + i4 + " [label=\"" + splitter.comparisonString(i4) + "\"]\n");
                    graphTraverse(childForBranch, stringBuffer, splitter.orderAdded, i4);
                }
            }
        }
    }

    public String legend() {
        Attribute attribute = null;
        if (this.m_trainInstances == null) {
            return "";
        }
        try {
            attribute = this.m_trainInstances.classAttribute();
        } catch (Exception e) {
        }
        if (this.m_numOfClasses == 1) {
            return "-ve = " + attribute.value(0) + ", +ve = " + attribute.value(1);
        }
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.m_numOfClasses; i++) {
            if (i > 0) {
                stringBuffer.append(", ");
            }
            stringBuffer.append(attribute.value(i));
        }
        return stringBuffer.toString();
    }

    public String numOfBoostingIterationsTipText() {
        return "The number of boosting iterations to use, which determines the size of the tree.";
    }

    public int getNumOfBoostingIterations() {
        return this.m_boostingIterations;
    }

    public void setNumOfBoostingIterations(int i) {
        this.m_boostingIterations = i;
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(1);
        vector.addElement(new Option("\tNumber of boosting iterations.\n\t(Default = 10)", "B", 1, "-B <number of boosting iterations>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('B', strArr);
        if (option.length() != 0) {
            setNumOfBoostingIterations(Integer.parseInt(option));
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[2 + super.getOptions().length];
        int i = 0 + 1;
        strArr[0] = "-B";
        int i2 = i + 1;
        strArr[i] = "" + getNumOfBoostingIterations();
        System.arraycopy(super.getOptions(), 0, strArr, i2, super.getOptions().length);
        while (i2 < strArr.length) {
            int i3 = i2;
            i2++;
            strArr[i3] = "";
        }
        return strArr;
    }

    public double measureTreeSize() {
        return numOfAllNodes(this.m_root);
    }

    public double measureNumLeaves() {
        return numOfPredictionNodes(this.m_root);
    }

    public double measureNumPredictionLeaves() {
        return numOfLeafNodes(this.m_root);
    }

    public double measureNodesExpanded() {
        return this.m_nodesExpanded;
    }

    public double measureExamplesCounted() {
        return this.m_examplesCounted;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public Enumeration enumerateMeasures() {
        Vector vector = new Vector(5);
        vector.addElement("measureTreeSize");
        vector.addElement("measureNumLeaves");
        vector.addElement("measureNumPredictionLeaves");
        vector.addElement("measureNodesExpanded");
        vector.addElement("measureExamplesCounted");
        return vector.elements();
    }

    @Override // weka.core.AdditionalMeasureProducer
    public double getMeasure(String str) {
        if (str.equals("measureTreeSize")) {
            return measureTreeSize();
        }
        if (str.equals("measureNodesExpanded")) {
            return measureNodesExpanded();
        }
        if (str.equals("measureNumLeaves")) {
            return measureNumLeaves();
        }
        if (str.equals("measureNumPredictionLeaves")) {
            return measureNumPredictionLeaves();
        }
        if (str.equals("measureExamplesCounted")) {
            return measureExamplesCounted();
        }
        throw new IllegalArgumentException(str + " not supported (ADTree)");
    }

    protected int numOfPredictionNodes(PredictionNode predictionNode) {
        int i = 0;
        if (predictionNode != null) {
            i = 0 + 1;
            Enumeration children = predictionNode.children();
            while (children.hasMoreElements()) {
                Splitter splitter = (Splitter) children.nextElement();
                for (int i2 = 0; i2 < splitter.getNumOfBranches(); i2++) {
                    i += numOfPredictionNodes(splitter.getChildForBranch(i2));
                }
            }
        }
        return i;
    }

    protected int numOfLeafNodes(PredictionNode predictionNode) {
        int i = 0;
        if (predictionNode.getChildren().size() > 0) {
            Enumeration children = predictionNode.children();
            while (children.hasMoreElements()) {
                Splitter splitter = (Splitter) children.nextElement();
                for (int i2 = 0; i2 < splitter.getNumOfBranches(); i2++) {
                    i += numOfLeafNodes(splitter.getChildForBranch(i2));
                }
            }
        } else {
            i = 1;
        }
        return i;
    }

    protected int numOfAllNodes(PredictionNode predictionNode) {
        int i = 0;
        if (predictionNode != null) {
            i = 0 + 1;
            Enumeration children = predictionNode.children();
            while (children.hasMoreElements()) {
                i++;
                Splitter splitter = (Splitter) children.nextElement();
                for (int i2 = 0; i2 < splitter.getNumOfBranches(); i2++) {
                    i += numOfAllNodes(splitter.getChildForBranch(i2));
                }
            }
        }
        return i;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        initClassifier(instances);
        for (int i = 0; i < this.m_boostingIterations; i++) {
            boost();
        }
    }

    public int predictiveError(Instances instances) {
        int i = 0;
        for (int numInstances = instances.numInstances() - 1; numInstances >= 0; numInstances--) {
            Instance instance = instances.instance(numInstances);
            try {
                if (classifyInstance(instance) != instance.classValue()) {
                    i++;
                }
            } catch (Exception e) {
                i++;
            }
        }
        return i;
    }

    public void merge(LADTree lADTree) throws Exception {
        if (this.m_root == null || lADTree.m_root == null) {
            throw new Exception("Trying to merge an uninitialized tree");
        }
        if (this.m_numOfClasses != lADTree.m_numOfClasses) {
            throw new Exception("Trees not suitable for merge - different sized prediction nodes");
        }
        this.m_root.merge(lADTree.m_root);
    }

    @Override // weka.core.Drawable
    public int graphType() {
        return 1;
    }

    @Override // weka.classifiers.Classifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 6036 $");
    }

    @Override // weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    public static void main(String[] strArr) {
        runClassifier(new LADTree(), strArr);
    }
}
