Skip to content

Commit

Permalink
balance updates
Browse files Browse the repository at this point in the history
  • Loading branch information
gperrett committed Jun 30, 2024
1 parent 1aafa45 commit d3fce32
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 36 deletions.
2 changes: 1 addition & 1 deletion thinkCausal/R/app_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ app_server <- function(input, output, session) {
store <- mod_analysis_variable_selection_server(module_ids$analysis$select, store)
store <- mod_analysis_verify_server(module_ids$analysis$verify, store)
store <- mod_analysis_visualize_server(module_ids$analysis$visualize, store)
store <- mod_analysis_balance_server(module_ids$analysis$balance, store)
#store <- mod_analysis_balance_server(module_ids$analysis$balance, store)
store <- mod_analysis_overlap_server(module_ids$analysis$overlap, store)
store <- mod_analysis_model_server(module_ids$analysis$model, store)
store <- mod_analysis_diagnostics_server(module_ids$analysis$diagnostics, store)
Expand Down
18 changes: 9 additions & 9 deletions thinkCausal/R/app_ui.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ app_ui <- function(request) {
tabName = 'analysis_visualize',
mod_analysis_visualize_ui(module_ids$analysis$visualize)
),
bs4Dash::tabItem(
tabName = 'analysis_balance',
mod_analysis_balance_ui(module_ids$analysis$balance)
),
# bs4Dash::tabItem(
# tabName = 'analysis_balance',
# mod_analysis_balance_ui(module_ids$analysis$balance)
# ),
bs4Dash::tabItem(
tabName = 'analysis_overlap',
mod_analysis_overlap_ui(module_ids$analysis$overlap)
Expand Down Expand Up @@ -335,11 +335,11 @@ app_ui <- function(request) {
tabName = 'analysis_visualize',
icon = icon('chart-bar', verify_fa = FALSE)
),
bs4Dash::menuSubItem(
text = 'Check balance',
tabName = 'analysis_balance',
icon = icon('chart-bar', verify_fa = FALSE)
),
# bs4Dash::menuSubItem(
# text = 'Check balance',
# tabName = 'analysis_balance',
# icon = icon('chart-bar', verify_fa = FALSE)
# ),
bs4Dash::menuSubItem(
text = 'Check overlap',
tabName = 'analysis_overlap',
Expand Down
4 changes: 2 additions & 2 deletions thinkCausal/R/mod_analysis_define_causal_question.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ mod_analysis_causal_question_server <- function(id, store){
choices = c(
"",
'Observational Study (Treatment not Randomized)',
'Completely Randomized Experement',
'Block Randomized Experement'
'Completely Randomized Experiment',
'Block Randomized Experiment'
)
),
HTML('<details><summary>Advanced options (random effects & survey weights)</summary>'),
Expand Down
2 changes: 1 addition & 1 deletion thinkCausal/R/mod_analysis_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ mod_analysis_model_server <- function(id, store){
.weights = store$column_assignments$weight,
ran_eff = store$column_assignments$ran_eff,
.estimand = base::tolower(input$analysis_model_estimand),
rct = store$analysis_select_design == 'Completely Randomized Experement'
rct = store$analysis_select_design == 'Completely Randomized Experiment'
)
store$analysis$model$model <- bart_model

Expand Down
146 changes: 123 additions & 23 deletions thinkCausal/R/mod_analysis_visualize.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,57 @@ mod_analysis_visualize_ui <- function(id){
width = 3,
collapsible = FALSE,
title = "Explore your data visually",

selectInput(
inputId = ns("analysis_eda_select_plot_type"),
label = "Plot type:",
multiple = FALSE,
choices = c("Scatter", "Histogram", "Barplot", "Density", "Boxplot"),
#"Pairs"
selected = "Scatter"
choices = c("Balance", "Scatter", "Histogram", "Barplot", "Density", "Boxplot")
),
conditionalPanel("input.analysis_eda_select_plot_type == 'Balance'",
ns = ns,
selectInput(
inputId = ns("analysis_balance_estimand"),
label = 'Check balance for the:',
choices = c('ATE', 'ATT', 'ATC'),
selected = 'ATE'
),
selectInput(
inputId = ns("analysis_balance_type"),
label = "Plot balance of:",
choices = c('means', 'variance', 'covariance'),
selected = 'means'
),
selectInput(inputId = ns('analysis_balance_select'),
label = 'Filter variables in balance plot:',
choices = c('Plot variables with most imbalance',
'Manually select variables to plot')
),
conditionalPanel(condition = "input.analysis_balance_select == 'Plot variables with most imbalance'",
ns = ns,
sliderInput(ns('analysis_balance_cat'),
label = 'Categorical variables in balance plot:',
value = 10,
min = 0,
max = 25),
sliderInput(ns('analysis_balance_cont'),
label = 'Continuous variables in balance plot:',
value = 10,
min = 0,
max = 25 )
),
conditionalPanel(condition = "input.analysis_balance_select == 'Manually select variables to plot'",
ns = ns,
selectInput(
inputId = ns("analysis_balance_select_var"),
label = "Select variables included in balance plot:",
multiple = TRUE,
choices = NULL,
selected = NULL
)
)
),
conditionalPanel(
condition = "input.analysis_eda_select_plot_type == 'Pairs'",
ns = ns,
selectInput(
inputId = ns("analysis_eda_variable_pairs_vars"),
label = "Columns to plot",
multiple = TRUE,
choices = NULL,
selected = NULL
)
),
conditionalPanel(
condition = "input.analysis_eda_select_plot_type != 'Pairs'",
condition = "input.analysis_eda_select_plot_type != 'Balance'",
ns = ns,
# selectInput(
# inputId = ns("analysis_eda_variable_x"),
Expand Down Expand Up @@ -205,7 +234,37 @@ mod_analysis_visualize_server <- function(id, store){

# next button
observeEvent(input$analysis_plots_descriptive_button_next, {
bs4Dash::updateTabItems(store$session_global, inputId = 'sidebar', selected = 'analysis_balance')
bs4Dash::updateTabItems(store$session_global, inputId = 'sidebar', selected = 'analysis_overlap')
})


# dynamic updates for balance plot
observeEvent(store$analysis$data$verify$analysis_verify_data_save,{
X <- store$verified_df
X <- clean_to_indicator(X)
treatment_col <- grep("^Z_", names(X), value = TRUE)

# get covariates
new_col_names <- colnames(clean_to_indicator(store$verified_df))
X_cols <- grep("^X_", new_col_names, value = TRUE)

# send them off to the UI
updateSelectInput(session = session,
inputId = 'analysis_balance_select_var',
choices = X_cols,
selected = NULL
)

updateSliderInput(session = session,
inputId = 'analysis_balance_cat',
max = ncol(X) - length(store$column_types$continuous)
)

updateSliderInput(session = session,
inputId = 'analysis_balance_cont',
max = length(store$column_types$continuous)
)

})

# update variables on the eda page once the save button on the verify data page is clicked
Expand All @@ -220,11 +279,11 @@ mod_analysis_visualize_server <- function(id, store){
choices = new_col_names,
selected = new_col_names
)
updateSelectInput(
session = session,
inputId = "analysis_eda_select_plot_type",
selected = store$analysis$data$verify$plot_vars$plot_type
)
# updateSelectInput(
# session = session,
# inputId = "analysis_eda_select_plot_type",
# selected = store$analysis$data$verify$plot_vars$plot_type
# )
# updateSelectInput(
# session = session,
# inputId = "analysis_eda_variable_x",
Expand Down Expand Up @@ -266,7 +325,6 @@ mod_analysis_visualize_server <- function(id, store){
inputId = "analysis_eda_variable_facet",
choices = c("None", cols_categorical)
)

# # update selects on balance plots
# X_cols <- grep("^X_", new_col_names, value = TRUE)
# X_cols_continuous <- grep("^X_", cols_continuous, value = TRUE)
Expand Down Expand Up @@ -330,6 +388,7 @@ mod_analysis_visualize_server <- function(id, store){
}
})


output$render_analysis_eda_variable_x <- renderUI({
new_col_names <- colnames(store$verified_df)
cols_categorical <- store$column_types$categorical
Expand Down Expand Up @@ -384,6 +443,46 @@ mod_analysis_visualize_server <- function(id, store){

# stop here if data hasn't been uploaded and selected
validate_data_verified(store)

if(input$analysis_eda_select_plot_type == 'Balance'){
# plot it
X <- store$verified_df
X <- clean_to_indicator(X)
treatment_col <- grep("^Z_", names(X), value = TRUE)
outcome_col <- grep("^Y_", names(X), value = TRUE)

if(input$analysis_balance_select == 'Plot variables with most imbalance'){
.confounders <- colnames(X)[colnames(X) %notin% c(treatment_col, outcome_col)]
# stop here if there are no columns selected
validate(need(length(.confounders) > 0,
"No columns available or currently selected"))
p <- plotBart::plot_balance(.data = X,
treatment = treatment_col,
confounders = .confounders,
compare = input$analysis_balance_type,
estimand = input$analysis_balance_estimand,
limit_catagorical = input$analysis_balance_cat + 1,
limit_continuous = input$analysis_balance_cont + 1
)
}else{
.confounders <- input$analysis_balance_select_var
# stop here if there are no columns selected
validate(need(length(.confounders) > 0,
"No columns available or currently selected"))
p <- plotBart::plot_balance(.data = X,
treatment = treatment_col,
confounders = .confounders,
compare = input$analysis_balance_type,
estimand = input$analysis_balance_estimand
)
}



# add theme
p <- p & store$options$theme_custom + ggplot2::theme(legend.position = 'none')

}else{
p <- tryCatch({
plot_exploration(
.data = store$verified_df,
Expand Down Expand Up @@ -411,6 +510,7 @@ mod_analysis_visualize_server <- function(id, store){

# add theme
p <- p + store$options$theme_custom
}

return(p)
})
Expand Down

0 comments on commit d3fce32

Please sign in to comment.