Browse code

moved stuff, fixing plotDecode

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@151 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 16/03/2009 14:41:06
Showing 13 changed files
1 1
similarity index 100%
2 2
rename from @LabelMap/LabelMap.m
3 3
rename to private/@LabelMap/LabelMap.m
4 4
similarity index 100%
5 5
rename from @LabelMap/display.m
6 6
rename to private/@LabelMap/display.m
7 7
similarity index 100%
8 8
rename from @LabelMap/lm_getClasses.m
9 9
rename to private/@LabelMap/lm_getClasses.m
10 10
similarity index 100%
11 11
rename from @LabelMap/lm_getCondition.m
12 12
rename to private/@LabelMap/lm_getCondition.m
13 13
similarity index 100%
14 14
rename from @LabelMap/lm_getLabel.m
15 15
rename to private/@LabelMap/lm_getLabel.m
16 16
similarity index 100%
17 17
rename from @LabelMap/lm_getPSTEventMatrix.m
18 18
rename to private/@LabelMap/lm_getPSTEventMatrix.m
19 19
similarity index 100%
20 20
rename from @LabelMap/lm_getSVMLabel.m
21 21
rename to private/@LabelMap/lm_getSVMLabel.m
22 22
similarity index 100%
23 23
rename from @LabelMap/lm_getValue.m
24 24
rename to private/@LabelMap/lm_getValue.m
... ...
@@ -7,22 +7,21 @@ PSTH_AXIS_MAX = 1;
7 7
 
8 8
 timeline = header.timeline;
9 9
 
10
-        psthStart         = timeline.psthStart;
11
-        psthEnd           = timeline.psthEnd;
12
-        frameStart        = timeline.frameShiftStart;
13
-        frameEnd          = timeline.frameShiftEnd;
10
+psthStart         = timeline.psthStart;
11
+psthEnd           = timeline.psthEnd;
12
+frameStart        = timeline.frameShiftStart;
13
+frameEnd          = timeline.frameShiftEnd;
14 14
 
15
-        nClasses          = numel(header.classDef.labelCells);
16
-        decodePerformance = decode.decodePerformance;
17
-        psth              = decode.rawTimeCourse;
18
-        SubjectID         = subjectData;
15
+nClasses          = numel(header.classDef.labelCells);
16
+decodePerformance = decode.decodePerformance;
17
+psth              = decode.rawTimeCourse;
18
+SubjectID         = subjectData;
19 19
 
20
-        smoothed          = 'yes';
21
-        
22
-        PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.svmcrossval;
23
-        PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.classPerformance;
24
-%         CROSSVAL_METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF;
20
+smoothed          = 'yes';
25 21
 
22
+PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.svmcrossval;
23
+PLOT_METHOD       = SVMCROSSVAL_CROSSVAL_METHOD_DEF.classPerformance;
24
+PLOT_METHOD       = 'x-subject-val';
26 25
 
27 26
     f = figure;
28 27
     subplot(2,1,1);
... ...
@@ -70,6 +69,17 @@ timeline = header.timeline;
70 69
             end
71 70
 
72 71
             plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b','LineWidth',2);
72
+            
73
+        case 'x-subject-val'
74
+            nSubjects = size(decodePerformance,2);
75
+            for c = 1:nSubjects
76
+                plot(frameStart:frameEnd, decodePerformance(:,c) ,[colorChooser(mod(c,nSubjects)+3) '-']);
77
+            end
78
+
79
+            plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b','LineWidth',2);
80
+            se = myStdErr(decodePerformance,2);
81
+            plot(frameStart:frameEnd, mean(decodePerformance,2)+se ,'b:');
82
+            plot(frameStart:frameEnd, mean(decodePerformance,2)-se ,'b:');
73 83
 
74 84
     end
75 85
     
76 86
similarity index 82%
77 87
rename from private/svm_crossval.m
78 88
rename to private/svm_class_performance.m
... ...
@@ -1,4 +1,4 @@
1
-function decodePerformance = svm_crossval(svmlabel,svmdata,svmopts)
1
+function decodePerformance = svm_class_performance(svmlabel,svmdata,svmopts)
2 2
 addpath 'libsvm-mat-2.88-1';
3 3
 
4 4
 svmmodel = svmtrain(svmlabel,svmdata,svmopts);
... ...
@@ -66,87 +66,6 @@ DEFAULT.wd  = fullfile('d:','Analyze','Choice','24pilot');
66 66
     assignin('base','model',model);
