Validating model fit¶
treecat.validate¶
-
treecat.validate.
eval
(dataset_path, param_csv_path, models_dir, result_path, **options)¶ Evaluate trained models.
-
treecat.validate.
read_param_csv
(param_csv_path, **options)¶ Reads configs from a csv file.
- Args:
- param_csv_path: The path to a csv file with one line per config. options: A dict of extra config parameters.
- Returns:
- A pair (header, configs), where: header is a list of parameters, and configs is list of config dicts.
-
treecat.validate.
split_data
(ragged_index, num_rows, num_parts, partid)¶ Split a dataset into training + holdout for n-fold crossvalidation.
This splits a dataset into num_parts disjoint parts by randomly holding out cells. Note that whereas supervised crossvalidation typically holds out entire rows, our unsupervised crossvalidation is intended to evaluate a model of the full joint distribution.
- Args:
- ragged_index: A [V+1]-shaped numpy array of indices into the ragged
- data array, where V is the number of features.
num_rows: An integer, the number of rows in the dataset. num_parts: An integer, the number of folds in n-fold crossvalidation. partid: An integer in [0, num_parts).
- Returns:
- A [N,R]-shaped mask where True means held-out and False means training. Here N = num_rows and R = ragged_index[-1].
-
treecat.validate.
train
(dataset_path, param_csv_path, models_dir, **options)¶ Tune parameters specified in a csv file.
-
treecat.validate.
train_task
(dataset_path, model_path, config_str)¶ INTERNAL Train a single model.