Fit the LiFE model. Finds the weights for each fiber to best predict the directional diffusion signal (dSig) fit = mctDiffusionModelFit(M,dSig,fitMethod) dSig: The diffusion weighted signal measured at each voxel in each direction. These are extracted from the dwi data at some white-matter coordinates. M: The LiFE difusion model matrix, constructed by feConnectomeBuildModel.m fitMethod: - 'bbnnls' - DEFAULT and best, faster large-scale solver. - 'lsqnonneg' - MatLab defaoult non-negative least-square solver (SLOW) - 'sgd', 'sgdnn' - Stochastic gradient descent. - 'sgdl1','sgdl1nn' - Stochastic gradient descent with L1 constrain on weights. See also: feCreate.m, feConnectomeBuildModel.m, feGet.m, feSet.m Example: Copyright (2013-2014), Franco Pestilli, Stanford University, pestillifranco@gmail.com. Notes about the LiFE model: The rows of the M matrix are nVoxels*nBvecs. We are going to predict the diffusion signal in each voxel for each direction. The columns of the M matrix are nFibers + nVoxels. The diffusion signal for each voxel is predicted as the weighted sum of predictions from each fibers that passes through a voxel plus an isotropic (CSF) term. In addition to M, we typically return dSig, which is the signal measured at each voxel in each direction. These are extracted from the dwi data and knowledge of the roiCoords.
0001 function [fit w R2] = feFitModel(M,dSig,fitMethod,lambda) 0002 % Fit the LiFE model. 0003 % 0004 % Finds the weights for each fiber to best predict the directional 0005 % diffusion signal (dSig) 0006 % 0007 % fit = mctDiffusionModelFit(M,dSig,fitMethod) 0008 % 0009 % dSig: The diffusion weighted signal measured at each 0010 % voxel in each direction. These are extracted from 0011 % the dwi data at some white-matter coordinates. 0012 % M: The LiFE difusion model matrix, constructed 0013 % by feConnectomeBuildModel.m 0014 % 0015 % fitMethod: 0016 % - 'bbnnls' - DEFAULT and best, faster large-scale solver. 0017 % - 'lsqnonneg' - MatLab defaoult non-negative least-square solver (SLOW) 0018 % - 'sgd', 'sgdnn' - Stochastic gradient descent. 0019 % - 'sgdl1','sgdl1nn' - Stochastic gradient descent with L1 constrain on weights. 0020 % 0021 % See also: feCreate.m, feConnectomeBuildModel.m, feGet.m, feSet.m 0022 % 0023 % Example: 0024 % 0025 % Copyright (2013-2014), Franco Pestilli, Stanford University, pestillifranco@gmail.com. 0026 % 0027 % Notes about the LiFE model: 0028 % 0029 % The rows of the M matrix are nVoxels*nBvecs. We are going to predict the 0030 % diffusion signal in each voxel for each direction. 0031 % 0032 % The columns of the M matrix are nFibers + nVoxels. The diffusion signal 0033 % for each voxel is predicted as the weighted sum of predictions from each 0034 % fibers that passes through a voxel plus an isotropic (CSF) term. 0035 % 0036 % In addition to M, we typically return dSig, which is the signal measured 0037 % at each voxel in each direction. These are extracted from the dwi data 0038 % and knowledge of the roiCoords. 0039 0040 % fit the model, by selecting the proper toolbox. 0041 switch fitMethod 0042 case {'lsqnonneg'} 0043 fprintf('\nLiFE: Computing least-square minimization with LSQNONEG...\n') 0044 options = optimset('lsqnonneg'); 0045 w = lsqnonneg(M,dSig,options); 0046 fprintf(' ...fit process completed in %2.3fs\n',toc) 0047 R2=[]; 0048 case {'bbnnls'} 0049 tic 0050 fprintf('\nLiFE: Computing least-square minimization with BBNNLS...\n') 0051 opt = solopt; 0052 opt.maxit = 500; 0053 opt.use_tolo = 1; 0054 out_data = bbnnls(M,dSig,zeros(size(M,2),1),opt); 0055 fprintf('BBNNLS status: %s\nReason: %s\n',out_data.status,out_data.termReason); 0056 w = out_data.x; 0057 fprintf(' ...fit process completed in %2.3fminutes\n',toc/60) 0058 % Save the state of the random generator so that the stochasit cfit can be recomputed. 0059 defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream; 0060 fit.randState = defaultStream.State; 0061 0062 % Save out some results 0063 fit.results.R2 = []; 0064 fit.results.nParams = size(M,2); 0065 fit.results.nMeasures = size(M,1); 0066 R2=[]; 0067 case {'sgd','sgdnn'}% stochastic gradient descend, or non-negative stochastic gradient descend 0068 tic 0069 % Stochastic gradient descent method. 0070 % it solves an L2 minimization problem with non-negative constrain. 0071 % 0072 % Basically it takes 'chuncks' of rows of the M matrix and solves those 0073 % separately but contraining to obtain a consistent global solution. 0074 signalSiz = size(M,1); 0075 if signalSiz >= 1000000 0076 siz = floor(signalSiz * .1); % size of the chuncks (number rows) taken at every iteration of the solver 0077 elseif signalSiz > 10000 || signalSiz < 1000000 0078 siz = floor(signalSiz * .5); % size of the chunks (number rows) taken at every iteration of the solver 0079 elseif signalSiz <= 10000 0080 siz = signalSiz; % size of the chuncks (number rows) taken at every iteration of the solver 0081 else 0082 keyboard 0083 end 0084 stepSiz = 0.0124; % step in the direction of the gradient, the larger the more prone to local minima 0085 stopCriteria = [.1 5 1]; % Stop signals: 0086 % First, if total error has not decreased less than 0087 % an XXX proportion of XXXX. 0088 % Second, number of small partial fits before 0089 % evaluating the quality of the large fit. 0090 % Third, Amount of R2 improvement judged to be 0091 % useful. 0092 % It used to be: percent improvement in R2 0093 % that is considered a change in quality 0094 % of fit, e.g., 1=1%. 0095 n = 100; % Number of iteration after which to check for total error. 0096 nonneg = strcmpi(fitMethod(end-2:end),'dnn'); 0097 fprintf('\nLiFE: Computing least-square minimization with Stochastic Gradient Descent...\n') 0098 [w, R2] = sgd(dSig,M,siz, stepSiz, stopCriteria, n, nonneg); 0099 %sgd(y, X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg,alpha,lambda) 0100 % Save out the Stochastic Gradient Descent parameters 0101 fit.params.stepSiz = stepSiz; 0102 fit.params.stopCriteria = stopCriteria; 0103 fit.params.numInters = n; 0104 0105 % Save the state of the random generator so that the stochasit cfit can be recomputed. 0106 defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream; 0107 fit.randState = defaultStream.State; 0108 0109 % Save out some results 0110 fit.results.R2 = R2; 0111 fit.results.nParams = size(M,2); 0112 fit.results.nMeasures = size(M,1); 0113 fprintf(' ...fit process completed in %2.3fs\n',toc) 0114 0115 case {'sgdl1','sgdl1nn'}% stochastic gradient descend, or non-negative stochastic gradient descend 0116 tic 0117 % Stochastic gradient descent method. 0118 % it solves an L2 minimization problem with non-negative constrain. 0119 % 0120 % Basically it takes 'chuncks' of rows of the M matrix and solves those 0121 % separately but contraining to obtain a consistent global solution. 0122 signalSiz = size(M,1); 0123 if signalSiz >= 1000000 0124 siz = floor(signalSiz * .1); % size of the chunks (number rows) taken at every iteration of the solver 0125 elseif signalSiz > 10000 || signalSiz < 1000000 0126 siz = floor(signalSiz * .5); % size of the chunks (number rows) taken at every iteration of the solver 0127 elseif signalSiz <= 10000 0128 siz = signalSiz; % size of the chunks (number rows) taken at every iteration of the solver 0129 else 0130 keyboard 0131 end 0132 stepSiz = 0.0124; % step in the direction of the gradient, the larger the more prone to local minima 0133 stopCriteria = [.1 5 1]; % Stop signals: 0134 % First, if total error has not decreased less than 0135 % an XXX proportion of XXXX. 0136 % Second, number of small partial fits before 0137 % evaluating the quality of the large fit. 0138 % Third, Amount of R2 improvement judged to be 0139 % useful. 0140 % It used to be: percent improvement in R2 0141 % that is considered a change in quality 0142 % of fit, e.g., 1=1%. 0143 n = 100; % Number of iteration after which to check for total error. 0144 nonneg = 1; 0145 fprintf('\nLiFE: Computing least-square minimization (L1) with Stochastic Gradient Descent...\n') 0146 %lambda = [length(dSig)*2.75]; 0147 [w, R2] = sgdL1(dSig,M,siz, stepSiz, stopCriteria, n,nonneg,[],lambda); 0148 fprintf('Lambda: %2.2f | nFibers: %i | L1 penalty: %2.3f | L2 penalty: %2.3f\n',lambda, length(find(w>0)),sum(w),sum(w.^2)) 0149 0150 % Save out the Stochastic Gradient Descent parameters 0151 fit.params.stepSiz = stepSiz; 0152 fit.params.stopCriteria = stopCriteria; 0153 fit.params.numInters = n; 0154 0155 % Save the state of the random generator so that the stochasit cfit can be recomputed. 0156 defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream; 0157 fit.randState = defaultStream.State; 0158 0159 % Save out some results 0160 fit.results.R2 = R2; 0161 fit.results.nParams = size(M,2); 0162 fit.results.nMeasures = size(M,1); 0163 fit.results.l2 = sum(w.^2); 0164 fit.results.l1 = sum(w); 0165 0166 fprintf(' ...fit process completed in %2.3fs\n',toc) 0167 0168 otherwise 0169 error('Cannot fit LiFE model using method: %s.\n',fitMethod); 0170 end 0171 0172 % Save output structure. 0173 fit.weights = w; 0174 fit.params.fitMethod = fitMethod; 0175 0176 end