Christoph Budziszewski commited on 2009-03-12 17:15:24
Zeige 5 geänderte Dateien mit 90 Einfügungen und 4 Löschungen.
git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@149 83ab2cfd-5345-466c-8aeb-2b2739fb922d
| ... | ... |
@@ -22,4 +22,7 @@ SVMCROSSVAL_SUBJECT_PREFIX = 'subject'; |
| 22 | 22 |
|
| 23 | 23 |
global SVMCROSSVAL_SUBJECTSTRUCT_NAME; |
| 24 | 24 |
SVMCROSSVAL_SUBJECTSTRUCT_NAME = 'subjectStruct'; |
| 25 |
+ |
|
| 26 |
+global SVMCROSSVAL_PREPROCESSED_DATA_NAME; |
|
| 27 |
+SVMCROSSVAL_PREPROCESSED_DATA_NAME = 'preprocessedData'; |
|
| 25 | 28 |
end |
| 26 | 29 |
\ No newline at end of file |
| ... | ... |
@@ -72,16 +72,18 @@ end |
| 72 | 72 |
end |
| 73 | 73 |
|
| 74 | 74 |
function decode(model,task) |
| 75 |
+preprocessedData = evalin('base','preprocessedData');
|
|
| 75 | 76 |
switch task |
| 76 | 77 |
case 'SVM' |
| 77 | 78 |
disp('SVM');
|
| 78 | 79 |
svmopts = getSvmArgs(model,1); |
| 79 |
- preprocessedData = evalin('base','preprocessedData');
|
|
| 80 | 80 |
decode = calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts); |
| 81 | 81 |
assignin('base','decode',decode);
|
| 82 |
- case 'X-SVM' |
|
| 82 |
+ case 'XSVM' |
|
| 83 | 83 |
disp('not implemented')
|
| 84 |
- |
|
| 84 |
+ svmopts = getSvmArgs(model,0); |
|
| 85 |
+ decode = xsvm_subject_loop(preprocessedData.header,preprocessedData.subjectdata,svmopts); |
|
| 86 |
+ assignin('base','decode',decode);
|
|
| 85 | 87 |
case 'SOM' |
| 86 | 88 |
disp('not implemented')
|
| 87 | 89 |
|
| ... | ... |
@@ -132,7 +132,7 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]); |
| 132 | 132 |
'Units','normalized',... |
| 133 | 133 |
'Position',[0 0.0 1 0.25]); |
| 134 | 134 |
set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
|
| 135 |
- set(btnRunXSVM,'Enable','off'); |
|
| 135 |
+ set(btnRunXSVM,'Enable','on'); |
|
| 136 | 136 |
|
| 137 | 137 |
btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',... |
| 138 | 138 |
'Units','normalized',... |
| ... | ... |
@@ -0,0 +1,73 @@ |
| 1 |
+%% subject loop |
|
| 2 |
+function decode = xsvm_subject_loop(header,subjectdata,svmopts) |
|
| 3 |
+ |
|
| 4 |
+nSubjects = numel(subjectdata); |
|
| 5 |
+ |
|
| 6 |
+decode = struct; |
|
| 7 |
+decode.decodePerformance = []; |
|
| 8 |
+decode.rawTimeCourse = []; |
|
| 9 |
+ |
|
| 10 |
+disp(sprintf('we have %g subjects. Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.',nSubjects));
|
|
| 11 |
+pause |
|
| 12 |
+ |
|
| 13 |
+ |
|
| 14 |
+timeline = header.timeline; |
|
| 15 |
+ |
|
| 16 |
+for subjectDataID = 1:nSubjects |
|
| 17 |
+ currentSubject = subjectdata{subjectDataID};
|
|
| 18 |
+ timePointArgs.pst = currentSubject.pst; |
|
| 19 |
+ timePointArgs.labelMap = LabelMap(header.classDef.labelCells,header.classDef.conditionCells); |
|
| 20 |
+ timePointArgs.eventList = header.classDef.eventMatrix; |
|
| 21 |
+ |
|
| 22 |
+ timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
|
|
| 23 |
+end |
|
| 24 |
+ |
|
| 25 |
+timeLineStart = timeline.frameShiftStart; |
|
| 26 |
+timeLineEnd = timeline.frameShiftEnd; |
|
| 27 |
+ |
|
| 28 |
+addpath 'libsvm-mat-2.88-1'; |
|
| 29 |
+ |
|
| 30 |
+for timeIndex = 1:timeLineEnd-timeLineStart+1 |
|
| 31 |
+ svm_train_label = []; |
|
| 32 |
+ svm_train_data = []; |
|
| 33 |
+ svm_validation_label = []; |
|
| 34 |
+ svm_validation_data = []; |
|
| 35 |
+ |
|
| 36 |
+ for validationSubjectID = 1:nSubjects |
|
| 37 |
+ for subjectDataID = 1:nSubjects |
|
| 38 |
+ svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
|
|
| 39 |
+ if subjectDataID == validationSubjectID |
|
| 40 |
+ svm_validation_label = svmstruct.svmlabel; |
|
| 41 |
+ svm_validation_data = svmstruct.svmdata; |
|
| 42 |
+ else |
|
| 43 |
+ svm_train_label = [svm_train_label; svmstruct.svmlabel]; |
|
| 44 |
+ svm_train_data = [svm_train_data; svmstruct.svmdata]; |
|
| 45 |
+ end |
|
| 46 |
+ end |
|
| 47 |
+ |
|
| 48 |
+ svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts); |
|
| 49 |
+ |
|
| 50 |
+ [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,''); |
|
| 51 |
+ |
|
| 52 |
+ accuracy(1) |
|
| 53 |
+ |
|
| 54 |
+% decode.(namehelper) = calculateDecodePerformance(header,currentSubject,svmopts); |
|
| 55 |
+% |
|
| 56 |
+% display('... done');
|
|
| 57 |
+% display('restoring warnings');
|
|
| 58 |
+% warning(warning_state); |
|
| 59 |
+% |
|
| 60 |
+% decode.decodePerformance = [decode.decodePerformance decode.(namehelper).decodePerformance]; |
|
| 61 |
+% decode.rawTimeCourse = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse]; |
|
| 62 |
+ |
|
| 63 |
+% assignin('base','decode',decode);
|
|
| 64 |
+ |
|
| 65 |
+ % if RANDOMIZE_DATAPOINTS |
|
| 66 |
+ % rndindex = randperm(length(svmlabel)); |
|
| 67 |
+ % svmdata = svmdata(rndindex,:); |
|
| 68 |
+ % svmlabel = svmlabel(rndindex); |
|
| 69 |
+ % end |
|
| 70 |
+ |
|
| 71 |
+ end |
|
| 72 |
+end |
|
| 73 |
+end |
|
| 0 | 74 |