Browse code

added 2nd classification method: no crossval, but class performance checks. may not plot correctly with multisubject

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@111 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 19/01/2009 18:28:59
Showing 5 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,36 @@
1
+function timePointMatrix = buildTimePointMatrix(argStruct)
2
+
3
+pst             = argStruct.pst;
4
+
5
+timeLineStart   = argStruct.timeLineStart;
6
+timeLineEnd     = argStruct.timeLineEnd;
7
+globalStart     = argStruct.globalStart;
8
+globalEnd       = argStruct.globalEnd;
9
+decodeDuration  = argStruct.decodeDuration;
10
+eventList       = argStruct.eventList;
11
+
12
+labelMap        = argStruct.labelMap;
13
+
14
+timePointMatrix = {};
15
+%%% build timepoint Matrix
16
+for timeShift   = timeLineStart:1:timeLineEnd
17
+    % center timepoint && relative shift
18
+    frameStart  = floor(-globalStart+1+timeShift - 0.5*decodeDuration);
19
+    frameEnd    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
20
+
21
+    %build svm inputmatrix
22
+    index = timeShift-timeLineStart+1; %Bad 1-indexing :-(
23
+    timePointMatrix{index} =[];
24
+    anyvoxel = 1;
25
+    for pstConditionGroup = 1:size(pst{1,anyvoxel},2)
26
+        for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point
27
+            row = getSVMLabel(labelMap,eventList(pstConditionGroup,1));
28
+            for voxel = 1:size(pst,2)
29
+                row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,[value,value,...],[value,value,...]...
30
+            end
31
+            timePointMatrix{index}  = [timePointMatrix{index}; row];
32
+        end
33
+    end
34
+end
35
+
36
+end
0 37
\ No newline at end of file
... ...
@@ -3,7 +3,9 @@ function outputStruct = calculateDecodePerformance(inputStruct,SubjectID)
3 3
 
4 4
 addpath 'libsvm-mat-2.88-1';
5 5
 
6
-SINGLE = 1;
6
+METHOD = 'single subject SVM';
7
+% METHOD = 'cross subject SVM';
8
+% METHOD = 'SOM';
7 9
 
8 10
 outputStruct = struct;
9 11
 
... ...
@@ -17,70 +19,73 @@ sessionList     = inputStruct.sessionList;
17 19
 voxelList       = inputStruct.(namehelper).voxelList;
18 20
 % classList       = inputStruct.classList;
19 21
 % labelMap        = inputStruct.labelMap;
20
-smoothed       = inputStruct.smoothed;
22
+smoothed        = inputStruct.smoothed;
21 23
 globalStart     = inputStruct.psthStart;
22 24
 globalEnd       = inputStruct.psthEnd;
23 25
 baselineStart   = inputStruct.baselineStart;
24 26
 baselineEnd     = inputStruct.baselineEnd;
25 27
 eventList       = inputStruct.eventList;
26
-labelMap        = inputStruct.labelMap;
27
-
28 28
 
29 29
 minPerformance = inf;
30 30
 maxPerformance = -inf;
31
-
32
-
33
-        
34
-        %Pro Voxel PSTH TIMELINE berechnen.
35
-        %   timeshift mit pst-timeline durchf�hren.
36
-        % psth-timeline -25 bis +15 zu RES Onset.
37
-        
38
-%         eventList       = [9,11,13;10,12,14];
39
-%         globalStart     = -25;
40
-%         globalEnd       = 15;
41
-%         baselineStart   = -22;
42
-%         baselineEnd     = -20;
43
-        
44 31
         
32
+        %% ERSETZEN DURCH ROI-IMAGE!
45 33
         for voxel = 1:size(voxelList,1)  % [[x;x],[y;y],[z;z]]
46
-                extr  = calculateImageData(voxelList(voxel,:),des,smoothed);
47
-                rawdata=cell2mat({extr.mean}); % Raw Data
34
+                extr        = calculateImageData(voxelList(voxel,:),des,smoothed); 
35
+                rawdata     = cell2mat({extr.mean}); % Raw Data
48 36
                 pst{voxel}  = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,rawdata,sessionList);
49 37
         end
50 38
         
51
-        decodePerformance = [];
52
-
53
-        for timeShift   = timeLineStart:1:timeLineEnd
54
-            frameStart  = floor(-globalStart+1+timeShift - 0.5*decodeDuration);
55
-            frameEnd    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
56
-            
57
-            tmp =[];
58
-            anyvoxel = 1;
59
-            for pstConditionGroup = 1:size(pst{1,anyvoxel},2) 
60
-                for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point
61
-                  row = getSVMLabel(labelMap,eventList(pstConditionGroup,1));
62
-                    for voxel = 1:size(pst,2)
63
-                        row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,value,value
64
-                    end
65
-                tmp  = [tmp; row];
66
-                end
67
-            end 
39
+        timePointArgs.pst = pst;
40
+        timePointArgs.timeLineStart = timeLineStart;
41
+        timePointArgs.timeLineEnd   = timeLineEnd;
42
+        timePointArgs.globalStart   = globalStart;
43
+        timePointArgs.globalEnd     = globalEnd;
44
+        timePointArgs.decodeDuration= decodeDuration;
45
+        timePointArgs.labelMap      = inputStruct.labelMap;
46
+        timePointArgs.eventList     = eventList;
68 47
         
