transport changes
Christoph Budziszewski

Christoph Budziszewski commited on 2009-01-25 22:54:27
Zeige 4 geänderte Dateien mit 26 Einfügungen und 214 Löschungen.


git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@114 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... ...
@@ -1,100 +0,0 @@
1
-function classify(varargin)
2
-
3
-
4
-
5
-switch nargin
6
-    case 1
7
-        paramModel = varargin{1};
8
-        % PROJECT_BASE_PATH = 'D:\Analyze\Stimolos';
9
-        PROJECT_BASE_PATH = 'D:\Analyze\Choice\24pilot';
10
-        PROJECT_RESULT_PATH = 'results\SPM.mat';
11
-    otherwise
12
-        error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model');
13
-end
14
-
15
-        
16
-        % common params
17
-        calculateParams  = struct;
18
-        calculateParams.smoothed        = getDouble(paramModel.txtSmoothed);
19
-
20
-        calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart);  % -20;
21
-        calculateParams.frameShiftEnd   = getDouble(paramModel.txtFrameShiftEnd); %15;
22
-        calculateParams.decodeDuration  = getDouble(paramModel.txtFrameShiftDur);
23
-        calculateParams.psthStart       = getDouble(paramModel.txtPSTHStart); % -25;
24
-        calculateParams.psthEnd         = getDouble(paramModel.txtPSTHEnd); % 20;
25
-        calculateParams.baselineStart   = getDouble(paramModel.txtBaselineStart); % -22;
26
-        calculateParams.baselineEnd     = getDouble(paramModel.txtBaselineEnd); % -20;
27
-
28
-        calculateParams.svmargs         = get(paramModel.txtSVMopts,'String');
29
-        calculateParams.sessionList     = 1:3;
30
-
31
-        classStruct = parseClassDef(paramModel);
32
-        
33
-        
34
-        calculateParams.labelMap        = LabelMap(classStruct.labelCells , classStruct.conditionCells, 'auto'); % LabelMap({'<','>','<+<','>+>','<+>','>+<'},{-2,-1,1,2,3,4}); 0 is autolabel
35
-        calculateParams.classList       = getClasses(calculateParams.labelMap);
36
-        calculateParams.eventList       = classStruct.eventMatrix; %[9,11,13; 10,12,14];
37
-%         calculateParams.eventList       = getPSTEventMatrix(calculateParams.labelMap);
38
-        
39
-        subjectSelection = getSubjectIDString(paramModel);
40
-        decode = struct;
41
-        decode.decodePerformance = [];
42
-        decode.rawTimeCourse     = [];
43
-        
44
-        for subjectCell = subjectSelection
45
-            SubjectID = cell2mat(subjectCell);
46
-            namehelper = strcat('s',SubjectID); %Vars can not start with numbers.
47
-
48
-            display('loading SPM.mat ...');
49
-            spm = load(fullfile(PROJECT_BASE_PATH,SubjectID,PROJECT_RESULT_PATH));
50
-            display('... done.');
51
-
52
-            %% calculate
53
-            calculateParams.(namehelper).des             = spm.SPM;
54
-            calculateParams.(namehelper).voxelList       = parseVoxelList(paramModel,SubjectID);
55
-            assignin('base','calculateParams',calculateParams);
56
-
57
-            display(sprintf('calculating cross-validation performance time-shift for Subject %s. Please Wait. ...',SubjectID));
58
-            display('switching off all warnings');
59
-            warning_state               = warning('off','all');
60
-            display('calculating ...');
61
-            decode.(namehelper)         = calculateDecodePerformance(calculateParams,SubjectID);
62
-            display('... done');
63
-            display('restoring warnings');
64
-            warning(warning_state);
65
-            
66
-            decode.decodePerformance    = [decode.decodePerformance decode.(namehelper).decodePerformance];
67
-            decode.rawTimeCourse        = [decode.rawTimeCourse decode.(namehelper).rawTimeCourse];
68
-
69
-            assignin('base','decode',decode);
70
-        end
71
-
72
-        display('Finished calculations.');
73
-        display('Plotting...');
74
-
75
-        plotParams                  = struct;
76
-        plotParams.psthStart        = calculateParams.psthStart;
77
-        plotParams.psthEnd   =  calculateParams.psthEnd;
78
-        plotParams.nClasses  = length(calculateParams.classList);
79
-        
80
-        plotParams.frameShiftStart   = calculateParams.frameShiftStart;
81
-        plotParams.frameShiftEnd     = calculateParams.frameShiftEnd;
82
-        plotParams.decodePerformance = decode.decodePerformance;
83
-        plotParams.rawTimeCourse     = decode.rawTimeCourse;
84
-        
85
-        if numel(subjectSelection) == 1
86
-          plotParams.SubjectID         = SubjectID;
87
-        else
88
-          plotParams.SubjectID         = 'Multiple';
89
-        end
90
-
91
-        plotParams.smoothed          = boolToYesNoString(calculateParams.smoothed);
92
-         
93
-
94
-        assignin('base','plotParams',plotParams);
95
-%         plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse);
96
-        plotDecodePerformance(plotParams);
97
-            
98
-        display('all done.');
99
-
100
-    end
101 0
\ No newline at end of file
... ...
@@ -1,103 +0,0 @@
1
-function plotDecodePerformance(varargin)
2
-% plotDecodePerformance(timeline,decodePerformance,nClasses,rawData)
3
-
4
-PSTH_AXIS_MIN = -1;
5
-PSTH_AXIS_MAX = 1;
6
-
7
-switch nargin
8
-    
9
-    case 1
10
-        inputStruct       = cell2mat(varargin(1));
11
-
12
-        psthStart         = inputStruct.psthStart;
13
-        psthEnd           = inputStruct.psthEnd;
14
-        nClasses          = inputStruct.nClasses;
15
-        decodePerformance = inputStruct.decodePerformance;
16
-        frameStart        = inputStruct.frameShiftStart;
17
-        frameEnd          = inputStruct.frameShiftEnd;
18
-        psth              = inputStruct.rawTimeCourse;
19
-        SubjectID         = inputStruct.SubjectID;
20
-        smoothed          = inputStruct.smoothed;
21
-
22
-    otherwise
23
-        error('spmtoolbox:SVMCrossVal:plotDecodePerformance:WrongArgument','Wrong Arguments');
24
-end
25
-
26
-    f = figure;
27
-    subplot(2,1,1);
28
-    hold on;
29
-      for voxel = 1:size(psth,2)
30
-          for label = 1:size(psth{voxel},2)
31
-              psthData = [];
32
-              for timepoint = 1:size(psth{voxel}{label},2)
33
-                  psthData = nanmean(psth{voxel}{label});
34
-              end
35
-              plot(psthStart:psthEnd,psthData,[colorChooser(voxel), lineStyleChooser(label)]);
36
-          end
37
-      end
38
-    axis([psthStart psthEnd PSTH_AXIS_MIN PSTH_AXIS_MAX])
39
-    hold off
40
-    
41
-    subplot(2,1,2)    
42
-    hold on;
43
-    
44
-    chanceLevel = 100/nClasses;
45
-    goodPredictionLevel = chanceLevel*1.5;
46
-    plot([psthStart psthEnd],[chanceLevel chanceLevel],'r');
47
-    plot([psthStart psthEnd],[goodPredictionLevel goodPredictionLevel],'g');
48
-    axis([psthStart psthEnd 0 100])
49
-    
50
-    plot(frameStart:frameEnd, mean(decodePerformance,2) ,'b');
51
-    PLOT_STD_ERR = 1;
52
-    PLOT_CLASS_PERFORMANCE = 1;
53
-    if PLOT_STD_ERR 
54
-        se = myStdErr(decodePerformance,2);
55
-        plot(frameStart:frameEnd, mean(decodePerformance,2)+se ,'b:');
56
-        plot(frameStart:frameEnd, mean(decodePerformance,2)-se ,'b:');
57
-    end
58
-    if PLOT_CLASS_PERFORMANCE
59
-        for c = 1:nClasses
60
-            plot(frameStart:frameEnd, decodePerformance() ,[colorChooser(c+2) '-']);
61
-        end
62
-    end
63
-    
64
-    
65
-    hold off;
66
-
67
-    title = sprintf('Subject %s, over %g voxel, smoothed %s',SubjectID,size(psth,2),smoothed);
68
-    set(f,'Name',title);
69
-    display(sprintf('%s',title));
70
-
71
-
72
-
73
-end
74
-
75
-
76
-function color = colorChooser(n)
77
-    switch (mod(n,8))
78
-    case 0
79
-        color = 'y';
80
-    case 1
81
-        color = 'r';
82
-    case 2
83
-        color = 'b';
84
-    case 3
85
-        color = 'g';
86
-    otherwise
87
-        color = 'k';
88
-    end
89
-end
90
-
91
-function style = lineStyleChooser(n)
92
-switch(mod(n,4))
93
-    case 0
94
-      style = '--';
95
-    case 1
96
-        style = '-';
97
-    case 2 
98
-        style = ':';
99
-    case 3
100
-        style = '-.';
101
-end
102
-end
103
-
... ...
@@ -3,9 +3,10 @@ function outputStruct = calculateDecodePerformance(inputStruct,SubjectID)
3 3
 
