fusedTree

fusedTree is a prediction model that integrates a set of low‑dimensional, established clinical variables with high‑dimensional, noisy omics variables. It fits (generalized) linear regression models in each leaf node of a tree, enabling both interpretability and flexibility in handling complex data structures.

Note: Tree construction must be done externally (e.g., with the rpart package in R).

For full methodological details, see the preprint.


Installation

# CRAN (when available)
install.packages("fusedTree")

# Development version from GitHub
remotes::install_github("JeroenGoedhart/fusedTree")

Example 1: Continuous Outcomes

We illustrate the model for a continuous response. The simulated data has a nonlinear relationship with clinical variables and a linear relationship with omics variables.

library(fusedTree)
if (!requireNamespace("rpart", quietly = TRUE)) install.packages("rpart")
library(rpart)

1. Simulate Data

set.seed(10)
p       <- 5       # Number of omics variables
p_Clin  <- 5       # Number of clinical variables
N       <- 100     # Sample size

# Nonlinear function of clinical variables
g <- function(z) {
  15 * sin(pi * z[,1] * z[,2]) +
    10 * (z[,3] - 0.5)^2 +
    2  * exp(z[,4]) +
    2  * z[,5]
}

# Clinical and omics covariates
Z      <- as.data.frame(matrix(runif(N * p_Clin), nrow = N))
X      <- matrix(rnorm(N * p), nrow = N)
betas  <- c(1, -1, 3, 2, -2)

# Response: nonlinear clinical + linear omics + noise
Y      <- as.vector(g(Z) + X %*% betas + rnorm(N))

Thus, the response is generated by a nonlinear clinical part and a separate linear omics part. Therefore, the omics variables do not vary with the clinical variables. The omics regressions in the different nodes of the tree should therefore benefit from strong fusion.

2. Fit a Tree on Clinical Variables

dat    <- cbind.data.frame(Y, Z)
rp     <- rpart(
  Y ~ ., data = dat,
  control = rpart.control(xval = 5, minbucket = 10),
  model = TRUE
)
# poste-prune the tree
cp     <- rp$cptable[which.min(rp$cptable[, "xerror"]), "CP"]
Treefit <- prune(rp, cp = cp)

plot(Treefit, main = "Clinical-variable Tree")
text(Treefit, use.n = TRUE)

3. (Optional) Understanding the Design Matrices

Before fitting the model, it’s useful to understand how fusedTree internally represents the data to enable leaf-specific regression. Each leaf node of the tree gets its own (generalized) linear regression model. To support this, two large design matrices are constructed:

These matrices are created automatically during model fitting, but you can inspect them yourself using the Dat_Tree() function:

Dat_fusedTree <- Dat_Tree(Tree = Treefit, X = X, Z = Z, LinVars = FALSE)

# Clinical design matrix: indicator for node membership
head(Dat_fusedTree$Clinical)
#>   N2 N6 N7
#> 1  0  1  0
#> 2  1  0  0
#> 3  0  1  0
#> 4  0  1  0
#> 5  1  0  0
#> 6  1  0  0

# Omics design matrix: omics data distributed across nodes
head(Dat_fusedTree$Omics)
#>           x1_N2      x1_N6 x1_N7     x2_N2     x2_N6 x2_N7      x3_N2
#> [1,]  0.0000000  1.0778926     0 0.0000000 -0.886788     0  0.0000000
#> [2,]  0.9317812  0.0000000     0 1.2711460  0.000000     0 -1.5233846
#> [3,]  0.0000000 -1.4607939     0 0.0000000 -1.605085     0  0.0000000
#> [4,]  0.0000000 -0.9060756     0 0.0000000  1.122273     0  0.0000000
#> [5,] -0.6803478  0.0000000     0 2.1584386  0.000000     0 -0.2874329
#> [6,]  1.0631660  0.0000000     0 0.4282466  0.000000     0 -0.4353083
#>           x3_N6 x3_N7       x4_N2      x4_N6 x4_N7      x5_N2      x5_N6 x5_N7
#> [1,]  1.1639675     0  0.00000000 -0.3121347     0  0.0000000 -0.8658204     0
#> [2,]  0.0000000     0 -0.69877530  0.0000000     0  0.8254939  0.0000000     0
#> [3,] -2.5183351     0  0.00000000 -2.6438498     0  0.0000000 -0.8001323     0
#> [4,] -0.7075292     0  0.00000000  0.8250224     0  0.0000000  0.9758301     0
#> [5,]  0.0000000     0  0.30692631  0.0000000     0  2.7000755  0.0000000     0
#> [6,]  0.0000000     0 -0.05803946  0.0000000     0 -0.1353896  0.0000000     0

