Home > demos > svm_demo.m

svm_demo

PURPOSE ^

% SVM Classifier Demo

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

% SVM Classifier Demo
 Several ways to use svm classifier
%
% Load dataset

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 %% SVM Classifier Demo
0002 % Several ways to use svm classifier
0003 %%
0004 %% Load dataset
0005 
0006 % Name of matlab dataset
0007 % Available names : [cancer, wine, iris, crab, glass, simpleclass, thyroid]
0008 datasetName = 'iris';
0009 
0010 % This is just an example to load matlab datasets
0011 % you can load datasets from different sources with different ways
0012 % as long as the you provide x the training instance which is a matrix
0013 % of size(Number of Instance,Number of Features) and y which is the
0014 % label matrix having size(Number of Instance, 1)
0015 load(strcat(datasetName, '_dataset'));
0016 eval(sprintf('x = %sInputs;', datasetName));
0017 eval(sprintf('y = %sTargets;', datasetName));
0018 x = x';
0019 y = y';
0020 numClasses = size(y, 2);
0021 [~,y] = max(y,[],2);
0022 
0023 numFeatures = size(x, 2);
0024 numInstances = size(x, 1);
0025 
0026 % display dataset info
0027 disp(['Dataset Name ' datasetName]);
0028 disp(['Number of Classes ' num2str(numClasses)]);
0029 disp(['Number of Instances ' num2str(numInstances)]);
0030 disp(['Number of Features ' num2str(numFeatures)]);
0031 
0032 
0033 %% Basic Usage
0034 
0035 fprintf('===========\n');
0036 fprintf('Basic Usage\n');
0037 fprintf('===========\n');
0038 
0039 % create svm classifier
0040 svmcl = SVMClassifier(numClasses);
0041 
0042 % train svm classifier
0043 disp('Training Classifier');
0044 [svmcl, learnErr] = learn(svmcl, x, y);
0045 fprintf('Learning Error %f\n', learnErr);
0046 
0047 disp('Testing Classifier');
0048 outs = computeOutputs(svmcl, x);
0049 % other way to calculate error either on learning dataset or other dataset
0050 err = sum(outs ~= y) / numInstances;
0051 fprintf('Learning Error %f\n', err);
0052 
0053 % comparing outputs of first 5 instances in predicted and target outputs
0054 disp('[predicted outputs : correct outputs]');
0055 disp([outs(1:5, :) y(1:5 , :)]);
0056 %% Instances Weights
0057 
0058 fprintf('=================\n');
0059 fprintf('Instances Weights\n');
0060 fprintf('=================\n');
0061 
0062 svmcl = SVMClassifier(numClasses);
0063 wts = ones(numInstances, 1) / numInstances;
0064 
0065 % training svm with given weights
0066 svmcl = learn(svmcl, x, y, wts);
0067 
0068 outs = computeOutputs(svmcl, x);
0069 
0070 err = sum(outs ~= y) / numInstances;
0071 fprintf('Error %f\n', err);
0072 
0073 %% Passing arguments and probability estimates
0074 
0075 fprintf('====================================\n');
0076 fprintf('Passing arguments and prob estimates\n');
0077 fprintf('====================================\n');
0078 % -b 1 is argument used to get probability estimates (prob. that each
0079 % instance belong to the predicted class)
0080 % for complete list of argument see http://www.csie.ntu.edu.tw/~cjlin/libsvm/
0081 svmcl = SVMClassifier(numClasses,'-c 10 -g 1 -b 1','-b 1');
0082 svmcl = learn(svmcl, x, y);
0083 
0084 [outs, prob] = computeOutputs(svmcl, x);
0085 
0086 disp('[predicted output : correct output : class probabilities]');
0087 disp([outs(1:5, :) y(1:5 , :) prob(1:5, :) ]);
0088 
0089 
0090 %% Displaying The Classifier
0091 
0092 fprintf('=========================\n');
0093 fprintf('Displaying The Classifier\n');
0094 fprintf('=========================\n');
0095 
0096 
0097 svmcl = SVMClassifier(numClasses);
0098 
0099 disp('Display before training');
0100 display(svmcl);
0101 
0102 disp('-----------------------');
0103 
0104 svmcl = learn(svmcl, x, y);
0105 disp('Display after training');
0106 display(svmcl);
0107 
0108 
0109 
0110 %% Save And Load Classifier To/From A File
0111 
0112 fprintf('=======================================\n');
0113 fprintf('Save And Load Classifier To/From A File\n');
0114 fprintf('=======================================\n');
0115 
0116 svmcl = SVMClassifier(numClasses);
0117 svmcl = learn(svmcl, x, y, wts);
0118 
0119 % same classifier to file
0120 saveToFile(svmcl, 'test.bin');
0121 
0122 % load classifier from file
0123 svmcl2 = loadFromFile(SVMClassifier, 'test.bin');
0124 
0125 outs1 = computeOutputs(svmcl, x);
0126 err1 = sum(outs1 ~= y) / numInstances;
0127 outs2 = computeOutputs(svmcl2, x);
0128 err2 = sum(outs2 ~= y) / numInstances;
0129 
0130 fprintf('Error Before Save %f, Error After Save %f\n', err1, err2);
0131 
0132 %% Using kfold with svm
0133 fprintf('=========================\n');
0134 fprintf('Using Kfold with SVM\n');
0135 fprintf('=========================\n');
0136 
0137 svmcl = SVMClassifier(numClasses);
0138 cp = kfold(x, y, 10, svmcl);
0139 
0140 fprintf('Accuracy of 10 fold-cross validation %f\n', cp.CorrectRate * 100);

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005