Compute decision tree from data set

compute_tree(
  x,
  data_test = NULL,
  target_lab = NULL,
  task = c("classification", "regression"),
  feat_types = NULL,
  label_map = NULL,
  clust_samps = TRUE,
  clust_target = TRUE,
  custom_layout = NULL,
  lev_fac = 1.3,
  panel_space = 0.001
)

Arguments

x

Dataframe or a `party` or `partynode` object representing a custom tree. If a dataframe is supplied, conditional inference tree is computed. If a custom tree is supplied, it must follow the partykit syntax: https://cran.r-project.org/web/packages/partykit/vignettes/partykit.pdf

data_test

Tidy test dataset. Required if `x` is a `partynode` object. If NULL, heatmap displays (training) data `x`.

target_lab

Name of the column in data that contains target/label information.

task

Character string indicating the type of problem, either 'classification' (categorical outcome) or 'regression' (continuous outcome).

feat_types

Named vector indicating the type of each features, e.g., c(sex = 'factor', age = 'numeric'). If feature types are not supplied, infer from column type.

label_map

Named vector of the meaning of the target values, e.g., c(`0` = 'Edible', `1` = 'Poisonous').

clust_samps

Logical. If TRUE, hierarchical clustering would be performed among samples within each leaf node.

clust_target

Logical. If TRUE, target/label is included in hierarchical clustering of samples within each leaf node and might yield a more interpretable heatmap.

custom_layout

Dataframe with 3 columns: id, x and y for manually input custom layout.

lev_fac

Relative weight of child node positions according to their levels, commonly ranges from 1 to 1.5. 1 for parent node perfectly in the middle of child nodes.

panel_space

Spacing between facets relative to viewport, recommended to range from 0.001 to 0.01.

Value

A list of results from `partykit::ctree` or provided custom tree, including fit, estimates, smart layout and terminal data.

Examples

fit_tree <- compute_tree(penguins, target_lab = 'species')
fit_tree$fit
#> 
#> Model formula:
#> species ~ island + culmen_length_mm + culmen_depth_mm + flipper_length_mm + 
#>     body_mass_g + sex
#> 
#> Fitted party:
#> [1] root
#> |   [2] island in Torgersen, Dream
#> |   |   [3] culmen_length_mm <= 44.1
#> |   |   |   [4] culmen_length_mm <= 42.3: Adelie (n = 100, err = 1.0%)
#> |   |   |   [5] culmen_length_mm > 42.3: Adelie (n = 12, err = 41.7%)
#> |   |   [6] culmen_length_mm > 44.1: Chinstrap (n = 64, err = 3.1%)
#> |   [7] island in Biscoe
#> |   |   [8] flipper_length_mm <= 203
#> |   |   |   [9] culmen_length_mm <= 41.4: Adelie (n = 38, err = 0.0%)
#> |   |   |   [10] culmen_length_mm > 41.4: Adelie (n = 8, err = 25.0%)
#> |   |   [11] flipper_length_mm > 203: Gentoo (n = 122, err = 0.0%)
#> 
#> Number of inner nodes:    5
#> Number of terminal nodes: 6
fit_tree$layout
#>    id         x   y
#> 1   1 0.5274183 1.0
#> 2   2 0.3289129 0.8
#> 3   3 0.2194596 0.6
#> 4   7 0.7259236 0.8
#> 5   8 0.6000531 0.6
#> 6   4 0.1400150 0.0
#> 7   5 0.2989042 0.0
#> 8   6 0.4131078 0.0
#> 9   9 0.5660389 0.0
#> 10 10 0.6340674 0.0
#> 11 11 0.8227470 0.0
dplyr::select(fit_tree$term_dat, - contains('nodedata'))
#>   id parent birth_order breaks_label info    info_list splitvar level kids
#> 1  4      3           1 NA <= NA....   NA c(0.8000....     <NA>     3    0
#> 2  5      3           2 NA >  NA....   NA           NA     <NA>     3    0
#> 3  6      2           2 NA >  NA....   NA c(64, 7.....     <NA>     3    0
#> 4  9      8           1 NA <= NA....   NA           NA     <NA>     3    0
#> 5 10      8           2 NA >  NA....   NA           NA     <NA>     3    0
#> 6 11      7           2 NA >  NA....   NA           NA     <NA>     3    0
#>   nodesize      p.value horizontal  x_parent y_parent     y_hat   n         x y
#> 1      100 6.811624e-01      FALSE 0.1666667      0.6    Adelie  94 0.1400150 0
#> 2       12           NA      FALSE 0.1666667      0.6    Adelie  12 0.2989042 0
#> 3       64 7.465153e-15      FALSE 0.2500000      0.8 Chinstrap  64 0.4131078 0
#> 4       38           NA      FALSE 0.6666667      0.6    Adelie  38 0.5660389 0
#> 5        8           NA      FALSE 0.6666667      0.6    Adelie   7 0.6340674 0
#> 6      122           NA      FALSE 0.7500000      0.8    Gentoo 119 0.8227470 0
#>   term_node
#> 1    Adelie
#> 2    Adelie
#> 3 Chinstrap
#> 4    Adelie
#> 5    Adelie
#> 6    Gentoo