Browse code

snapshot, classification

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@140 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 02/03/2009 18:28:41
Showing 11 changed files
... ...
@@ -19,4 +19,7 @@ SVMCROSSVAL_VOXEL_SELECTION_MODE_DEF.roiImage     = 'use ROI image by pop-up ima
19 19
 global SVMCROSSVAL_SUBJECT_PREFIX;
20 20
 % internally used to prefix subject-ids starting with numbers.
21 21
 SVMCROSSVAL_SUBJECT_PREFIX                        = 'subject';
22
+
23
+global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
24
+SVMCROSSVAL_SUBJECTSTRUCT_NAME = 'subjectStruct';
22 25
 end
23 26
\ No newline at end of file
... ...
@@ -1,74 +1,23 @@
1
-% function [decodePerformance rawTimecourse ] = calculateDecodePerformance(des,timeLineStart, timeLineEnd, decodeDuration, svmargs, conditionList, sessionList, voxelList, classList, labelMap,normalize)
2
-function outputStruct = calculateDecodePerformance(timeline,inputStruct,subjectParams)
3
-
4
-global CROSSVAL_METHOD_DEF;
5
-
6
-
7
-addpath 'libsvm-mat-2.88-1';
8
-
9
-% CROSSVAL_METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF;
10
-METHOD              = inputStruct.CROSSVAL_METHOD;
11
-
12
-RANDOMIZE_DATAPOINTS = inputStruct.RANDOMIZE;
13
-
14
-% SubjectID       = subjectParams.SubjectID;
15
-% namehelper      = subjectParams.namehelper;
16
-voxelList       = subjectParams.voxelList;
17
-des             = subjectParams.des;
18
-
1
+function outputStruct = calculateDecodePerformance(timeline,subjectStruct,model)
19 2
 outputStruct = struct;
3
+RANDOMIZE_DATAPOINTS = 0;
20 4
 
21
-svmargs         = inputStruct.svmargs;
22
-sessionList     = inputStruct.sessionList;
23 5
 
24
-% classList       = inputStruct.classList;
25
-% labelMap        = inputStruct.labelMap;
26 6
 eventList       = inputStruct.eventList;
27 7
 
28 8
 timeLineStart   = timeline.frameShiftStart;
29 9
 timeLineEnd     = timeline.frameShiftEnd;
30
-% decodeDuration  = timeline.decodeDuration;
31
-% globalStart     = timeline.psthStart;
32
-% globalEnd       = timeline.psthEnd;
33
-% baselineStart   = timeline.baselineStart;
34
-% baselineEnd     = timeline.baselineEnd;
35
-
36
-
37
-minPerformance = inf;
38
-maxPerformance = -inf;
39 10
 
40
-subjectDir = '';
41
-sessionDirList = sessionList2DirList(sessionList) ;
42
-mask = '^fandersen.*img$';
43
-imageFiles = getImageFileList(subjectDir,sessionDirList,mask);
11
+% for iVoxel = 1:nVoxel
12
+%     rawdata = [];
13
+%     for iImage = 1:length(extr);
14
+%         tmp = extr(iImage);
15
+%         rawdata = [rawdata tmp.dat(iVoxel)];
16
+%     end
17
+%     pst{iVoxel} = calculatePST(timeline,calculatePstOpts,rawdata);
18
+% end
44 19
 
45 20
 
46
-disp('press key');
47
-pause
48
-
49
-extr = calculateImageData(imageFiles,voxelList);
50
-
51
-nVoxel = size(voxelList,1);
52
-
53
-calculatePstOpts = struct;
54
-calculatePstOpts.des = des;
55
-calculatePstOpts.eventList = eventList;
56
-calculatePstOpts.sessionList = sessionList;
57
-
58
-for iVoxel = 1:nVoxel
59
-    rawdata = [];
60
-    for iImage = 1:length(extr);
61
-        tmp = extr(iImage);
62
-        rawdata = [rawdata tmp.dat(iVoxel)];
63
-    end
64
-    pst{iVoxel} = calculatePST(timeline,calculatePstOpts,rawdata);
65
-end
66
-
67
-%         for voxel = 1:size(voxelList,1)  % [[x;x],[y;y],[z;z]]
68
-%             extr        = calculateImageData(imageFiles,voxelList(voxel,:));
69
-%             rawdata     = cell2mat({extr.mean}); % Raw Data
70
-%             pst{voxel}  = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,rawdata,sessionList);
71
-%         end
72 21
 
73 22
 timePointArgs.pst = pst;
74 23
 
... ...
@@ -89,46 +38,9 @@ for index = 1:timeLineEnd-timeLineStart+1
89 38
         svmlabel  = svmlabel(rndindex);
90 39
     end
91 40
 
