tidymodelsのrecipesパッケージがworkflowsパッケージの使用を推奨し始めた

by
カテゴリ:
タグ:

tidymodelsを使ったモデリングにおいて、recipesパッケージは特徴量エンジニアリングを担います。従来、recipesパッケージは単体で、特徴量抽エンジニアリング方法の

  1. 定義
    • recipe関数 + step_*関数群
  2. 学習
    • prep関数
  3. 適用
    • bake関数(汎用)
    • juice関数(学習データ専用)

の一連の流れを担っていました。学習と適用の分割は、テストへのリークが発生対策です。標準化やPCAを行うとして、そのパラメータは学習データから決めようというわけですね。しかし、学習と適用はworkflowsパッケージに任せるのが最新式なようです。

If you are using a recipe as a preprocessor for modeling, we highly recommend that you use a workflow() instead of manually estimating a recipe (see the example in [recipe()]). 2021-06-29の更新

これは早急にworkflowsの使い方を学ばねばなりませんね。

tidymodelsを使って線型モデルを学習する例

基本の流れはこんな感じ。

  1. 訓練データとテストデータの分割方法の定義実行(rsampleパッケージ)
  2. 前処理方法の定義(recipesパッケージ)
  3. モデルの定義(parsnipパッケージ)
  4. 前処理・モデルの統合と実行(workflowsパッケージ)

なんとなくですが、詳細な定義はrecipesやparsnipでやって、実行はworkflowsっていう流れみたいです。将来的には、2値分類の閾値変更などの後処理も担うそうです。一方でデータの分割は実行も自身でやるのが気になるところですが、今のところ、workflowsパッケージが分割の実行を担うことはなさそうです(検索結果)。

library(magrittr)

set.seed(1L)

# データの分割
split <- ggplot2::diamonds %>%
  dplyr::select(where(is.numeric)) %>%
  rsample::initial_split(prop = .9)
training_data <-rsample::training(split)
testing_data <- rsample::testing(split)

# 特徴量エンジニアリング方法の定義
preprocessor <- recipes::recipe(training_data, price ~ .) %>%
  recipes::step_center(recipes::all_numeric_predictors()) %>%
  recipes::step_scale(recipes::all_numeric_predictors())

# モデルの定義
spec <- parsnip::linear_reg() %>%
  parsnip::set_engine("lm")

# ワークフローの定義
wf <- workflows::workflow() %>%
  workflows::add_recipe(preprocessor) %>%
  workflows::add_model(spec)

# ワークフローの学習
trained <- generics::fit(wf, training_data)

ワークフローに特徴量エンジニアリングとモデリングの両方を追加していますが、どちらか一方でもいいですし、順序も問いません。自動的に特徴量エンジニアリング、モデリングの順になります。

学習に使ったgenerics::fit関数は内部的にはworkflows:::fit.workflowを呼んでいます。 tidymodelsにおいてはparsnipパッケージがfit関数をエクスポートしていて、parsnip::fit関数でも同様に処理できます。しかし、workflowsパッケージとparsnipパッケージの役割が混ざるので、genericsパッケージから呼びました。 workflowsパッケージにfitをエクスポートしてもらった方がいい気がしますね。

学習結果の調査

tidy関数を使うとモデルの学習結果や、特徴量エンジニアリングの概要を見れます。

# 学習結果
broom::tidy(trained, "model")
#> # A tibble: 7 x 5
#>   term        estimate std.error statistic   p.value
#>   <chr>          <dbl>     <dbl>     <dbl>     <dbl>
#> 1 (Intercept)   3928.       6.76    581.   0        
#> 2 carat         5067.      31.3     162.   0        
#> 3 depth         -294.       8.18    -36.0  6.99e-280
#> 4 table         -229.       7.23    -31.8  4.01e-219
#> 5 x            -1474.      49.0     -30.1  7.82e-197
#> 6 y               60.4     29.1       2.07 3.80e-  2
#> 7 z               34.3     31.2       1.10 2.71e-  1

# 特徴量エンジニアリングの概要
broom::tidy(trained, "recipe")
#> # A tibble: 2 x 6
#>   number operation type   trained skip  id          
#>    <int> <chr>     <chr>  <lgl>   <lgl> <chr>       
#> 1      1 step      center TRUE    FALSE center_bZ3xW
#> 2      2 step      scale  TRUE    FALSE scale_pTXgd

特徴量エンジニアリングについて詳しく見たい場合はworkflows::pull_workflow_preprocessor関数を使うらしい。返り値はrecipeクラスオブジェクト。

workflows::pull_workflow_prepped_recipe(trained) %>% class
#> [1] "recipe"

というわけで従来、recipesパッケージでやっていたように、broom::tidy関数でid引数を指定してやれば、たとえば中心化に使ったパラメータ(訓練データの各特徴量の平均値)が見れるはず。

workflows::pull_workflow_prepped_recipe(trained) %>%
  broom::tidy(id = "center_bZ3xW")
#> # A tibble: 6 x 3
#>   terms  value id          
#>   <chr>  <dbl> <chr>       
#> 1 carat  0.798 center_bZ3xW
#> 2 depth 61.8   center_bZ3xW
#> 3 table 57.5   center_bZ3xW
#> 4 x      5.73  center_bZ3xW
#> 5 y      5.73  center_bZ3xW
#> 6 z      3.54  center_bZ3xW

