function [cl, err] = learn(cl, trnExamples, targetOuts) learning function for the BestDecisionStumpClassifier Inputs: trainingExamples: examples for training the classifier. targetOuts: target classification outputs. its size must be the same as the number of examples. wts: weights of training examples (used to compute the weighted classification error) Outputs: cl : trained BestDecisionStumpClassifier clErr: classification error of the trained classifier
0001 function [cl, clErr] = learn(cl, trnExamples, targetOuts, wts) 0002 % function [cl, err] = learn(cl, trnExamples, targetOuts) 0003 % learning function for the BestDecisionStumpClassifier 0004 % 0005 % Inputs: 0006 % trainingExamples: examples for training the classifier. 0007 % targetOuts: target classification outputs. its size must be the 0008 % same as the number of examples. 0009 % wts: weights of training examples (used to compute the weighted 0010 % classification error) 0011 % Outputs: 0012 % cl : trained BestDecisionStumpClassifier 0013 % clErr: classification error of the trained classifier 0014 0015 %% Deal with Cell Array Input 0016 if iscell(trnExamples), 0017 trnExamples = cell2mat(trnExamples); 0018 end 0019 0020 cl.trainedCl = ClassificationTree.fit(trnExamples, targetOuts, ... 0021 'weights', wts, 'MinParent', size(trnExamples,1), 'MergeLeaves', 'off'); 0022 0023 view(cl.trainedCl); 0024 cl.isTrained = true; 0025 0026 %% Compute Classification Error 0027 if nargout > 1 0028 outs = cl.trainedCl.predict(trnExamples); 0029 0030 if nargin < 4, 0031 clErr = sum(outs ~= targetOuts) / nExamples; 0032 else 0033 clErr = (outs ~= targetOuts)' * wts; 0034 end 0035 0036 if isnan(clErr), 0037 error('error is nan'); 0038 end 0039 end