4 4
 addpath 'libsvm-mat-2.88-1';
5 5
 
6
-METHOD = 'single subject SVM';
7
-% METHOD = 'cross subject SVM';
8
-% METHOD = 'SOM';
6
+METHOD_DEF = inputStruct.CROSSVAL_METHOD_DEF;
7
+METHOD     = inputStruct.CROSSVAL_METHOD;
8
+
9
+RANDOMIZE_DATAPOINTS = inputStruct.RANDOMIZE;
9 10
 
10 11
 outputStruct = struct;
11 12
 
... ...
@@ -49,7 +50,7 @@ maxPerformance = -inf;
49 50
 
50 51
         decodePerformance = [];
51 52
         for index = 1:timeLineEnd-timeLineStart+1
52
-            RANDOMIZE_DATAPOINTS = 0;
53
+
53 54
             svmdata      = timePointMatrix{index}(:,2:size(timePointMatrix{index},2));
54 55
             svmlabel     = timePointMatrix{index}(:,1);
55 56
             
... ...
@@ -59,16 +60,18 @@ maxPerformance = -inf;
59 60
                 svmlabel  = svmlabel(rndindex);
60 61
             end
61 62
 
62
-            SVM_METHOD = 'som training'
63
-            switch SVM_METHOD;
64
-                case 'libsvm crossval'
63
+            switch METHOD;
64
+                case CROSSVAL_METHOD_DEF.svmcrossval
65
+                    
65 66
                     performance  = svmtrain(svmlabel, svmdata, svmargs);
