Browse code

multi-subject support

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@96 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on05/01/2009 18:25:16
Showing8 changed files
... ...
@@ -1,17 +1,17 @@
1 1
 % function [decodePerformance rawTimecourse ] = calculateDecodePerformance(des,timeLineStart, timeLineEnd, decodeDuration, svmargs, conditionList, sessionList, voxelList, classList, labelMap,normalize)
2
-function outputStruct = calculateDecodePerformance(inputStruct)
2
+function outputStruct = calculateDecodePerformance(inputStruct,SubjectID)
3 3
 
4 4
 addpath 'libsvm-mat-2.88-1';
5 5
 
6 6
 outputStruct = struct;
7 7
 
8
-des             = inputStruct.des;
8
+des             = inputStruct.(SubjectID).des;
9 9
 timeLineStart   = inputStruct.frameShiftStart;
10 10
 timeLineEnd     = inputStruct.frameShiftEnd;
11 11
 decodeDuration  = inputStruct.decodeDuration;
12 12
 svmargs         = inputStruct.svmargs;
13 13
 sessionList     = inputStruct.sessionList;
14
-voxelList       = inputStruct.voxelList;
14
+voxelList       = inputStruct.(SubjectID).voxelList;
15 15
 % classList       = inputStruct.classList;
16 16
 % labelMap        = inputStruct.labelMap;
17 17
 smoothed       = inputStruct.smoothed;
... ...
@@ -1,13 +1,12 @@
1 1
 function classify(varargin)
2 2
 
3
+PROJECT_BASE_PATH = 'D:\Analyze\Choice\24pilot';
4
+PROJECT_RESULT_PATH = 'results\SPM.mat';
5
+
3 6
 switch nargin
4
-    case 0
5
-        action = 'decode';
6
-        SubjectID = 'JZ006';
7 7
     case 1
8 8
         action = 'decode';
9 9
         paramModel = varargin{1};
10
-        SubjectID = getSubjectIDString(paramModel);
11 10
     otherwise
12 11
         error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model');
13 12
 end
... ...
@@ -20,26 +19,11 @@ end
20 19
       
21 20
     case 'decode'
22 21
         
23
-        display('loading SPM.mat');
24
-%         SubjectID = 'JZ006';
25
-%         SubjectID = 'AI020';
26
-%         SubjectID = 'HG027';
27
-        spm = load(fullfile('D:\Analyze\Choice\24pilot',SubjectID,'results\SPM.mat'));
28
-
29
-        display('done.');
30 22
         
31
-        params = struct;
32
-        params.nClasses = 2;
33
-
34
-%         assignin('base','params',params);
35
-        %% calculate
36
-        display('calculating cross-validation performance time-shift');
23
+        % common params
37 24
         calculateParams  = struct;
38
-        
39
-        calculateParams.des             = spm.SPM;
40
-        
41 25
         calculateParams.smoothed        = getDouble(paramModel.txtSmoothed);
42
-        
26
+
43 27
         calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart);  % -20;
44 28
         calculateParams.frameShiftEnd   = getDouble(paramModel.txtFrameShiftEnd); %15;
45 29
         calculateParams.decodeDuration  = getDouble(paramModel.txtFrameShiftDur);
... ...
@@ -48,47 +32,85 @@ end
48 32
         calculateParams.baselineStart   = getDouble(paramModel.txtBaselineStart); % -22;
49 33
         calculateParams.baselineEnd     = getDouble(paramModel.txtBaselineEnd); % -20;
50 34
 
51
-        calculateParams.voxelList       = parseVoxelList(paramModel);
52
-
53 35
         calculateParams.svmargs         = get(paramModel.txtSVMopts,'String');
54 36
         calculateParams.sessionList     = 1:3;
55 37
 
56 38
         classStruct = parseClassDef(paramModel);
57
-%         assignin('base','classStruct',classStruct);
58 39
         
59 40
         calculateParams.classList       = classStruct.label; %{'<','>'};
