added 2nd classification method: no crossval, but class performance checks. may not plot correctly with multisubject
Christoph Budziszewski

Christoph Budziszewski commited on 2009-01-19 18:28:59
Zeige 5 geänderte Dateien mit 100 Einfügungen und 43 Löschungen.


git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@111 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... ...
@@ -0,0 +1,36 @@
1
+function timePointMatrix = buildTimePointMatrix(argStruct)
2
+
3
+pst             = argStruct.pst;
4
+
5
+timeLineStart   = argStruct.timeLineStart;
6
+timeLineEnd     = argStruct.timeLineEnd;
7
+globalStart     = argStruct.globalStart;
8
+globalEnd       = argStruct.globalEnd;
9
+decodeDuration  = argStruct.decodeDuration;
10
+eventList       = argStruct.eventList;
11
+
12
+labelMap        = argStruct.labelMap;
13
+
14
+timePointMatrix = {};
15
+%%% build timepoint Matrix
16
+for timeShift   = timeLineStart:1:timeLineEnd
17
+    % center timepoint && relative shift
18
+    frameStart  = floor(-globalStart+1+timeShift - 0.5*decodeDuration);
19
+    frameEnd    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
20
+
21
+    %build svm inputmatrix
22
+    index = timeShift-timeLineStart+1; %Bad 1-indexing :-(
23
+    timePointMatrix{index} =[];
24
+    anyvoxel = 1;
25
+    for pstConditionGroup = 1:size(pst{1,anyvoxel},2)
26
+        for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point
27
+            row = getSVMLabel(labelMap,eventList(pstConditionGroup,1));
28
+            for voxel = 1:size(pst,2)
29
+                row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,[value,value,...],[value,value,...]...
30
+            end
31
+            timePointMatrix{index}  = [timePointMatrix{index}; row];
32
+        end
33
+    end
34
+end
35
+
36
+end
0 37
\ No newline at end of file
... ...
@@ -3,7 +3,9 @@ function outputStruct = calculateDecodePerformance(inputStruct,SubjectID)
3 3
 
4 4
 addpath 'libsvm-mat-2.88-1';
5 5
 
6
-SINGLE = 1;
6
+METHOD = 'single subject SVM';
7
+% METHOD = 'cross subject SVM';
8
+% METHOD = 'SOM';
7 9
 
8 10
 outputStruct = struct;
9 11
 
... ...
@@ -23,64 +25,67 @@ globalEnd       = inputStruct.psthEnd;
23 25
 baselineStart   = inputStruct.baselineStart;
24 26
 baselineEnd     = inputStruct.baselineEnd;
25 27
 eventList       = inputStruct.eventList;
26
-labelMap        = inputStruct.labelMap;
27
-
28 28
 
29 29
 minPerformance = inf;
30 30
 maxPerformance = -inf;
31 31
         
32
-
33
-        
34
-        %Pro Voxel PSTH TIMELINE berechnen.
35
-        %   timeshift mit pst-timeline durchf�hren.
36
-        % psth-timeline -25 bis +15 zu RES Onset.
37
-        
38
-%         eventList       = [9,11,13;10,12,14];
39
-%         globalStart     = -25;
40
-%         globalEnd       = 15;
41
-%         baselineStart   = -22;
42
-%         baselineEnd     = -20;
43
-        
44
-        
32
+        %% ERSETZEN DURCH ROI-IMAGE!
45 33
         for voxel = 1:size(voxelList,1)  % [[x;x],[y;y],[z;z]]
46 34
                 extr        = calculateImageData(voxelList(voxel,:),des,smoothed); 
47 35
                 rawdata     = cell2mat({extr.mean}); % Raw Data
48 36
                 pst{voxel}  = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,rawdata,sessionList);
49 37
         end
50 38
         
51
-        decodePerformance = [];
52
-
53
-        for timeShift   = timeLineStart:1:timeLineEnd
54
-            frameStart  = floor(-globalStart+1+timeShift - 0.5*decodeDuration);
55
-            frameEnd    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
56
-            
57
-            tmp =[];
58
-            anyvoxel = 1;
59
-            for pstConditionGroup = 1:size(pst{1,anyvoxel},2) 
60
-                for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point
61
-                  row = getSVMLabel(labelMap,eventList(pstConditionGroup,1));
62
-                    for voxel = 1:size(pst,2)
63
-                        row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,value,value
64
-                    end
65
-                tmp  = [tmp; row];
66
-                end
67
-            end 
39
+        timePointArgs.pst = pst;
40
+        timePointArgs.timeLineStart = timeLineStart;
41
+        timePointArgs.timeLineEnd   = timeLineEnd;
42
+        timePointArgs.globalStart   = globalStart;
43
+        timePointArgs.globalEnd     = globalEnd;
44
+        timePointArgs.decodeDuration= decodeDuration;
45
+        timePointArgs.labelMap      = inputStruct.labelMap;
46
+        timePointArgs.eventList     = eventList;
68 47
         
