0001 function [ab, clErr, itrErrors] = learn(ab, trnExamples, targetOuts, ...
0002 nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021 function printfTrue(condition, varargin)
0022 if condition
0023 fprintf(varargin{:});
0024 end
0025 end
0026
0027
0028 if isempty(ab.weakCl)
0029 error('A weak classifier has not been specified');
0030 end
0031
0032 if nargin < 3
0033 error('incorrect number of arguments');
0034 end
0035
0036 if nargin < 4 || isempty(nStages),
0037 nStages = 10;
0038 end
0039
0040 if nargin < 5
0041 verbose = false;
0042 end
0043
0044 if nargin > 5 && nargin < 7
0045 error('cross validation output is not given');
0046 end
0047
0048 useCrossValidationSet = false;
0049
0050 if nargin >= 6 && ~isempty(crossvalset) && ~isempty(crossvalout)
0051 useCrossValidationSet = true;
0052 end
0053
0054 if nargin < 8 || isempty(reqDetRate)
0055 reqDetRate = NaN;
0056 end
0057
0058 if nargin >= 9 && ~isempty(reqErrBnd)
0059 ab.errBound = reqErrBnd;
0060 end
0061
0062 if nargout > 2
0063 if isinf(nStages)
0064 errors('Cant record errors with infinite number of stages');
0065 else
0066 itrErrors = zeros(nStages, 1);
0067 end
0068 end
0069
0070 printfTrue(verbose, 'AdaBoost\n');
0071 printfTrue(verbose, '========\n');
0072
0073
0074 if isempty(ab.lastExWeights)
0075
0076
0077
0078 wts = zeros(length(targetOuts), 1);
0079 nc = getNumClasses(ab);
0080 for c = 1:nc
0081 inds = targetOuts == c;
0082 numc = sum(inds);
0083 wts(inds) = 1 / (nc * numc);
0084 end
0085
0086
0087 else
0088
0089
0090 wts = ab.lastExWeights;
0091 end;
0092
0093
0094
0095 clErr = Inf;
0096 ab.thresh = 0;
0097 while (isinf(nStages) && clErr > ab.errBound) || ...
0098 (~isinf(nStages) && length(ab.clsWeights) < nStages),
0099
0100 printfTrue(verbose, '\n\t\t\t==================\n');
0101 printfTrue(verbose, '\t\t\tBoosting Stage # %d\n', length(ab.clsWeights) + 1);
0102 printfTrue(verbose, '\t\t\t==================\n\n');
0103
0104 if useCrossValidationSet
0105
0106 [trndCl, err] = learn(ab.weakCl, trnExamples, targetOuts, wts, ...
0107 crossvalset, crossvalout);
0108 else
0109 [trndCl, err] = learn(ab.weakCl, trnExamples, targetOuts, wts);
0110 end
0111
0112 if err < 1e-9,
0113
0114
0115
0116 err = 1e-9;
0117 end
0118
0119 printfTrue(verbose, ...
0120 '\t\t\tWeak Classifier has been trained, err = %f\n', err);
0121
0122 beta = err / (1 - err);
0123
0124 ab.trndCls = {ab.trndCls{:}, trndCl};
0125
0126 alpha = log(1 / beta) + log(getNumClasses(ab) - 1);
0127 ab.clsWeights = [ab.clsWeights, alpha];
0128
0129
0130
0131
0132 outs = computeOutputs(trndCl, trnExamples);
0133
0134 e = outs ~= targetOuts;
0135 wts = wts .* exp(alpha .* e);
0136
0137 wts = wts / sum(wts);
0138 if isinf(nStages) || nargout > 2
0139
0140
0141 abOuts = computeOutputs(ab, trnExamples);
0142 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0143 printfTrue(verbose, ...
0144 '\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0145 if ~isinf(nStages)
0146 itrErrors(length(ab.clsWeights) ,1) = clErr;
0147 end
0148 end
0149 end
0150 if ~isinf(nStages) && nargout < 3
0151
0152
0153 abOuts = computeOutputs(ab, trnExamples);
0154 clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0155 end
0156
0157
0158 ab.lastExWeights = wts;
0159
0160 if getNumClasses(ab) == 2
0161 if ~isnan(reqDetRate),
0162 [ab, clErr] = adjustThreshold(ab, trnExamples, targetOuts, ...
0163 reqDetRate);
0164 else
0165 ab.detectionRate = ...
0166 sum((targetOuts == getPosVal(ab)) & (abOuts == getPosVal(ab))) / ...
0167 sum((targetOuts == getPosVal(ab)));
0168 end
0169 end
0170
0171 printfTrue(verbose, ...
0172 '\t\t\tAdaBoost Training is Done with err = %f\n', clErr);
0173 end