92
-    switch METHOD;
93
-        case CROSSVAL_METHOD_DEF.svmcrossval
94
-
95
-            performance  = svmtrain(svmlabel, svmdata, svmargs);
96
-
97
-            minPerformance = min(minPerformance,performance);
98
-            maxPerformance = max(maxPerformance,performance);
99
-
100
-            decodePerformance = [decodePerformance; performance];
101
-
102
-        case CROSSVAL_METHOD_DEF.classPerformance
103
-
104
-            newsvmopt = killCrossvalOpt(svmargs);
105
-
106
-            model = svmtrain(svmlabel,svmdata,newsvmopt);
107
-            classperformance = [];
108
-            for class = unique(svmlabel)';
109
-
110
-                filterindex = find(class == svmlabel);
111
-                testing_label = svmlabel(filterindex);
112
-                testing_data  = svmdata(filterindex);
113
-                [plabel accuracy dvalue] = svmpredict(testing_label,testing_data,model,'');
114
-
115
-                classperformance = [classperformance accuracy(1)];
116
-            end
117
-            decodePerformance = [decodePerformance; classperformance];
118
-
119
-        case CROSSVAL_METHOD_DEF.somTraining
120
-
121
-            display('SOM TRAINING');
122
-            addpath 'somtoolbox2';
123
-            sD = som_data_struct(svmdata,'label',num2str(svmlabel));
124
-            assignin('base','sD',sD);
125
-            sM = som_make(sD,'msize', [3 3],'lattice', 'hexa');
126
-
127
-            assignin('base','sD',sD);
128
-            assignin('base','sM',sM);
129
-            display('type ''figure'' before visualisation');
130
-    end
131
-
41
+    decodePerformance = [decodePerformance; svm_single_crossval(svmlabel,svmdata,svmopts)];
42
+    
43
+    
132 44
 end
133 45
 
134 46
 outputStruct.decodePerformance  = decodePerformance;
... ...
@@ -139,19 +51,5 @@ outputStruct.minPerformance     = minPerformance;
139 51
 outputStruct.maxPerformance     = maxPerformance;
140 52
 end
141 53
 
142
-function opts = killCrossvalOpt(svmopt)
143
-opts = '';
144
-idx1 = 1;
145
-for idx2=strfind(svmopt,' -')
146
-    if idx1 ~= strfind(svmopt,' -v')
147
-        opts = strcat(opts,svmopt(idx1:idx2));
148
-    end
149
-    idx1=idx2;
150
-    if idx2==max(strfind(svmopt,' -'))
151
-        opts = strcat(opts,svmopt(idx2:end));
152
-    end
153
-end
154
-end
155
-
156 54
 
157 55
 
158 56
new file mode 100644
... ...
@@ -0,0 +1,10 @@
1
+function svmargs = getSvmArgs(model,single_run)
2
+
3
+    svmargs = get(model.txtSVMopts,'String');
4
+
5
+if single_run
6
+    svmargs = [svmargs ' -v ' num2str(getSvmNFold(model))];
7
+end
8
+
9
+
10
+end
0 11
\ No newline at end of file
1 12
new file mode 100644
... ...
@@ -0,0 +1,3 @@
1
+function nfold = getSvmNFold(model)
2
+nfold = str2double(get(model.txtSVMnfold,'String'));
3
+end
0 4
\ No newline at end of file
... ...
@@ -38,13 +38,20 @@ switch task
38 38
         roiargs.sessionList = 1:3;
39 39
         roiargs.eventList   = classDef.eventMatrix;
40 40
         
41
-        
42
-        assignin('base','roiargs',roiargs);
43
-        
44 41
         runROIImageMaskMode(roiargs);
45 42
         
46 43
     case 'FBS'
47 44
         disp('FBS')
45
+        
46
+    case 'SVM'
47
+        disp('classify with svm');
48
+        
49
+    case 'X-SVM'
50
+        
51
+    case 'SOM'
52
+
53
+    case 'X-SOM'
54
+        
48 55
 end
49 56
 
50 57
 % disp('warings restored');
... ...
@@ -1,4 +1,7 @@
1 1
 function runCoordTable(args)
2
+
3
+global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
4
+
2 5
     disp('run coord table')
3 6
     
4 7
     subjects = args.subjects;
... ...
@@ -37,33 +40,5 @@ function runCoordTable(args)
37 40
        disp(sprintf('done %g // %g',s,nSubjects));
38 41
     end
39 42
     
40
-    assignin('base','subjectStruct',subjectStruct);
43
+    assignin('base',SVMCROSSVAL_SUBJECTSTRUCT_NAME,subjectStruct);
41 44
 end
