From d3fce322f8bf2317787d6859489ce737bc528fb2 Mon Sep 17 00:00:00 2001 From: George Perrett Date: Sun, 30 Jun 2024 14:26:00 -0400 Subject: [PATCH] balance updates --- thinkCausal/R/app_server.R | 2 +- thinkCausal/R/app_ui.R | 18 +-- .../R/mod_analysis_define_causal_question.R | 4 +- thinkCausal/R/mod_analysis_model.R | 2 +- thinkCausal/R/mod_analysis_visualize.R | 146 +++++++++++++++--- 5 files changed, 136 insertions(+), 36 deletions(-) diff --git a/thinkCausal/R/app_server.R b/thinkCausal/R/app_server.R index c4a5f7c..352b72f 100644 --- a/thinkCausal/R/app_server.R +++ b/thinkCausal/R/app_server.R @@ -70,7 +70,7 @@ app_server <- function(input, output, session) { store <- mod_analysis_variable_selection_server(module_ids$analysis$select, store) store <- mod_analysis_verify_server(module_ids$analysis$verify, store) store <- mod_analysis_visualize_server(module_ids$analysis$visualize, store) - store <- mod_analysis_balance_server(module_ids$analysis$balance, store) + #store <- mod_analysis_balance_server(module_ids$analysis$balance, store) store <- mod_analysis_overlap_server(module_ids$analysis$overlap, store) store <- mod_analysis_model_server(module_ids$analysis$model, store) store <- mod_analysis_diagnostics_server(module_ids$analysis$diagnostics, store) diff --git a/thinkCausal/R/app_ui.R b/thinkCausal/R/app_ui.R index 6c089a0..fa64b8a 100644 --- a/thinkCausal/R/app_ui.R +++ b/thinkCausal/R/app_ui.R @@ -162,10 +162,10 @@ app_ui <- function(request) { tabName = 'analysis_visualize', mod_analysis_visualize_ui(module_ids$analysis$visualize) ), - bs4Dash::tabItem( - tabName = 'analysis_balance', - mod_analysis_balance_ui(module_ids$analysis$balance) - ), + # bs4Dash::tabItem( + # tabName = 'analysis_balance', + # mod_analysis_balance_ui(module_ids$analysis$balance) + # ), bs4Dash::tabItem( tabName = 'analysis_overlap', mod_analysis_overlap_ui(module_ids$analysis$overlap) @@ -335,11 +335,11 @@ app_ui <- function(request) { tabName = 'analysis_visualize', icon = icon('chart-bar', verify_fa = FALSE) ), - bs4Dash::menuSubItem( - text = 'Check balance', - tabName = 'analysis_balance', - icon = icon('chart-bar', verify_fa = FALSE) - ), + # bs4Dash::menuSubItem( + # text = 'Check balance', + # tabName = 'analysis_balance', + # icon = icon('chart-bar', verify_fa = FALSE) + # ), bs4Dash::menuSubItem( text = 'Check overlap', tabName = 'analysis_overlap', diff --git a/thinkCausal/R/mod_analysis_define_causal_question.R b/thinkCausal/R/mod_analysis_define_causal_question.R index 5cbb604..58b093d 100644 --- a/thinkCausal/R/mod_analysis_define_causal_question.R +++ b/thinkCausal/R/mod_analysis_define_causal_question.R @@ -370,8 +370,8 @@ mod_analysis_causal_question_server <- function(id, store){ choices = c( "", 'Observational Study (Treatment not Randomized)', - 'Completely Randomized Experement', - 'Block Randomized Experement' + 'Completely Randomized Experiment', + 'Block Randomized Experiment' ) ), HTML('
Advanced options (random effects & survey weights)'), diff --git a/thinkCausal/R/mod_analysis_model.R b/thinkCausal/R/mod_analysis_model.R index 3317856..9e3c157 100644 --- a/thinkCausal/R/mod_analysis_model.R +++ b/thinkCausal/R/mod_analysis_model.R @@ -331,7 +331,7 @@ mod_analysis_model_server <- function(id, store){ .weights = store$column_assignments$weight, ran_eff = store$column_assignments$ran_eff, .estimand = base::tolower(input$analysis_model_estimand), - rct = store$analysis_select_design == 'Completely Randomized Experement' + rct = store$analysis_select_design == 'Completely Randomized Experiment' ) store$analysis$model$model <- bart_model diff --git a/thinkCausal/R/mod_analysis_visualize.R b/thinkCausal/R/mod_analysis_visualize.R index ca2d054..b8d40c2 100644 --- a/thinkCausal/R/mod_analysis_visualize.R +++ b/thinkCausal/R/mod_analysis_visualize.R @@ -16,28 +16,57 @@ mod_analysis_visualize_ui <- function(id){ width = 3, collapsible = FALSE, title = "Explore your data visually", - selectInput( inputId = ns("analysis_eda_select_plot_type"), label = "Plot type:", multiple = FALSE, - choices = c("Scatter", "Histogram", "Barplot", "Density", "Boxplot"), - #"Pairs" - selected = "Scatter" + choices = c("Balance", "Scatter", "Histogram", "Barplot", "Density", "Boxplot") ), + conditionalPanel("input.analysis_eda_select_plot_type == 'Balance'", + ns = ns, + selectInput( + inputId = ns("analysis_balance_estimand"), + label = 'Check balance for the:', + choices = c('ATE', 'ATT', 'ATC'), + selected = 'ATE' + ), + selectInput( + inputId = ns("analysis_balance_type"), + label = "Plot balance of:", + choices = c('means', 'variance', 'covariance'), + selected = 'means' + ), + selectInput(inputId = ns('analysis_balance_select'), + label = 'Filter variables in balance plot:', + choices = c('Plot variables with most imbalance', + 'Manually select variables to plot') + ), + conditionalPanel(condition = "input.analysis_balance_select == 'Plot variables with most imbalance'", + ns = ns, + sliderInput(ns('analysis_balance_cat'), + label = 'Categorical variables in balance plot:', + value = 10, + min = 0, + max = 25), + sliderInput(ns('analysis_balance_cont'), + label = 'Continuous variables in balance plot:', + value = 10, + min = 0, + max = 25 ) + ), + conditionalPanel(condition = "input.analysis_balance_select == 'Manually select variables to plot'", + ns = ns, + selectInput( + inputId = ns("analysis_balance_select_var"), + label = "Select variables included in balance plot:", + multiple = TRUE, + choices = NULL, + selected = NULL + ) + ) + ), conditionalPanel( - condition = "input.analysis_eda_select_plot_type == 'Pairs'", - ns = ns, - selectInput( - inputId = ns("analysis_eda_variable_pairs_vars"), - label = "Columns to plot", - multiple = TRUE, - choices = NULL, - selected = NULL - ) - ), - conditionalPanel( - condition = "input.analysis_eda_select_plot_type != 'Pairs'", + condition = "input.analysis_eda_select_plot_type != 'Balance'", ns = ns, # selectInput( # inputId = ns("analysis_eda_variable_x"), @@ -205,7 +234,37 @@ mod_analysis_visualize_server <- function(id, store){ # next button observeEvent(input$analysis_plots_descriptive_button_next, { - bs4Dash::updateTabItems(store$session_global, inputId = 'sidebar', selected = 'analysis_balance') + bs4Dash::updateTabItems(store$session_global, inputId = 'sidebar', selected = 'analysis_overlap') + }) + + + # dynamic updates for balance plot + observeEvent(store$analysis$data$verify$analysis_verify_data_save,{ + X <- store$verified_df + X <- clean_to_indicator(X) + treatment_col <- grep("^Z_", names(X), value = TRUE) + + # get covariates + new_col_names <- colnames(clean_to_indicator(store$verified_df)) + X_cols <- grep("^X_", new_col_names, value = TRUE) + + # send them off to the UI + updateSelectInput(session = session, + inputId = 'analysis_balance_select_var', + choices = X_cols, + selected = NULL + ) + + updateSliderInput(session = session, + inputId = 'analysis_balance_cat', + max = ncol(X) - length(store$column_types$continuous) + ) + + updateSliderInput(session = session, + inputId = 'analysis_balance_cont', + max = length(store$column_types$continuous) + ) + }) # update variables on the eda page once the save button on the verify data page is clicked @@ -220,11 +279,11 @@ mod_analysis_visualize_server <- function(id, store){ choices = new_col_names, selected = new_col_names ) - updateSelectInput( - session = session, - inputId = "analysis_eda_select_plot_type", - selected = store$analysis$data$verify$plot_vars$plot_type - ) + # updateSelectInput( + # session = session, + # inputId = "analysis_eda_select_plot_type", + # selected = store$analysis$data$verify$plot_vars$plot_type + # ) # updateSelectInput( # session = session, # inputId = "analysis_eda_variable_x", @@ -266,7 +325,6 @@ mod_analysis_visualize_server <- function(id, store){ inputId = "analysis_eda_variable_facet", choices = c("None", cols_categorical) ) - # # update selects on balance plots # X_cols <- grep("^X_", new_col_names, value = TRUE) # X_cols_continuous <- grep("^X_", cols_continuous, value = TRUE) @@ -330,6 +388,7 @@ mod_analysis_visualize_server <- function(id, store){ } }) + output$render_analysis_eda_variable_x <- renderUI({ new_col_names <- colnames(store$verified_df) cols_categorical <- store$column_types$categorical @@ -384,6 +443,46 @@ mod_analysis_visualize_server <- function(id, store){ # stop here if data hasn't been uploaded and selected validate_data_verified(store) + + if(input$analysis_eda_select_plot_type == 'Balance'){ + # plot it + X <- store$verified_df + X <- clean_to_indicator(X) + treatment_col <- grep("^Z_", names(X), value = TRUE) + outcome_col <- grep("^Y_", names(X), value = TRUE) + + if(input$analysis_balance_select == 'Plot variables with most imbalance'){ + .confounders <- colnames(X)[colnames(X) %notin% c(treatment_col, outcome_col)] + # stop here if there are no columns selected + validate(need(length(.confounders) > 0, + "No columns available or currently selected")) + p <- plotBart::plot_balance(.data = X, + treatment = treatment_col, + confounders = .confounders, + compare = input$analysis_balance_type, + estimand = input$analysis_balance_estimand, + limit_catagorical = input$analysis_balance_cat + 1, + limit_continuous = input$analysis_balance_cont + 1 + ) + }else{ + .confounders <- input$analysis_balance_select_var + # stop here if there are no columns selected + validate(need(length(.confounders) > 0, + "No columns available or currently selected")) + p <- plotBart::plot_balance(.data = X, + treatment = treatment_col, + confounders = .confounders, + compare = input$analysis_balance_type, + estimand = input$analysis_balance_estimand + ) + } + + + + # add theme + p <- p & store$options$theme_custom + ggplot2::theme(legend.position = 'none') + + }else{ p <- tryCatch({ plot_exploration( .data = store$verified_df, @@ -411,6 +510,7 @@ mod_analysis_visualize_server <- function(id, store){ # add theme p <- p + store$options$theme_custom + } return(p) })