diff --git a/DESCRIPTION b/DESCRIPTION index 447b5a1..1217f16 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,7 +19,8 @@ Imports: ggplot2, furrr, future, - pROC + pROC, + e1071 Suggests: testthat (>= 3.0.0), knitr, diff --git a/R/api.R b/R/api.R index 999b35a..221dddc 100644 --- a/R/api.R +++ b/R/api.R @@ -221,16 +221,17 @@ fit_models <- function(config,data) { contrasts <- config$analysis$contrasts results <- list() + if(is.null(contrasts)){ + stop("Logistic regression with multi-class outcome requires 'contrasts'") + } + + if(!is.list(contrasts)){ + stop("'contrasts' must be a list of length-2 vectors") + } + #Running models and seeing which to run if("logistic" %in% models){ - if(is.null(contrasts)){ - stop("Logistic regression with multi-class outcome requires 'contrasts'") - } - - if(!is.list(contrasts)){ - stop("'contrasts' must be a list of length-2 vectors") - } logistic_results <- list() @@ -241,7 +242,8 @@ fit_models <- function(config,data) { fit_logistic_contrast, data = data, outcome = outcome, - predictors = predictors + predictors = predictors, + .options = furrr::furrr_options(seed = TRUE) ) names(logistic_results) <- vapply( @@ -253,6 +255,29 @@ fit_models <- function(config,data) { results$logistic <- logistic_results } + if("svm" %in% models){ + svm_results <- list() + #The fimctopm os called using + svm_results <- furrr::future_pmap( + list( + contrast = contrasts + ), + fit_svm_contrast, + data = data, + outcome = outcome, + predictors = predictors, + .options = furrr::furrr_options(seed = TRUE) + ) + + names(svm_results) <- vapply( + contrasts, + function(x) paste0(x[[1]], "_vs_", x[[2]]), + character(1) + ) + + + results$svm <- svm_results + } if(length(results) == 0){ stop("No supported models requested") @@ -261,7 +286,7 @@ fit_models <- function(config,data) { } -#' Title fit_logistic_contrast +#' fit_logistic_contrast #' #' @param contrast #' @param data @@ -276,6 +301,7 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) { group0 <- contrast[[1]] group1 <- contrast[[2]] + print(group1) subset_idx <- data[[outcome]] %in% c(group0, group1) sub_data <- data[subset_idx, , drop = FALSE] @@ -285,13 +311,19 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) { } sub_data[[outcome]] <- ifelse(sub_data[[outcome]] == group1, 1, 0) + cat("Rows before NA removal:", nrow(sub_data), "\n") + + sub_data <- sub_data[complete.cases(sub_data[, c(outcome, predictors)]), ] + + cat("Rows after NA removal:", nrow(sub_data), "\n") if(length(unique(sub_data[[outcome]])) != 2){ stop("Contrast does not produce a binary outcome") } - fml <- stats::as.formula( - paste(outcome, "~", paste(predictors, collapse = " + ")) + fml <- stats::reformulate( + termlabels = predictors, + response = outcome ) fit <- stats::glm( @@ -305,6 +337,71 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) { fitresult<- list( model = fit, predictions = preds, + truth=sub_data[[outcome]], + formula = fml, + contrast = c(group0, group1), + n = nrow(sub_data) + ) + return(fitresult) +} + +#' Title fit_SVM_contrast +#' +#' @param contrast +#' @param data +#' @param outcome +#' @param predictors +#' +#' @returns fitted list of SVM contrasts +#' +fit_svm_contrast<-function(contrast, data, outcome, predictors){ + if(length(contrast) != 2){ + stop("Each contrast must have exactly two outcome values") + } + + group0 <- contrast[[1]] + group1 <- contrast[[2]] + + subset_idx <- data[[outcome]] %in% c(group0, group1) + sub_data <- data[subset_idx, , drop = FALSE] + + if(nrow(sub_data)== 0){ + stop("No rows found for contrast ", group0, " vs ", group1) + } + + sub_data[[outcome]] <- ifelse(sub_data[[outcome]] == group1, 1, 0) + cat("Rows before NA removal:", nrow(sub_data), "\n") + + sub_data <- sub_data[complete.cases(sub_data[, c(outcome, predictors)]), ] + + cat("Rows after NA removal:", nrow(sub_data), "\n") + #Needs to be factor for classification and to get probabilties + sub_data[[outcome]] <- factor(sub_data[[outcome]]) + if(length(unique(sub_data[[outcome]])) != 2){ + stop("Contrast does not produce a binary outcome") + } + + fml <- stats::as.formula( + paste(outcome, "~", paste(predictors, collapse = " + ")) + ) + + fit <- e1071::svm( + formula = fml, + data = sub_data, + probability = TRUE + ) + #Predict gets ran differently with different models + preds <- stats::predict(fit, sub_data, probability = TRUE) + #To get the predictions object consistent with the logistic regression predict format + prob_matrix <- attr(preds, "probabilities") + attr(preds, "probabilities") <- NULL + fitresult <- list( + model = fit, + predictions = list( + class = preds, + probabilities = prob_matrix + ), + truth=sub_data[[outcome]], formula = fml, contrast = c(group0, group1), n = nrow(sub_data) @@ -332,15 +429,29 @@ evaluate_models <- function(models){ if(length(models) == 0){ stop("models must contain at least one model", call. = FALSE) } + #Logistic regression is default, if statemens in loop will deal with outputs of other models + #if they do not fit LR logic for(model_name in names(models)){ model_evals <- list() - + print(model_name) for(contrast_name in names(models[[model_name]])){ res <- models[[model_name]][[contrast_name]] preds <- res$predictions - truth <- res$model$y + truth <- res$truth + #########To handle the SVM + if(is.list(preds)){ + prob_matrix <- preds$probabilities + print(prob_matrix) + if(is.null(prob_matrix)){ + stop("Predictions list missing probability matrix.", call. = FALSE) + } + + positive_col <- which(colnames(prob_matrix) == "1") + preds <- prob_matrix[, positive_col] + } + ############## pred_class <- ifelse(preds >= 0.5, 1, 0) diff --git a/man/fit_SVM_contrast.Rd b/man/fit_SVM_contrast.Rd new file mode 100644 index 0000000..9e5549f --- /dev/null +++ b/man/fit_SVM_contrast.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/api.R +\name{fit_svm_contrast} +\alias{fit_svm_contrast} +\title{Title fit_SVM_contrast} +\usage{ +fit_svm_contrast(contrast, data, outcome, predictors) +} +\arguments{ +\item{contrast}{} + +\item{data}{} + +\item{outcome}{} + +\item{predictors}{} +} +\value{ +fitted list of SVM contrasts +} +\description{ +Title fit_SVM_contrast +} diff --git a/man/fit_logistic_contrast.Rd b/man/fit_logistic_contrast.Rd index bfc6e5f..fb4aed9 100644 --- a/man/fit_logistic_contrast.Rd +++ b/man/fit_logistic_contrast.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/api.R \name{fit_logistic_contrast} \alias{fit_logistic_contrast} -\title{Title fit_logistic_contrast} +\title{fit_logistic_contrast} \usage{ fit_logistic_contrast(contrast, data, outcome, predictors) } @@ -19,5 +19,5 @@ fit_logistic_contrast(contrast, data, outcome, predictors) fitted list of the logistic regressiong results } \description{ -Title fit_logistic_contrast +fit_logistic_contrast } diff --git a/package_build.R b/package_build.R index 2e5614e..f9ae68e 100644 --- a/package_build.R +++ b/package_build.R @@ -1,6 +1,6 @@ library(devtools) library(roxygen2) - +setwd("~/Rwork/TabularTools") #This creates a folder with the project files in it, only call once #create("tabularTools") diff --git a/runAnalysis.R b/runAnalysis.R index ebe0838..7053df6 100644 --- a/runAnalysis.R +++ b/runAnalysis.R @@ -15,9 +15,10 @@ validated=validate_data(cfg,data) pdata=preprocess_data(cfg,data) ### -plan(multisession, workers = 4) -lrmod=fit_models(cfg,data) plan(sequential) + +lrmod=fit_models(cfg,data) +plan(multisession, workers = 4) lrmod=fit_models(cfg,data) ### evals=evaluate_models(lrmod) diff --git a/tests/testthat/test-evaluate_models.R b/tests/testthat/test-evaluate_models.R index af87747..a80fa8a 100644 --- a/tests/testthat/test-evaluate_models.R +++ b/tests/testthat/test-evaluate_models.R @@ -17,7 +17,7 @@ test_that("evaluate_models returns expected structure", { glm = list( "0_vs_1" = list( predictions = c(0.2, 0.8), - model = list(y = c(0, 1)) + truth = c(0, 1) ) ) ) @@ -37,7 +37,7 @@ test_that("evaluate_models returns valid metrics", { glm = list( "0_vs_1" = list( predictions = c(0.1, 0.4, 0.7, 0.9), - model = list(y = c(0, 0, 1, 1)) + truth = c(0, 0, 1, 1) ) ) ) @@ -57,7 +57,7 @@ test_that("evaluate_models returns a confusion matrix", { glm = list( "0_vs_1" = list( predictions = c(0.3, 0.6), - model = list(y = c(0, 1)) + truth = c(0, 1) ) ) ) @@ -68,3 +68,43 @@ test_that("evaluate_models returns a confusion matrix", { expect_s3_class(conf, "table") expect_equal(sum(conf), 2) }) +#Multiple models at once +test_that("evaluate_models works with logistic and svm models together", { + + fake_models <- list( + logistic = list( + "0_vs_1" = list( + predictions = c(0.2, 0.8, 0.6, 0.1), + truth = c(0, 1, 1, 0) + ) + ), + svm = list( + "0_vs_1" = list( + predictions = c(0.3, 0.7, 0.4, 0.9), + truth = c(0, 1, 0, 1) + ) + ) + ) + + res <- evaluate_models(fake_models) + + # Top-level names + expect_named(res, c("logistic", "svm")) + + # Check structure for logistic + expect_named(res$logistic, "0_vs_1") + expect_named(res$logistic$`0_vs_1`, + c("metrics", "confusion", "roc")) + + # Check structure for svm + expect_named(res$svm, "0_vs_1") + expect_named(res$svm$`0_vs_1`, + c("metrics", "confusion", "roc")) + + # Check AUC values exist and are valid + expect_gte(res$logistic$`0_vs_1`$metrics$auc, 0) + expect_lte(res$logistic$`0_vs_1`$metrics$auc, 1) + + expect_gte(res$svm$`0_vs_1`$metrics$auc, 0) + expect_lte(res$svm$`0_vs_1`$metrics$auc, 1) +}) \ No newline at end of file