Christoph Budziszewski commited on 2009-03-12 17:15:24
Zeige 5 geänderte Dateien mit 90 Einfügungen und 4 Löschungen.
git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@149 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... | ... |
@@ -22,4 +22,7 @@ SVMCROSSVAL_SUBJECT_PREFIX = 'subject'; |
22 | 22 |
|
23 | 23 |
global SVMCROSSVAL_SUBJECTSTRUCT_NAME; |
24 | 24 |
SVMCROSSVAL_SUBJECTSTRUCT_NAME = 'subjectStruct'; |
25 |
+ |
|
26 |
+global SVMCROSSVAL_PREPROCESSED_DATA_NAME; |
|
27 |
+SVMCROSSVAL_PREPROCESSED_DATA_NAME = 'preprocessedData'; |
|
25 | 28 |
end |
26 | 29 |
\ No newline at end of file |
... | ... |
@@ -72,16 +72,18 @@ end |
72 | 72 |
end |
73 | 73 |
|
74 | 74 |
function decode(model,task) |
75 |
+preprocessedData = evalin('base','preprocessedData'); |
|
75 | 76 |
switch task |
76 | 77 |
case 'SVM' |
77 | 78 |
disp('SVM'); |
78 | 79 |
svmopts = getSvmArgs(model,1); |
79 |
- preprocessedData = evalin('base','preprocessedData'); |
|
80 | 80 |
decode = calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts); |
81 | 81 |
assignin('base','decode',decode); |
82 |
- case 'X-SVM' |
|
82 |
+ case 'XSVM' |
|
83 | 83 |
disp('not implemented') |
84 |
- |
|
84 |
+ svmopts = getSvmArgs(model,0); |
|
85 |
+ decode = xsvm_subject_loop(preprocessedData.header,preprocessedData.subjectdata,svmopts); |
|
86 |
+ assignin('base','decode',decode); |
|
85 | 87 |
case 'SOM' |
86 | 88 |
disp('not implemented') |
87 | 89 |
|
... | ... |
@@ -132,7 +132,7 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]); |
132 | 132 |
'Units','normalized',... |
133 | 133 |
'Position',[0 0.0 1 0.25]); |
134 | 134 |
set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model. |
135 |
- set(btnRunXSVM,'Enable','off'); |
|
135 |
+ set(btnRunXSVM,'Enable','on'); |
|
136 | 136 |
|
137 | 137 |
btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',... |
138 | 138 |
'Units','normalized',... |
... | ... |
@@ -0,0 +1,73 @@ |
1 |
+%% subject loop |
|
2 |
+function decode = xsvm_subject_loop(header,subjectdata,svmopts) |
|
3 |
+ |
|
4 |
+nSubjects = numel(subjectdata); |
|
5 |
+ |
|
6 |
+decode = struct; |
|
7 |
+decode.decodePerformance = []; |
|
8 |
+decode.rawTimeCourse = []; |
|
9 |
+ |
|
10 |
+disp(sprintf('we have %g subjects. Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.',nSubjects)); |
|
11 |
+pause |
|
12 |
+ |
|
13 |
+ |
|
14 |
+timeline = header.timeline; |
|
15 |
+ |
|
16 |
+for subjectDataID = 1:nSubjects |
|
17 |
+ currentSubject = subjectdata{subjectDataID}; |
|
18 |
+ timePointArgs.pst = currentSubject.pst; |
|
19 |
+ timePointArgs.labelMap = LabelMap(header.classDef.labelCells,header.classDef.conditionCells); |
|
20 |
+ timePointArgs.eventList = header.classDef.eventMatrix; |
|
21 |
+ |
|
22 |
+ timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs); |
|
23 |
+end |
|
24 |
+ |
|
25 |
+timeLineStart = timeline.frameShiftStart; |
|
26 |
+timeLineEnd = timeline.frameShiftEnd; |
|
27 |
+ |
|
28 |
+addpath 'libsvm-mat-2.88-1'; |
|
29 |
+ |
|
30 |
+for timeIndex = 1:timeLineEnd-timeLineStart+1 |
|
31 |
+ svm_train_label = []; |
|
32 |
+ svm_train_data = []; |
|
33 |
+ svm_validation_label = []; |
|
34 |
+ svm_validation_data = []; |
|
35 |
+ |
|
36 |
+ for validationSubjectID = 1:nSubjects |
|
37 |
+ for subjectDataID = 1:nSubjects |
|
38 |
+ svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex); |
|
39 |
+ if subjectDataID == validationSubjectID |
|
40 |
+ svm_validation_label = svmstruct.svmlabel; |
|
41 |
+ svm_validation_data = svmstruct.svmdata; |
|
42 |
+ else |
|
43 |
+ svm_train_label = [svm_train_label; svmstruct.svmlabel]; |
|
44 |
+ svm_train_data = [svm_train_data; svmstruct.svmdata]; |
|
45 |
+ end |
|
46 |
+ end |
|
47 |
+ |
|
48 |
+ svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts); |
|
49 |
+ |
|
50 |
+ [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,''); |
|
51 |
+ |
|
52 |
+ accuracy(1) |
|
53 |
+ |
|
54 |
+% decode.(namehelper) = calculateDecodePerformance(header,currentSubject,svmopts); |
|
55 |
+% |
|
56 |
+% display('... done'); |
|
57 |
+% display('restoring warnings'); |
|
58 |
+% warning(warning_state); |
|
59 |
+% |
|
60 |
+% decode.decodePerformance = [decode.decodePerformance decode.(namehelper).decodePerformance]; |
|
61 |
+% decode.rawTimeCourse = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse]; |
|
62 |
+ |
|
63 |
+% assignin('base','decode',decode); |
|
64 |
+ |
|
65 |
+ % if RANDOMIZE_DATAPOINTS |
|
66 |
+ % rndindex = randperm(length(svmlabel)); |
|
67 |
+ % svmdata = svmdata(rndindex,:); |
|
68 |
+ % svmlabel = svmlabel(rndindex); |
|
69 |
+ % end |
|
70 |
+ |
|
71 |
+ end |
|
72 |
+end |
|
73 |
+end |
|
0 | 74 |