60 41
         calculateParams.labelMap        = LabelMap(classStruct.label , classStruct.value); % LabelMap({'<','>','<+<','>+>','<+>','>+<'},{-2,-1,1,2,3,4});
61 42
         calculateParams.eventList       = classStruct.event; %[9,11,13; 10,12,14];
43
+
44
+        params = struct;
45
+        params.nClasses = 2;
62 46
         
63 47
         
64
-        assignin('base','calculateParams',calculateParams);
65
-        
66
-%         [decodeTable rawTimeCourse] = calculateDecodePerformance(spm,params.frameShiftStart,params.frameShiftEnd,params.xTimeWindow,params.svmopts,1:4,params.sessionList,params.voxelList,params.classList,params.labelMap,params.normalize);
67
-        decode = calculateDecodePerformance(calculateParams);
68
-        display(sprintf('Min CrossVal Accuracy: %g%% \t Max CrossVal Accuracy: %g%%',decode.minPerformance,decode.maxPerformance));
48
+        subjectSelection = getSubjectIDString(paramModel);
49
+        decode = struct;
50
+        decode.decodePerformance = [];
51
+        decode.rawTimeCourse     = [];
69 52
         
70
-        assignin('base','decode',decode);
53
+        for subjectCell = subjectSelection
54
+            SubjectID = cell2mat(subjectCell);
55
+
56
+            display('loading SPM.mat');
57
+            spm = load(fullfile(PROJECT_BASE_PATH,SubjectID,PROJECT_RESULT_PATH));
58
+%             display('done.');
59
+
60
+            %% calculate
61
+            display(sprintf('calculating cross-validation performance time-shift for Subject %s',SubjectID));
62
+
63
+            calculateParams.(SubjectID).des             = spm.SPM;
64
+            calculateParams.(SubjectID).voxelList       = parseVoxelList(paramModel,SubjectID);
65
+            
66
+            assignin('base','calculateParams',calculateParams);
67
+            
68
+    %         [decodeTable rawTimeCourse] = calculateDecodePerformance(spm,params.frameShiftStart,params.frameShiftEnd,params.xTimeWindow,params.svmopts,1:4,params.sessionList,params.voxelList,params.classList,params.labelMap,params.normalize);
69
+            display('switching off all warnings');
70
+            warning_state = warning('off','all');
71
+            
72
+            decode.(SubjectID) = calculateDecodePerformance(calculateParams,SubjectID);
73
+            display('restoring warnings');
74
+            warning(warning_state);
75
+            
76
+            decode.decodePerformance = [decode.decodePerformance decode.(SubjectID).decodePerformance];
77
+            decode.rawTimeCourse = [decode.rawTimeCourse decode.(SubjectID).rawTimeCourse];
78
+  
79
+            
80
+%             display(sprintf('Min CrossVal Accuracy: %g%% \t Max CrossVal Accuracy: %g%%',decode.minPerformance,decode.maxPerformance));
81
+
82
+            assignin('base','decode',decode);
83
+        end
71 84
 
72 85
         display('Finished calculations.');
73
-        display('Plotting.');
86
+        display('Plotting...');
74 87
 
75 88
         plotParams = struct;
76 89
         plotParams.psthStart = calculateParams.psthStart;
77 90
         plotParams.psthEnd   = calculateParams.psthEnd;
78 91
         plotParams.nClasses  = length(calculateParams.classList);
92
+        
79 93
         plotParams.frameShiftStart   = calculateParams.frameShiftStart;
80 94
         plotParams.frameShiftEnd     = calculateParams.frameShiftEnd;
81 95
         plotParams.decodePerformance = decode.decodePerformance;
82 96
         plotParams.rawTimeCourse     = decode.rawTimeCourse;
83
-        plotParams.SubjectID         = SubjectID;
84
-        plotParams.smoothed          = calculateParams.smoothed;
85 97
         
98
+        if numel(subjectSelection) == 1
99
+          plotParams.SubjectID         = SubjectID;
100
+        else
101
+          plotParams.SubjectID         = 'Multiple';
102
+        end
103
+
104
+        plotParams.smoothed          = boolToYesNoString(calculateParams.smoothed);
105
+         
106
+
86 107
         assignin('base','plotParams',plotParams);
