AFT_Bayes_LASSO: Bayesian variable selection in the AFT model via LASSO and DP prior for complex error structure
- The code is developed for the Bayesian variable selection method in the Accelerated failure time (AFT) model for censored survival data, developed in: Zhang, Z., Sinha, S., Maiti, T., and Shipp, E. (2016). Bayesian variable selection in the AFT model with an application to the SEER breast cancer data. Statistical Methods in Medical Research. To appear. Preprint
- Please contact the authors if there are any questions or implementation issues: Zhen Zhang,
An example
function [] = demo()
id = 1;
rng('default'); rng(8*id)
n = 5e3;
q = 10;
p = 4;
inds = 1:p;
betas = zeros(q,1); betas(inds) = [0.5, 0.5, 0.35, -0.35]';
Z = binornd(1, 0.5, [n,q]);
mu = Z*betas;
labs = ones(1,n);
lambda0 = 29.2672; k0 = 0.8178; C0 = 1.57;
eps = log(wblrnd(lambda0, k0,[n,1]))/C0;
eps = exp(-0.5*mu.^2).*eps;
K = 78;
T = exp(1 + mu + eps);
C = Z(:,1) + Z(:,2) + unifrnd(0,K,[n,1]);
V = min(T,C);
Delta = (T<=C);
fprintf('For this simulation, the proportion of non-censoring: %3.4f\n', mean(Delta))
n0 = 4;
Z0 = zeros(n0, q);
Z0(1, 1:p) = [1,1,1,0];
Z0(2, 1:p) = [1,0,1,1];
Z0(3, 1:p) = [1,1,1,1];
Z0(4, 1:p) = [0,0,0,1];
t0 = 0:0.2:60; len = numel(t0); trueY = zeros(len, n0);
for i = 1:n0
trueMu = Z0(i,:)*betas;
x = exp(0.5*trueMu^2)*(log(t0)-1-trueMu);
trueY(:,i) = exp(- (exp(C0*x)/lambda0).^k0);
Trues.eps = eps; Trues.T = T; Trues.C = C;
Trues.beta = betas; Trues.inds = inds;
Trues.labs = labs;
tot = 200; burn = 100;
N =150;
nch = 3;
myseeds = [1,2,3];
ngammas = 2;
niter = tot-burn;
nsample = niter*nch;
matpara = zeros(nsample, q*2+1+ngammas+1 + 1);
Ps = zeros(nsample, N);
Theta = zeros(nsample, N*2);
initBetas = Trues.beta;
for ch = 1:nch
i0 = (ch-1)*niter + (1:niter);
[matpara(i0,:), Theta(i0,:), Ps(i0,:)] = AFT_Bayes_LASSO(V, Delta, Z, N, tot, burn, initBetas, myseeds(ch));
fprintf('\nMCMC chain %d: ',ch)
betaSamples = matpara(:, 1:q)';
lb = quantile(betaSamples,0.025,2); ub = quantile(betaSamples,0.975,2);
fprintf('\nTrue, posterior mean, lower and upper bound of 95%% credible interval, coverage of the true:\n')
disp(num2str([Trues.beta,mean(betaSamples,2), lb, ub, (lb<Trues.beta).*(ub>Trues.beta)==1], 3))
mu0 = Z0*betaSamples;
nu0 = zeros(n0, nsample);
spm = nan(len, n0); spl = nan(len, n0); spu = nan(len, n0);
fprintf('\nThe design matrix for some example groups for survival probability:\n')
for i = 1:n0
for j = 1:nsample
nu0(i, j) = exp( [mu0(i,j), mu0(i,j).^2]*matpara(j, (2*q+1)+(1:ngammas))' );
for t = 1:len
t1 = t0(t);
tmp = 1 - sum(Ps.*normcdf(...
( repmat( ((log(t1) - mu0(i,:))./sqrt(nu0(i,:)))', [1,N]) - Theta(:,1:N) )./sqrt(Theta(:,N+(1:N)))...
), 2);
spm(t, i) = mean(tmp); spl(t, i) = quantile(tmp, .025); spu(t, i) = quantile(tmp, .975);
subplot(2,2,i); plot(t0, trueY(:,i), 'r-')
xlim([Z0(i,:)*betas+1, 60])
hold on,
for j = t0
plot(j, mean(V(~( sum((Z(:,1:p) - repmat(Z0(i,1:p), [n,1])).^2, 2) ))>=j), 'k.')
plot(t0, spm(:,i), 'b-')
plot(t0, spl(:,i), 'b--')
plot(t0, spu(:,i), 'b--')
hold off
disp('Demo completed. For full results, need to increase tot for achieving reasonable mixing rates and convergence.')
For this simulation, the proportion of non-censoring: 0.7090
MCMC chain 1: 200 iterations are done with elapsed time 0.14 minutes.
MCMC chain 2: 200 iterations are done with elapsed time 0.13 minutes.
MCMC chain 3: 200 iterations are done with elapsed time 0.13 minutes.
True, posterior mean, lower and upper bound of 95% credible interval, coverage of the true:
0.5 0.499 0.494 0.503 1
0.5 0.499 0.494 0.503 1
0.35 0.349 0.342 0.356 1
-0.35 -0.352 -0.357 -0.347 1
0 0.00363 -0.00462 0.0115 1
0 -0.000818 -0.00794 0.0036 1
0 -0.00223 -0.00764 0.00244 1
0 -0.000749 -0.00521 0.0048 1
0 -0.00105 -0.00609 0.00368 1
0 0.0019 -0.00144 0.00574 1
The design matrix for some example groups for survival probability:
1 1 1 0 0 0 0 0 0 0
1 0 1 1 0 0 0 0 0 0
1 1 1 1 0 0 0 0 0 0
0 0 0 1 0 0 0 0 0 0
Demo completed. For full results, need to increase tot for achieving reasonable mixing rates and convergence.
Main Program for Bayesian MCMC runs
function [matpara, Theta, Ps] = AFT_Bayes_LASSO(V, Delta, Z, N, tot, burn, init_beta, randomSeed)
invtau2_eta = 1/1e-3;
phi1 = 0; phi2 = 2.5; phi3 = 1;
mean_IG = 1; var_IG = .01;
phi4 = 2+mean_IG^2/var_IG;
phi5 = mean_IG*(phi4-1);
mean_gamma = 0; var_gamma2 = 1e4;
lambda_mean = 0.1; lambda_var = 5;
b_lambda = lambda_var/lambda_mean;
a_lambda = lambda_mean/b_lambda;
b_lambda = 1/b_lambda;
alpha_mean = 1; alpha_var = 10;
b_alpha = alpha_var/alpha_mean;
a_alpha = alpha_mean/b_alpha;
verbose = 0;
reportrate = 0;
ngammas = 2;
[n,q] = size(Z);
cind = find(Delta==0);
L_cind = length(cind);
rng('default'); rng(randomSeed)
lTstar = log( V + unifrnd(0,1,[n,1]).*(Delta==0) );
alphas = 10;
lambda2 = gamrnd(a_lambda, 1/b_lambda);
eps_eta = normrnd(0,sqrt(1/invtau2_eta),[n,1]);
betas = init_beta;
gammas = zeros(1,ngammas);
ps = 1/N*ones(1,N);
labs = ones(1,n);
u = ones(1,q);
zeta = 1;
theta = zeros(N,2);
theta(:,2) = 1./gamrnd(phi2, 1/phi3, [N,1]);
theta(:,1) = normrnd(phi1*ones(N,1), sqrt(zeta*theta(:,2)));
mu = Z*betas;
eta = mu + eps_eta;
nu = exp([eta,eta.^2]*gammas');
fulltheta = theta(labs, 1);
fulltheta2 = 1./( nu.*theta(labs, 2) );
objrate = 0.44;
batchLen = min(50, tot); batchNum = 0;
batchTot = tot/batchLen;
nlen = ngammas + n;
accepts = zeros(1, nlen);
rates = zeros(batchTot, nlen);
tunings = [-4*ones(1,ngammas), -0.5*ones(1,n)];
matpara = zeros(tot-burn, q*2+1+ngammas+1 + 1);
Ps = zeros(tot-burn, N);
Theta = zeros(tot-burn, N*2);
for iter = 1:tot
if verbose == 1
fprintf('%6d', iter)
Ztild = repmat(sqrt(fulltheta2+invtau2_eta),[1,q]).*Z;
Sigma = diag(u) + Ztild'*Ztild;
Sigma = chol(Sigma, 'lower');
Mu = Sigma\( Z'*((lTstar - fulltheta.*sqrt(nu)).*fulltheta2 + eta*invtau2_eta) );
betas = Sigma'\( randn(size(Mu)) + Mu );
betas = betas';
for j = 1:q
u0 = rand(1); u_mu = sqrt(lambda2)/abs(betas(j));
y0 = (randn(1))^2; tlambda2 = 2*lambda2;
unew = u_mu + u_mu^2*y0/tlambda2 - u_mu/tlambda2*sqrt(2*tlambda2*u_mu*y0 + u_mu^2*y0^2);
if u0 <= u_mu/(u_mu+unew)
u(j) = unew;
u(j) = u_mu^2/unew;
lambda2 = gamrnd(a_lambda+q, 1./(b_lambda + 0.5*sum(u.^2)));
err = lTstar - Z*betas';
fulltheta2 = 1./theta(labs, 2);
for j = 1:ngammas
gamma1 = gammas; gamma1(j) = normrnd(gammas(j), exp(tunings(j)));
lognu = [eta,eta.^2]*gammas';
lognu1 = [eta,eta.^2]*gamma1';
loglik = - 0.5*sum((err./sqrt(exp(lognu)) - fulltheta).^2.*fulltheta2) - 0.5*sum(lognu) - (gammas(j)-mean_gamma)^2./var_gamma2;
loglik1 = - 0.5*sum((err./sqrt(exp(lognu1)) - fulltheta).^2.*fulltheta2) - 0.5*sum(lognu1) - (gamma1(j)-mean_gamma)^2./var_gamma2;
MH = exp(loglik1 - loglik);
u0 = rand(1);
if u0 <= MH
gammas = gamma1; lognu = lognu1;
accepts(j) = accepts(j)+1;
nu = exp(lognu);
if ~isempty(cind)
snu = sqrt(nu(cind));
R = rand([L_cind,1]); tmp = Z(cind,:)*betas' + snu.*fulltheta(cind);
snu = snu./sqrt(fulltheta2(cind));
lTstar(cind) = tmp + snu.*norminv( (1-R).*min(0.9999, normcdf((log(V(cind)) - tmp)./snu)) + R );
mu = Z*betas';
err = lTstar - mu;
lognu = [eta, eta.^2]*gammas';
eta1 = normrnd(eta, exp(tunings(ngammas+(1:n))'));
lognu1 = [eta1, eta1.^2]*gammas';
loglik = - 0.5*(err./sqrt(exp(lognu)) - fulltheta).^2.*fulltheta2 - lognu/2 - 0.5*invtau2_eta*(eta-mu).^2;
loglik1 = - 0.5*(err./sqrt(exp(lognu1)) - fulltheta).^2.*fulltheta2 - lognu1/2 - 0.5*invtau2_eta*(eta1-mu).^2;
MH = exp(loglik1 - loglik);
u0 = rand(n,1); inds = find(u0 <= MH);
if numel(inds) > 0
eta(inds) = eta1(inds);
lognu(inds) = lognu1(inds);
accepts(ngammas+inds) = accepts(ngammas+inds)+1;
nu = exp(lognu);
err = err./sqrt(nu);
ps0 = repmat(ps, [n,1]).*normpdf( (repmat(err, [1,N]) - repmat(theta(:,1)', [n,1]))./repmat(sqrt(theta(:,2)'), [n,1]) );
ps0 = cumsum(ps0./repmat(sum(ps0,2), [1,N]), 2);
myu0 = rand([n,1]);
labs = sum(ps0 <= repmat(myu0, [1,N]), 2) + 1;
ms = histc(labs, 1:N);
ps = gamrnd(alphas/N + ms, 1);
ps = ps'./sum(ps);
tmpsum = 0.0;
for j = 1:N
if ms(j) > 0
var_theta = zeta/(1+ms(j)*zeta);
mean_theta = var_theta*(phi1/zeta + sum(err(labs ==j)) );
var_theta = theta(j,2)*var_theta;
theta(j,1) = normrnd(mean_theta, sqrt(var_theta));
theta(j,2) = 1./gamrnd(phi2 + 0.5*1 + 0.5*ms(j), ...
1/(phi3 + 0.5*(theta(j,1)-phi1)^2/zeta + 0.5*sum((err(labs==j) - theta(j,1)).^2) ) );
tmpsum = tmpsum + 0.5*(theta(j,1)-phi1)^2/theta(j,2);
theta(j,2) = 1./gamrnd(phi2, 1/phi3 );
theta(j,1) = normrnd(phi1, sqrt(zeta*theta(j,2)));
zeta = 1/gamrnd(phi4 + 0.5*sum(ms>0), 1/(phi5 + tmpsum));
fulltheta = theta(labs,1);
fulltheta2 = 1./( nu.*theta(labs, 2) );
alpha1 = gamrnd(a_alpha, b_alpha);
sps = sum(log(ps));
loglik =gammaln(alphas) - N*gammaln(alphas/N) +(alphas/N-1)*sps;
loglik1 =gammaln(alpha1) - N*gammaln(alpha1/N) +(alpha1/N-1)*sps;
MH = exp(loglik1 - loglik);
u0 = rand(1);
if u0 <= MH
alphas = alpha1;
if iter > burn
iter0 = iter-burn;
matpara(iter0,:) = [betas, u, lambda2, gammas, alphas, zeta];
Theta(iter0,:) = reshape(theta, [1, 2*N]);
Ps(iter0,:) = ps;
if ~mod(iter, batchLen)
batchNum = batchNum+1;
accepts = accepts./batchLen;
rates(batchNum,:) = accepts;
if reportrate == 1
disp(num2str( [ min(accepts(1:ngammas)), max(accepts(1:ngammas)), ...
min(accepts(ngammas+(1:n))), max(accepts(ngammas+(1:n))) ], 2))
tunings = tunings + sign((accepts>objrate)-0.5).*min(0.01, 1/sqrt(batchNum));
accepts = zeros(1,nlen);
runtime = toc/60;
fprintf('%d iterations are done with elapsed time %.2f minutes.\n', tot, runtime)