Home > compute > feFitModel.m

feFitModel

PURPOSE ^

Fit the LiFE model.

SYNOPSIS ^

function [fit w R2] = feFitModel(M,dSig,fitMethod,lambda)

DESCRIPTION ^

 Fit the LiFE model.

 Finds the weights for each fiber to best predict the directional
 diffusion signal (dSig)

  fit = mctDiffusionModelFit(M,dSig,fitMethod)

 dSig:  The diffusion weighted signal measured at each
        voxel in each direction. These are extracted from 
        the dwi data at some white-matter coordinates.
 M:     The LiFE difusion model matrix, constructed
        by feConnectomeBuildModel.m

 fitMethod: 
  - 'bbnnls' - DEFAULT and best, faster large-scale solver.
  - 'lsqnonneg' - MatLab defaoult non-negative least-square solver (SLOW)
  - 'sgd', 'sgdnn' - Stochastic gradient descent.
  - 'sgdl1','sgdl1nn' - Stochastic gradient descent with L1 constrain on weights.

 See also: feCreate.m, feConnectomeBuildModel.m, feGet.m, feSet.m

 Example:

 Copyright (2013-2014), Franco Pestilli, Stanford University, pestillifranco@gmail.com.

 Notes about the LiFE model:

 The rows of the M matrix are nVoxels*nBvecs. We are going to predict the
 diffusion signal in each voxel for each direction.

 The columns of the M matrix are nFibers + nVoxels.  The diffusion signal
 for each voxel is predicted as the weighted sum of predictions from each
 fibers that passes through a voxel plus an isotropic (CSF) term.

 In addition to M, we typically return dSig, which is the signal measured
 at each voxel in each direction.  These are extracted from the dwi data
 and knowledge of the roiCoords.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [fit w R2] = feFitModel(M,dSig,fitMethod,lambda)
