Browse code

more x-svm classification

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@149 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 12/03/2009 17:15:24
Showing 5 changed files
... ...
@@ -21,5 +21,8 @@ global SVMCROSSVAL_SUBJECT_PREFIX;
21 21
 SVMCROSSVAL_SUBJECT_PREFIX                        = 'subject';
22 22
 
23 23
 global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
24
-SVMCROSSVAL_SUBJECTSTRUCT_NAME = 'subjectStruct';
24
+SVMCROSSVAL_SUBJECTSTRUCT_NAME                    = 'subjectStruct';
25
+
26
+global SVMCROSSVAL_PREPROCESSED_DATA_NAME;
27
+SVMCROSSVAL_PREPROCESSED_DATA_NAME                = 'preprocessedData';
25 28
 end
26 29
\ No newline at end of file
27 30
new file mode 100644
... ...
@@ -0,0 +1,8 @@
1
+function svmstruct = calculateSVMTables(timePointMatrix,timeIndex)
2
+    svmstruct.svmdata  = timePointMatrix{timeIndex}(:,2:size(timePointMatrix{timeIndex},2));
3
+    svmstruct.svmlabel = timePointMatrix{timeIndex}(:,1);
4
+end
5
+
6
+
7
+
8
+
... ...
@@ -72,16 +72,18 @@ end
72 72
 end
73 73
 
74 74
 function decode(model,task)
75
+preprocessedData = evalin('base','preprocessedData');
75 76
 switch task
76 77
     case 'SVM'
77 78
         disp('SVM');
78 79
         svmopts    = getSvmArgs(model,1);
79
-        preprocessedData = evalin('base','preprocessedData');
80 80
         decode = calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts);
81 81
         assignin('base','decode',decode);
82
-    case 'X-SVM'
82
+    case 'XSVM'
83 83
         disp('not implemented')
84
-        
84
+        svmopts  = getSvmArgs(model,0);
85
+        decode = xsvm_subject_loop(preprocessedData.header,preprocessedData.subjectdata,svmopts);
86
+        assignin('base','decode',decode);
85 87
     case 'SOM'
86 88
         disp('not implemented')
87 89
 
... ...
@@ -132,7 +132,7 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]);
132 132
         'Units','normalized',...
133 133
         'Position',[0 0.0 1 0.25]);
134 134
     set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
135
-    set(btnRunXSVM,'Enable','off');
135
+    set(btnRunXSVM,'Enable','on');
136 136
     
137 137
     btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
138 138
         'Units','normalized',...
139 139
new file mode 100644
... ...
@@ -0,0 +1,73 @@
1
+%% subject loop
2
+function decode = xsvm_subject_loop(header,subjectdata,svmopts)
3
+
4
+nSubjects = numel(subjectdata);
5
+
6
+decode = struct;
7
+decode.decodePerformance = [];
8
+decode.rawTimeCourse     = [];
9
+
10
+disp(sprintf('we have %g subjects. Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.',nSubjects));
11
+pause
12
+
13
+
14
+timeline = header.timeline;
15
+
16
+for subjectDataID = 1:nSubjects
17
+    currentSubject = subjectdata{subjectDataID};
18
+    timePointArgs.pst           = currentSubject.pst;
19
+    timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
20
+    timePointArgs.eventList     = header.classDef.eventMatrix;
21
+
22
+    timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
23
+end
24
+
25
+timeLineStart   = timeline.frameShiftStart;
26
+timeLineEnd     = timeline.frameShiftEnd;
27
+
28
+addpath 'libsvm-mat-2.88-1';
29
+
30
+for timeIndex = 1:timeLineEnd-timeLineStart+1
31
+    svm_train_label = [];
32
+    svm_train_data  = [];
33
+    svm_validation_label = [];
34
+    svm_validation_data  = [];
35
+
36
+    for validationSubjectID = 1:nSubjects
37
+        for subjectDataID = 1:nSubjects
38
+            svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
39
+            if subjectDataID == validationSubjectID
40
+                svm_validation_label = svmstruct.svmlabel;
41
+                svm_validation_data  = svmstruct.svmdata;
42
+            else
43
+                svm_train_label = [svm_train_label; svmstruct.svmlabel];
44
+                svm_train_data  = [svm_train_data;  svmstruct.svmdata];
45
+            end
46
+        end
47
+        
48
+        svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts);
49
+        
50
+        [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,'');
51
+        
52
+        accuracy(1)
53
+
54
+%         decode.(namehelper)         = calculateDecodePerformance(header,currentSubject,svmopts);
55
+% 
56
+%         display('... done');
57
+%         display('restoring warnings');
58
+%         warning(warning_state);
59
+% 
60
+%         decode.decodePerformance    = [decode.decodePerformance decode.(namehelper).decodePerformance];
61
+%         decode.rawTimeCourse        = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse];
62
+
63
+%         assignin('base','decode',decode);
64
+
65
+        %         if RANDOMIZE_DATAPOINTS
66
+        %         rndindex  = randperm(length(svmlabel));
67
+        %         svmdata   = svmdata(rndindex,:);
68
+        %         svmlabel  = svmlabel(rndindex);
69
+        %         end
70
+
71
+    end
72
+end
73
+end