42
-
43
-
44
-
45
-
46
-% 
47
-% %         decode = claculateMultiSubjectDecodePerformance(timelineParams,calculateParams,paramModel);
48
-% 
49
-%         display('Finished calculations.');
50
-%         display('Plotting...');
51
-% 
52
-%         plotParams                   = struct;
53
-%         
54
-% %         plotParams.SVMCROSSVAL_CROSSVAL_METHOD_DEF = SVMCROSSVAL_CROSSVAL_METHOD_DEF;
55
-%         plotParams.CROSSVAL_METHOD     = calculateParams.CROSSVAL_METHOD;
56
-%         
57
-%         plotParams.nClasses          = length(calculateParams.classList);
58
-% 
59
-%         plotParams.decodePerformance = decode.decodePerformance;
60
-%         plotParams.rawTimeCourse     = decode.rawTimeCourse;
61
-%         plotParams.SubjectID         = subjectSelection;
62
-%         plotParams.smoothed          = boolToYesNoString(calculateParams.smoothed);
63
-% 
64
-%         assignin('base','plotParams',plotParams);
65
-% %         plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse);
66
-%         plotDecodePerformance(timelineParams,plotParams);
67
-%             
68
-%         display('all done.');
69
-% 
... ...
@@ -1,5 +1,7 @@
1 1
 function runROIImageMaskMode(args)
2 2
 
3
+global SVMCROSSVAL_SUBJECTSTRUCT_NAME;
4
+
3 5
 subjects = args.subjects;
4 6
    
5 7
 nSubjects = size(subjects);
... ...
@@ -47,7 +49,7 @@ for s = 1:nSubjects
47 49
     disp('done');
48 50
 end
49 51
 
50
-assignin('base','subjectStruct',subjectStruct);
52
+assignin('base',SVMCROSSVAL_SUBJECTSTRUCT_NAME,subjectStruct);
51 53
 
52 54
 end
53 55
 
54 56
new file mode 100644
... ...
@@ -0,0 +1,16 @@
1
+function decodePerformance = svm_crossval(svmlabel,svmdata,svmopts)
2
+addpath 'libsvm-mat-2.88-1';
3
+
4
+svmmodel = svmtrain(svmlabel,svmdata,svmopts);
5
+classperformance = [];
6
+for class = unique(svmlabel)';
7
+
8
+    filterindex = find(class == svmlabel);
9
+    testing_label = svmlabel(filterindex);
10
+    testing_data  = svmdata(filterindex);
11
+    [plabel accuracy dvalue] = svmpredict(testing_label,testing_data,svmmodel,'');
12
+
13
+    classperformance = [classperformance accuracy(1)];
14
+end
15
+decodePerformance = [decodePerformance; classperformance];
16
+end
0 17
\ No newline at end of file
1 18
new file mode 100644
... ...
@@ -0,0 +1,4 @@
1
+function performance = svm_single_crossval(svmlabel,svmdata,svmopts)
2
+addpath 'libsvm-mat-2.88-1';
3
+performance  = svmtrain(svmlabel, svmdata, svmopts);
4
+end
0 5
\ No newline at end of file
1 6
new file mode 100644
... ...
@@ -0,0 +1,12 @@
1
+function [sD sM] = train_som(svmlabel, svmdata, somOptions)
2
+
3
+display('SOM TRAINING');
4
+addpath 'somtoolbox2';
5
+sD = som_data_struct(svmdata,'label',num2str(svmlabel));
6
+assignin('base','sD',sD);
7
+sM = som_make(sD,'msize', [3 3],'lattice', 'hexa');
8
+
9
+assignin('base','sD',sD);
10
+assignin('base','sM',sM);
11
+display('type ''figure'' before visualisation');
12
+end
0 13
\ No newline at end of file
... ...
@@ -91,7 +91,7 @@ function model = mcb_load(src,evnt,model)
91 91
 disp('LOAD');
92 92
 [file path] = uigetfile('*.mat','Load Params ...',model.baseDir);
93 93
 l = load(fullfile(path,file));
94
-assignin('base','l',l);
94
+% assignin('base','l',l);
95 95
 model = setTimeLineParams(model,l.timeLine);
96 96
 model = setClassDefString(model,l.classDefString);
97 97
 model = setCoordDefString(model,l.coordDefString);
... ...
@@ -319,12 +319,10 @@ function model = createFirstStepPanel(model,parent,DEFAULT)
319 319
             'Units','normalized','Position',[0.66 0 0.33 1]);
320 320
         set(btnRunButton3,'Callback',{@cbRunPreprocessing,model,'ROI'}); % set here, because of model.    
321 321
         set(btnRunButton3,'Enable','on');
322
-        
323
-        assignin('base','model',model);
324 322
 end
325 323
 
326 324
 function cbRunPreprocessing(src,evnt,model,task)
327
-main(model,task)
325
+main(model,task);
328 326
 end
329 327
 
330 328
 function label = createLabel(parent,  pos, labelText)