Home > compute > sgd.m

sgd

PURPOSE ^

Least-square stochastic gradient-descend fit.

SYNOPSIS ^

function [w, R2] = sgd(y,X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg,alpha,lambda)

DESCRIPTION ^

 Least-square stochastic gradient-descend fit.
 
 sgd(y,X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg)
 
 <y> is p x 1 with the data

 <X> is p x q with the regressors

 <numtoselect> is the number of data points to randomly select on each
 iteration

 <finalstepsize> is like 0.05

 <convergencecriterion> is [A B C] where A is in (0,1), B is a positive
 integer, and C is number of percentages.
   
   We stop if we see a series of max(B,round(A*[current total-error-check
   number])) total-error-check iterations that do not improve performance
   on the estimation set, where improvement must be better by at least 1%
   of the previously marked R^2.
 
 <checkerror> is the number of iterations between total-error-checks

 <nonneg> if set to 1 costrains the solution to be positive
 
 For reference see : Kay et al. 2008 (Supplemental material)
 
 <alpha, lambda> ElasticNet parameters (optional, defaults to 1 and 0).
 The ElasticNet is a reguarization and variable selection algorithm. 
 The EN penalty is: 
 
 (y - X*w).^2) + lambda * sum(alpha * w.^2 + (1-alpha) * abs(w)) 

 Such that lambda sets the slope of the additional regularization error
 surface and alpha balances between the L1 and L2 constraints. When alpha
 is 1, the algorithm reduces to ridge regression. When alpha is 0, the
 algorithm reduces to the Lasso.
 
 Reference: Zou and Hastie (Zou & Hastie, (2005) Journal of the Royal
 Statistical Society B 67, Part 2, pp. 301-320)
 See also: Friedman, Hastie and Tibshirani (2008). The elements of
 statistical learning, chapter 3 (page 31, equation 3.54)


 Copyright Franco Pestilli and Kendrick Kay (2013) Vistasoft Stanford University.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [w, R2] = sgd(y,X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg,alpha,lambda)
