Browse code

enabled svm classification. labelmap not working

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@143 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 05/03/2009 16:47:04
Showing 5 changed files
... ...
@@ -1,10 +1,8 @@
1
-function outputStruct = calculateDecodePerformance(timeline,subjectStruct,model)
1
+function outputStruct = calculateDecodePerformance(timeline,subjectStruct,svmopts)
2 2
 outputStruct = struct;
3
-RANDOMIZE_DATAPOINTS = 0;
3
+RANDOMIZE_DATAPOINTS = 1;
4 4
 
5 5
 
6
-eventList       = inputStruct.eventList;
7
-
8 6
 timeLineStart   = timeline.frameShiftStart;
9 7
 timeLineEnd     = timeline.frameShiftEnd;
10 8
 
... ...
@@ -17,11 +15,8 @@ timeLineEnd     = timeline.frameShiftEnd;
17 15
 %     pst{iVoxel} = calculatePST(timeline,calculatePstOpts,rawdata);
18 16
 % end
19 17
 
20
-
21
-
22
-timePointArgs.pst = pst;
23
-
24
-timePointArgs.labelMap      = inputStruct.labelMap;
18
+timePointArgs.pst           = subjectStruct.pst;
19
+timePointArgs.labelMap      = labelMap;
25 20
 timePointArgs.eventList     = eventList;
26 21
 
27 22
 timePointMatrix = buildTimePointMatrix(timeline,timePointArgs);
28 23
new file mode 100644
... ...
@@ -0,0 +1,33 @@
1
+%% subject loop
2
+function decode = calculateMultiSubjectDecodePerformance(header,subjectdata,svmopts)
3
+
4
+
5
+
6
+decode = struct;
7
+decode.decodePerformance = [];
8
+decode.rawTimeCourse     = [];
9
+
10
+for subjectDataID = 1:size(subjectdata)
11
+%     SubjectID = cell2mat(subjectCell);
12
+    currentSubject = subjectdata{subjectDataID};
13
+
14
+    namehelper = strcat('s',currentSubject.name); %Vars can not start with numbers.
15
+
16
+    display(sprintf('calculating cross-validation performance time-shift for Subject %s. Please Wait. ...',currentSubject.name));
17
+    display('switching off all warnings');
18
+    warning_state               = warning('off','all');
19
+    display('calculating ...');
20
+    
21
+        decode.(namehelper)         = calculateDecodePerformance(header.timeline,currentSubject,svmopts);
22
+
23
+    display('... done');
24
+    display('restoring warnings');
25
+    warning(warning_state);
26
+
27
+    decode.decodePerformance    = [decode.decodePerformance decode.(namehelper).decodePerformance];
28
+    decode.rawTimeCourse        = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse];
29
+
30
+    assignin('base','decode',decode);
31
+end
32
+
33
+end
... ...
@@ -1,10 +1,20 @@
1
-function main(model,task)
2
-
3
-% parse the GUI and pass parameters as structure
1
+function main(model,task,subtask)
4 2
 disp('RUN');
3
+switch task 
4
+    
5
+    case 'pre'
6
+        preprocess(model,subtask);
7
+    case 'decode'
8
+        decode(model,subtask);
9
+end
10
+end
5 11
 % disp('all warnings OFF')
6 12
 % warn = warning('off','all');
7 13
 
14
+
15
+function preprocess(model,task)
16
+% parse the GUI and pass parameters as structure
17
+
8 18
 timeLine = getTimeLineParams(model);
9 19
 subjects = getSubjectCellList(model);
10 20
 classDef = parseClassDef(model);
... ...
@@ -16,41 +26,62 @@ mask     = ['^' cell2mat(getImageFileMask(model)) '.*\.img$'];
16 26
 switch task
17 27
     case 'COORD'
18 28
         disp('COORD');
29
+       
30
+        out = struct;
31
+        out.header = struct;
32
+        out.header.type = 'COORD';
33
+        out.header.timeline = timeLine;
34
+        out.header.classDef = classDef;
35
+        
36
+        
19 37
         coordargs = struct;
20
-        coordargs.subjects = subjects;
21
-        coordargs.timeline = timeLine;
22
-        coordargs.basedir = model.baseDir;
23
-        coordargs.sessionList = 1:3;
24
-        coordargs.eventList = classDef.eventMatrix;
25
-        coordargs.coords = parseCoordinateTextField(model);
26
-        coordargs.mask   = mask;
38
+        coordargs.subjects      = subjects;
39
+        coordargs.timeline      = timeLine;
40
+        coordargs.basedir       = model.baseDir;
41
+        coordargs.sessionList   = 1:3;
42
+        coordargs.eventList     = classDef.eventMatrix;
43
+        coordargs.coords        = parseCoordinateTextField(model);
44
+        coordargs.mask          = mask;
45
+        
46
+        out.subjectdata = runCoordTable(coordargs);
27 47
         
28
-        runCoordTable(coordargs)
48
+        assignin('base','preprocessedData',out);
29 49
         
30 50
     case 'ROI'
31 51
         disp('ROI');
32 52
         roiargs = struct;
33
-        roiargs.subjects    = subjects;
34
-        roiargs.timeline    = timeLine;
35
-        roiargs.classes     = classDef;
36
-        roiargs.mask        = mask;
37
-        roiargs.basedir     = model.baseDir;
38
-        roiargs.sessionList = 1:3;
39
-        roiargs.eventList   = classDef.eventMatrix;
53
+        roiargs.subjects        = subjects;
54
+        roiargs.timeline        = timeLine;
55
+        roiargs.classes         = classDef;
56
+        roiargs.mask            = mask;
57
+        roiargs.basedir         = model.baseDir;
58
+        roiargs.sessionList     = 1:3;
59
+        roiargs.eventList       = classDef.eventMatrix;
40 60
         
