Record Class ClassificationValidation<M>

java.lang.Object
java.lang.Record
smile.validation.ClassificationValidation<M>
Type Parameters:
M - The model type.
Record Components:
model - The classification model.
truth - The ground true of validation data.
prediction - The model prediction.
posteriori - The posteriori probability of prediction if the model is a soft classifier.
confusion - The confusion matrix.
metrics - The classification metrics.
All Implemented Interfaces:
Serializable

public record ClassificationValidation<M>(M model, int[] truth, int[] prediction, double[][] posteriori, ConfusionMatrix confusion, ClassificationMetrics metrics) extends Record implements Serializable
Classification model validation results.
See Also:
  • Constructor Details

    • ClassificationValidation

      public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction)
      Constructor.
      Parameters:
      model - the model.
      fitTime - the time in milliseconds of fitting the model.
      scoreTime - the time in milliseconds of scoring the validation data.
      truth - the ground truth.
      prediction - the predictions.
    • ClassificationValidation

      public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction, double[][] posteriori)
      Constructor of soft classifier validation.
      Parameters:
      model - the model.
      fitTime - the time in milliseconds of fitting the model.
      scoreTime - the time in milliseconds of scoring the validation data.
      truth - the ground truth.
      prediction - the predictions.
      posteriori - the posteriori probabilities of predictions.
    • ClassificationValidation

      public ClassificationValidation(M model, int[] truth, int[] prediction, double[][] posteriori, ConfusionMatrix confusion, ClassificationMetrics metrics)
      Creates an instance of a ClassificationValidation record class.
      Parameters:
      model - the value for the model record component
      truth - the value for the truth record component
      prediction - the value for the prediction record component
      posteriori - the value for the posteriori record component
      confusion - the value for the confusion record component
      metrics - the value for the metrics record component
  • Method Details

    • toString

      public String toString()
      Returns a string representation of this record class. The representation contains the name of the class, followed by the name and value of each of the record components.
      Specified by:
      toString in class Record
      Returns:
      a string representation of this object
    • of

      public static <T, M extends Classifier<T>> ClassificationValidation<M> of(T[] x, int[] y, T[] testx, int[] testy, BiFunction<T[],int[],M> trainer)
      Trains and validates a model on a train/validation split.
      Type Parameters:
      T - the data type of samples.
      M - the model type.
      Parameters:
      x - the training data.
      y - the class labels of training data.
      testx - the validation data.
      testy - the class labels of validation data.
      trainer - the lambda to train the model.
      Returns:
      the validation results.
    • of

      public static <T, M extends Classifier<T>> ClassificationValidations<M> of(Bag[] bags, T[] x, int[] y, BiFunction<T[],int[],M> trainer)
      Trains and validates a model on multiple train/validation split.
      Type Parameters:
      T - the data type of samples.
      M - the model type.
      Parameters:
      bags - the data splits.
      x - the training data.
      y - the class labels.
      trainer - the lambda to train the model.
      Returns:
      the validation results.
    • of

      public static <M extends DataFrameClassifier> ClassificationValidation<M> of(Formula formula, DataFrame train, DataFrame test, BiFunction<Formula,DataFrame,M> trainer)
      Trains and validates a model on a train/validation split.
      Type Parameters:
      M - the model type.
      Parameters:
      formula - the model formula.
      train - the training data.
      test - the validation data.
      trainer - the lambda to train the model.
      Returns:
      the validation results.
    • of

      public static <M extends DataFrameClassifier> ClassificationValidations<M> of(Bag[] bags, Formula formula, DataFrame data, BiFunction<Formula,DataFrame,M> trainer)
      Trains and validates a model on multiple train/validation split.
      Type Parameters:
      M - the model type.
      Parameters:
      bags - the data splits.
      formula - the model formula.
      data - the data.
      trainer - the lambda to train the model.
      Returns:
      the validation results.
    • hashCode

      public final int hashCode()
      Returns a hash code value for this object. The value is derived from the hash code of each of the record components.
      Specified by:
      hashCode in class Record
      Returns:
      a hash code value for this object
    • equals

      public final boolean equals(Object o)
      Indicates whether some other object is "equal to" this one. The objects are equal if the other object is of the same class and if all the record components are equal. All components in this record class are compared with Objects::equals(Object,Object).
      Specified by:
      equals in class Record
      Parameters:
      o - the object with which to compare
      Returns:
      true if this object is the same as the o argument; false otherwise.
    • model

      public M model()
      Returns the value of the model record component.
      Returns:
      the value of the model record component
    • truth

      public int[] truth()
      Returns the value of the truth record component.
      Returns:
      the value of the truth record component
    • prediction

      public int[] prediction()
      Returns the value of the prediction record component.
      Returns:
      the value of the prediction record component
    • posteriori

      public double[][] posteriori()
      Returns the value of the posteriori record component.
      Returns:
      the value of the posteriori record component
    • confusion

      public ConfusionMatrix confusion()
      Returns the value of the confusion record component.
      Returns:
      the value of the confusion record component
    • metrics

      public ClassificationMetrics metrics()
      Returns the value of the metrics record component.
      Returns:
      the value of the metrics record component