[Home]


RSLDA

A Matlab code for "Robust and sparse linear discriminant analysis via alternating direction method of multiplier." (You could Right-Click [Code] , and Save, then you can download the whole matlab code.)


Reference

Li C N, Shao Y H, Yin W, Liu M Z. Robust and sparse linear discriminant analysis via alternating direction method of multipliers. IEEE Transactions on Neural Networks and Learning Systems, 2020, 31(3): 915-926.

Main Function

function W = RSLDA(Data, Prjdim, RSLDAPara) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % Ref: "Robust and sparse linear discriminant analysis via alternating % direction method of multiplier" % Algs 1 and 3 in paper: Robust and sparse linear discriminant analysis via % alternating direction method of multiplier % Input: % Data: Data.x - the training samples; % Data.y - the labels corresponding to training samples; % Prjdim: - the dimension to be projected; % RSLADPara: the parameters of RSLAD Alg 1. % RSLDAPara.method - takes 0 or 1, 0 for RLDA, 1 for RSLDA % RSLDAPara.rho - the augment lagrangian parameter; % RSLDAPara.lambda - the lambda for Alg 1. % RSLDAPara.sigm - the sigma for Alg 1. % RSLDAPara.tol - the epsilon for Alg 1. % Ouput: % W: -the project vectors: dim x Prjdim; % dim: the dimension of samples; % Prjdim: the number of projection vector; % % Using Example: % Data.trainX = rand(50,10); % Data.trainY = [ones(25,1);-ones(25,1)] % Prjdim = 5; % RSLDAPara.method = 0; % RSLDAPara.rho = 5; % RSLDAPara.lambda = 0.5; % RSLDAPara.sigm = 0.05;; % RSLDAPara.tol = 1e-3; % W = RSLDA(Data, Prjdim, RSLDAPara); % Reference: % Li C N, Shao Y H, Yin W, Liu M Z. Robust and sparse linear discriminant % analysis via alternating direction method of multipliers. IEEE Transactions % on Neural Networks and Learning Systems, 2020, 31(3): 915-926. % % Version 2.0 -- Dec/2019 % Written by Chun-Na Li (na1013na@163.com) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% function begin .... rho = RSLDAPara.rho; lambda = RSLDAPara.lambda; tol = RSLDAPara.tol; epsi = 0.001; % Smoothed parameter X = Data.trainX; Y = Data.trainY; % Number of dimension of training samples(total) [nsams, dim] = size(X); I = eye(dim); % nsamsc: Number of sample for each class % labels: Set of labels of dataset; [nsamsc,labels] = hist(Y, unique(Y)); % nc: Number of class nc = numel(labels); % Proj vector matrix; W = zeros(dim,Prjdim); for k = 1:Prjdim if k == 1 tstart = tic; end % mean for each class clsmean = zeros(nc, dim); % Sw: Scatter matrix within class Sw = zeros(dim,dim); for i = 1:nc %% calculate the mean of each class cls_idx = (Y == labels(i)); clsmean(i,:) = mean(X(cls_idx,:),1); Sw = Sw + (X(cls_idx,:)-repmat(clsmean(i,:),nsamsc(i),1))'*... (X(cls_idx,:)-repmat(clsmean(i,:),nsamsc(i),1)); end Sw = Sw/nsams; % normalize X0 = (clsmean - repmat(mean(X,1),nc,1))' * diag(nsamsc); % rand initial w = rand(dim,1); u2 = rand(dim,1); y = rand(nc,1); u1 = rand(nc,1); % convergence condition % eps_pri_one, eps_pri_two eps_dual_one, eps_dual_two eps_pri_one = 1.0; eps_pri_two = 1.0; eps_dual_one = 1.0; eps_dual_two = 1.0; eps_pri_one_old = 1.0; eps_pri_two_old = 1.0; eps_dual_one_old = 1.0; eps_dual_two_old = 1.0; iter_while = 1; Ginv = (X0*X0' + I + 2*lambda/rho * Sw)\I; while ( (eps_pri_one > tol) || ... (eps_pri_two > tol) || ... (eps_dual_one > tol) || ... (eps_dual_two > tol) ) %% solve z z = Ginv*(X0 * (y - u1) + (w - u2)); % Ginv * g %% solve y y0 = y; Xz = X0'*z; rho_y = rho; for k_y = 1:10 y_inner(:,k_y) = (rho_y*(X0'*z + u1).*sqrt(y.^2+epsi))./(rho_y*sqrt(y.^2+epsi) - 1); if norm(y_inner(:,k_y)-y)<10^(-3) break; end y = y_inner(:,k_y); end if k == 1 y_iter{iter_while} = y_inner; a(:,iter_while) = X0'*z + u1; aiter_while = a(:,iter_while); rho_rightside(:,iter_while) = (sqrt(epsi) + 0.5*abs(aiter_while))/epsi + sqrt(abs(aiter_while)/epsi^1.5 + 0.25*aiter_while.^2/epsi^2); con_judge(:,iter_while) = (rho > rho_rightside(:,iter_while)); end %% solve w w0 = w; w = z - u2; % % === For RSLDA, the following codes are also needed ==== if RSLDAPara.method == 1 ka = sigm/rho; w(w > ka) = w(w > ka) - ka; w(w < -ka) = w(w < -ka) + ka; w( w<=ka & w>=-ka) = 0; end %% solve u1 and u2 u1 = u1 + Xz - y; u2 = u2 + w - z; eps_pri_one_old = eps_pri_one; eps_pri_two_old = eps_pri_two; eps_dual_one_old = eps_dual_one; eps_dual_two_old = eps_dual_two; eps_pri_one = norm(Xz - y); eps_pri_two = norm(w - z); eps_dual_one = norm( X0*(y - y0)); eps_dual_two = norm(w - w0); if iter_while > 1000 break; end if ( (abs(eps_pri_one - eps_pri_one_old ) < 1E-3) && ... (abs(eps_pri_two - eps_pri_two_old ) < 1E-3) && ... (abs(eps_dual_one - eps_dual_one_old ) < 1E-3 )&& ... (abs(eps_pri_two - eps_pri_two_old ) < 1E-3) ) break; end if k == 1 P1(iter_while,:) = eps_pri_one; P2(iter_while,:) = eps_pri_two; D1(iter_while,:) = eps_dual_one; D2(iter_while,:) = eps_dual_two; end iter_while = iter_while + 1; end % end while W(:,k) = w; X = X - (X * w) * w'; fprintf('.'); if k == 1 time = toc(tstart); end end fprintf('\n'); end
Contacts


Any question or advice please email to na1013na@163.com or shaoyuanhai21@163.com.