Christoph Budziszewski commited on 2009-01-25 22:54:27
Zeige 4 geänderte Dateien mit 26 Einfügungen und 214 Löschungen.
git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@114 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... | ... |
@@ -1,100 +0,0 @@ |
1 |
-function classify(varargin) |
|
2 |
- |
|
3 |
- |
|
4 |
- |
|
5 |
-switch nargin |
|
6 |
- case 1 |
|
7 |
- paramModel = varargin{1}; |
|
8 |
- % PROJECT_BASE_PATH = 'D:\Analyze\Stimolos'; |
|
9 |
- PROJECT_BASE_PATH = 'D:\Analyze\Choice\24pilot'; |
|
10 |
- PROJECT_RESULT_PATH = 'results\SPM.mat'; |
|
11 |
- otherwise |
|
12 |
- error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model'); |
|
13 |
-end |
|
14 |
- |
|
15 |
- |
|
16 |
- % common params |
|
17 |
- calculateParams = struct; |
|
18 |
- calculateParams.smoothed = getDouble(paramModel.txtSmoothed); |
|
19 |
- |
|
20 |
- calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart); % -20; |
|
21 |
- calculateParams.frameShiftEnd = getDouble(paramModel.txtFrameShiftEnd); %15; |
|
22 |
- calculateParams.decodeDuration = getDouble(paramModel.txtFrameShiftDur); |
|
23 |
- calculateParams.psthStart = getDouble(paramModel.txtPSTHStart); % -25; |
|
24 |
- calculateParams.psthEnd = getDouble(paramModel.txtPSTHEnd); % 20; |
|
25 |
- calculateParams.baselineStart = getDouble(paramModel.txtBaselineStart); % -22; |
|
26 |
- calculateParams.baselineEnd = getDouble(paramModel.txtBaselineEnd); % -20; |
|
27 |
- |
|
28 |
- calculateParams.svmargs = get(paramModel.txtSVMopts,'String'); |
|
29 |
- calculateParams.sessionList = 1:3; |
|
30 |
- |
|
31 |
- classStruct = parseClassDef(paramModel); |
|
32 |
- |
|
33 |
- |
|
34 |
- calculateParams.labelMap = LabelMap(classStruct.labelCells , classStruct.conditionCells, 'auto'); % LabelMap({'<','>','<+<','>+>','<+>','>+<'},{-2,-1,1,2,3,4}); 0 is autolabel |
|
35 |
- calculateParams.classList = getClasses(calculateParams.labelMap); |
|
36 |
- calculateParams.eventList = classStruct.eventMatrix; %[9,11,13; 10,12,14]; |
|
37 |
-% calculateParams.eventList = getPSTEventMatrix(calculateParams.labelMap); |
|
38 |
- |
|
39 |
- subjectSelection = getSubjectIDString(paramModel); |
|
40 |
- decode = struct; |
|
41 |
- decode.decodePerformance = []; |
|
42 |
- decode.rawTimeCourse = []; |
|
43 |
- |
|
44 |
- for subjectCell = subjectSelection |
|
45 |
- SubjectID = cell2mat(subjectCell); |
|
46 |
- namehelper = strcat('s',SubjectID); %Vars can not start with numbers. |
|
47 |
- |
|
48 |
- display('loading SPM.mat ...'); |
|
49 |
- spm = load(fullfile(PROJECT_BASE_PATH,SubjectID,PROJECT_RESULT_PATH)); |
|
50 |
- display('... done.'); |
|
51 |
- |
|
52 |
- %% calculate |
|
53 |
- calculateParams.(namehelper).des = spm.SPM; |
|
54 |
- calculateParams.(namehelper).voxelList = parseVoxelList(paramModel,SubjectID); |
|
55 |
- assignin('base','calculateParams',calculateParams); |
|
56 |
- |
|
57 |
- display(sprintf('calculating cross-validation performance time-shift for Subject %s. Please Wait. ...',SubjectID)); |
|
58 |
- display('switching off all warnings'); |
|
59 |
- warning_state = warning('off','all'); |
|
60 |
- display('calculating ...'); |
|
61 |
- decode.(namehelper) = calculateDecodePerformance(calculateParams,SubjectID); |
|
62 |
- display('... done'); |
|
63 |
- display('restoring warnings'); |
|
64 |
- warning(warning_state); |
|
65 |
- |
|
66 |
- decode.decodePerformance = [decode.decodePerformance decode.(namehelper).decodePerformance]; |
|
67 |
- decode.rawTimeCourse = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse]; |
|
68 |
- |
|
69 |
- assignin('base','decode',decode); |
|
70 |
- end |
|
71 |
- |
|
72 |
- display('Finished calculations.'); |
|
73 |
- display('Plotting...'); |
|
74 |
- |
|
75 |
- plotParams = struct; |
|
76 |
- plotParams.psthStart = calculateParams.psthStart; |
|
77 |
- plotParams.psthEnd = calculateParams.psthEnd; |
|
78 |
- plotParams.nClasses = length(calculateParams.classList); |
|
79 |
- |
|
80 |
- plotParams.frameShiftStart = calculateParams.frameShiftStart; |
|
81 |
- plotParams.frameShiftEnd = calculateParams.frameShiftEnd; |
|
82 |
- plotParams.decodePerformance = decode.decodePerformance; |
|
83 |
- plotParams.rawTimeCourse = decode.rawTimeCourse; |
|
84 |
- |
|
85 |
- if numel(subjectSelection) == 1 |
|
86 |
- plotParams.SubjectID = SubjectID; |
|
87 |
- else |
|
88 |
- plotParams.SubjectID = 'Multiple'; |
|
89 |
- end |
|
90 |
- |
|
91 |
- plotParams.smoothed = boolToYesNoString(calculateParams.smoothed); |
|
92 |
- |
|
93 |
- |
|
94 |
- assignin('base','plotParams',plotParams); |
|
95 |
-% plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse); |
|
96 |
- plotDecodePerformance(plotParams); |
|
97 |
- |
|
98 |
- display('all done.'); |
|
99 |
- |
|
100 |
- end |
|
101 | 0 |
\ No newline at end of file |
... | ... |
@@ -1,103 +0,0 @@ |
1 |
-function plotDecodePerformance(varargin) |
|
2 |
-% plotDecodePerformance(timeline,decodePerformance,nClasses,rawData) |
|
3 |
- |
|
4 |
-PSTH_AXIS_MIN = -1; |
|
5 |
-PSTH_AXIS_MAX = 1; |
|
6 |
- |
|
7 |
-switch nargin |
|
8 |
- |
|
9 |
- case 1 |
|
10 |
- inputStruct = cell2mat(varargin(1)); |
|
11 |
- |
|
12 |
- psthStart = inputStruct.psthStart; |
|
13 |
- psthEnd = inputStruct.psthEnd; |
|
14 |
- nClasses = inputStruct.nClasses; |
|
15 |
- decodePerformance = inputStruct.decodePerformance; |
|
16 |
- frameStart = inputStruct.frameShiftStart; |
|
17 |
- frameEnd = inputStruct.frameShiftEnd; |
|
18 |
- psth = inputStruct.rawTimeCourse; |
|
19 |
- SubjectID = inputStruct.SubjectID; |
|
20 |
- smoothed = inputStruct.smoothed; |
|
21 |
- |
|
22 |
- otherwise |
|
23 |
- error('spmtoolbox:SVMCrossVal:plotDecodePerformance:WrongArgument','Wrong Arguments'); |
|
24 |
-end |
|
25 |
- |
|
26 |
- f = figure; |
|
27 |
- subplot(2,1,1); |
|
28 |
- hold on; |
|
29 |
- for voxel = 1:size(psth,2) |
|
30 |
- for label = 1:size(psth{voxel},2) |
|
31 |
- psthData = []; |
|
32 |
- for timepoint = 1:size(psth{voxel}{label},2) |
|
33 |
- psthData = nanmean(psth{voxel}{label}); |
|
34 |
- end |
|
35 |
- plot(psthStart:psthEnd,psthData,[colorChooser(voxel), lineStyleChooser(label)]); |
|
36 |
- end |
|
37 |
- end |
|
38 |
- axis([psthStart psthEnd PSTH_AXIS_MIN PSTH_AXIS_MAX]) |
|
39 |
- hold off |
|
40 |
- |
|
41 |
- subplot(2,1,2) |
|
42 |
- hold on; |
|
43 |
- |
|
44 |
- chanceLevel = 100/nClasses; |
|
45 |
- goodPredictionLevel = chanceLevel*1.5; |
|
46 |
- plot([psthStart psthEnd],[chanceLevel chanceLevel],'r'); |
|
47 |
- plot([psthStart psthEnd],[goodPredictionLevel goodPredictionLevel],'g'); |
|
48 |
- axis([psthStart psthEnd 0 100]) |
|
49 |
- |
|
50 |
- plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b'); |
|
51 |
- PLOT_STD_ERR = 1; |
|
52 |
- PLOT_CLASS_PERFORMANCE = 1; |
|
53 |
- if PLOT_STD_ERR |
|
54 |
- se = myStdErr(decodePerformance,2); |
|
55 |
- plot(frameStart:frameEnd, mean(decodePerformance,2)+se ,'b:'); |
|
56 |
- plot(frameStart:frameEnd, mean(decodePerformance,2)-se ,'b:'); |
|
57 |
- end |
|
58 |
- if PLOT_CLASS_PERFORMANCE |
|
59 |
- for c = 1:nClasses |
|
60 |
- plot(frameStart:frameEnd, decodePerformance() ,[colorChooser(c+2) '-']); |
|
61 |
- end |
|
62 |
- end |
|
63 |
- |
|
64 |
- |
|
65 |
- hold off; |
|
66 |
- |
|
67 |
- title = sprintf('Subject %s, over %g voxel, smoothed %s',SubjectID,size(psth,2),smoothed); |
|
68 |
- set(f,'Name',title); |
|
69 |
- display(sprintf('%s',title)); |
|
70 |
- |
|
71 |
- |
|
72 |
- |
|
73 |
-end |
|
74 |
- |
|
75 |
- |
|
76 |
-function color = colorChooser(n) |
|
77 |
- switch (mod(n,8)) |
|
78 |
- case 0 |
|
79 |
- color = 'y'; |
|
80 |
- case 1 |
|
81 |
- color = 'r'; |
|
82 |
- case 2 |
|
83 |
- color = 'b'; |
|
84 |
- case 3 |
|
85 |
- color = 'g'; |
|
86 |
- otherwise |
|
87 |
- color = 'k'; |
|
88 |
- end |
|
89 |
-end |
|
90 |
- |
|
91 |
-function style = lineStyleChooser(n) |
|
92 |
-switch(mod(n,4)) |
|
93 |
- case 0 |
|
94 |
- style = '--'; |
|
95 |
- case 1 |
|
96 |
- style = '-'; |
|
97 |
- case 2 |
|
98 |
- style = ':'; |
|
99 |
- case 3 |
|
100 |
- style = '-.'; |
|
101 |
-end |
|
102 |
-end |
|
103 |
- |
... | ... |
@@ -3,9 +3,10 @@ function outputStruct = calculateDecodePerformance(inputStruct,SubjectID) |
3 | 3 |
|
4 | 4 |
addpath 'libsvm-mat-2.88-1'; |
5 | 5 |
|
6 |
-METHOD = 'single subject SVM'; |
|
7 |
-% METHOD = 'cross subject SVM'; |
|
8 |
-% METHOD = 'SOM'; |
|
6 |
+METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF; |
|
7 |
+METHOD = inputStruct.CROSSVAL_METHOD; |
|
8 |
+ |
|
9 |
+RANDOMIZE_DATAPOINTS = inputStruct.RANDOMIZE; |
|
9 | 10 |
|
10 | 11 |
outputStruct = struct; |
11 | 12 |
|
... | ... |
@@ -49,7 +50,7 @@ maxPerformance = -inf; |
49 | 50 |
|
50 | 51 |
decodePerformance = []; |
51 | 52 |
for index = 1:timeLineEnd-timeLineStart+1 |
52 |
- RANDOMIZE_DATAPOINTS = 0; |
|
53 |
+ |
|
53 | 54 |
svmdata = timePointMatrix{index}(:,2:size(timePointMatrix{index},2)); |
54 | 55 |
svmlabel = timePointMatrix{index}(:,1); |
55 | 56 |
|
... | ... |
@@ -59,16 +60,18 @@ maxPerformance = -inf; |
59 | 60 |
svmlabel = svmlabel(rndindex); |
60 | 61 |
end |
61 | 62 |
|
62 |
- SVM_METHOD = 'som training' |
|
63 |
- switch SVM_METHOD; |
|
64 |
- case 'libsvm crossval' |
|
63 |
+ switch METHOD; |
|
64 |
+ case CROSSVAL_METHOD_DEF.svmcrossval |
|
65 |
+ |
|
65 | 66 |
performance = svmtrain(svmlabel, svmdata, svmargs); |
66 | 67 |
|
67 | 68 |
minPerformance = min(minPerformance,performance); |
68 | 69 |
maxPerformance = max(maxPerformance,performance); |
69 | 70 |
|
70 | 71 |
decodePerformance = [decodePerformance; performance]; |
71 |
- case 'class performance' |
|
72 |
+ |
|
73 |
+ case CROSSVAL_METHOD_DEF.classPerformance |
|
74 |
+ |
|
72 | 75 |
newsvmopt = killCrossvalOpt(svmargs); |
73 | 76 |
|
74 | 77 |
model = svmtrain(svmlabel,svmdata,newsvmopt); |
... | ... |
@@ -84,7 +87,8 @@ maxPerformance = -inf; |
84 | 87 |
end |
85 | 88 |
decodePerformance = [decodePerformance; classperformance]; |
86 | 89 |
|
87 |
- case 'som training' |
|
90 |
+ case CROSSVAL_METHOD_DEF.somTraining |
|
91 |
+ |
|
88 | 92 |
display('SOM TRAINING'); |
89 | 93 |
addpath 'somtoolbox2'; |
90 | 94 |
sD = som_data_struct(svmdata,'label',num2str(svmlabel)); |
... | ... |
@@ -93,7 +97,7 @@ maxPerformance = -inf; |
93 | 97 |
|
94 | 98 |
assignin('base','sD',sD); |
95 | 99 |
assignin('base','sM',sM); |
96 |
- |
|
100 |
+ display('type ''figure'' before visualisation'); |
|
97 | 101 |
end |
98 | 102 |
|
99 | 103 |
end |
... | ... |
@@ -1,5 +1,10 @@ |
1 | 1 |
function classify(varargin) |
2 | 2 |
|
3 |
+CROSSVAL_METHOD_DEF.svmcrossval = 'svm crossval'; |
|
4 |
+CROSSVAL_METHOD_DEF.classPerformance = 'svm class performance'; |
|
5 |
+CROSSVAL_METHOD_DEF.crossSubject = 'svm cross subject testing'; |
|
6 |
+CROSSVAL_METHOD_DEF.somTraining = 'som Training'; |
|
7 |
+ |
|
3 | 8 |
switch nargin |
4 | 9 |
case 1 |
5 | 10 |
paramModel = varargin{1}; |
... | ... |
@@ -12,6 +17,12 @@ end |
12 | 17 |
|
13 | 18 |
% common params |
14 | 19 |
calculateParams = struct; |
20 |
+ |
|
21 |
+ calculateParams.CROSSVAL_METHOD_DEF = CROSSVAL_METHOD_DEF; |
|
22 |
+ calculateParams.CROSSVAL_METHOD = CROSSVAL_METHOD_DEF.svmcrossval; |
|
23 |
+ |
|
24 |
+ calculateParams.RANDOMIZE = 0; |
|
25 |
+ |
|
15 | 26 |
calculateParams.smoothed = getChkValue(paramModel.chkSmoothed); |
16 | 27 |
|
17 | 28 |
calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart); % -20; |
... | ... |
@@ -46,7 +57,7 @@ end |
46 | 57 |
spm = load(fullfile(PROJECT_BASE_PATH,SubjectID,PROJECT_RESULT_PATH)); |
47 | 58 |
display('... done.'); |
48 | 59 |
|
49 |
- %% calculate |
|
60 |
+ % calculate |
|
50 | 61 |
calculateParams.(namehelper).des = spm.SPM; |
51 | 62 |
calculateParams.(namehelper).voxelList = parseVoxelList(paramModel,SubjectID); |
52 | 63 |
assignin('base','calculateParams',calculateParams); |
53 | 64 |