87 108
 %         plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse);
88 109
         plotDecodePerformance(plotParams);
110
+            
111
+        display('done.');
89 112
 
90 113
     otherwise
91
-        display('give action command: clear decode');
114
+        display('give action command: clear, decode');
92 115
     end
93
-    
94
-end
95 116
\ No newline at end of file
117
+    end
96 118
\ No newline at end of file
... ...
@@ -1,32 +1,26 @@
1 1
 function plotDecodePerformance(varargin)
2 2
 % plotDecodePerformance(timeline,decodePerformance,nClasses,rawData)
3 3
 
4
-if(nargin==1)
5
-    inputStruct       = cell2mat(varargin(1));
6
-    
7
-    psthStart         = inputStruct.psthStart;
8
-    psthEnd           = inputStruct.psthEnd;
9
-    nClasses          = inputStruct.nClasses;
10
-    decodePerformance = inputStruct.decodePerformance;
11
-    frameStart        = inputStruct.frameShiftStart;
12
-    frameEnd          = inputStruct.frameShiftEnd;
13
-    psth              = inputStruct.rawTimeCourse;
14
-    SubjectID         = inputStruct.SubjectID;
15
-    smoothed          = inputStruct.smoothed;
16
-    
17
-    
18
-elseif( nargin == 7)
4
+PSTH_AXIS_MIN = -2;
5
+PSTH_AXIS_MAX = 5;
6
+
7
+switch nargin
19 8
     
20
-    psthStart   = cell2mat(varargin(1));
21
-    psthEnd     = cell2mat(varargin(2));
22
-    nClasses    = cell2mat(varargin(3));
23
-    decodePerformance = cell2mat(varargin(4));
24
-    frameStart  = cell2mat(varargin(5));
25
-    frameEnd    = cell2mat(varargin(6));
26
-    psth        = varargin(7);
27
-    psth        = psth{1};
28
-    SubjectID   = '';
29
-    smoothed    = '';
9
+    case 1
10
+        inputStruct       = cell2mat(varargin(1));
11
+
12
+        psthStart         = inputStruct.psthStart;
13
+        psthEnd           = inputStruct.psthEnd;
14
+        nClasses          = inputStruct.nClasses;
15
+        decodePerformance = inputStruct.decodePerformance;
16
+        frameStart        = inputStruct.frameShiftStart;
17
+        frameEnd          = inputStruct.frameShiftEnd;
18
+        psth              = inputStruct.rawTimeCourse;
19
+        SubjectID         = inputStruct.SubjectID;
20
+        smoothed          = inputStruct.smoothed;
21
+
22
+    otherwise
23
+        error('spmtoolbox:SVMCrossVal:plotDecodePerformance:WrongArgument','Wrong Arguments');
30 24
 end
31 25
 
32 26
     f = figure;
... ...
@@ -41,7 +35,7 @@ end
41 35
               plot(psthStart:psthEnd,psthData,[colorChooser(voxel), lineStyleChooser(label)]);
42 36
           end
43 37
       end
44
-    axis([psthStart psthEnd -2 5])
38
+    axis([psthStart psthEnd PSTH_AXIS_MIN PSTH_AXIS_MAX])
45 39
     hold off
46 40
     
47 41
     subplot(2,1,2)    
... ...
@@ -53,11 +47,15 @@ end
53 47
     plot([psthStart psthEnd],[goodPredictionLevel goodPredictionLevel],'g');
54 48
     axis([psthStart psthEnd 0 100])
55 49
     
56
-    plot(frameStart:frameEnd, decodePerformance ,'b');
57
-
50
+    plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b');
51
+    se = myStdErr(decodePerformance,2);
52
+    plot(frameStart:frameEnd, mean(decodePerformance,2)+se ,'b:');
53
+    plot(frameStart:frameEnd, mean(decodePerformance,2)-se ,'b:');
54
+    
55
+    
58 56
     hold off;
59 57
 
