Browse code

SVMCrossVal toolbox init

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@90 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on17/12/2008 13:45:29
Showing1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,339 @@
1
+#include <stdio.h>
2
+#include <stdlib.h>
3
+#include <string.h>
4
+#include "svm.h"
5
+
6
+#include "mex.h"
7
+#include "svm_model_matlab.h"
8
+
9
+#if MX_API_VER < 0x07030000
10
+typedef int mwIndex;
11
+#endif 
12
+
13
+#define CMD_LEN 2048
14
+
15
+void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
16
+{
17
+	int i, j, low, high;
18
+	mwIndex *ir, *jc;
19
+	double *samples;
20
+
21
+	ir = mxGetIr(prhs);
22
+	jc = mxGetJc(prhs);
23
+	samples = mxGetPr(prhs);
24
+
25
+	// each column is one instance
26
+	j = 0;
27
+	low = jc[index], high = jc[index+1];
28
+	for(i=low;i<high;i++)
29
+	{
30
+		x[j].index = ir[i] + 1;
31
+		x[j].value = samples[i];
32
+		j++;
33
+ 	}
34
+	x[j].index = -1;
35
+}
36
+
37
+static void fake_answer(mxArray *plhs[])
38
+{
39
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
40
+	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
41
+	plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
42
+}
43
+
44
+void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
45
+{
46
+	int label_vector_row_num, label_vector_col_num;
47
+	int feature_number, testing_instance_number;
48
+	int instance_index;
49
+	double *ptr_instance, *ptr_label, *ptr_predict_label; 
50
+	double *ptr_prob_estimates, *ptr_dec_values, *ptr;
51
+	struct svm_node *x;
52
+	mxArray *pplhs[1]; // transposed instance sparse matrix
53
+
54
+	int correct = 0;
55
+	int total = 0;
56
+	double error = 0;
57
+	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
58
+
59
+	int svm_type=svm_get_svm_type(model);
60
+	int nr_class=svm_get_nr_class(model);
61
+	double *prob_estimates=NULL;
62
+
63
+	// prhs[1] = testing instance matrix
64
+	feature_number = mxGetN(prhs[1]);
65
+	testing_instance_number = mxGetM(prhs[1]);
66
+	label_vector_row_num = mxGetM(prhs[0]);
67
+	label_vector_col_num = mxGetN(prhs[0]);
68
+
69
+	if(label_vector_row_num!=testing_instance_number)
70
+	{
71
+		mexPrintf("Length of label vector does not match # of instances.\n");
72
+		fake_answer(plhs);
73
+		return;
74
+	}
75
+	if(label_vector_col_num!=1)
76
+	{
77
+		mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
78
+		fake_answer(plhs);
79
+		return;
80
+	}
81
+
82
+	ptr_instance = mxGetPr(prhs[1]);
83
+	ptr_label    = mxGetPr(prhs[0]);
84
+	
85
+	// transpose instance matrix
86
+	if(mxIsSparse(prhs[1]))
87
+	{
88
+		if(model->param.kernel_type == PRECOMPUTED)
89
+		{
90
+			// precomputed kernel requires dense matrix, so we make one
91
+			mxArray *rhs[1], *lhs[1];
92
+			rhs[0] = mxDuplicateArray(prhs[1]);
93
+			if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
94
+			{
95
+				mexPrintf("Error: cannot full testing instance matrix\n");
96
+				fake_answer(plhs);
97
+				return;
98
+			}
99
+			ptr_instance = mxGetPr(lhs[0]);
100
+			mxDestroyArray(rhs[0]);
101
+		}
102
+		else
103
+		{
104
+			mxArray *pprhs[1];
105
+			pprhs[0] = mxDuplicateArray(prhs[1]);
106
+			if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
107
+			{
108
+				mexPrintf("Error: cannot transpose testing instance matrix\n");
109
+				fake_answer(plhs);
110
+				return;
111
+			}
112
+		}
113
+	}
114
+
115
+	if(predict_probability)
116
+	{
117
+		if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
118
+			mexPrintf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
119
+		else
120
+			prob_estimates = (double *) malloc(nr_class*sizeof(double));
121
+	}
122
+
123
+	plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
124
+	if(predict_probability)
125
+	{
126
+		// prob estimates are in plhs[2]
127
+		if(svm_type==C_SVC || svm_type==NU_SVC)
128
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
129
+		else
130
+			plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
131
+	}
132
+	else
133
+	{
134
+		// decision values are in plhs[2]
135
+		if(svm_type == ONE_CLASS ||
136
+		   svm_type == EPSILON_SVR ||
137
+		   svm_type == NU_SVR)
138
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
139
+		else
140
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
141
+	}
142
+
143
+	ptr_predict_label = mxGetPr(plhs[0]);
144
+	ptr_prob_estimates = mxGetPr(plhs[2]);
145
+	ptr_dec_values = mxGetPr(plhs[2]);
146
+	x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
147
+	for(instance_index=0;instance_index<testing_instance_number;instance_index++)
148
+	{
149
+		int i;
150
+		double target,v;
151
+
152
+		target = ptr_label[instance_index];
153
+
154
+		if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
155
+			read_sparse_instance(pplhs[0], instance_index, x);
156
+		else
157
+		{
158
+			for(i=0;i<feature_number;i++)
159
+			{
160
+				x[i].index = i+1;
161
+				x[i].value = ptr_instance[testing_instance_number*i+instance_index];
162
+			}
163
+			x[feature_number].index = -1;
164
+		}
165
+
166
+		if(predict_probability) 
167
+		{
168
+			if(svm_type==C_SVC || svm_type==NU_SVC)
169
+			{
170
+				v = svm_predict_probability(model, x, prob_estimates);
171
+				ptr_predict_label[instance_index] = v;
172
+				for(i=0;i<nr_class;i++)
173
+					ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
174
+			} else {
175
+				v = svm_predict(model,x);
176
+				ptr_predict_label[instance_index] = v;
177
+			}
178
+		}
179
+		else
180
+		{
181
+			v = svm_predict(model,x);
182
+			ptr_predict_label[instance_index] = v;
183
+
184
+			if(svm_type == ONE_CLASS ||
185
+			   svm_type == EPSILON_SVR ||
186
+			   svm_type == NU_SVR)
187
+			{
188
+				double res;
189
+				svm_predict_values(model, x, &res);
190
+				ptr_dec_values[instance_index] = res;
191
+			}
192
+			else
193
+			{
194
+				double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
195
+				svm_predict_values(model, x, dec_values);
196
+				for(i=0;i<(nr_class*(nr_class-1))/2;i++)
197
+					ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
198
+				free(dec_values);
199
+			}
200
+		}
201
+
202
+		if(v == target)
203
+			++correct;
204
+		error += (v-target)*(v-target);
205
+		sumv += v;
206
+		sumy += target;
207
+		sumvv += v*v;
208
+		sumyy += target*target;
209
+		sumvy += v*target;
210
+		++total;
211
+	}
212
+	if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
213
+	{
214
+		mexPrintf("Mean squared error = %g (regression)\n",error/total);
215
+		mexPrintf("Squared correlation coefficient = %g (regression)\n",
216
+			((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
217
+			((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))
218
+			);
219
+	}
220
+	else
221
+		mexPrintf("Accuracy = %g%% (%d/%d) (classification)\n",
222
+			(double)correct/total*100,correct,total);
223
+
224
+	// return accuracy, mean squared error, squared correlation coefficient
225
+	plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
226
+	ptr = mxGetPr(plhs[1]);
227
+	ptr[0] = (double)correct/total*100;
228
+	ptr[1] = error/total;
229
+	ptr[2] = ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
230
+				((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy));
231
+
232
+	free(x);
233
+	if(prob_estimates != NULL)
234
+		free(prob_estimates);
235
+}
236
+
237
+void exit_with_help()
238
+{
239
+	mexPrintf(
240
+	"Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
241
+	"libsvm_options:\n"
242
+	"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
243
+	);
244
+}
245
+
246
+void mexFunction( int nlhs, mxArray *plhs[],
247
+		 int nrhs, const mxArray *prhs[] )
248
+{
249
+	int prob_estimate_flag = 0;
250
+	struct svm_model *model;
251
+
252
+	if(nrhs > 4 || nrhs < 3)
253
+	{
254
+		exit_with_help();
255
+		fake_answer(plhs);
256
+		return;
257
+	}
258
+
259
+	if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
260
+		mexPrintf("Error: label vector and instance matrix must be double\n");
261
+		fake_answer(plhs);
262
+		return;
263
+	}
264
+
265
+	if(mxIsStruct(prhs[2]))
266
+	{
267
+		const char *error_msg;
268
+
269
+		// parse options
270
+		if(nrhs==4)
271
+		{
272
+			int i, argc = 1;
273
+			char cmd[CMD_LEN], *argv[CMD_LEN/2];
274
+
275
+			// put options in argv[]
276
+			mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
277
+			if((argv[argc] = strtok(cmd, " ")) != NULL)
278
+				while((argv[++argc] = strtok(NULL, " ")) != NULL)
279
+					;
280
+
281
+			for(i=1;i<argc;i++)
282
+			{
283
+				if(argv[i][0] != '-') break;
284
+				if(++i>=argc)
285
+				{
286
+					exit_with_help();
287
+					fake_answer(plhs);
288
+					return;
289
+				}
290
+				switch(argv[i-1][1])
291
+				{
292
+					case 'b':
293
+						prob_estimate_flag = atoi(argv[i]);
294
+						break;
295
+					default:
296
+						mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
297
+						exit_with_help();
298
+						fake_answer(plhs);
299
+						return;
300
+				}
301
+			}
302
+		}
303
+
304
+		model = matlab_matrix_to_model(prhs[2], &error_msg);
305
+		if (model == NULL)
306
+		{
307
+			mexPrintf("Error: can't read model: %s\n", error_msg);
308
+			fake_answer(plhs);
309
+			return;
310
+		}
311
+
312
+		if(prob_estimate_flag)
313
+		{
314
+			if(svm_check_probability_model(model)==0)
315
+			{
316
+				mexPrintf("Model does not support probabiliy estimates\n");
317
+				fake_answer(plhs);
318
+				svm_destroy_model(model);
319
+				return;
320
+			}
321
+		}
322
+		else
323
+		{
324
+			if(svm_check_probability_model(model)!=0)
325
+				printf("Model supports probability estimates, but disabled in predicton.\n");
326
+		}
327
+
328
+		predict(plhs, prhs, model, prob_estimate_flag);
329
+		// destroy model
330
+		svm_destroy_model(model);
331
+	}
332
+	else
333
+	{
334
+		mexPrintf("model file should be a struct array\n");
335
+		fake_answer(plhs);
336
+	}
337
+
338
+	return;
339
+}