0002 % Fit the LiFE model.
0003 %
0004 % Finds the weights for each fiber to best predict the directional
0005 % diffusion signal (dSig)
0006 %
0007 %  fit = mctDiffusionModelFit(M,dSig,fitMethod)
0008 %
0009 % dSig:  The diffusion weighted signal measured at each
0010 %        voxel in each direction. These are extracted from
0011 %        the dwi data at some white-matter coordinates.
0012 % M:     The LiFE difusion model matrix, constructed
0013 %        by feConnectomeBuildModel.m
0014 %
0015 % fitMethod:
0016 %  - 'bbnnls' - DEFAULT and best, faster large-scale solver.
0017 %  - 'lsqnonneg' - MatLab defaoult non-negative least-square solver (SLOW)
0018 %  - 'sgd', 'sgdnn' - Stochastic gradient descent.
0019 %  - 'sgdl1','sgdl1nn' - Stochastic gradient descent with L1 constrain on weights.
0020 %
0021 % See also: feCreate.m, feConnectomeBuildModel.m, feGet.m, feSet.m
0022 %
0023 % Example:
0024 %
0025 % Copyright (2013-2014), Franco Pestilli, Stanford University, pestillifranco@gmail.com.
0026 %
0027 % Notes about the LiFE model:
0028 %
0029 % The rows of the M matrix are nVoxels*nBvecs. We are going to predict the
0030 % diffusion signal in each voxel for each direction.
0031 %
0032 % The columns of the M matrix are nFibers + nVoxels.  The diffusion signal
0033 % for each voxel is predicted as the weighted sum of predictions from each
0034 % fibers that passes through a voxel plus an isotropic (CSF) term.
0035 %
0036 % In addition to M, we typically return dSig, which is the signal measured
0037 % at each voxel in each direction.  These are extracted from the dwi data
0038 % and knowledge of the roiCoords.
0039 
0040 % fit the model, by selecting the proper toolbox.
0041 switch fitMethod
0042   case {'lsqnonneg'}
0043     fprintf('\nLiFE: Computing least-square minimization with LSQNONEG...\n')
0044     options      = optimset('lsqnonneg');
0045     w = lsqnonneg(M,dSig,options);
0046     fprintf(' ...fit process completed in %2.3fs\n',toc)
0047     R2=[];
0048   case {'bbnnls'}
0049     tic
0050     fprintf('\nLiFE: Computing least-square minimization with BBNNLS...\n')
0051     opt = solopt;
0052     opt.maxit = 500;
0053     opt.use_tolo = 1;
0054     out_data = bbnnls(M,dSig,zeros(size(M,2),1),opt);
0055     fprintf('BBNNLS status: %s\nReason: %s\n',out_data.status,out_data.termReason);
0056     w = out_data.x;
0057     fprintf(' ...fit process completed in %2.3fminutes\n',toc/60)
0058     % Save the state of the random generator so that the stochasit cfit can be recomputed.
0059     defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream;
0060     fit.randState = defaultStream.State;   
0061     
0062     % Save out some results
0063     fit.results.R2        = [];
0064     fit.results.nParams   = size(M,2);
0065     fit.results.nMeasures = size(M,1);
0066     R2=[];
0067   case {'sgd','sgdnn'}% stochastic gradient descend, or non-negative stochastic gradient descend
0068     tic
0069     % Stochastic gradient descent method.
0070     % it solves an L2 minimization problem with non-negative constrain.
0071     %
0072     % Basically it takes 'chuncks' of rows of the M matrix and solves those
0073     % separately but contraining to obtain a consistent global solution.
0074     signalSiz = size(M,1);
0075     if signalSiz >= 1000000
0076       siz     = floor(signalSiz * .1); % size of the chuncks (number rows) taken at every iteration of the solver
0077     elseif signalSiz > 10000 || signalSiz < 1000000
0078       siz     = floor(signalSiz * .5); % size of the chunks (number rows) taken at every iteration of the solver
0079     elseif signalSiz <= 10000
0080       siz     = signalSiz; % size of the chuncks (number rows) taken at every iteration of the solver
0081     else
0082       keyboard
0083     end
0084     stepSiz      = 0.0124; % step in the direction of the gradient, the larger the more prone to local minima
0085     stopCriteria = [.1 5 1]; % Stop signals:
0086     % First, if total error has not decreased less than
0087     %        an XXX proportion of XXXX.
0088     % Second, number of small partial fits before
0089     %         evaluating the quality of the large fit.
0090     % Third, Amount of R2 improvement judged to be
0091     %        useful.
0092     %        It used to be:  percent improvement in R2
0093     %        that is considered a change in quality
0094     %        of fit, e.g., 1=1%.
0095     n      = 100;       % Number of iteration after which to check for total error.
0096     nonneg = strcmpi(fitMethod(end-2:end),'dnn');
0097     fprintf('\nLiFE: Computing least-square minimization with Stochastic Gradient Descent...\n')
0098     [w, R2] = sgd(dSig,M,siz,        stepSiz,      stopCriteria,        n,         nonneg);
0099              %sgd(y,   X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg,alpha,lambda)
0100     % Save out the Stochastic Gradient Descent parameters
0101     fit.params.stepSiz      = stepSiz;
0102     fit.params.stopCriteria = stopCriteria;
0103     fit.params.numInters    = n;
0104     
0105     % Save the state of the random generator so that the stochasit cfit can be recomputed.
0106     defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream;
0107     fit.randState = defaultStream.State;   
0108     
0109     % Save out some results
0110     fit.results.R2        = R2;
0111     fit.results.nParams   = size(M,2);
0112     fit.results.nMeasures = size(M,1);
0113     fprintf(' ...fit process completed in %2.3fs\n',toc)
0114 
0115     case {'sgdl1','sgdl1nn'}% stochastic gradient descend, or non-negative stochastic gradient descend
0116     tic
0117     % Stochastic gradient descent method.
0118     % it solves an L2 minimization problem with non-negative constrain.
0119     %
0120     % Basically it takes 'chuncks' of rows of the M matrix and solves those
0121     % separately but contraining to obtain a consistent global solution.
0122     signalSiz = size(M,1);
0123     if signalSiz >= 1000000
0124       siz     = floor(signalSiz * .1); % size of the chunks (number rows) taken at every iteration of the solver
0125     elseif signalSiz > 10000 || signalSiz < 1000000
0126       siz     = floor(signalSiz * .5); % size of the chunks (number rows) taken at every iteration of the solver
0127     elseif signalSiz <= 10000
0128       siz     = signalSiz; % size of the chunks (number rows) taken at every iteration of the solver
0129     else
0130       keyboard
0131     end
0132     stepSiz      = 0.0124; % step in the direction of the gradient, the larger the more prone to local minima
0133     stopCriteria = [.1 5 1]; % Stop signals:
0134     % First, if total error has not decreased less than
0135     %        an XXX proportion of XXXX.
0136     % Second, number of small partial fits before
0137     %         evaluating the quality of the large fit.
0138     % Third, Amount of R2 improvement judged to be
0139     %        useful.
0140     %        It used to be:  percent improvement in R2
0141     %        that is considered a change in quality
0142     %        of fit, e.g., 1=1%.
0143     n      = 100;       % Number of iteration after which to check for total error.
0144     nonneg = 1;
0145     fprintf('\nLiFE: Computing least-square minimization (L1) with Stochastic Gradient Descent...\n')
0146     %lambda = [length(dSig)*2.75];
0147     [w, R2] = sgdL1(dSig,M,siz, stepSiz, stopCriteria, n,nonneg,[],lambda);
0148     fprintf('Lambda: %2.2f | nFibers: %i | L1 penalty: %2.3f | L2 penalty: %2.3f\n',lambda, length(find(w>0)),sum(w),sum(w.^2))
0149 
0150     % Save out the Stochastic Gradient Descent parameters
0151     fit.params.stepSiz      = stepSiz;
0152     fit.params.stopCriteria = stopCriteria;
0153     fit.params.numInters    = n;
0154     
0155     % Save the state of the random generator so that the stochasit cfit can be recomputed.
0156     defaultStream = RandStream.getGlobalStream; %RandStream.getDefaultStream;
0157     fit.randState = defaultStream.State;   
0158     
0159     % Save out some results
0160     fit.results.R2        = R2;
0161     fit.results.nParams   = size(M,2);
0162     fit.results.nMeasures = size(M,1); 
0163     fit.results.l2        = sum(w.^2);
0164     fit.results.l1        = sum(w);
0165     
0166     fprintf(' ...fit process completed in %2.3fs\n',toc)
0167 
0168   otherwise
0169     error('Cannot fit LiFE model using method: %s.\n',fitMethod);
0170 end
0171 
0172 % Save output structure.
0173 fit.weights             = w;
0174 fit.params.fitMethod    = fitMethod;
0175 
0176 end

Generated on Wed 16-Jul-2014 19:56:13 by m2html © 2005