Skip to contents

Transform sample model output into a forecast object

Usage

transform_sample_model_out(
  model_out_tbl,
  oracle_output,
  compound_taskid_set = NULL
)

Arguments

model_out_tbl

Model output tibble with predictions

oracle_output

Predictions that would have been generated by an oracle model that knew the observed target data values in advance

compound_taskid_set

Character vector of task ID column names that stay constant within each sample draw (i.e., define the compound modeling task grouping). When NULL (the default), each modeling task is scored independently (marginal scoring). When provided, sample draws are treated as joint predictions across the task ID dimensions not in compound_taskid_set, and multivariate scoring metrics are used.

Value

A forecast_sample object (when compound_taskid_set is NULL) or a forecast_multivariate_sample object (when compound_taskid_set is provided).

Examples

# Marginal sample forecast: each modeling task scored independently
sample_forecast <- hubExamples::forecast_outputs |>
  dplyr::filter(.data[["output_type"]] == "sample") |>
  transform_sample_model_out(
    oracle_output = hubExamples::forecast_oracle_output
  )
sample_forecast
#> Forecast type: sample
#> Forecast unit:
#> model, reference_date, target, horizon, location, and target_end_date
#> 
#>       sample_id predicted observed             model reference_date
#>          <char>     <num>    <num>            <char>         <Date>
#>    1:      2101         0       79 Flusight-baseline     2022-11-19
#>    2:      2102         2       79 Flusight-baseline     2022-11-19
#>    3:      2103        52       79 Flusight-baseline     2022-11-19
#>    4:      2104        47       79 Flusight-baseline     2022-11-19
#>    5:      2105        56       79 Flusight-baseline     2022-11-19
#>   ---                                                              
#> 4796:      4396       978     1170          PSI-DICE     2022-12-17
#> 4797:      4397      1025     1170          PSI-DICE     2022-12-17
#> 4798:      4398      1040     1170          PSI-DICE     2022-12-17
#> 4799:      4399      1339     1170          PSI-DICE     2022-12-17
#> 4800:      4400      1175     1170          PSI-DICE     2022-12-17
#>                target horizon location target_end_date
#>                <char>   <int>   <char>          <Date>
#>    1: wk inc flu hosp       0       25      2022-11-19
#>    2: wk inc flu hosp       0       25      2022-11-19
#>    3: wk inc flu hosp       0       25      2022-11-19
#>    4: wk inc flu hosp       0       25      2022-11-19
#>    5: wk inc flu hosp       0       25      2022-11-19
#>   ---                                                 
#> 4796: wk inc flu hosp       3       48      2023-01-07
#> 4797: wk inc flu hosp       3       48      2023-01-07
#> 4798: wk inc flu hosp       3       48      2023-01-07
#> 4799: wk inc flu hosp       3       48      2023-01-07
#> 4800: wk inc flu hosp       3       48      2023-01-07

# Compound sample forecast: jointly score across non-compound task IDs
compound_forecast <- hubExamples::forecast_outputs |>
  dplyr::filter(.data[["output_type"]] == "sample") |>
  transform_sample_model_out(
    oracle_output = hubExamples::forecast_oracle_output,
    compound_taskid_set = c("reference_date", "location")
  )
compound_forecast
#> Forecast type: multivariate_sample
#> Forecast unit:
#> model, reference_date, target, horizon, location, and target_end_date
#> Joint across:
#> horizon and target_end_date
#> 
#>       sample_id predicted observed             model reference_date
#>          <char>     <num>    <num>            <char>         <Date>
#>    1:      2101         0       79 Flusight-baseline     2022-11-19
#>    2:      2102         2       79 Flusight-baseline     2022-11-19
#>    3:      2103        52       79 Flusight-baseline     2022-11-19
#>    4:      2104        47       79 Flusight-baseline     2022-11-19
#>    5:      2105        56       79 Flusight-baseline     2022-11-19
#>   ---                                                              
#> 4796:      4396       978     1170          PSI-DICE     2022-12-17
#> 4797:      4397      1025     1170          PSI-DICE     2022-12-17
#> 4798:      4398      1040     1170          PSI-DICE     2022-12-17
#> 4799:      4399      1339     1170          PSI-DICE     2022-12-17
#> 4800:      4400      1175     1170          PSI-DICE     2022-12-17
#>                target horizon location target_end_date .mv_group_id
#>                <char>   <int>   <char>          <Date>        <int>
#>    1: wk inc flu hosp       0       25      2022-11-19            1
#>    2: wk inc flu hosp       0       25      2022-11-19            1
#>    3: wk inc flu hosp       0       25      2022-11-19            1
#>    4: wk inc flu hosp       0       25      2022-11-19            1
#>    5: wk inc flu hosp       0       25      2022-11-19            1
#>   ---                                                              
#> 4796: wk inc flu hosp       3       48      2023-01-07           12
#> 4797: wk inc flu hosp       3       48      2023-01-07           12
#> 4798: wk inc flu hosp       3       48      2023-01-07           12
#> 4799: wk inc flu hosp       3       48      2023-01-07           12
#> 4800: wk inc flu hosp       3       48      2023-01-07           12