Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Imports:
ggplot2,
furrr,
future,
pROC
pROC,
e1071
Suggests:
testthat (>= 3.0.0),
knitr,
Expand Down
137 changes: 124 additions & 13 deletions R/api.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,17 @@ fit_models <- function(config,data) {
contrasts <- config$analysis$contrasts

results <- list()
if(is.null(contrasts)){
stop("Logistic regression with multi-class outcome requires 'contrasts'")
}

if(!is.list(contrasts)){
stop("'contrasts' must be a list of length-2 vectors")
}

#Running models and seeing which to run
if("logistic" %in% models){

if(is.null(contrasts)){
stop("Logistic regression with multi-class outcome requires 'contrasts'")
}

if(!is.list(contrasts)){
stop("'contrasts' must be a list of length-2 vectors")
}

logistic_results <- list()

Expand All @@ -241,7 +242,8 @@ fit_models <- function(config,data) {
fit_logistic_contrast,
data = data,
outcome = outcome,
predictors = predictors
predictors = predictors,
.options = furrr::furrr_options(seed = TRUE)
)

names(logistic_results) <- vapply(
Expand All @@ -253,6 +255,29 @@ fit_models <- function(config,data) {

results$logistic <- logistic_results
}
if("svm" %in% models){
svm_results <- list()
#The fimctopm os called using
svm_results <- furrr::future_pmap(
list(
contrast = contrasts
),
fit_svm_contrast,
data = data,
outcome = outcome,
predictors = predictors,
.options = furrr::furrr_options(seed = TRUE)
)

names(svm_results) <- vapply(
contrasts,
function(x) paste0(x[[1]], "_vs_", x[[2]]),
character(1)
)


results$svm <- svm_results
}

if(length(results) == 0){
stop("No supported models requested")
Expand All @@ -261,7 +286,7 @@ fit_models <- function(config,data) {
}


#' Title fit_logistic_contrast
#' fit_logistic_contrast
#'
#' @param contrast
#' @param data
Expand All @@ -276,6 +301,7 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) {

group0 <- contrast[[1]]
group1 <- contrast[[2]]
print(group1)

subset_idx <- data[[outcome]] %in% c(group0, group1)
sub_data <- data[subset_idx, , drop = FALSE]
Expand All @@ -285,13 +311,19 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) {
}

sub_data[[outcome]] <- ifelse(sub_data[[outcome]] == group1, 1, 0)
cat("Rows before NA removal:", nrow(sub_data), "\n")

sub_data <- sub_data[complete.cases(sub_data[, c(outcome, predictors)]), ]

cat("Rows after NA removal:", nrow(sub_data), "\n")

if(length(unique(sub_data[[outcome]])) != 2){
stop("Contrast does not produce a binary outcome")
}

fml <- stats::as.formula(
paste(outcome, "~", paste(predictors, collapse = " + "))
fml <- stats::reformulate(
termlabels = predictors,
response = outcome
)

fit <- stats::glm(
Expand All @@ -305,6 +337,71 @@ fit_logistic_contrast <- function(contrast, data, outcome, predictors) {
fitresult<- list(
model = fit,
predictions = preds,
truth=sub_data[[outcome]],
formula = fml,
contrast = c(group0, group1),
n = nrow(sub_data)
)
return(fitresult)
}

#' Title fit_SVM_contrast
#'
#' @param contrast
#' @param data
#' @param outcome
#' @param predictors
#'
#' @returns fitted list of SVM contrasts
#'
fit_svm_contrast<-function(contrast, data, outcome, predictors){
if(length(contrast) != 2){
stop("Each contrast must have exactly two outcome values")
}

group0 <- contrast[[1]]
group1 <- contrast[[2]]

subset_idx <- data[[outcome]] %in% c(group0, group1)
sub_data <- data[subset_idx, , drop = FALSE]

if(nrow(sub_data)== 0){
stop("No rows found for contrast ", group0, " vs ", group1)
}

sub_data[[outcome]] <- ifelse(sub_data[[outcome]] == group1, 1, 0)
cat("Rows before NA removal:", nrow(sub_data), "\n")

sub_data <- sub_data[complete.cases(sub_data[, c(outcome, predictors)]), ]

cat("Rows after NA removal:", nrow(sub_data), "\n")
#Needs to be factor for classification and to get probabilties
sub_data[[outcome]] <- factor(sub_data[[outcome]])
if(length(unique(sub_data[[outcome]])) != 2){
stop("Contrast does not produce a binary outcome")
}

fml <- stats::as.formula(
paste(outcome, "~", paste(predictors, collapse = " + "))
)

fit <- e1071::svm(
formula = fml,
data = sub_data,
probability = TRUE
)
#Predict gets ran differently with different models
preds <- stats::predict(fit, sub_data, probability = TRUE)
#To get the predictions object consistent with the logistic regression predict format
prob_matrix <- attr(preds, "probabilities")
attr(preds, "probabilities") <- NULL
fitresult <- list(
model = fit,
predictions = list(
class = preds,
probabilities = prob_matrix
),
truth=sub_data[[outcome]],
formula = fml,
contrast = c(group0, group1),
n = nrow(sub_data)
Expand Down Expand Up @@ -332,15 +429,29 @@ evaluate_models <- function(models){
if(length(models) == 0){
stop("models must contain at least one model", call. = FALSE)
}
#Logistic regression is default, if statemens in loop will deal with outputs of other models
#if they do not fit LR logic
for(model_name in names(models)){
model_evals <- list()

print(model_name)
for(contrast_name in names(models[[model_name]])){
res <- models[[model_name]][[contrast_name]]

preds <- res$predictions
truth <- res$model$y
truth <- res$truth

#########To handle the SVM
if(is.list(preds)){
prob_matrix <- preds$probabilities
print(prob_matrix)
if(is.null(prob_matrix)){
stop("Predictions list missing probability matrix.", call. = FALSE)
}

positive_col <- which(colnames(prob_matrix) == "1")
preds <- prob_matrix[, positive_col]
}
##############

pred_class <- ifelse(preds >= 0.5, 1, 0)

Expand Down
23 changes: 23 additions & 0 deletions man/fit_SVM_contrast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/fit_logistic_contrast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package_build.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
library(devtools)
library(roxygen2)

setwd("~/Rwork/TabularTools")
#This creates a folder with the project files in it, only call once
#create("tabularTools")

Expand Down
5 changes: 3 additions & 2 deletions runAnalysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ validated=validate_data(cfg,data)
pdata=preprocess_data(cfg,data)

###
plan(multisession, workers = 4)
lrmod=fit_models(cfg,data)
plan(sequential)

lrmod=fit_models(cfg,data)
plan(multisession, workers = 4)
lrmod=fit_models(cfg,data)
###
evals=evaluate_models(lrmod)
Expand Down
46 changes: 43 additions & 3 deletions tests/testthat/test-evaluate_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test_that("evaluate_models returns expected structure", {
glm = list(
"0_vs_1" = list(
predictions = c(0.2, 0.8),
model = list(y = c(0, 1))
truth = c(0, 1)
)
)
)
Expand All @@ -37,7 +37,7 @@ test_that("evaluate_models returns valid metrics", {
glm = list(
"0_vs_1" = list(
predictions = c(0.1, 0.4, 0.7, 0.9),
model = list(y = c(0, 0, 1, 1))
truth = c(0, 0, 1, 1)
)
)
)
Expand All @@ -57,7 +57,7 @@ test_that("evaluate_models returns a confusion matrix", {
glm = list(
"0_vs_1" = list(
predictions = c(0.3, 0.6),
model = list(y = c(0, 1))
truth = c(0, 1)
)
)
)
Expand All @@ -68,3 +68,43 @@ test_that("evaluate_models returns a confusion matrix", {
expect_s3_class(conf, "table")
expect_equal(sum(conf), 2)
})
#Multiple models at once
test_that("evaluate_models works with logistic and svm models together", {

fake_models <- list(
logistic = list(
"0_vs_1" = list(
predictions = c(0.2, 0.8, 0.6, 0.1),
truth = c(0, 1, 1, 0)
)
),
svm = list(
"0_vs_1" = list(
predictions = c(0.3, 0.7, 0.4, 0.9),
truth = c(0, 1, 0, 1)
)
)
)

res <- evaluate_models(fake_models)

# Top-level names
expect_named(res, c("logistic", "svm"))

# Check structure for logistic
expect_named(res$logistic, "0_vs_1")
expect_named(res$logistic$`0_vs_1`,
c("metrics", "confusion", "roc"))

# Check structure for svm
expect_named(res$svm, "0_vs_1")
expect_named(res$svm$`0_vs_1`,
c("metrics", "confusion", "roc"))

# Check AUC values exist and are valid
expect_gte(res$logistic$`0_vs_1`$metrics$auc, 0)
expect_lte(res$logistic$`0_vs_1`$metrics$auc, 1)

expect_gte(res$svm$`0_vs_1`$metrics$auc, 0)
expect_lte(res$svm$`0_vs_1`$metrics$auc, 1)
})