69
-            svmdata      = tmp(:,2:size(tmp,2));
70
-            svmlabel     = tmp(:,1);
71
-            
72
-%             RANDOMIZE INPUT
73
-%             rndindex  = randperm(length(svmlabel));
74
-%             svmdata   = svmdata(rndindex,:);
75
-%             svmlabel  = svmlabel(rndindex);
48
+        timePointMatrix = buildTimePointMatrix(timePointArgs);
49
+
50
+        decodePerformance = [];
51
+        for index = 1:timeLineEnd-timeLineStart+1
52
+            RANDOMIZE_DATAPOINTS = 0;
53
+            svmdata      = timePointMatrix{index}(:,2:size(timePointMatrix{index},2));
54
+            svmlabel     = timePointMatrix{index}(:,1);
76 55
             
77
-            if SINGLE
78
-                performance  = svmtrain(svmlabel, svmdata, svmargs);
56
+            if RANDOMIZE_DATAPOINTS
57
+                rndindex  = randperm(length(svmlabel));
58
+                svmdata   = svmdata(rndindex,:);
59
+                svmlabel  = svmlabel(rndindex);
60
+            end
61
+
62
+            SVM_METHOD = 2;
63
+            switch SVM_METHOD;
64
+                case 1
65
+                    performance  = svmtrain(svmlabel, svmdata, svmargs);
79 66
 
80
-                minPerformance = min(minPerformance,performance);
81
-                maxPerformance = max(maxPerformance,performance);
67
+                    minPerformance = min(minPerformance,performance);
68
+                    maxPerformance = max(maxPerformance,performance);
82 69
 
83
-                decodePerformance = [decodePerformance; performance];
70
+                    decodePerformance = [decodePerformance; performance];
71
+                case 2
72
+                    newsvmopt = killCrossvalOpt(svmargs);
73
+                    
74
+                    model = svmtrain(svmlabel,svmdata,newsvmopt);
75
+                    classperformance = [];
76
+                    for class = unique(svmlabel)';
77
+%                         assignin('base','uniquelabel',unique(svmlabel));
78
+%                         assignin('base','class',class);
79
+%                         assignin('base','svmlabel',svmlabel);
80
+                        filterindex = find(class == svmlabel);
81
+                        testing_label = svmlabel(filterindex)
82
+                        testing_data  = svmdata(filterindex)
83
+                        [plabel accuracy dvalue] = svmpredict(testing_label,testing_data,model,'')
84
+%                         assignin('base','accuracy',accuracy);
85
+                        classperformance = [classperformance accuracy(1)];
86
+                    end
87
+                    decodePerformance = [decodePerformance; classperformance];
88
+                    
84 89
             end
85 90
             
86 91
         end
... ...
@@ -93,3 +98,17 @@ maxPerformance = -inf;
93 98
         outputStruct.maxPerformance     = maxPerformance;
94 99
 end
95 100
 
101
+function opts = killCrossvalOpt(svmopt)
102
+opts = '';
103
+idx1 = 1;
104
+for idx2=strfind(svmopt,' -')
105
+    if idx1 ~= strfind(svmopt,' -v')
106
+        opts = strcat(opts,svmopt(idx1:idx2));
107
+    end
108
+    idx1=idx2;
109
+    if idx2==max(strfind(svmopt,' -'))
110
+        opts = strcat(opts,svmopt(idx2:end));
111
+    end
112
+end
113
+end
114
+
... ...
@@ -1,7 +1,5 @@
1 1
 function classify(varargin)
2 2
 
3
-
4
-
5 3
 switch nargin
6 4
     case 1
7 5
         paramModel = varargin{1};
... ...
@@ -11,11 +9,10 @@ switch nargin
11 9
     otherwise
12 10
         error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model');
13 11
 end
14
-
15
-        
12
+       
16 13
         % common params
17 14
         calculateParams  = struct;
18
-        calculateParams.smoothed        = getDouble(paramModel.txtSmoothed);
15
+        calculateParams.smoothed        = getChkValue(paramModel.chkSmoothed);
19 16
 
20 17
         calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart);  % -20;
21 18
         calculateParams.frameShiftEnd   = getDouble(paramModel.txtFrameShiftEnd); %15;
... ...
@@ -59,6 +56,7 @@ end
59 56
             warning_state               = warning('off','all');
60 57
             display('calculating ...');
61 58
             decode.(namehelper)         = calculateDecodePerformance(calculateParams,SubjectID);
59
+            
62 60
             display('... done');
63 61
             display('restoring warnings');
64 62
             warning(warning_state);
... ...
@@ -89,12 +87,10 @@ end
89 87
         end
90 88
 
91 89
         plotParams.smoothed          = boolToYesNoString(calculateParams.smoothed);
92
-         
93 90
 
94 91
         assignin('base','plotParams',plotParams);
95 92
 %         plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse);
96 93
         plotDecodePerformance(plotParams);
97 94
             
98 95
         display('all done.');
99
-
100 96
     end
101 97
\ No newline at end of file
102 98
new file mode 100644
... ...
@@ -0,0 +1,3 @@
1
+function d = getChkValue(fieldhandle)
2
+    d = get(fieldhandle,'Value');
3
+end
0 4
\ No newline at end of file
... ...
@@ -61,7 +61,7 @@ DEFAULT.svmoptstring    = '-s 0 -t 0 -v 6 -c 1';
61 61
     set(model.subjectSelector,'BackgroundColor','w');
62 62
     
63 63
     createLabel(pSubject,[0.68*frameWidth firstRow*2  0.25*frameWidth controlElementHeight],'Smooth Data');
64
-    model.txtSmoothed     = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow  0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed);
64
+    model.chkSmoothed     = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow  0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed);
65 65
 
66 66
     createLabel(pSubject,[0.68*frameWidth firstRow*4  0.25*frameWidth controlElementHeight],'Crossvalidation');
67 67
     model.txtMultisubject = createTextField(pSubject,[0.68*frameWidth firstRow*3  0.25*frameWidth controlElementHeight],DEFAULT.multisubject);