67 67
 end
68 68
 
69
-function model = mcb_cd(src,evnt,model)
70
-disp('CD');
71
-directory_name = uigetdir(model.baseDir,'Select Study Base Directory ...');
72
-model.baseDir = directory_name;
73
-model = scanDirs(model);
74
-end
75
-
76
-function mcb_save(src,evnt,model)
77
-disp('SAVE');
78
-baseDir  = model.baseDir;
79
-timeLine = getTimeLineParams(model);
80
-classDefString = getClassDefString(model);
81
-coordDefString = getCoordDefString(model);
82
-
83
-[file path] = uiputfile('*.mat','Save current Params ...',model.baseDir);
84
-save( fullfile(path,file),'baseDir','timeLine','classDefString','coordDefString') ;
85
-end
86
-
87
-function model = mcb_load(src,evnt,model)
88
-disp('LOAD');
89
-[file path] = uigetfile('*.mat','Load Params ...',model.baseDir);
90
-l = load(fullfile(path,file));
91
-% assignin('base','l',l);
92
-model = setTimeLineParams(model,l.timeLine);
93
-model = setClassDefString(model,l.classDefString);
94
-model = setCoordDefString(model,l.coordDefString);
95
-model.baseDir = l.baseDir;
96
-model = scanDirs(model);
97
-
98
-end
99
-
100
-function model = createSecondStepPanel(model,parent,DEFAULT,basecolor)
101
-    
102
-pSVM = uipanel(parent,'Units','normalized','Position',[0 0.0 0.5 1]);
103
-    set(pSVM,'Title','SVM Classification');
104
-    set(pSVM,'BackgroundColor',basecolor);
105
-
106
-    model.txtSVMopts = createTextField(pSVM,[0 0.75 1 0.25],DEFAULT.svmoptstring);
107
-    set(model.txtSVMopts,'HorizontalAlignment','left');
108
-    
109
-    model.txtSVMnfold = createTextField(pSVM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
110
-    createLabel(pSVM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
111
-    
112
-pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]);
113
-    set(pSOM,'Title','SOM Classification');
114
-    set(pSOM,'BackgroundColor',basecolor);
115
-
116
-    model.txtSOMopts = createTextField(pSOM,[0 0.75 1 0.25],'4x3 rect');
117
-    set(model.txtSOMopts,'HorizontalAlignment','left');
118
-        set(model.txtSOMopts,'Enable','off');
119
-
120
-    model.txtSOMnfold = createTextField(pSOM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
121
-        set(model.txtSOMnfold,'Enable','off');
122
-    createLabel(pSOM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
123
-
124
-% buttons
125
-    btnRunSVM = uicontrol(pSVM,'String','run SVM Crossvalidation',...
126
-        'Units','normalized',...
127
-        'Position',[0 0.25 1 0.25]);
128
-    set(btnRunSVM,'Callback',{@cbRunDecode,model,'SVM'}); % set here, because of model.
129
-    set(btnRunSVM,'Enable','on');
130
-    
131
-    btnRunXSVM = uicontrol(pSVM,'String','run SVM X-Subject validation',...
132
-        'Units','normalized',...
133
-        'Position',[0 0.0 1 0.25]);
134
-    set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
135
-    set(btnRunXSVM,'Enable','on');
136
-    
137
-    btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
138
-        'Units','normalized',...
139
-    'Position',[0.0 0.25 1 0.25]);
140
-    set(btnRunSOM,'Callback',{@cbRunDecode,model,'SOM'}); % set here, because of model.
141
-    set(btnRunSOM,'Enable','off');
142
-
143
-    btnRunXSOM = uicontrol(pSOM,'String','run SOM X-Subject validation',...
144
-        'Units','normalized',...
145
-        'Position',[0.0 0.0 1 0.25]);
146
-    set(btnRunXSOM,'Callback',{@cbRunDecode,model,'XSOM'}); % set here, because of model.
147
-    set(btnRunXSOM,'Enable','off');
148
-end
149
-
150 69
 function model = createFirstStepPanel(model,parent,DEFAULT)
151 70
 
152 71
     main_grid = cell(2,4);
... ...
@@ -317,6 +236,58 @@ function model = createFirstStepPanel(model,parent,DEFAULT)
317 236
         set(btnRunButton3,'Enable','on');
318 237
 end
319 238
 
239
+function model = createSecondStepPanel(model,parent,DEFAULT,basecolor)
240
+    
241
+pSVM = uipanel(parent,'Units','normalized','Position',[0 0.0 0.5 1]);
242
+    set(pSVM,'Title','SVM Classification');
243
+    set(pSVM,'BackgroundColor',basecolor);
244
+
245
+    model.txtSVMopts = createTextField(pSVM,[0 0.75 1 0.25],DEFAULT.svmoptstring);
246
+    set(model.txtSVMopts,'HorizontalAlignment','left');
247
+    
248
+    model.txtSVMnfold = createTextField(pSVM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
249
+    createLabel(pSVM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
250
+    
251
+pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.0 0.5 1]);
252
+    set(pSOM,'Title','SOM Classification');
253
+    set(pSOM,'BackgroundColor',basecolor);
254
+
255
+    model.txtSOMopts = createTextField(pSOM,[0 0.75 1 0.25],'4x3 rect');
256
+    set(model.txtSOMopts,'HorizontalAlignment','left');
257
+        set(model.txtSOMopts,'Enable','off');
258
+
259
+    model.txtSOMnfold = createTextField(pSOM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
260
+        set(model.txtSOMnfold,'Enable','off');
261
+    createLabel(pSOM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
262
+
263
+% buttons
264
+    btnRunSVM = uicontrol(pSVM,'String','run SVM Crossvalidation',...
265
+        'Units','normalized',...
266
+        'Position',[0 0.25 1 0.25]);
267
+    set(btnRunSVM,'Callback',{@cbRunDecode,model,'SVM'}); % set here, because of model.
268
+    set(btnRunSVM,'Enable','on');
269
+    
270
+    btnRunXSVM = uicontrol(pSVM,'String','run SVM X-Subject validation',...
271
+        'Units','normalized',...
272
+        'Position',[0 0.0 1 0.25]);
273
+    set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
274
+    set(btnRunXSVM,'Enable','on');
275
+    
276
+    btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
277
+        'Units','normalized',...
278
+    'Position',[0.0 0.25 1 0.25]);
279
+    set(btnRunSOM,'Callback',{@cbRunDecode,model,'SOM'}); % set here, because of model.
280
+    set(btnRunSOM,'Enable','off');
281
+
282
+    btnRunXSOM = uicontrol(pSOM,'String','run SOM X-Subject validation',...
283
+        'Units','normalized',...
284
+        'Position',[0.0 0.0 1 0.25]);
285
+    set(btnRunXSOM,'Callback',{@cbRunDecode,model,'XSOM'}); % set here, because of model.
286
+    set(btnRunXSOM,'Enable','off');
287
+end
288
+
289
+
290
+
320 291
 function cbRunPreprocessing(src,evnt,model,task)
321 292
 main(model,'pre',task);
322 293
 end
... ...
@@ -325,6 +296,37 @@ function cbRunDecode(src,evnt,model,task)
325 296
 main(model,'decode',task);
326 297
 end
327 298
 
299
+function model = mcb_cd(src,evnt,model)
300
+disp('CD');
301
+directory_name = uigetdir(model.baseDir,'Select Study Base Directory ...');
302
+model.baseDir = directory_name;
303
+model = scanDirs(model);
304
+end
305
+
306
+function mcb_save(src,evnt,model)
307
+disp('SAVE');
308
+baseDir  = model.baseDir;
309
+timeLine = getTimeLineParams(model);
310
+classDefString = getClassDefString(model);
311
+coordDefString = getCoordDefString(model);
312
+
313
+[file path] = uiputfile('*.mat','Save current Params ...',model.baseDir);
314
+save( fullfile(path,file),'baseDir','timeLine','classDefString','coordDefString') ;
315
+end
316
+
317
+function model = mcb_load(src,evnt,model)
318
+disp('LOAD');
319
+[file path] = uigetfile('*.mat','Load Params ...',model.baseDir);
320
+l = load(fullfile(path,file));
321
+% assignin('base','l',l);
322
+model = setTimeLineParams(model,l.timeLine);
323
+model = setClassDefString(model,l.classDefString);
324
+model = setCoordDefString(model,l.coordDefString);
325
+model.baseDir = l.baseDir;
326
+model = scanDirs(model);
327
+
328
+end
329
+
328 330
 function label = createLabel(parent,  pos, labelText)
329 331
     label = uicontrol(parent,'Style','text','Units','normalized','String',labelText,'Position',pos);
330 332
     set(label,'HorizontalAlignment','left');
331 333
deleted file mode 100644
... ...
@@ -1,9 +0,0 @@
1
-function decode = xsvm_single_crossval()
2
-
3
-%for each subject in test
4
-% append svmdata and svm label
5
-% train svm
6
-% test with testsubject
7
-%end
8
-
9
-end
10 0
\ No newline at end of file
... ...
@@ -1,8 +1,12 @@
1 1
 %% subject loop
2 2
 function decode = xsvm_subject_loop(header,subjectdata,svmopts)
3 3
 
4
+addpath 'libsvm-mat-2.88-1';
5
+
4 6
 nSubjects = numel(subjectdata);
5 7
 
8
+RANDOMIZE_DATAPOINTS = 0;
9
+
6 10
 decode = struct;
7 11
 decode.decodePerformance = [];
8 12
 decode.rawTimeCourse     = [];
... ...
@@ -11,6 +15,7 @@ disp(sprintf('computinig additional datastructs for %u subjects',nSubjects));
11 15
 
12 16
 timeline = header.timeline;
13 17
 
18
+% TimePointMatrix
14 19
 for subjectDataID = 1:nSubjects
15 20
     currentSubject = subjectdata{subjectDataID};
16 21
     timePointArgs.pst           = currentSubject.pst;
... ...
@@ -18,16 +23,18 @@ for subjectDataID = 1:nSubjects
18 23
     timePointArgs.eventList     = header.classDef.eventMatrix;
19 24
 
20 25
     timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
26
+    
27
+    decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
21 28
 end
22 29
 
30
+% timeframe x-subject validation
23 31
 timeLineStart   = timeline.frameShiftStart;
24 32
 timeLineEnd     = timeline.frameShiftEnd;
25 33
 
26
-addpath 'libsvm-mat-2.88-1';
27
-
28 34
 display(sprintf('%u -fold cross validation for %u timeslices.\n',nSubjects,size(1:timeLineEnd-timeLineStart+1,2)));
29
-disp(sprintf('Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.'));
30
-pause
35
+% disp(sprintf('Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.'));
36
+% pause
37
+
31 38
 for timeIndex = 1:timeLineEnd-timeLineStart+1
32 39
     cross_value = [];
33 40
     for validationSubjectID = 1:nSubjects
... ...
@@ -46,35 +53,19 @@ for timeIndex = 1:timeLineEnd-timeLineStart+1
46 53
             end
47 54
         end
48 55
         
49
-%         display(sprintf('Time %u: validation subject: %u, validation set size %g, training set size %g with %u subjects',...
50
-%             timeIndex, validationSubjectID, numel(svm_validation_label), numel(svm_train_label),nSubjects-1));
51
-        
56
+        if RANDOMIZE_DATAPOINTS
57
+            rndindex  = randperm(length(svm_train_label));
58
+            svm_train_data   = svm_train_data(rndindex,:);
59
+            svm_train_label  = svm_train_label(rndindex);
60
+        end
61
+
52 62
         svmmodel = svmtrain(svm_train_label,svm_train_data,svmopts);
53 63
         
54 64
         [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,'');
55 65
         cross_value = [cross_value accuracy(1)];
56 66
         
57 67
     end
58
-    decode.decodePerformance = [decode.decodePerformance mean(cross_value)];
59
-%     decode.rawTimeCourse = [decode.rawTimeCourse cross_value];
60
-    
61
-%         decode.(namehelper)         = calculateDecodePerformance(header,currentSubject,svmopts);
62
-% 
63
-%         display('... done');
64
-%         display('restoring warnings');
65
-%         warning(warning_state);
66
-% 
67
-%         decode.decodePerformance    = [decode.decodePerformance decode.(namehelper).decodePerformance];
68
-%         decode.rawTimeCourse        = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse];
69
-
70
-%         assignin('base','decode',decode);
71
-
72
-        %         if RANDOMIZE_DATAPOINTS
73
-        %         rndindex  = randperm(length(svmlabel));
74
-        %         svmdata   = svmdata(rndindex,:);
75
-        %         svmlabel  = svmlabel(rndindex);
76
-        %         end
77
-
68
+    decode.decodePerformance = [decode.decodePerformance; cross_value];
78 69
     
79 70
 end
80 71
 end