Note: You do not need to create these matrices manually — this step is handled internally by the fusedTree() function. However, visualizing them can help you better understand how the model applies fusion across leaf-specific regressions.

4. Tune Hyperparameters

Create balanced cross‑validation folds across the leaf nodes. Folds are balanced w.r.t the proportion of observations in the leaf nodes, and w.r.t the outcome for binary and survival data.

set.seed(11)
folds <- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z, model = "linear")

optPenalties <- PenOpt(
  Tree      = Treefit,
  X         = X,
  Y         = Y,
  Z         = Z,
  model     = "linear",
  lambdaInit = 10,
  alphaInit  = 10,
  loss      = "loglik",
  LinVars   = FALSE,
  folds     = folds,
  multistart = FALSE  # TRUE yields more stable but slower results
)
#> Tuning fusedTree with fusion penalty

optPenalties
#>       lambda        alpha 
#> 1.490862e-13 3.843843e+12

As seen, the fusion penalty alpha is tuned to a (very) large value as expected. The standard ridge penalty is (very) small because of the low-dimensional simulation setting

5. Fit fusedTree

fit <- fusedTree(
  Tree    = Treefit,
  X       = X,
  Y       = Y,
  Z       = Z,
  LinVars = FALSE,
  model   = "linear",
  lambda  = optPenalties[1],
  alpha   = optPenalties[2]
)
#> Fit fusedTree with fusion penalty

# View results
fit$Effects     # Omics effects per leaf
#>         N2         N6         N7      x1_N2      x1_N6      x1_N7      x2_N2 
#>  9.1434885 11.5976204 18.2735435  0.6656824  0.6656824  0.6656824 -0.9519646 
#>      x2_N6      x2_N7      x3_N2      x3_N6      x3_N7      x4_N2      x4_N6 
#> -0.9519646 -0.9519646  3.1750430  3.1750430  3.1750430  1.7737451  1.7737451 
#>      x4_N7      x5_N2      x5_N6      x5_N7 
#>  1.7737451 -1.9979752 -1.9979752 -1.9979752
plot(fit$Tree)  # Underlying tree structure

fit$Pars        # Model parameters
#>        Model LinVar        Alpha       Lambda
#> alpha linear  FALSE 3.843843e+12 1.490862e-13

Because of the strong fusion penalty, the estimated omics effects across leaf nodes are (nearly) identical. However, some bias remains in the omics effect estimates due to the tree’s limited ability to capture the nonlinear structure in the clinical variables. Since the leaf-node-specific intercepts (representing the clinical contribution) and the omics effects are estimated jointly, bias in the intercepts propagates into the omics coefficients.

6. Predict on New Data

# Simulate test set
N_test <- 50
Z_test <- as.data.frame(matrix(runif(N_test * p_Clin), nrow = N_test))
X_test <- matrix(rnorm(N_test * p), nrow = N_test)
Y_test <- as.vector(g(Z_test) + X_test %*% betas + rnorm(N_test))

# Generate predictions
Preds <- predict(fit, newX = X_test, newY = Y_test, newZ = Z_test)
PMSE  <- mean((Preds$Resp - Preds$Ypred)^2)
PMSE
#> [1] 15.03962

Example 2: Binary Outcome

Below is a short example showing how to use fusedTree for binary outcomes. We simulate a binary response using a logistic model, with omics effects shared across leaf nodes.

# Load package
library(fusedTree)
if (!requireNamespace("rpart", quietly = TRUE)) install.packages("rpart")

# Settings
set.seed(13)
N <- 300
p <- 5
p_Clin <- 5

# Simulate data
Z <- as.data.frame(matrix(runif(N * p_Clin), nrow = N))  # clinical variables
X <- matrix(rnorm(N * p), nrow = N)                      # omics variables
betas  <- c(1, -1, 3, 2, -2)
eta <- 15 * sin(pi * Z[,1] * Z[,2]) - 10 * (Z[,3] - 0.5)^2 -
       2 * exp(Z[,4]) - 2 * Z[,5] + X %*% betas