66 67
 
67 68
                     minPerformance = min(minPerformance,performance);
68 69
                     maxPerformance = max(maxPerformance,performance);
69 70
 
70 71
                     decodePerformance = [decodePerformance; performance];
71
-                case 'class performance'
72
+                    
73
+                case CROSSVAL_METHOD_DEF.classPerformance
74
+                    
72 75
                     newsvmopt = killCrossvalOpt(svmargs);
73 76
                     
74 77
                     model = svmtrain(svmlabel,svmdata,newsvmopt);
... ...
@@ -84,7 +87,8 @@ maxPerformance = -inf;
84 87
                     end
85 88
                     decodePerformance = [decodePerformance; classperformance];
86 89
                     
87
-                case 'som training'
90
+                case CROSSVAL_METHOD_DEF.somTraining
91
+                    
88 92
                     display('SOM TRAINING');
89 93
                     addpath 'somtoolbox2';
90 94
                     sD = som_data_struct(svmdata,'label',num2str(svmlabel));
... ...
@@ -93,7 +97,7 @@ maxPerformance = -inf;
93 97
                     
94 98
                     assignin('base','sD',sD);
95 99
                     assignin('base','sM',sM);
96
-                    
100
+                    display('type ''figure'' before visualisation');
97 101
             end
98 102
             
99 103
         end
... ...
@@ -1,5 +1,10 @@
1 1
 function classify(varargin)
2 2
 
3
+CROSSVAL_METHOD_DEF.svmcrossval       = 'svm crossval';
4
+CROSSVAL_METHOD_DEF.classPerformance  = 'svm class performance';
5
+CROSSVAL_METHOD_DEF.crossSubject      = 'svm cross subject testing';
6
+CROSSVAL_METHOD_DEF.somTraining       = 'som Training';
7
+
3 8
 switch nargin
4 9
     case 1
5 10
         paramModel = varargin{1};
... ...
@@ -12,6 +17,12 @@ end
12 17
        
13 18
         % common params
14 19
         calculateParams  = struct;
20
+        
21
+        calculateParams.CROSSVAL_METHOD_DEF = CROSSVAL_METHOD_DEF;
22
+        calculateParams.CROSSVAL_METHOD     = CROSSVAL_METHOD_DEF.svmcrossval;
23
+        
24
+        calculateParams.RANDOMIZE       = 0;
25
+        
15 26
         calculateParams.smoothed        = getChkValue(paramModel.chkSmoothed);
16 27
 
17 28
         calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart);  % -20;
... ...
@@ -46,7 +57,7 @@ end
46 57
             spm = load(fullfile(PROJECT_BASE_PATH,SubjectID,PROJECT_RESULT_PATH));
47 58
             display('... done.');
48 59
 
49
-            %% calculate
60
+            % calculate
50 61
             calculateParams.(namehelper).des             = spm.SPM;
51 62
             calculateParams.(namehelper).voxelList       = parseVoxelList(paramModel,SubjectID);
52 63
             assignin('base','calculateParams',calculateParams);
53 64