added some batch files, workaround for NAN - Bug
Christoph Budziszewski

Christoph Budziszewski commited on 2009-04-23 16:48:20
Zeige 8 geänderte Dateien mit 50 Einfügungen und 12 Löschungen.


git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@178 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... ...
@@ -1,6 +1,7 @@
1 1
 function outputStruct = calculateDecodePerformance(header,subjectStruct,svmopts)
2 2
 outputStruct = struct;
3 3
 RANDOMIZE_DATAPOINTS = header.svmrnd;
4
+NAN_AS_ZERO = header.nantozero;
4 5
 
5 6
 timeline = header.timeline;
6 7
 timeline.frameShiftStart = header.frameShift.frameShiftStart;
... ...
@@ -37,6 +38,11 @@ for index = 1:timeLineEnd-timeLineStart+1
37 38
         svmlabel  = svmlabel(rndindex);
38 39
     end
39 40
 
41
+    if NAN_AS_ZERO
42
+        svmdata(isnan(svmdata))=0;
43
+    end
44
+        
45
+    
40 46
     decodePerformance = [decodePerformance; svm_single_crossval(svmlabel,svmdata,svmopts)];
41 47
     
42 48
     
... ...
@@ -122,17 +122,22 @@ switch task
122 122
         disp('SVM');
123 123
         svmopts    = getSvmArgs(model,1);
124 124
         header.svmrnd = getSvmRnd(model);
125
+        header.nantozero = 1;
125 126
         decode = calculateMultiSubjectDecodePerformance(header,data,svmopts);
126 127
         decode.header = header;
127 128
         assignin('base','decode',decode);
128 129
     case 'XSVM'
129 130
         disp('XSVM')
130 131
         svmopts  = getSvmArgs(model,0);
132
+        header.svmrnd = getSvmRnd(model);
133
+        header.nantozero = 1;
131 134
         decode = xsvm_subject_loop(header,data,svmopts);
132 135
         decode.header = header;
133 136
         assignin('base','decode',decode);
134 137
     case 'SOM'
135 138
         display('SOM');
139
+        somopts.rnd = 1;
140
+        somopts.nantozero = 1;
136 141
         somopts.size = [3 3];
137 142
         somopts.lattice = 'rect';
138 143
         somopts.nfold = 6;
... ...
@@ -141,6 +146,8 @@ switch task
141 146
         assignin('base','decode',decode);
142 147
     case 'XSOM'
143 148
         display('XSOM');
149
+        somopts.rnd = 1;
150
+        somopts.nantozero = 1;
144 151
         somopts.size = [3 3];
145 152
         somopts.lattice = 'rect';
146 153
         decode = som_xsubject_performance(header,data,somopts);
... ...
@@ -129,7 +129,7 @@ for s = 1:nSubjects
129 129
     for timeShiftIdx = 1:nSamplePoints
130 130
     % center timepoint && relative shift
131 131
     frameStartIdx  = floor(-globalStart+1+timeShiftIdx - 0.5*decodeDuration);
