Training a model

treecat.training

class treecat.training.TreeCatTrainer(table, tree_prior, config)

Class for training a TreeCat model.

compute_edge_logits()

Compute non-normalized logprob of all V(V-1)/2 candidate edges.

This is used for sampling and estimating the latent tree.

logprob()

Compute non-normalized log probability of data and latent state.

This is used for testing goodness of fit of the latent state kernel.

train()

Train a TreeCat model using subsample-annealed MCMC.

Returns:
A trained model as a dictionary with keys:

config: A global config dict. tree: A TreeStructure instance with the learned latent

structure.

edge_logits: A [K]-shaped array of all edge logits. suffstats: Sufficient statistics of features, vertices, and

edges and a ragged_index for the features array.
assignments: An [N, V]-shaped numpy array of latent cluster
ids for each cell in the dataset, where N be the number of data rows and V is the number of features.
class treecat.training.TreeGaussTrainer(data, tree_prior, config)

Class for training a TreeGauss model.

compute_edge_logits()

Compute non-normalized logprob of all V(V-1)/2 candidate edges.

This is used for sampling and estimating the latent tree.

logprob()

Compute non-normalized log probability of data and latent state.

This is used for testing goodness of fit of the latent state kernel.

train()

Train a TreeGauss model using subsample-annealed MCMC.

Returns:
A trained model as a dictionary with keys:

config: A global config dict. tree: A TreeStructure instance with the learned latent

structure.

edge_logits: A [K]-shaped array of all edge logits. suffstats: Sufficient statistics of features and vertices. latent: An [N, V, M]-shaped numpy array of latent states, where

N is the number of data rows, V is the number of features, and M is the dimension of each latent variable.
class treecat.training.TreeMogTrainer(data, tree_prior, config)

Class for training a tree mixture-of-Gaussians model.

add_row(row_id)

Add a given row to the current subsample.

compute_edge_logits()

Compute edge log probabilities on the complete graph.

logprob()

Compute non-normalized log probability of data and latent state.

remove_row(row_id)

Remove a given row from the current subsample.

class treecat.training.TreeTrainer(N, V, tree_prior, config)

Abstract base class for training a tree model various latent state.

Derived classes must implement: - add_row(row_id) - remove_row(row_id) - compute_edge_logits() - logprob()

add_row(row_id)

Add a given row to the current subsample.

compute_edge_logits()

Compute edge log probabilities on the complete graph.

estimate_tree()

Compute a maximum likelihood tree.

Returns:
A pair (edges, edge_logits), where:
edges: A list of (vertex, vertex) pairs. edge_logits: A [K]-shaped numpy array of edge logits.
get_edges()

Get a list of the edges in the current tree.

Returns:
An E-long list of (vertex,vertex) pairs.
logprob()

Compute non-normalized log probability of data and latent state.

This is used for testing goodness of fit of the latent state kernel. This should only be called after training, i.e. after all rows have been added.

remove_row(row_id)

Remove a given row from the current subsample.

sample_tree()

Samples a random tree.

Returns:
A pair (edges, edge_logits), where:
edges: A list of (vertex, vertex) pairs. edge_logits: A [K]-shaped numpy array of edge logits.
set_edges(edges)

Set edges of the latent structure and update statistics.

Args:
edges: An E-long list of (vertex,vertex) pairs.
train()

Train a model using subsample-annealed MCMC.

Returns:
A trained model as a dictionary with keys:

config: A global config dict. tree: A TreeStructure instance with the learned latent

structure.

edge_logits: A [K]-shaped array of all edge logits.

treecat.training.count_pairs(assignments, v1, v2, M)

Construct sufficient statistics for (v1, v2) pairs.

Args:
assignments: An _ x V assignment matrix with values in range(M). v1, v2: Column ids of the assignments matrix. M: The number of possible assignment bins.
Returns:
An M x M array of counts.
treecat.training.logprob_dc(counts, prior, axis=None)

Non-normalized log probability of a Dirichlet-Categorical distribution.

See https://en.wikipedia.org/wiki/Dirichlet-multinomial_distribution

treecat.training.make_annealing_schedule(num_rows, epochs, sample_tree_rate)

Iterator for subsample annealing, yielding (action, arg) pairs.

This generates a subsample annealing schedule starting from an empty assignment state (no rows are assigned). It then interleaves ‘add_row’ and ‘remove_row’ actions so as to gradually increase the number of assigned rows. The increase rate is linear.

Args:
num_rows (int): Number of rows in dataset.
The annealing schedule terminates when all rows are assigned.
epochs (float): Number of epochs in the schedule (i.e. the number of
times each datapoint is assigned). The fastest schedule is epochs=1 which simply sequentially assigns all datapoints. More epochs takes more time.
sample_tree_rate (float): The rate at which ‘sample_tree’ actions are
generated. At sample_tree_rate=1, trees are sampled after each complete flushing of the subsample.
Yields: (action, arg) pairs.
Actions are one of: ‘add_row’, ‘remove_row’, or ‘sample_tree’. When action is ‘add_row’ or ‘remove_row’, arg is a row_id in range(num_rows). When action is ‘sample_tree’ arg is undefined.
treecat.training.train_ensemble(table, tree_prior, config)

Train a TreeCat ensemble model using subsample-annealed MCMC.

The ensemble size is controlled by config[‘model_ensemble_size’]. Let N be the number of data rows and V be the number of features.

Args:

table: A Table instance holding N rows of V features of data. tree_prior: A [K]-shaped numpy array of prior edge log odds, where

K is the number of edges in the complete graph on V vertices.

config: A global config dict.

Returns:
A trained model as a dictionary with keys:

tree: A TreeStructure instance with the learned latent structure. suffstats: Sufficient statistics of features, vertices, and edges. assignments: An [N, V] numpy array of latent cluster ids for each

cell in the dataset.
treecat.training.train_model(table, tree_prior, config)

Train a TreeCat model using subsample-annealed MCMC.

Let N be the number of data rows and V be the number of features.

Args:

table: A Table instance holding N rows of V features of data. tree_prior: A [K]-shaped numpy array of prior edge log odds, where

K is the number of edges in the complete graph on V vertices.

config: A global config dict.

Returns:
A trained model as a dictionary with keys:

config: A global config dict. tree: A TreeStructure instance with the learned latent structure. edge_logits: A [K]-shaped array of all edge logits. suffstats: Sufficient statistics of features, vertices, and edges. assignments: An [N, V] numpy array of latent cluster ids for each

cell in the dataset.