60
-    title = sprintf('Subject %s, over %g voxel, smoothed %g',SubjectID,size(psth,2),smoothed);
58
+    title = sprintf('Subject %s, over %g voxel, smoothed %s',SubjectID,size(psth,2),smoothed);
61 59
     set(f,'Name',title);
62 60
     display(sprintf('%s',title));
63 61
 
... ...
@@ -65,6 +63,7 @@ end
65 63
 
66 64
 end
67 65
 
66
+
68 67
 function color = colorChooser(n)
69 68
     switch (mod(n,8))
70 69
     case 0
71 70
new file mode 100644
... ...
@@ -0,0 +1,7 @@
1
+function s = boolToYesNoString(bool)
2
+    if bool
3
+        s = 'yes';
4
+    else
5
+        s = 'no';
6
+    end
7
+end
0 8
\ No newline at end of file
... ...
@@ -1,9 +1,12 @@
1
-function s = getSubjectIDString(model)
1
+function s = getSubjectIDString(model) % TODO rename to getSubjectIDStringCellArray
2 2
     tmp_sidx =  get(model.subjectSelector,'Value');
3 3
     tmp_cellList = getSubjectCellList(model.subjectMap);
4
-    s = cell2mat(tmp_cellList(tmp_sidx));
5
-    if ~ischar(s)
6
-        error('spmtoolbox:SVMCrossVal:getSubjectID:NoString','convert error');
4
+
5
+    if size(tmp_sidx,2) >= 1
6
+        s = tmp_cellList(tmp_sidx);
7
+    else
8
+        error('spmtoolbox:SVMCrossVal:getSubjectID:NoSelection','no subject selected');
7 9
     end
10
+
8 11
 end
9 12
  
10 13
\ No newline at end of file
11 14
new file mode 100644
... ...
@@ -0,0 +1,3 @@
1
+function se =myStdErr(args,dim)
2
+    se = std(args,0,dim)/sqrt(size(args,dim));
3
+end
0 4
\ No newline at end of file
... ...
@@ -1,13 +1,19 @@
1
-function voxelList = parseVoxelList(model)
1
+function voxelList = parseVoxelList(model,multisubjectid)
2 2
         voxelList = [];
3 3
 
4 4
         %<ROI Name>,<ROI Modifier>;
5 5
         txt = get(model.txtVoxelDef,'String');
6 6
         map = model.subjectMap;
7
-        SubjectID = getSubjectIDString(model);
8
-        
9
-%         assignin('base','txt',txt);
10
-        
7
+
8
+        switch nargin
9
+            case 1
10
+                SubjectID = getSubjectIDString(model);
11
+            case 2
12
+                SubjectID = multisubjectid;
13
+            otherwise
14
+                error('spmtoolbox:SVMCrossVal:parseVoxelList:nargin','wrong number of arguments given');
15
+        end
16
+         
11 17
         rows  = size(txt,1);
12 18
         
13 19
         for i = 1:rows 
... ...
@@ -19,26 +25,6 @@ function voxelList = parseVoxelList(model)
19 25
             roimod = parseModifier(line);
20 26
             voxelList = [voxelList; getCoordinate(map,SubjectID,roi)+eval(roimod)];
21 27
         end
22
-
23
-%         voxelList  = [...
24
-%                       getCoordinate(map,SubjectID,'SPL l')+[0,0,0];...
25
-%                           getCoordinate(map,SubjectID,'SPL l')+[1,0,0];...
26
-%                           getCoordinate(map,SubjectID,'SPL l')+[-1,0,0];...
27
-%                           getCoordinate(map,SubjectID,'SPL l')+[0,1,0];...
28
-%                           getCoordinate(map,SubjectID,'SPL l')+[0,-1,0];...
29
-%                           getCoordinate(map,SubjectID,'SPL l')+[0,0,1];...
30
-%                           getCoordinate(map,SubjectID,'SPL l')+[0,0,-1];...
31
-%                       getCoordinate(map,SubjectID,'SPL r')+[0,0,0];...
32
-%                           getCoordinate(map,SubjectID,'SPL r')+[1,0,0];...
33
-%                           getCoordinate(map,SubjectID,'SPL r')+[-1,0,0];...
34
-%                           getCoordinate(map,SubjectID,'SPL r')+[0,1,0];...
35
-%                           getCoordinate(map,SubjectID,'SPL r')+[0,-1,0];...
36
-%                           getCoordinate(map,SubjectID,'SPL r')+[0,0,1];...
37
-%                           getCoordinate(map,SubjectID,'SPL r')+[0,0,-1];...
38
-%                       getCoordinate(map,SubjectID,'M1 r')+[0,0,0];...
39
-%                       getCoordinate(map,SubjectID,'M1 l')+[0,0,0];...
40
-%                       ];
41
-
42 28
 end