似たような名前のworkflows::pull_workflow_preprocessorもあるが、こいつは学習前の定義を取り出すだけなので注意。こいつ、いる……?

Session Info

sessioninfo::session_info(c("recipes", "parsnip", "workflows"))
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.0 (2021-05-18)
#>  os       Ubuntu 20.04.2 LTS          
#>  system   x86_64, linux-gnu           
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Etc/UTC                     
#>  date     2021-07-01                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version     date       lib source                               
#>  class         7.3-19      2021-05-03 [2] CRAN (R 4.1.0)                       
#>  cli           2.5.0       2021-04-26 [1] RSPM (R 4.1.0)                       
#>  codetools     0.2-18      2020-11-04 [2] CRAN (R 4.1.0)                       
#>  cpp11         0.3.1       2021-06-25 [1] RSPM (R 4.1.0)                       
#>  crayon        1.4.1       2021-02-08 [1] RSPM (R 4.1.0)                       
#>  dplyr         1.0.7       2021-06-18 [1] RSPM (R 4.1.0)                       
#>  ellipsis      0.3.2       2021-04-29 [1] RSPM (R 4.1.0)                       
#>  fansi         0.5.0       2021-05-25 [1] RSPM (R 4.1.0)                       
#>  generics      0.1.0       2020-10-31 [1] RSPM (R 4.1.0)                       
#>  globals       0.14.0      2020-11-22 [1] RSPM (R 4.1.0)                       
#>  glue          1.4.2       2020-08-27 [1] RSPM (R 4.1.0)                       
#>  gower         0.2.2       2020-06-23 [1] RSPM (R 4.1.0)                       
#>  hardhat       0.1.5       2020-11-09 [1] RSPM (R 4.1.0)                       
#>  ipred         0.9-11      2021-03-12 [1] RSPM (R 4.1.0)                       
#>  KernSmooth    2.23-20     2021-05-03 [2] CRAN (R 4.1.0)                       
#>  lattice       0.20-44     2021-05-02 [2] CRAN (R 4.1.0)                       
#>  lava          1.6.9       2021-03-11 [1] RSPM (R 4.1.0)                       
#>  lifecycle     1.0.0       2021-02-15 [1] RSPM (R 4.1.0)                       
#>  lubridate     1.7.10      2021-02-26 [1] RSPM (R 4.1.0)                       
#>  magrittr    * 2.0.1       2020-11-17 [1] RSPM (R 4.1.0)                       
#>  MASS          7.3-54      2021-05-03 [2] CRAN (R 4.1.0)                       
#>  Matrix        1.3-3       2021-05-04 [2] CRAN (R 4.1.0)                       
#>  nnet          7.3-16      2021-05-03 [2] CRAN (R 4.1.0)                       
#>  numDeriv      2016.8-1.1  2019-06-06 [1] RSPM (R 4.1.0)                       
#>  parsnip       0.1.6.9000  2021-07-01 [1] Github (tidymodels/parsnip@89f8f93)  
#>  pillar        1.6.1       2021-05-16 [1] RSPM (R 4.1.0)                       
#>  pkgconfig     2.0.3       2019-09-22 [1] RSPM (R 4.1.0)                       
#>  prettyunits   1.1.1       2020-01-24 [1] RSPM (R 4.1.0)                       
#>  prodlim       2019.11.13  2019-11-17 [1] RSPM (R 4.1.0)                       
#>  purrr         0.3.4       2020-04-17 [1] RSPM (R 4.1.0)                       
#>  R6            2.5.0       2020-10-28 [1] RSPM (R 4.1.0)                       
#>  Rcpp          1.0.6       2021-01-15 [1] RSPM (R 4.1.0)                       
#>  recipes       0.1.16.9000 2021-07-01 [1] Github (tidymodels/recipes@39bc4e8)  
#>  rlang         0.4.11      2021-04-30 [1] RSPM (R 4.1.0)                       
#>  rpart         4.1-15      2019-04-12 [2] CRAN (R 4.1.0)                       
#>  SQUAREM       2021.1      2021-01-13 [1] RSPM (R 4.1.0)                       
#>  survival      3.2-11      2021-04-26 [2] CRAN (R 4.1.0)                       
#>  tibble        3.1.2       2021-05-16 [1] RSPM (R 4.1.0)                       
#>  tidyr         1.1.3       2021-03-03 [1] RSPM (R 4.1.0)                       
#>  tidyselect    1.1.1       2021-04-30 [1] RSPM (R 4.1.0)                       
#>  timeDate      3043.102    2018-02-21 [1] RSPM (R 4.1.0)                       
#>  utf8          1.2.1       2021-03-12 [1] RSPM (R 4.1.0)                       
#>  vctrs         0.3.8       2021-04-29 [1] RSPM (R 4.1.0)                       
#>  withr         2.4.2       2021-04-18 [1] RSPM (R 4.1.0)                       
#>  workflows     0.2.2.9000  2021-07-01 [1] Github (tidymodels/workflows@8ad5a9d)
#> 
#> [1] /usr/local/lib/R/site-library
#> [2] /usr/local/lib/R/library