Christoph Budziszewski commited on 2009-04-23 16:48:20
Zeige 8 geänderte Dateien mit 50 Einfügungen und 12 Löschungen.
git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@178 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... | ... |
@@ -1,6 +1,7 @@ |
1 | 1 |
function outputStruct = calculateDecodePerformance(header,subjectStruct,svmopts) |
2 | 2 |
outputStruct = struct; |
3 | 3 |
RANDOMIZE_DATAPOINTS = header.svmrnd; |
4 |
+NAN_AS_ZERO = header.nantozero; |
|
4 | 5 |
|
5 | 6 |
timeline = header.timeline; |
6 | 7 |
timeline.frameShiftStart = header.frameShift.frameShiftStart; |
... | ... |
@@ -37,6 +38,11 @@ for index = 1:timeLineEnd-timeLineStart+1 |
37 | 38 |
svmlabel = svmlabel(rndindex); |
38 | 39 |
end |
39 | 40 |
|
41 |
+ if NAN_AS_ZERO |
|
42 |
+ svmdata(isnan(svmdata))=0; |
|
43 |
+ end |
|
44 |
+ |
|
45 |
+ |
|
40 | 46 |
decodePerformance = [decodePerformance; svm_single_crossval(svmlabel,svmdata,svmopts)]; |
41 | 47 |
|
42 | 48 |
|
... | ... |
@@ -122,17 +122,22 @@ switch task |
122 | 122 |
disp('SVM'); |
123 | 123 |
svmopts = getSvmArgs(model,1); |
124 | 124 |
header.svmrnd = getSvmRnd(model); |
125 |
+ header.nantozero = 1; |
|
125 | 126 |
decode = calculateMultiSubjectDecodePerformance(header,data,svmopts); |
126 | 127 |
decode.header = header; |
127 | 128 |
assignin('base','decode',decode); |
128 | 129 |
case 'XSVM' |
129 | 130 |
disp('XSVM') |
130 | 131 |
svmopts = getSvmArgs(model,0); |
132 |
+ header.svmrnd = getSvmRnd(model); |
|
133 |
+ header.nantozero = 1; |
|
131 | 134 |
decode = xsvm_subject_loop(header,data,svmopts); |
132 | 135 |
decode.header = header; |
133 | 136 |
assignin('base','decode',decode); |
134 | 137 |
case 'SOM' |
135 | 138 |
display('SOM'); |
139 |
+ somopts.rnd = 1; |
|
140 |
+ somopts.nantozero = 1; |
|
136 | 141 |
somopts.size = [3 3]; |
137 | 142 |
somopts.lattice = 'rect'; |
138 | 143 |
somopts.nfold = 6; |
... | ... |
@@ -141,6 +146,8 @@ switch task |
141 | 146 |
assignin('base','decode',decode); |
142 | 147 |
case 'XSOM' |
143 | 148 |
display('XSOM'); |
149 |
+ somopts.rnd = 1; |
|
150 |
+ somopts.nantozero = 1; |
|
144 | 151 |
somopts.size = [3 3]; |
145 | 152 |
somopts.lattice = 'rect'; |
146 | 153 |
decode = som_xsubject_performance(header,data,somopts); |
... | ... |
@@ -129,7 +129,7 @@ for s = 1:nSubjects |
129 | 129 |
for timeShiftIdx = 1:nSamplePoints |
130 | 130 |
% center timepoint && relative shift |
131 | 131 |
frameStartIdx = floor(-globalStart+1+timeShiftIdx - 0.5*decodeDuration); |
132 |
- frameEndIdx = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd); |
|
132 |
+ frameEndIdx = min(ceil(frameStartIdx+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd); |
|
133 | 133 |
|
134 | 134 |
img3D = zeros(size(mask_image)); %output image prepare |
135 | 135 |
|
... | ... |
@@ -3,7 +3,8 @@ function decode = som_subject_batch(header,subjectdata,somOpts) |
3 | 3 |
|
4 | 4 |
addpath 'somtoolbox2'; |
5 | 5 |
|
6 |
-RANDOMIZE_DATAPOINTS = 1; |
|
6 |
+RANDOMIZE_DATAPOINTS = somOpts.rnd; |
|
7 |
+NAN_AS_ZERO = somOpts.nantozero; |
|
7 | 8 |
|
8 | 9 |
decode = struct; |
9 | 10 |
decode.decodePerformance = []; |
... | ... |
@@ -62,9 +63,18 @@ for subjectDataID = 1:nSubjects |
62 | 63 |
svm_train_label(chunkstart:chunkend) = []; %del test set |
63 | 64 |
svm_train_data(chunkstart:chunkend,:) = [];% del test set |
64 | 65 |
|
65 |
- [sD sM] = som_train(svm_train_label, svm_train_data, somOpts); |
|
66 |
+ if NAN_AS_ZERO |
|
67 |
+ svm_train_data(isnan(svm_train_data))=0; |
|
68 |
+ svm_validation_data(isnan(svm_validation_data))=0; |
|
69 |
+ display('NaN to 0'); |
|
70 |
+ end |
|
66 | 71 |
|
72 |
+ if isempty(svm_train_data) |
|
73 |
+ performance = 0; |
|
74 |
+ else |
|
75 |
+ [sD sM] = som_train(svm_train_label, svm_train_data, somOpts); |
|
67 | 76 |
performance = som_decode(sM, svm_validation_data,svm_validation_label); |
77 |
+ end |
|
68 | 78 |
|
69 | 79 |
cross_value = [cross_value, performance]; |
70 | 80 |
end |
... | ... |
@@ -6,7 +6,7 @@ som_lattice = somOptions.lattice; |
6 | 6 |
addpath 'somtoolbox2'; |
7 | 7 |
sD = som_data_struct(svmdata,'labels',num2str(svmlabel)); |
8 | 8 |
|
9 |
-sM = som_make(sD,'msize', som_size,'lattice', som_lattice); |
|
9 |
+sM = som_make(sD,'msize', som_size,'lattice', som_lattice, 'tracking', 1, 'init', 'lininit'); |
|
10 | 10 |
sM = som_autolabel(sM,sD,'vote'); |
11 | 11 |
|
12 | 12 |
end |
13 | 13 |
\ No newline at end of file |
... | ... |
@@ -8,7 +8,8 @@ if(nSubjects < 2) |
8 | 8 |
error('SVMCrossVal:somXSubjectPerformance:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!'); |
9 | 9 |
end |
10 | 10 |
|
11 |
-RANDOMIZE_DATAPOINTS = 1; |
|
11 |
+RANDOMIZE_DATAPOINTS = somOpts.rnd; |
|
12 |
+NAN_AS_ZERO = somOpts.nantozero; |
|
12 | 13 |
|
13 | 14 |
decode = struct; |
14 | 15 |
decode.decodePerformance = []; |
... | ... |
@@ -65,9 +66,17 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1 |
65 | 66 |
svm_train_label = svm_train_label(rndindex); |
66 | 67 |
end |
67 | 68 |
|
69 |
+ if NAN_AS_ZERO |
|
70 |
+ svm_train_data(isnan(svm_train_data))=0; |
|
71 |
+ end |
|
72 |
+ |
|
73 |
+ if isempty(svm_train_data) |
|
74 |
+ performance = 0; |
|
75 |
+ else |
|
68 | 76 |
[sD sM] = som_train(svm_train_label, svm_train_data, somOpts); |
69 | 77 |
|
70 | 78 |
performance = som_decode(sM, svm_validation_data,svm_validation_label); |
79 |
+ end |
|
71 | 80 |
|
72 | 81 |
cross_value = [cross_value performance]; |
73 | 82 |
end |
... | ... |
@@ -1,14 +1,14 @@ |
1 | 1 |
function ui_main(varargin) |
2 | 2 |
|
3 |
-DEFAULT.selectedSubject = 2; |
|
3 |
+DEFAULT.selectedSubject = [2]; |
|
4 | 4 |
|
5 |
-DEFAULT.pststart = -15; |
|
6 |
-DEFAULT.pstend = 40; |
|
7 |
-DEFAULT.baselinestart = -3; |
|
8 |
-DEFAULT.baselineend = -1; |
|
5 |
+DEFAULT.pststart = -5; |
|
6 |
+DEFAULT.pstend = 15; |
|
7 |
+DEFAULT.baselinestart = 0; |
|
8 |
+DEFAULT.baselineend = 0; |
|
9 | 9 |
DEFAULT.trfactor = 0.5; |
10 | 10 |
DEFAULT.frameshiftstart = -5; |
11 |
-DEFAULT.frameshiftend = 35; |
|
11 |
+DEFAULT.frameshiftend = 15; |
|
12 | 12 |
DEFAULT.frameshiftdur = 0; |
13 | 13 |
DEFAULT.classdefstring = 'left,\t[9,11,13]\nright,\t[10,12,14]'; |
14 | 14 |
DEFAULT.voxelstring = 'M1 l + 3 \nM1 r + 3\n'; |
... | ... |
@@ -8,7 +8,8 @@ if(nSubjects < 2) |
8 | 8 |
error('SVMCrossVal:xsvmSubjectLoop:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!'); |
9 | 9 |
end |
10 | 10 |
|
11 |
-RANDOMIZE_DATAPOINTS = 1; |
|
11 |
+RANDOMIZE_DATAPOINTS = header.svmrnd; |
|
12 |
+NAN_AS_ZERO = header.nantozero; |
|
12 | 13 |
|
13 | 14 |
decode = struct; |
14 | 15 |
decode.decodePerformance = []; |
... | ... |
@@ -65,6 +66,11 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1 |
65 | 66 |
svm_train_label = svm_train_label(rndindex); |
66 | 67 |
end |
67 | 68 |
|
69 |
+ if NAN_AS_ZERO |
|
70 |
+ svm_train_data(isnan(svm_train_data))=0; |
|
71 |
+ end |
|
72 |
+ |
|
73 |
+ |
|
68 | 74 |
svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts); |
69 | 75 |
|
70 | 76 |
[plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,''); |
71 | 77 |