Class CART

java.lang.Object
smile.base.cart.CART
All Implemented Interfaces:
Serializable, SHAP<Tuple>
Direct Known Subclasses:
DecisionTree, RegressionTree

public abstract class CART extends Object implements SHAP<Tuple>, Serializable
Classification and regression tree.
See Also:
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    protected Formula
    The model formula.
    protected double[]
    Variable importance.
    protected int[]
    An index of samples to their original locations in training dataset.
    protected int
    The maximum depth of the tree.
    protected int
    The maximum number of leaf nodes in the tree.
    protected int
    The number of input variables to be used to determine the decision at a node of the tree.
    protected int
    The number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
    protected int[][]
    An index of training values.
    protected StructField
    The schema of response variable.
    protected Node
    The root of decision tree.
    protected int[]
    The samples for training this node.
    protected StructType
    The schema of predictors.
    protected DataFrame
    The training data.
  • Constructor Summary

    Constructors
    Constructor
    Description
    CART(DataFrame x, StructField y, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order)
    Constructor.
    CART(Formula formula, StructType schema, StructField response, Node root, double[] importance)
    Constructor.
  • Method Summary

    Modifier and Type
    Method
    Description
    protected void
    Clear the workspace of building tree.
    dot()
    Returns the graphic representation in Graphviz dot format.
    protected abstract Optional<Split>
    findBestSplit(LeafNode node, int column, double impurity, int lo, int hi)
    Finds the best split for given column.
    protected Optional<Split>
    findBestSplit(LeafNode node, int lo, int hi, boolean[] unsplittable)
    Finds the best attribute to split on a set of samples.
    double[]
    Returns the variable importance.
    protected abstract double
    Returns the impurity of node.
    protected abstract LeafNode
    newNode(int[] nodeSamples)
    Creates a new leaf node.
    static int[][]
    Returns the index of ordered samples for each ordinal column.
    protected Tuple
    Returns the predictors by the model formula if it is not null.
    Returs the root node.
    double[]
    Returns the average of absolute SHAP values over a data frame.
    double[]
    Returns the SHAP values.
    int
    Returns the number of nodes in the tree.
    protected boolean
    split(Split split, PriorityQueue<Split> queue)
    Split a node into two children nodes.
    Returns a text representation of the tree in R's rpart format.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface smile.feature.importance.SHAP

    shap
  • Field Details

    • formula

      protected Formula formula
      The model formula.
    • schema

      protected StructType schema
      The schema of predictors.
    • response

      protected StructField response
      The schema of response variable.
    • root

      protected Node root
      The root of decision tree.
    • maxDepth

      protected int maxDepth
      The maximum depth of the tree.
    • maxNodes

      protected int maxNodes
      The maximum number of leaf nodes in the tree.
    • nodeSize

      protected int nodeSize
      The number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
    • mtry

      protected int mtry
      The number of input variables to be used to determine the decision at a node of the tree.
    • importance

      protected double[] importance
      Variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
    • x

      protected transient DataFrame x
      The training data.
    • samples

      protected transient int[] samples
      The samples for training this node. Note that samples[i] is the number of sampling of dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible because of sampling with replacement.
    • index

      protected transient int[] index
      An index of samples to their original locations in training dataset.
    • order

      protected transient int[][] order
      An index of training values. Initially, order[i] is a set of indices that iterate through the training values for attribute i in ascending order. During training, the array is rearranged so that all values for each leaf node occupy a contiguous range, but within that range they maintain the original ordering. Note that only numeric attributes will be sorted; non-numeric attributes will have a null in the corresponding place in the array.
  • Constructor Details

    • CART

      public CART(Formula formula, StructType schema, StructField response, Node root, double[] importance)
      Constructor.
      Parameters:
      formula - The model formula.
      schema - The data schema of predictors.
      response - The response variable.
      root - The root node.
      importance - The feature importance.
    • CART

      public CART(DataFrame x, StructField y, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order)
      Constructor.
      Parameters:
      x - the data frame of the explanatory variable.
      y - the response variables.
      maxDepth - the maximum depth of the tree.
      maxNodes - the maximum number of leaf nodes in the tree.
      nodeSize - the minimum size of leaf nodes.
      mtry - the number of input variables to pick to split on at each node. It seems that sqrt(p) give generally good performance, where p is the number of variables.
      samples - the sample set of instances for stochastic learning. samples[i] is the number of sampling for instance i.
      order - the index of training values in ascending order. Note that only numeric attributes need be sorted.
  • Method Details

    • size

      public int size()
      Returns the number of nodes in the tree.
      Returns:
      the number of nodes in the tree.
    • order

      public static int[][] order(DataFrame x)
      Returns the index of ordered samples for each ordinal column.
      Parameters:
      x - the predictors.
      Returns:
      the index of ordered samples for each ordinal column.
    • predictors

      protected Tuple predictors(Tuple x)
      Returns the predictors by the model formula if it is not null. Otherwise, return the input tuple.
      Parameters:
      x - the input tuple.
      Returns:
      the predictors.
    • clear

      protected void clear()
      Clear the workspace of building tree.
    • split

      protected boolean split(Split split, PriorityQueue<Split> queue)
      Split a node into two children nodes.
      Parameters:
      split - the split candidate.
      queue - the queue of splits.
      Returns:
      true if split success.
    • findBestSplit

      protected Optional<Split> findBestSplit(LeafNode node, int lo, int hi, boolean[] unsplittable)
      Finds the best attribute to split on a set of samples. at the current node. Returns Optional.empty if a split doesn't exist to reduce the impurity.
      Parameters:
      node - the leaf node to split.
      lo - the inclusive lower bound of the data partition in the reordered sample index array.
      hi - the exclusive upper bound of the data partition in the reordered sample index array.
      unsplittable - unsplittable[j] is true if the column j cannot be split further in the node.
      Returns:
      the best split candidate.
    • impurity

      protected abstract double impurity(LeafNode node)
      Returns the impurity of node.
      Parameters:
      node - the node to calculate the impurity.
      Returns:
      the impurity of node.
    • newNode

      protected abstract LeafNode newNode(int[] nodeSamples)
      Creates a new leaf node.
      Parameters:
      nodeSamples - the samples belonging to this node.
      Returns:
      the new leaf node.
    • findBestSplit

      protected abstract Optional<Split> findBestSplit(LeafNode node, int column, double impurity, int lo, int hi)
      Finds the best split for given column.
      Parameters:
      node - the node to split.
      column - the column to split on.
      impurity - the impurity of node.
      lo - the lower bound of sample index in the node.
      hi - the upper bound of sample index in the node.
      Returns:
      the best split.
    • importance

      public double[] importance()
      Returns the variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
      Returns:
      the variable importance
    • root

      public Node root()
      Returs the root node.
      Returns:
      root node.
    • dot

      public String dot()
      Returns the graphic representation in Graphviz dot format. Try http://viz-js.com/ to visualize the returned string.
      Returns:
      the graphic representation in Graphviz dot format.
    • toString

      public String toString()
      Returns a text representation of the tree in R's rpart format. A semi-graphical layout of the tree. Indentation is used to convey the tree topology. Information for each node includes the node number, split, size, deviance, and fitted value. For the decision tree, the class probabilities are also printed.
      Overrides:
      toString in class Object
    • shap

      public double[] shap(DataFrame data)
      Returns the average of absolute SHAP values over a data frame.
      Parameters:
      data - the data.
      Returns:
      the average of absolute SHAP values.
    • shap

      public double[] shap(Tuple x)
      Description copied from interface: SHAP
      Returns the SHAP values. For regression, the length of SHAP values is same as the number of features. For classification, SHAP values are of p x k, where p is the number of features and k is the classes. The first k elements are the SHAP values of first feature over k classes, respectively. The rest features follow accordingly.
      Specified by:
      shap in interface SHAP<Tuple>
      Parameters:
      x - an instance.
      Returns:
      the SHAP values.