Browse code

nSubject-fold cross validation. visualization still missing.

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@150 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 16/03/2009 13:50:49
Showing 2 changed files
... ...
@@ -20,12 +20,15 @@ timeline = header.timeline;
20 20
         smoothed          = 'yes';
21 21
         
22 22
         PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.svmcrossval;
23
+        PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.classPerformance;
23 24
 %         CROSSVAL_METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF;
24 25
 
25 26
 
26 27
     f = figure;
27 28
     subplot(2,1,1);
28 29
     hold on;
30
+    size(psth)
31
+    if (size(psth) > 0)
29 32
       for voxel = 1:size(psth,2)
30 33
           for label = 1:size(psth{voxel},2)
31 34
               psthData = [];
... ...
@@ -35,6 +38,7 @@ timeline = header.timeline;
35 38
               plot(psthStart:psthEnd,psthData,[colorChooser(voxel), lineStyleChooser(label)]);
36 39
           end
37 40
       end
41
+    end
38 42
     axis([psthStart psthEnd PSTH_AXIS_MIN PSTH_AXIS_MAX])
39 43
     xlabel('time [sec]');
40 44
     ylabel('fMRI-signal change [%]');
... ...
@@ -7,9 +7,7 @@ decode = struct;
7 7
 decode.decodePerformance = [];
8 8
 decode.rawTimeCourse     = [];
9 9
 
10
-disp(sprintf('we have %g subjects. Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.',nSubjects));
11
-pause
12
-
10
+disp(sprintf('computinig additional datastructs for %u subjects',nSubjects));
13 11
 
14 12
 timeline = header.timeline;
15 13
 
... ...
@@ -27,13 +25,16 @@ timeLineEnd     = timeline.frameShiftEnd;
27 25
 
28 26
 addpath 'libsvm-mat-2.88-1';
29 27
 
28
+display(sprintf('%u -fold cross validation for %u timeslices.\n',nSubjects,size(1:timeLineEnd-timeLineStart+1,2)));
29
+disp(sprintf('Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.'));
30
+pause
30 31
 for timeIndex = 1:timeLineEnd-timeLineStart+1
31
-    svm_train_label = [];
32
-    svm_train_data  = [];
33
-    svm_validation_label = [];
34
-    svm_validation_data  = [];
35
-
32
+    cross_value = [];
36 33
     for validationSubjectID = 1:nSubjects
34
+        svm_train_label = [];
35
+        svm_train_data  = [];
36
+        svm_validation_label = [];
37
+        svm_validation_data  = [];
37 38
         for subjectDataID = 1:nSubjects
38 39
             svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
39 40
             if subjectDataID == validationSubjectID
... ...
@@ -45,12 +46,18 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1
45 46
             end
46 47
         end
47 48
         
49
+%         display(sprintf('Time %u: validation subject: %u, validation set size %g, training set size %g with %u subjects',...
50
+%             timeIndex, validationSubjectID, numel(svm_validation_label), numel(svm_train_label),nSubjects-1));
51
+        
48 52
         svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts);
49 53
         
50 54
         [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,'');
55
+        cross_value = [cross_value accuracy(1)];
51 56
         
52
-        accuracy(1)
53
-
57
+    end
58
+    decode.decodePerformance = [decode.decodePerformance mean(cross_value)];
59
+%     decode.rawTimeCourse = [decode.rawTimeCourse cross_value];
60
+    
54 61
 %         decode.(namehelper)         = calculateDecodePerformance(header,currentSubject,svmopts);
55 62
 % 
56 63
 %         display('... done');
... ...
@@ -68,6 +75,6 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1
68 75
         %         svmlabel  = svmlabel(rndindex);
69 76
         %         end
70 77
 
71
-    end
78
+    
72 79
 end
73 80
 end