single subject -> svm classification -> plotDecodePerformance working!
Christoph Budziszewski

Christoph Budziszewski commited on 2009-03-06 13:17:25
Zeige 5 geänderte Dateien mit 25 Einfügungen und 20 Löschungen.


git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@145 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... ...
@@ -1,3 +1,3 @@
1 1
 function svmlabel = lm_getSVMLabel(mapping,condition)
2
-    svmlabel = getValue(mapping,lm_getLabel(mapping,condition));
2
+    svmlabel = lm_getValue(mapping,lm_getLabel(mapping,condition));
3 3
 end
4 4
\ No newline at end of file
... ...
@@ -1,7 +1,8 @@
1
-function outputStruct = calculateDecodePerformance(timeline,subjectStruct,svmopts)
1
+function outputStruct = calculateDecodePerformance(header,subjectStruct,svmopts)
2 2
 outputStruct = struct;
3 3
 RANDOMIZE_DATAPOINTS = 1;
4 4
 
5
+timeline = header.timeline;
5 6
 
6 7
 timeLineStart   = timeline.frameShiftStart;
7 8
 timeLineEnd     = timeline.frameShiftEnd;
... ...
@@ -16,8 +17,8 @@ timeLineEnd     = timeline.frameShiftEnd;
16 17
 % end
17 18
 
18 19
 timePointArgs.pst           = subjectStruct.pst;
19
-timePointArgs.labelMap      = labelMap;
20
-timePointArgs.eventList     = eventList;
20
+timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
21
+timePointArgs.eventList     = header.classDef.eventMatrix;
21 22
 
22 23
 timePointMatrix = buildTimePointMatrix(timeline,timePointArgs);
23 24
 
... ...
@@ -41,9 +42,9 @@ end
41 42
 outputStruct.decodePerformance  = decodePerformance;
42 43
 outputStruct.svmdata            = svmdata;
43 44
 outputStruct.svmlabel           = svmlabel;
44
-outputStruct.rawTimeCourse      = pst;
45
-outputStruct.minPerformance     = minPerformance;
46
-outputStruct.maxPerformance     = maxPerformance;
45
+outputStruct.rawTimeCourse      = subjectStruct.pst;
46
+% outputStruct.minPerformance     = minPerformance;
47
+% outputStruct.maxPerformance     = maxPerformance;
47 48
 end
48 49
 
49 50
 
... ...
@@ -18,7 +18,7 @@ for subjectDataID = 1:size(subjectdata)
18 18
     warning_state               = warning('off','all');
19 19
     display('calculating ...');
20 20
     
21
-        decode.(namehelper)         = calculateDecodePerformance(header.timeline,currentSubject,svmopts);
21
+        decode.(namehelper)         = calculateDecodePerformance(header,currentSubject,svmopts);
22 22
 
23 23
     display('... done');
24 24
     display('restoring warnings');
... ...
@@ -71,7 +71,7 @@ switch task
71 71
         disp('SVM');
72 72
         svmopts    = getSvmArgs(model,1);
73 73
         preprocessedData = evalin('base','preprocessedData');
74
-        calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts);
74
+        decode = calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts);
75 75
         
76 76
         
77 77
     case 'X-SVM'
... ...
@@ -1,21 +1,25 @@
1
-function plotDecodePerformance(timeline,inputStruct)
1
+function plotDecodePerformance(header,decode,subjectData)
2 2
 
3
-global CROSSVAL_METHOD_DEF;
3
+global SVMCROSSVAL_CROSSVAL_METHOD_DEF;
4 4
 
5 5
 PSTH_AXIS_MIN = -1;
6 6
 PSTH_AXIS_MAX = 1;
7 7
 
8
+timeline = header.timeline;
9
+
8 10
         psthStart         = timeline.psthStart;
9 11
         psthEnd           = timeline.psthEnd;
10 12
         frameStart        = timeline.frameShiftStart;
11 13
         frameEnd          = timeline.frameShiftEnd;
12 14
 
13
-        nClasses          = inputStruct.nClasses;
14
-        decodePerformance = inputStruct.decodePerformance;
15
-        psth              = inputStruct.rawTimeCourse;
16
-        SubjectID         = inputStruct.SubjectID;
17
-        smoothed          = inputStruct.smoothed;
18
-        PLOT_METHOD       = inputStruct.CROSSVAL_METHOD;
15
+        nClasses          = numel(header.classDef.labelCells);
16
+        decodePerformance = decode.decodePerformance;
17
+        psth              = decode.rawTimeCourse;
18
+        SubjectID         = subjectData;
19
+
20
+        smoothed          = 'yes';
21
+        
22
+        PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.svmcrossval;
19 23
 %         CROSSVAL_METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF;
20 24
 
21 25
 
... ...
@@ -50,13 +54,13 @@ PSTH_AXIS_MAX = 1;
50 54
     
51 55
     
52 56
     switch PLOT_METHOD
53
-        case CROSSVAL_METHOD_DEF.svmcrossval
57
+        case SVMCROSSVAL_CROSSVAL_METHOD_DEF.svmcrossval
54 58
             plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b','LineWidth',2);
55 59
 
56 60
             se = myStdErr(decodePerformance,2);
57 61
             plot(frameStart:frameEnd, mean(decodePerformance,2)+se ,'b:');
58 62
             plot(frameStart:frameEnd, mean(decodePerformance,2)-se ,'b:');
59
-        case CROSSVAL_METHOD_DEF.classPerformance
63
+        case SVMCROSSVAL_CROSSVAL_METHOD_DEF.classPerformance
60 64
             for c = 1:size(decodePerformance,2)
61 65
                 plot(frameStart:frameEnd, decodePerformance(:,c) ,[colorChooser(mod(c,nClasses)+3) '-']);
62 66
             end
... ...
@@ -78,7 +82,7 @@ PSTH_AXIS_MAX = 1;
78 82
     end
79 83
     
80 84
     if nSubjects == 1
81
-        subjectName = cell2Mat(SubjectID);
85
+        subjectName = SubjectID{1}.name;
82 86
         title = sprintf('Subject %s, over %g voxel, %s',subjectName,nVoxelPerSubject,smoothedString);
83 87
     else
84 88
         title = sprintf('%g Subjects, %g Voxel per Subject, %s',nSubjects,nVoxelPerSubject,smoothedString);
85 89