69
-            svmdata      = tmp(:,2:size(tmp,2));
70
-            svmlabel     = tmp(:,1);
48
+        timePointMatrix = buildTimePointMatrix(timePointArgs);
71 49
 
72
-%             RANDOMIZE INPUT
73
-%             rndindex  = randperm(length(svmlabel));
74
-%             svmdata   = svmdata(rndindex,:);
75
-%             svmlabel  = svmlabel(rndindex);
50
+        decodePerformance = [];
51
+        for index = 1:timeLineEnd-timeLineStart+1
52
+            RANDOMIZE_DATAPOINTS = 0;
53
+            svmdata      = timePointMatrix{index}(:,2:size(timePointMatrix{index},2));
54
+            svmlabel     = timePointMatrix{index}(:,1);
55
+            
56
+            if RANDOMIZE_DATAPOINTS
57
+                rndindex  = randperm(length(svmlabel));
58
+                svmdata   = svmdata(rndindex,:);
59
+                svmlabel  = svmlabel(rndindex);
60
+            end
76 61
 
77
-            if SINGLE
62
+            SVM_METHOD = 2;
63
+            switch SVM_METHOD;
64
+                case 1
78 65
                     performance  = svmtrain(svmlabel, svmdata, svmargs);
79 66
 
80 67
                     minPerformance = min(minPerformance,performance);
81 68
                     maxPerformance = max(maxPerformance,performance);
82 69
 
83 70
                     decodePerformance = [decodePerformance; performance];
71
+                case 2
72
+                    newsvmopt = killCrossvalOpt(svmargs);
73
+                    
74
+                    model = svmtrain(svmlabel,svmdata,newsvmopt);
75
+                    classperformance = [];
76
+                    for class = unique(svmlabel)';
77
+%                         assignin('base','uniquelabel',unique(svmlabel));
78
+%                         assignin('base','class',class);
79
+%                         assignin('base','svmlabel',svmlabel);
80
+                        filterindex = find(class == svmlabel);
81
+                        testing_label = svmlabel(filterindex)
82
+                        testing_data  = svmdata(filterindex)
83
+                        [plabel accuracy dvalue] = svmpredict(testing_label,testing_data,model,'')
84
+%                         assignin('base','accuracy',accuracy);
85
+                        classperformance = [classperformance accuracy(1)];
86
+                    end
87
+                    decodePerformance = [decodePerformance; classperformance];
88
+                    
84 89
             end
85 90
             
86 91
         end
... ...
@@ -93,3 +98,17 @@ maxPerformance = -inf;
93 98
         outputStruct.maxPerformance     = maxPerformance;
94 99
 end
95 100
 
101
+function opts = killCrossvalOpt(svmopt)
102
+opts = '';
103
+idx1 = 1;
104
+for idx2=strfind(svmopt,' -')
105
+    if idx1 ~= strfind(svmopt,' -v')
106
+        opts = strcat(opts,svmopt(idx1:idx2));
107
+    end
108
+    idx1=idx2;
109
+    if idx2==max(strfind(svmopt,' -'))
110
+        opts = strcat(opts,svmopt(idx2:end));
111
+    end
112
+end
113
+end
114
+
... ...
@@ -12,10 +10,9 @@ switch nargin
12 10
         error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model');
13 11
 end
14 12
        
15
-        
16 13
         % common params
17 14
         calculateParams  = struct;
18
-        calculateParams.smoothed        = getDouble(paramModel.txtSmoothed);
15
+        calculateParams.smoothed        = getChkValue(paramModel.chkSmoothed);
19 16
 
20 17
         calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart);  % -20;
21 18
         calculateParams.frameShiftEnd   = getDouble(paramModel.txtFrameShiftEnd); %15;
... ...
@@ -0,0 +1,3 @@
1
+function d = getChkValue(fieldhandle)
2
+    d = get(fieldhandle,'Value');
3
+end
0 4
\ No newline at end of file
... ...
@@ -61,7 +61,7 @@ DEFAULT.svmoptstring    = '-s 0 -t 0 -v 6 -c 1';
61 61
     set(model.subjectSelector,'BackgroundColor','w');
62 62
     
63 63
     createLabel(pSubject,[0.68*frameWidth firstRow*2  0.25*frameWidth controlElementHeight],'Smooth Data');
64
-    model.txtSmoothed     = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow  0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed);
64
+    model.chkSmoothed     = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow  0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed);
65 65
 
66 66
     createLabel(pSubject,[0.68*frameWidth firstRow*4  0.25*frameWidth controlElementHeight],'Crossvalidation');
67 67
     model.txtMultisubject = createTextField(pSubject,[0.68*frameWidth firstRow*3  0.25*frameWidth controlElementHeight],DEFAULT.multisubject);
68 68