0002 % Least-square stochastic gradient-descend fit.
0003 %
0004 % sgd(y,X,numtoselect,finalstepsize,convergencecriterion,checkerror,nonneg)
0005 %
0006 % <y> is p x 1 with the data
0007 %
0008 % <X> is p x q with the regressors
0009 %
0010 % <numtoselect> is the number of data points to randomly select on each
0011 % iteration
0012 %
0013 % <finalstepsize> is like 0.05
0014 %
0015 % <convergencecriterion> is [A B C] where A is in (0,1), B is a positive
0016 % integer, and C is number of percentages.
0017 %
0018 %   We stop if we see a series of max(B,round(A*[current total-error-check
0019 %   number])) total-error-check iterations that do not improve performance
0020 %   on the estimation set, where improvement must be better by at least 1%
0021 %   of the previously marked R^2.
0022 %
0023 % <checkerror> is the number of iterations between total-error-checks
0024 %
0025 % <nonneg> if set to 1 costrains the solution to be positive
0026 %
0027 % For reference see : Kay et al. 2008 (Supplemental material)
0028 %
0029 % <alpha, lambda> ElasticNet parameters (optional, defaults to 1 and 0).
0030 % The ElasticNet is a reguarization and variable selection algorithm.
0031 % The EN penalty is:
0032 %
0033 % (y - X*w).^2) + lambda * sum(alpha * w.^2 + (1-alpha) * abs(w))
0034 %
0035 % Such that lambda sets the slope of the additional regularization error
0036 % surface and alpha balances between the L1 and L2 constraints. When alpha
0037 % is 1, the algorithm reduces to ridge regression. When alpha is 0, the
0038 % algorithm reduces to the Lasso.
0039 %
0040 % Reference: Zou and Hastie (Zou & Hastie, (2005) Journal of the Royal
0041 % Statistical Society B 67, Part 2, pp. 301-320)
0042 % See also: Friedman, Hastie and Tibshirani (2008). The elements of
0043 % statistical learning, chapter 3 (page 31, equation 3.54)
0044 %
0045 %
0046 % Copyright Franco Pestilli and Kendrick Kay (2013) Vistasoft Stanford University.
0047 
0048 % Set the default for input params:
0049 if notDefined('convergencecriterion'), convergencecriterion = [.15 3 5]; end 
0050 if notDefined('numtoselect'), numtoselect = 0.1 * size(y,1); end
0051 if notDefined('checkerror'), checkerror=40; end
0052 if notDefined('finalstepsize'), finalstepsize=0.05; end
0053 if notDefined('nonneg'), nonneg = false; end
0054 if notDefined('coordDescent'), coordDescent = false; end
0055 
0056 % Set default values for ElasticNet (defaults to regular OLS):
0057 if notDefined('alpha'), alpha   = 0; end 
0058 if notDefined('lambda'), lambda = 0; end
0059 
0060 p = size(y,1);  % number of data points
0061 q = size(X,2);  % number of parameters
0062 orig_ssq = full(sum((y).^2)); % Sum of Squres fo the data
0063 
0064 % initialize various variables used in the fitting:
0065 w          = 0 .* rand(q,1); % the set of weights, between 0 and .1
0066 w_best     = w;          % The best set of weights.
0067 est_ssq_best = inf;      % minimum estimation error found so far
0068 estbadcnt  = 0;          % number of times estimation error has gone up
0069 iter       = 1;          % the iteration number
0070 cnt        = 1;          % the total-error-check number
0071   
0072 % report
0073 fprintf('[%s] Performing fit | %d measurements | %d parameters | ',mfilename,p,q);
0074 
0075 % Start computing the fit.
0076 while 1
0077   % Indices to selected signal and model
0078   ix      = randi(p,1,numtoselect); % Slower indexing method
0079   ix2     = false(p,1);
0080   ix2(ix) = true;
0081 
0082   % select the subset of signal and model to use for fitting
0083   y0 = y(ix2);
0084   X0 = X(ix2,:);
0085   
0086   % if not the first iteration, adjust parameters
0087   if iter ~= 1
0088     % Calculate the gradient (change in error over change in parameter):
0089     grad = -((y0 - X0*w)' * X0)' + lambda * (alpha + 2*(1 - alpha)*w);
0090     
0091     % This computes the coordinate descent instead of the gradient descent.
0092     if coordDescent
0093         % Coordinate descent
0094         m    = min(grad);
0095         grad = (grad==m)*m;
0096     end
0097     
0098     % Unit-length normalize the gradient
0099     grad = unitlengthfast(grad);
0100     
0101     % Perform gradient descent
0102     w = w - finalstepsize*grad;
0103         
0104     % Non-negative constrain, we set negative weights to zero
0105     if ( nonneg ), w(w<0) = 0;end
0106   end
0107   
0108   % check the total error every so often
0109   if mod(iter,checkerror) == 1
0110     % Curent estimated sum of the squares of the residuals (SSQ)
0111     est_ssq = sum((y - X*w).^2) + lambda*(alpha*sum(w) + (1-alpha)*sum(w.^2));
0112      
0113     % Check if the SSQ improved
0114     isimprove = est_ssq < est_ssq_best; 
0115             
0116     % We keep fitting if the SSQ is not Inf OR some percent smaller than
0117     % the best SSQ obtained so far
0118     %keepfitting = isinf(est_ssq_best) | (est_ssq < ((est_ssq_best - min_ssq)));
0119     keepfitting = isinf(est_ssq_best) | (est_ssq < ((est_ssq_best * (1-convergencecriterion(3)/100))));
0120 
0121     % do we consider this iteration to be the best yet?
0122     if isimprove
0123       % The SSQ was smaller, the fit improved.
0124       w_best       = w;       % Set the current to be the best so far
0125       est_ssq_best = est_ssq; % The min error
0126       
0127       % OK we improved, but check whether improvement is too small to be
0128       % considered useful.
0129       if keepfitting
0130         % THe fit improved more than the minimum accptable improvement.
0131         % Reset the counter fo rthe bad fits, so that we start over
0132         % checking for stopping.
0133         estbadcnt  = 0;
0134         %est_ssq_best = est_ssq; %
0135       else
0136         estbadcnt = estbadcnt + 1;
0137       end
0138     else
0139       % The fit actually was bad, SSQ increases count how many bad fit we had. Stop after a centrain number.
0140       estbadcnt = estbadcnt + 1;
0141     end
0142     
0143     % stop if we haven't improved in a while
0144     if estbadcnt >= max(convergencecriterion(2),round(convergencecriterion(1)*cnt))
0145       R2 = 100*(1-(est_ssq_best/orig_ssq));
0146       fprintf(' DONE fitting | SSQ=%2.3f (Original SSQ=%2.3f) | Rzero-squared %2.3f%%.\n',...
0147               est_ssq_best,orig_ssq,R2);
0148       break;
0149     end
0150     
0151     % Update the counter
0152     cnt = cnt + 1;
0153   end
0154   iter = iter + 1;
0155 end
0156 
0157 % prepare output
0158 w = w_best;
0159 
0160 function [v,len] = unitlengthfast(v,dim)
0161 
0162 % function [v,len] = unitlengthfast(v,dim)
0163 %
0164 % <v> is a vector (row or column) or a 2D matrix
0165 % <dim> (optional) is dimension along which vectors are oriented.
0166 %   if not supplied, assume that <v> is a row or column vector.
0167 %
0168 % unit-length normalize <v>.  aside from input flexibility,
0169 % the difference between this function and unitlength.m is that
0170 % we do not deal with NaNs (i.e. we assume <v> does not have NaNs),
0171 % and if a vector has 0 length, it becomes all NaNs.
0172 %
0173 % we also return <len> which is the original vector length of <v>.
0174 % when <dim> is not supplied, <len> is a scalar.  when <dim> is
0175 % supplied, <len> is the same dimensions as <v> except collapsed
0176 % along <dim>.
0177 %
0178 % note some weird cases:
0179 %   unitlengthfast([]) is [].
0180 %   unitlengthfast([0 0]) is [NaN NaN].
0181 %
0182 % example:
0183 % a = [3 0];
0184 % isequalwithequalnans(unitlengthfast(a),[1 0])
0185 
0186 if nargin==1
0187   len = sqrt(v(:).'*v(:));
0188   v = v / len;
0189 else
0190   if dim==1
0191     len = sqrt(sum(v.^2,1));
0192     v = v ./ repmat(len,[size(v,1) 1]);  % like this for speed.  maybe use the indexing trick to speed up even more??
0193   else
0194     len = sqrt(sum(v.^2,2));
0195     v = v ./ repmat(len,[1 size(v,2)]);
0196   end
0197 end

Generated on Wed 16-Jul-2014 19:56:13 by m2html © 2005