0001 function [lb, clErr] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate, reqErrBnd)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019 disp('============');
0020 disp('Logit Boost')
0021 disp('============');
0022
0023 if nargin < 3
0024 error('incorrect number of arguments');
0025 end
0026
0027 if nargin < 4 || isempty(nStages),
0028 nStages = 10;
0029 end
0030
0031 if nargin < 5 || isempty(reqDetRate)
0032 reqDetRate = NaN;
0033 end
0034
0035 if nargin >= 6 && ~isempty(reqErrBnd)
0036 lb.errBound = reqErrBnd;
0037 end
0038
0039
0040 if isempty(lb.lastExWeights)
0041 wts = ones(length(targetOuts), 1) * 1/length(targetOuts);
0042 else
0043
0044
0045 disp('classifier was trained before');
0046 wts = lb.lastExWeights;
0047 end;
0048
0049
0050
0051 clErr = Inf;
0052 lb.thresh = 0;
0053
0054
0055 trainData = trnExamples;
0056 if iscell(trainData),
0057 trainData = cell2mat(trainData);
0058 end
0059
0060
0061 J = getNumClasses(lb);
0062 if lb.nStages == 0
0063 lb.F = zeros(J, size(trainData,2) + 1);
0064 lb.P = ones(size(trainData,1), J) * 1/J;
0065 end
0066
0067 iStage = lb.nStages;
0068
0069 while (isinf(nStages) && clErr > lb.errBound) || ...
0070 (~isinf(nStages) && iStage < nStages),
0071
0072 fprintf('Iteration %d\n', iStage);
0073 z = zeros(size(trainData,1), J);
0074 w = zeros(size(trainData,1), J);
0075 Fj = zeros(J, size(trainData,2) + 1);
0076 sumFj = zeros(J, size(trainData,2) + 1);
0077 breakLoop = false;
0078 for j = 1:J
0079 JVec = ones(size(trainData,1),1) * j;
0080 YStar = JVec == targetOuts;
0081
0082 z(:,j) = (YStar - lb.P(:,j)) ./ ( (lb.P(:,j)) .* ( 1 - lb.P(:,j)) );
0083 w(:,j) = ( (lb.P(:,j)) .* ( 1 - lb.P(:,j)) );
0084
0085
0086 if (sum(isnan(w(:,j))) > 0 || sum(isnan(z(:,j))) > 0)
0087 breakLoop = true;
0088 break;
0089 end
0090
0091 regressor = learn(lb.regressor, trainData, z(:,j), w(:,j));
0092 Fj(j,:) = getWeights(regressor);
0093
0094
0095 if (sum(isnan(Fj(j,:))) > 0)
0096 breakLoop = true;
0097 break;
0098 end
0099
0100 sumFj(j,:) = sumFj(j,:) + Fj(j,:);
0101 end
0102
0103 if (~breakLoop)
0104 Fj = ((J-1)/J)*(Fj - (1/J) * sumFj );
0105 lb.F = lb.F + Fj;
0106 else
0107 fprintf('next iterations (after %d) will have no effects due to nan\n' ...
0108 ,iStage+1);
0109
0110 end
0111
0112
0113 Fvalues = [trainData ones(size(trainData,1),1)] * (lb.F)';
0114
0115 eFvalues = exp(Fvalues);
0116 eps = 1e-6;
0117 lb.P = bsxfun(@times, eFvalues, 1 ./ max(sum(eFvalues,2), eps ));
0118
0119
0120 if isinf(nStages)
0121
0122
0123
0124 abOuts = computeOutputs(lb, trainData);
0125
0126 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0127 fprintf('\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0128 end
0129
0130 iStage = iStage + 1;
0131 lb.nStages = lb.nStages + 1;
0132 end
0133 if ~isinf(nStages)
0134
0135
0136 abOuts = computeOutputs(lb, trainData);
0137 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0138 end
0139
0140
0141 lb.lastExWeights = wts;
0142
0143
0144 if getNumClasses(lb) == 2
0145 if ~isnan(reqDetRate),
0146 if useGlobalExamples,
0147 [lb, clErr] = adjustThreshold(lb, [], targetOuts, reqDetRate);
0148 else
0149 [lb, clErr] = adjustThreshold(lb, trainingExamples, targetOuts, ...
0150 reqDetRate);
0151 end
0152 else
0153 lb.detectionRate = ...
0154 sum((targetOuts == getPosVal(lb)) & (abOuts == getPosVal(lb))) / ...
0155 sum((targetOuts == getPosVal(lb)));
0156 end
0157 end
0158 fprintf('\t\t\tLogit Boost Training is Done with err = %f\n', clErr);