132
-    frameEndIdx    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
132
+    frameEndIdx    = min(ceil(frameStartIdx+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
133 133
 
134 134
         img3D = zeros(size(mask_image)); %output image prepare
135 135
 
... ...
@@ -3,7 +3,8 @@ function decode = som_subject_batch(header,subjectdata,somOpts)
3 3
 
4 4
 addpath 'somtoolbox2';
5 5
 
6
-RANDOMIZE_DATAPOINTS = 1;
6
+RANDOMIZE_DATAPOINTS = somOpts.rnd;
7
+NAN_AS_ZERO = somOpts.nantozero;
7 8
 
8 9
 decode = struct;
9 10
 decode.decodePerformance = [];
... ...
@@ -62,9 +63,18 @@ for subjectDataID = 1:nSubjects
62 63
             svm_train_label(chunkstart:chunkend) = []; %del test set
63 64
             svm_train_data(chunkstart:chunkend,:) = [];% del test set
64 65
 
65
-            [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
66
+            if NAN_AS_ZERO
67
+                svm_train_data(isnan(svm_train_data))=0;
68
+                svm_validation_data(isnan(svm_validation_data))=0;
69
+                display('NaN to 0');
70
+            end
66 71
             
72
+            if isempty(svm_train_data)
73
+                performance = 0;
74
+            else
75
+                [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
67 76
                 performance = som_decode(sM, svm_validation_data,svm_validation_label);
77
+            end
68 78
 
69 79
             cross_value = [cross_value, performance];
70 80
         end
... ...
@@ -6,7 +6,7 @@ som_lattice = somOptions.lattice;
6 6
 addpath 'somtoolbox2';
7 7
 sD = som_data_struct(svmdata,'labels',num2str(svmlabel));
8 8
 
9
-sM = som_make(sD,'msize', som_size,'lattice', som_lattice);
9
+sM = som_make(sD,'msize', som_size,'lattice', som_lattice, 'tracking', 1, 'init', 'lininit');
10 10
 sM = som_autolabel(sM,sD,'vote');
11 11
 
12 12
 end
13 13
\ No newline at end of file
... ...
@@ -8,7 +8,8 @@ if(nSubjects < 2)
8 8
     error('SVMCrossVal:somXSubjectPerformance:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!');
9 9
 end
10 10
 
11
-RANDOMIZE_DATAPOINTS = 1;
11
+RANDOMIZE_DATAPOINTS = somOpts.rnd;
12
+NAN_AS_ZERO = somOpts.nantozero;
12 13
 
13 14
 decode = struct;
14 15
 decode.decodePerformance = [];
... ...
@@ -65,9 +66,17 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1
65 66
             svm_train_label  = svm_train_label(rndindex);
66 67
         end
67 68
 
69
+        if NAN_AS_ZERO
70
+            svm_train_data(isnan(svm_train_data))=0;
71
+        end
72
+        
73
+        if isempty(svm_train_data)
74
+            performance = 0;
75
+        else
68 76
             [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
69 77
 
70 78
             performance = som_decode(sM, svm_validation_data,svm_validation_label);
79
+        end
71 80
         
72 81
         cross_value = [cross_value performance];
73 82
     end
... ...
@@ -1,14 +1,14 @@
1 1
 function ui_main(varargin)
2 2
 
3
-DEFAULT.selectedSubject = 2;
3
+DEFAULT.selectedSubject = [2];
4 4
 
5
-DEFAULT.pststart        = -15;
6
-DEFAULT.pstend          = 40;
7
-DEFAULT.baselinestart   = -3;
8
-DEFAULT.baselineend     = -1;
5
+DEFAULT.pststart        = -5;
6
+DEFAULT.pstend          = 15;
7
+DEFAULT.baselinestart   = 0;
8
+DEFAULT.baselineend     = 0;
9 9
 DEFAULT.trfactor        = 0.5;
10 10
 DEFAULT.frameshiftstart = -5;
11
-DEFAULT.frameshiftend   = 35;
11
+DEFAULT.frameshiftend   = 15;
12 12
 DEFAULT.frameshiftdur   = 0;
13 13
 DEFAULT.classdefstring  = 'left,\t[9,11,13]\nright,\t[10,12,14]';
14 14
 DEFAULT.voxelstring     = 'M1 l + 3 \nM1 r + 3\n';
... ...
@@ -8,7 +8,8 @@ if(nSubjects < 2)
8 8
     error('SVMCrossVal:xsvmSubjectLoop:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!');
9 9
 end
10 10
 
11
-RANDOMIZE_DATAPOINTS = 1;
11
+RANDOMIZE_DATAPOINTS = header.svmrnd;
12
+NAN_AS_ZERO = header.nantozero;
12 13
 
13 14
 decode = struct;
14 15
 decode.decodePerformance = [];
... ...
@@ -65,6 +66,11 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1
65 66
             svm_train_label  = svm_train_label(rndindex);
66 67
         end
67 68
         
69
+        if NAN_AS_ZERO
70
+            svm_train_data(isnan(svm_train_data))=0;
71
+        end
72
+        
73
+
68 74
         svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts);
69 75
         
70 76
         [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,'');
71 77