prob <- 1 / (1 + exp(-eta))
Y <- rbinom(N, size = 1, prob = prob)

# Fit tree using only clinical variables
dat <- data.frame(Y = Y, Z)
rp <- rpart::rpart(Y ~ ., data = dat,
                   control = rpart::rpart.control(xval = 10, minbucket = 10),
                   method = "class", model = TRUE)
cp <- rp$cptable[,1][which.min(rp$cptable[,4])]
Treefit <- rpart::prune(rp, cp = cp)
plot(Treefit)

We then tune the penalties and fit the fusedTree model:

# Create folds
set.seed(30)
folds <- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z, 
                     model = "logistic", nrepeat = 1)

# Tune hyperparameters
optPenalties <- PenOpt(Tree = Treefit, X = X, Y = Y, Z = Z,
                       model = "logistic",
                       lambdaInit = 10, alphaInit = 10,
                       loss = "loglik",
                       LinVars = FALSE,
                       folds = folds,
                       multistart = TRUE) # slower
#> Tuning fusedTree with fusion penalty
optPenalties
#>      lambda       alpha 
#>   0.2211904 141.7153124

# Fit fusedTree
fit_bin <- fusedTree(Tree = Treefit, X = X, Y = Y, Z = Z,
                     LinVars = FALSE, model = "logistic",
                     lambda = optPenalties[1],
                     alpha = optPenalties[2],
                     verbose = TRUE) # prints progress of IRLS algorithm
#> Fit fusedTree with fusion penalty
#> Iteration  1   log likelihood equals:   -101.096
#> Iteration  2   log likelihood equals:    -85.313
#> Iteration  3   log likelihood equals:    -81.537
#> Iteration  4   log likelihood equals:    -81.180
#> Iteration  5   log likelihood equals:    -81.176
#> Iteration  6   log likelihood equals:    -81.176
#> Iteration  7   log likelihood equals:    -81.176
#> IRLS converged at iteration  7
fit_bin$Effects
#>         N2         N6         N7      x1_N2      x1_N6      x1_N7      x2_N2 
#> -1.5850973 -2.1764195  3.4338211  0.6738147  0.6737848  0.6949362 -0.3938863 
#>      x2_N6      x2_N7      x3_N2      x3_N6      x3_N7      x4_N2      x4_N6 
#> -0.3700743 -0.3992398  1.2956706  1.2868135  1.2815384  1.2211976  1.2232870 
#>      x4_N7      x5_N2      x5_N6      x5_N7 
#>  1.2347340 -1.2410471 -1.2507165 -1.2090639

Finally, we simulate test data and evaluate classification performance:

# Simulate test data
N_test <- 50
Z_test <- as.data.frame(matrix(runif(N_test * p_Clin), nrow = N_test))
X_test <- matrix(rnorm(N_test * p), nrow = N_test)
eta_test <- 15 * sin(pi * Z_test[,1] * Z_test[,2]) - 10 * (Z_test[,3] - 0.5)^2 -
            2 * exp(Z_test[,4]) - 2 * Z_test[,5] + X_test %*% betas
prob_test <- 1 / (1 + exp(-eta_test))
Y_test <- rbinom(N_test, size = 1, prob = prob_test)

# Predict
Preds <- predict(fit_bin, newX = X_test, newY = Y_test, newZ = Z_test)

# AUC
if (!requireNamespace("pROC", quietly = TRUE)) install.packages("pROC")
library(pROC)
#> Type 'citation("pROC")' for a citation.
#> 
#> Attaching package: 'pROC'
#> The following objects are masked from 'package:stats':
#> 
#>     cov, smooth, var
auc_result <- pROC::auc(Y_test, Preds$Ypred)
#> Setting levels: control = 0, case = 1
#> Setting direction: controls < cases
auc_result
#> Area under the curve: 0.9328

This example demonstrates how to apply fusedTree to binary classification problems using logistic regression and prediction based on the estimated fused model.

Summary

fusedTree provides:

See the paper for applications to survival outcomes, and further methodological details.