0001 function [rab, clErr, itrErrors] = learn(rab, trnExamples, targetOuts, ...
0002 nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 function printfTrue(condition, varargin)
0030 if condition
0031 fprintf(varargin{:});
0032 end
0033 end
0034
0035
0036 if isempty(rab.weakCl)
0037 error('A weak classifier has not been specified');
0038 end
0039
0040 if nargin < 3
0041 error('incorrect number of arguments');
0042 end
0043
0044 if nargin < 4 || isempty(nStages),
0045 nStages = 10;
0046 rab.nStages = 10;
0047 else
0048 rab.nStages = nStages;
0049 end
0050
0051 if nargin < 5
0052 verbose = false;
0053 end
0054
0055 if nargin > 5 && nargin < 7
0056 error('cross validation output is not given');
0057 end
0058
0059 useCrossValidationSet = false;
0060
0061 if nargin >= 6 && ~isempty(crossvalset) && ~isempty(crossvalout)
0062 useCrossValidationSet = true;
0063 end
0064
0065 if nargin < 8 || isempty(reqDetRate)
0066 reqDetRate = NaN;
0067 end
0068
0069 if nargin >= 9 && ~isempty(reqErrBnd)
0070 rab.errBound = reqErrBnd;
0071 end
0072
0073 if nargout > 2
0074 if isinf(nStages)
0075 errors('Cant record errors with infinite number of stages');
0076 else
0077 itrErrors = zeros(nStages, 1);
0078 end
0079 end
0080
0081 printfTrue(verbose, 'RealAdaBoost\n');
0082 printfTrue(verbose, '============\n');
0083
0084
0085 if isempty(rab.lastExWeights)
0086 printfTrue(verbose, 'RealAda boost Classifier wasnt trained before\n' ,1);
0087 wts = ones(length(targetOuts), 1) * 1/length(targetOuts);
0088 else
0089
0090
0091 printfTrue(verbose, 'Real AdaBoost has been trained before\n');
0092 wts = rab.lastExWeights;
0093 end;
0094
0095
0096
0097 clErr = Inf;
0098 rab.thresh = 0;
0099
0100 iStage = length(rab.trndCls) + 1;
0101
0102 while (isinf(nStages) && clErr > rab.errBound) || ...
0103 (~isinf(nStages) && iStage <= nStages),
0104
0105 printfTrue(verbose, '\n\t\t\t==========================\n');
0106 printfTrue(verbose, '\t\t\tRealAda Boosting Stage # %d\n', iStage);
0107 printfTrue(verbose, '\t\t\t==========================\n');
0108
0109
0110 if useCrossValidationSet
0111
0112 [trndCl, err] = learn(rab.weakCl, trnExamples, targetOuts, wts, ...
0113 crossvalset, crossvalout);
0114 else
0115 [trndCl, err] = learn(rab.weakCl, trnExamples, targetOuts, wts);
0116 end
0117
0118
0119 if err < 1e-9,
0120 err = 1e-9;
0121 end
0122
0123 printfTrue(verbose, ...
0124 '\t\t\tWeak Classifier has been trained, err = %f\n', err);
0125
0126
0127 rab.trndCls = {rab.trndCls{:}, trndCl};
0128
0129 beta = err / (1 - err);
0130 alpha = log(1 / beta) + log(getNumClasses(rab) - 1);
0131 rab.clsWeights = [rab.clsWeights, alpha];
0132
0133
0134 [~, prob] = computeOutputs(trndCl, trnExamples);
0135
0136
0137
0138
0139
0140 K = getNumClasses(rab);
0141 Y_CODE = ones(length(targetOuts), K) * (-1/(K-1));
0142 Y_CODE(sub2ind(size(Y_CODE), (1:length(targetOuts))' , targetOuts)) = 1;
0143
0144
0145 prob(prob <= 0) = 1e-5;
0146
0147
0148 estimator_wts = -1 * ((K-1)/K) * sum(Y_CODE .* log(prob), 2);
0149
0150
0151 wts = wts .* exp( (estimator_wts .* (+(estimator_wts < 0) |+(wts > 0) ) ) );
0152
0153
0154 wts = wts ./ sum(wts);
0155
0156 if isinf(nStages) || nargout > 2
0157
0158
0159 abOuts = computeOutputs(rab, trnExamples);
0160 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0161 printfTrue(verbose, '\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0162 if ~isinf(nStages)
0163 itrErrors(iStage ,1) = clErr;
0164 end
0165 end
0166
0167 iStage = iStage + 1;
0168 end
0169
0170 if ~isinf(nStages) && nargout < 3
0171
0172
0173
0174 abOuts = computeOutputs(rab, trnExamples);
0175 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0176 end
0177
0178
0179 rab.lastExWeights = wts;
0180
0181 printfTrue(verbose, '\t\t\tRealAdaBoost Training is Done with err = %f\n', clErr)
0182 end