43 29
 
44 30
 function roi = parseROIName(line)
... ...
@@ -19,10 +19,10 @@ function spm_SVMCrossVal
19 19
     nElementRows = 24;
20 20
     optionLineHeight = 1.0/nElementRows;
21 21
     controlElementHeight=optionLineHeight*(1.0/1.5)*frameHeight;
22
-    pSubject     = uipanel(frame,'Title','Subject',          'Position',[0 optionLineHeight*22 frameWidth optionLineHeight*2]);
23
-    pPSTH        = uipanel(frame,'Title','PSTH Options',     'Position',[0 optionLineHeight*17 frameWidth optionLineHeight*5]); 
24
-    pCLASS       = uipanel(frame,'Title','Class Definitions','Position',[0 optionLineHeight*11 frameWidth optionLineHeight*6]); 
25
-    pVOXEL       = uipanel(frame,'Title','Voxel Selector',   'Position',[0 optionLineHeight*3  frameWidth optionLineHeight*8]); 
22
+    pSubject     = uipanel(frame,'Title','Subject',          'Position',[0 optionLineHeight*19 frameWidth optionLineHeight*5]);
23
+    pPSTH        = uipanel(frame,'Title','PSTH Options',     'Position',[0 optionLineHeight*14 frameWidth optionLineHeight*5]); 
24
+    pCLASS       = uipanel(frame,'Title','Class Definitions','Position',[0 optionLineHeight*9  frameWidth optionLineHeight*5]); 
25
+    pVOXEL       = uipanel(frame,'Title','Voxel Selector',   'Position',[0 optionLineHeight*3  frameWidth optionLineHeight*6]); 
26 26
     pSVM         = uipanel(frame,'Title','SVM Options',      'Position',[0 optionLineHeight*1  frameWidth optionLineHeight*2]); 
27 27
     btnRunButton = uicontrol(frame,'Tag','run','String','run decode-performance visualiser','Position',[2 optionLineHeight*0  frameWidth controlElementHeight*1.6]);
28 28
 
... ...
@@ -31,10 +31,11 @@ function spm_SVMCrossVal
31 31
     firstColumn  =  5.00;
32 32
     firstRow     =  1.00 * controlElementHeight;
33 33
     
34
-    model.subjectSelector = uicontrol(pSubject,'Style','popupmenu',...
34
+    model.subjectSelector = uicontrol(pSubject,'Style','listbox',...
35
+                    'Min',1, 'Max',3,...
35 36
                     'String',getSubjectCellList(model.subjectMap),...
36
-                    'Value',5,...
37
-                    'Position',[firstColumn firstRow 0.66*frameWidth controlElementHeight]);
37
+                    'Value',5,...  % default selected item
38
+                    'Position',[firstColumn firstRow 0.66*frameWidth controlElementHeight*6]);
38 39
     set(model.subjectSelector,'BackgroundColor','w');
39 40
     
40 41
     model.txtSmoothed = createTextField(pSubject,[0.68*frameWidth firstRow  0.25*frameWidth controlElementHeight],'0');
... ...
@@ -158,6 +159,13 @@ function cbRunSVM(src,evnt,model)
158 159
 end
159 160
 
160 161
 
162
+function save(model)
163
+
164
+end
165
+
166
+function model = load()
167
+end
168
+
161 169
 
162 170
 
163 171