41 61
         runROIImageMaskMode(roiargs);
42 62
         
43 63
     case 'FBS'
44
-        disp('FBS')
45
-        
64
+        disp('not implemented')
65
+end
66
+end
67
+
68
+function decode(model,task)
69
+switch task
46 70
     case 'SVM'
47
-        disp('classify with svm');
71
+        disp('SVM');
72
+        svmopts    = getSvmArgs(model,1);
73
+        preprocessedData = evalin('base','preprocessedData');
74
+        calculateMultiSubjectDecodePerformance(preprocessedData.header,preprocessedData.subjectdata,svmopts);
75
+        
48 76
         
49 77
     case 'X-SVM'
78
+        disp('not implemented')
50 79
         
51 80
     case 'SOM'
81
+        disp('not implemented')
52 82
 
53 83
     case 'X-SOM'
84
+        disp('not implemented')
54 85
         
55 86
 end
56 87
 
... ...
@@ -1,4 +1,4 @@
1
-function runCoordTable(args)
1
+function subjectData = runCoordTable(args)
2 2
 
3 3
 global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
4 4
 
... ...
@@ -40,5 +40,6 @@ global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
40 40
        disp(sprintf('done %g // %g',s,nSubjects));
41 41
     end
42 42
     
43
-    assignin('base',SVMCROSSVAL_SUBJECTSTRUCT_NAME,subjectStruct);
43
+    subjectData = subjectStruct;
44
+%     assignin('base',SVMCROSSVAL_SUBJECTSTRUCT_NAME,subjectStruct);
44 45
 end
... ...
@@ -36,18 +36,15 @@ DEFAULT.wd  = fullfile('d:','Analyze','Choice','24pilot');
36 36
 
37 37
     model.txtBaseDir = createLabel(frame,[0 0.97 1 0.03],model.baseDir);
38 38
     set(model.txtBaseDir,'BackgroundColor','w');
39
+    set(model.txtBaseDir,'ForegroundColor','b');
39 40
     
40 41
     pFirstStep   = uipanel(frame,'Title','Preprocessing','Position',[0 0.25 1 0.720]);
41 42
     set(pFirstStep,'BackgroundColor','w');
42 43
     set(pFirstStep,'Units','normalized');
43 44
     
44
-    
45
-
46
-    
47 45
     model.selectedSubject = DEFAULT.selectedSubject;
48 46
     model = createFirstStepPanel(model,pFirstStep,DEFAULT);
49 47
     
50
-    
51 48
     %Classification Step
52 49
     secondStepBaseColor = 'w';
53 50
     pSecondStep = uipanel(frame,'Title','Classification','Position',[0 0 1 0.25]);
... ...
@@ -111,8 +108,6 @@ pSVM = uipanel(parent,'Units','normalized','Position',[0 0.0 0.5 1]);
111 108
     
112 109
     model.txtSVMnfold = createTextField(pSVM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
113 110
     createLabel(pSVM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
114
-
115
-
116 111
     
117 112
 pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]);
118 113
     set(pSOM,'Title','SOM Classification');
... ...
@@ -126,28 +121,29 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]);
126 121
         set(model.txtSOMnfold,'Enable','off');
127 122
     createLabel(pSOM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
128 123
 
124
+% buttons
129 125
     btnRunSVM = uicontrol(pSVM,'String','run SVM Crossvalidation',...
130 126
         'Units','normalized',...
131 127
         'Position',[0 0.25 1 0.25]);
132
-    set(btnRunSVM,'Callback',{@cbRunSVM,model}); % set here, because of model.
133
-    set(btnRunSVM,'Enable','off');
128
+    set(btnRunSVM,'Callback',{@cbRunDecode,model,'SVM'}); % set here, because of model.
129
+    set(btnRunSVM,'Enable','on');
134 130
     
135 131
     btnRunXSVM = uicontrol(pSVM,'String','run SVM X-Subject validation',...
136 132
         'Units','normalized',...
137 133
         'Position',[0 0.0 1 0.25]);
138
-    set(btnRunXSVM,'Callback',{@cbRunXSVM,model}); % set here, because of model.
134
+    set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
139 135
     set(btnRunXSVM,'Enable','off');
140 136
     
141 137
     btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
142 138
         'Units','normalized',...
143 139
     'Position',[0.0 0.25 1 0.25]);
144
-    set(btnRunSOM,'Callback',{@cbRunSOM,model}); % set here, because of model.
140
+    set(btnRunSOM,'Callback',{@cbRunDecode,model,'SOM'}); % set here, because of model.
145 141
     set(btnRunSOM,'Enable','off');
146 142
 
147 143
     btnRunXSOM = uicontrol(pSOM,'String','run SOM X-Subject validation',...
148 144
         'Units','normalized',...
149 145
         'Position',[0.0 0.0 1 0.25]);
150
-    set(btnRunXSOM,'Callback',{@cbRunXSOM,model}); % set here, because of model.
146
+    set(btnRunXSOM,'Callback',{@cbRunDecode,model,'XSOM'}); % set here, because of model.
151 147
     set(btnRunXSOM,'Enable','off');
152 148
 end
153 149
 
... ...
@@ -322,7 +318,11 @@ function model = createFirstStepPanel(model,parent,DEFAULT)
322 318
 end
323 319
 
324 320
 function cbRunPreprocessing(src,evnt,model,task)
325
-main(model,task);
321
+main(model,'pre',task);
322
+end
323
+
324
+function cbRunDecode(src,evnt,model,task)
325
+main(model,'decode',task);
326 326
 end
327 327
 
328 328
 function label = createLabel(parent,  pos, labelText)