| Title: | Pathwise Estimation of Covariate Balancing Propensity Scores |
|---|---|
| Description: | Provides pathwise estimation of regularized logistic propensity score models using covariate balancing loss functions rather than maximum likelihood. Regularization paths are fit via the 'adelie' elastic-net solver with a 'glmnet'-like interface, yielding balancing weights that target covariate balance for the ATE and ATT. Under lasso penalization, lambda bounds the maximum covariate imbalance, so the regularization path traces a sequence of decreasing imbalance tolerances. For details, see Sverdrup & Hastie (2026) <doi:10.48550/arXiv.2602.18577>. |
| Authors: | Erik Sverdrup [aut, cre], Trevor Hastie [aut], James Yang [ctb] (Author of the bundled adelie C++ library.) |
| Maintainer: | Erik Sverdrup <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.0.3 |
| Built: | 2026-05-26 08:00:10 UTC |
| Source: | https://github.com/cran/balnet |
Fits regularized logistic regression models using covariate balancing loss functions, yielding balancing weights targeting the ATE, ATT, or treated/control means.
balnet( X, W, target = c("ATE", "ATT", "treated", "control"), sample.weights = NULL, max.imbalance = NULL, nlambda = 100L, lambda.min.ratio = 0.01, lambda = NULL, penalty.factor = NULL, groups = NULL, alpha = 1, standardize = TRUE, tol = 1e-07, maxit = as.integer(1e+05), verbose = FALSE, num.threads = 1L, ... )balnet( X, W, target = c("ATE", "ATT", "treated", "control"), sample.weights = NULL, max.imbalance = NULL, nlambda = 100L, lambda.min.ratio = 0.01, lambda = NULL, penalty.factor = NULL, groups = NULL, alpha = 1, standardize = TRUE, tol = 1e-07, maxit = as.integer(1e+05), verbose = FALSE, num.threads = 1L, ... )
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0 = control, 1 = treated). |
target |
The target estimand. Default is "ATE". |
sample.weights |
Optional sample weights. If |
max.imbalance |
Optional upper bound on the standardized covariate imbalance. For lasso penalization
( |
nlambda |
Number of values for |
lambda.min.ratio |
Ratio of smallest to largest lambda. Default is 1e-2. |
lambda |
Optional |
penalty.factor |
Penalty factor per feature. Default is 1 (i.e., each feature receives the same penalty). If groups are specified, the penalty factors default to the square root of each group size. |
groups |
Optional list of group indices for group penalization. |
alpha |
Elastic net mixing parameter. Default is 1 (lasso), 0 corresponds to ridge.
For |
standardize |
Whether to standardize the input matrix. Should only be |
tol |
Coordinate descent convergence tolerance. Default is 1e-7. |
maxit |
Maximum number of coordinate descent iterations. Default is 1e5. |
verbose |
Whether to display information during fitting. Default is |
num.threads |
Number of threads to use. Default is 1. |
... |
Additional internal arguments passed to the solver. |
This function aims to find balancing weights , using logistic propensity scores,
that balance covariate means to a target vector, i.e.,
With lasso regularization (alpha = 1), imbalance is controlled in the sense,
allowing absolute slack of at most per covariate.
For target = "ATE", two logistic models are fit, one per arm, with
is the fitted propensity score for arm .
For target = "ATT", weights balance the control means:
A fit balnet object.
Sverdrup, Erik and Trevor Hastie. "balnet: Pathwise Estimation of Covariate Balancing Propensity Scores". arXiv preprint, arXiv:2602.18577, 2026.
# Simulate data with confounding. n <- 2000 p <- 10 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1.5 + exp(X[, 2] + X[, 3]))) Y <- W + 2 * log(1 + exp(X[, 1] + X[, 2] + X[, 3])) + rnorm(n) # Fit model targeting the ATE = E[Y(1)] - E[Y(0)]. # Two logistic models are fit: one for treated, one for control. fit <- balnet(X, W, target = "ATE") # Print path summary. print(fit) # Visualize the path. plot(fit) plot(fit, lambda = 0) # Predict propensity scores at end of lambda path. W.hat <- predict(fit, X, lambda = 0) # Get balancing weights at end of lambda path. ipw.weights <- balweights(fit, lambda = 0) # Estimate ATE using balancing weights. mean(Y * (ipw.weights$treated - ipw.weights$control))# Simulate data with confounding. n <- 2000 p <- 10 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1.5 + exp(X[, 2] + X[, 3]))) Y <- W + 2 * log(1 + exp(X[, 1] + X[, 2] + X[, 3])) + rnorm(n) # Fit model targeting the ATE = E[Y(1)] - E[Y(0)]. # Two logistic models are fit: one for treated, one for control. fit <- balnet(X, W, target = "ATE") # Print path summary. print(fit) # Visualize the path. plot(fit) plot(fit, lambda = 0) # Predict propensity scores at end of lambda path. W.hat <- predict(fit, X, lambda = 0) # Get balancing weights at end of lambda path. ipw.weights <- balweights(fit, lambda = 0) # Estimate ATE using balancing weights. mean(Y * (ipw.weights$treated - ipw.weights$control))
Convenience method for extracting the estimated balancing weights .
Under unconfoundedness, these correspond to inverse probability weights (IPW)
for standard treatment effect estimands and are computed from the fitted
covariate balancing propensity scores.
balweights(object, lambda = NULL, ...) ## S3 method for class 'balnet' balweights(object, lambda = NULL, ...) ## S3 method for class 'balweights.contrast' print(x, ...) ## S3 method for class 'balweights.contrast' summary(object, ...) ## S3 method for class 'cv.balnet' balweights(object, lambda = "lambda.min", ...)balweights(object, lambda = NULL, ...) ## S3 method for class 'balnet' balweights(object, lambda = NULL, ...) ## S3 method for class 'balweights.contrast' print(x, ...) ## S3 method for class 'balweights.contrast' summary(object, ...) ## S3 method for class 'cv.balnet' balweights(object, lambda = "lambda.min", ...)
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
x |
A |
Estimated balancing weights (for dual-arm fits, returns a list with entries for each arm).
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Extract balancing weights over fit lambda sequence. wts <- balweights(fit) # Extract balancing weights at specified lambda. wts <- balweights(fit, lambda = 0)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Extract balancing weights over fit lambda sequence. wts <- balweights(fit) # Extract balancing weights at specified lambda. wts <- balweights(fit, lambda = 0)
Extract coefficients from a balnet object.
## S3 method for class 'balnet' coef(object, lambda = NULL, ...) ## S3 method for class 'coef.balnet.contrast' print(x, ...) ## S3 method for class 'coef.balnet.contrast' summary(object, ...)## S3 method for class 'balnet' coef(object, lambda = NULL, ...) ## S3 method for class 'coef.balnet.contrast' print(x, ...) ## S3 method for class 'coef.balnet.contrast' summary(object, ...)
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
x |
A |
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Extract coefficients over fit lambda sequence. coefs <- coef(fit) # Extract coefficients at specified lambda. coefs <- coef(fit, lambda = 0)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Extract coefficients over fit lambda sequence. coefs <- coef(fit) # Extract coefficients at specified lambda. coefs <- coef(fit, lambda = 0)
Extract coefficients from a cv.balnet object.
## S3 method for class 'cv.balnet' coef(object, lambda = "lambda.min", ...)## S3 method for class 'cv.balnet' coef(object, lambda = "lambda.min", ...)
object |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments (currently ignored). |
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Extract coefficients at cross-validated lambda. coefs <- coef(cv.fit)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Extract coefficients at cross-validated lambda. coefs <- coef(cv.fit)
Cross-validation for balnet.
cv.balnet( X, W, type.measure = c("balance.loss"), nfolds = 10, foldid = NULL, ... )cv.balnet( X, W, type.measure = c("balance.loss"), nfolds = 10, foldid = NULL, ... )
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0: control, 1: treated). |
type.measure |
The loss to minimize for cross-validation. Default is balance loss. |
nfolds |
The number of folds used for cross-validation, default is 10. |
foldid |
An optional |
... |
Arguments for |
A fit cv.balnet object.
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATE model. cv.fit <- cv.balnet(X, W) # Print CV summary. print(cv.fit) # Plot at cross-validated lambda. plot(cv.fit) # Predict at cross-validated lambda. W.hat <- predict(cv.fit, X)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATE model. cv.fit <- cv.balnet(X, W) # Print CV summary. print(cv.fit) # Plot at cross-validated lambda. plot(cv.fit) # Predict at cross-validated lambda. W.hat <- predict(cv.fit, X)
balnet object.Shows effective sample size (ESS) and percent bias reduction (PBR; reduction
in mean absolute imbalance) along the regularization path, computed from balancing
weights and normalized to percentages. The right-hand axis maps these values
to the coefficient of variation (CV) of the weights.
Supplying the lambda argument displays the standardized covariate imbalance
,
computed using the balancing weights at the specified lambda.
## S3 method for class 'balnet' plot(x, lambda = NULL, groups = NULL, max = NULL, ...)## S3 method for class 'balnet' plot(x, lambda = NULL, groups = NULL, max = NULL, ...)
x |
A |
lambda |
If NULL (default) diagnostics over the lambda path is shown. Otherwise, covariate balance at provided lambda value is shown (if target = "ATE", lambda can be a 2-vector, arm 0 and arm 1.) |
groups |
Optional named list of contiguous covariate index ranges to
aggregate into a single variable before computing covariate imbalance
(e.g., |
max |
The number of covariates to display in covariate balance plot. Defaults to all covariates. |
... |
Additional arguments. |
Invisibly returns the information underlying the plot.
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Plot the five covariates with the largest unweighted imbalance plot(fit, lambda = 0, max = 5)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Plot the five covariates with the largest unweighted imbalance plot(fit, lambda = 0, max = 5)
cv.balnet object.Plot diagnostics for a cv.balnet object.
## S3 method for class 'cv.balnet' plot(x, lambda = "lambda.min", ...)## S3 method for class 'cv.balnet' plot(x, lambda = "lambda.min", ...)
x |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments. |
Invisibly returns the information underlying the plot.
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Plot at cross-validated lambda. plot(cv.fit)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Plot at cross-validated lambda. plot(cv.fit)
Predict using a balnet object.
## S3 method for class 'balnet' predict(object, newdata, lambda = NULL, type = c("response", "link"), ...)## S3 method for class 'balnet' predict(object, newdata, lambda = NULL, type = c("response", "link"), ...)
object |
A |
newdata |
A numeric matrix. |
lambda |
Value(s) of the penalty parameter
|
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Predict propensity scores over fit lambda sequence. W.hat <- predict(fit, X)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Predict propensity scores over fit lambda sequence. W.hat <- predict(fit, X)
Predict using a cv.balnet object.
## S3 method for class 'cv.balnet' predict(object, newdata, lambda = "lambda.min", type = c("response"), ...)## S3 method for class 'cv.balnet' predict(object, newdata, lambda = "lambda.min", type = c("response"), ...)
object |
A |
newdata |
A numeric matrix. |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Predict propensity scores at cross-validated lambda. W.hat <- predict(cv.fit, X)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Predict propensity scores at cross-validated lambda. W.hat <- predict(cv.fit, X)
Print a balnet object.
## S3 method for class 'balnet' print(x, digits = max(3L, getOption("digits") - 3L), max = 3, ...)## S3 method for class 'balnet' print(x, digits = max(3L, getOption("digits") - 3L), max = 3, ...)
x |
A |
digits |
Number of digits to print. |
max |
Total number of rows to show from the beginning and end of the path |
... |
Additional print arguments. |
Invisibly returns the printed information.
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Print path summary. print(fit)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. fit <- balnet(X, W, target = "ATT") # Print path summary. print(fit)
Print a cv.balnet object.
## S3 method for class 'cv.balnet' print(x, digits = max(3L, getOption("digits") - 3L), ...)## S3 method for class 'cv.balnet' print(x, digits = max(3L, getOption("digits") - 3L), ...)
x |
A |
digits |
Number of digits to print. |
... |
Additional print arguments. |
Invisibly returns the printed information.
n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Print CV summary. print(cv.fit)n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) # Fit an ATT model. cv.fit <- cv.balnet(X, W, target = "ATT") # Print CV summary. print(cv.fit)
Summarize a balnet object.
## S3 method for class 'balnet' summary(object, ...)## S3 method for class 'balnet' summary(object, ...)
object |
|
... |
Additional summary arguments. |
Returns the printed information.
Summarize a cv.balnet object.
## S3 method for class 'cv.balnet' summary(object, ...)## S3 method for class 'cv.balnet' summary(object, ...)
object |
|
... |
Additional summary arguments. |
Returns the printed information.