Browse code

SOM Single run

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@174 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 07/04/2009 16:43:52
Showing 4 changed files
... ...
@@ -133,8 +133,12 @@ switch task
133 133
         assignin('base','decode',decode);
134 134
     case 'SOM'
135 135
         display('SOM');
136
-        disp('not implemented')
137
-
136
+        somopts.size = [3 3];
137
+        somopts.lattice = 'rect';
138
+        somopts.nfold = 6;
139
+        decode = som_subject_batch(header,data,somopts);
140
+        decode.header = header;
141
+        assignin('base','decode',decode);
138 142
     case 'XSOM'
139 143
         display('XSOM');
140 144
         somopts.size = [3 3];
141 145
deleted file mode 100644
... ...
@@ -1,80 +0,0 @@
1
-%% subject loop
2
-function decode = som_combined_subject_batch(header,subjectdata,somOpts)
3
-
4
-somOpts.size = [3 3];
5
-somOpts.lattice = 'rect';
6
-
7
-addpath 'somtoolbox2';
8
-
9
-% nSubjects = numel(subjectdata);
10
-% if(nSubjects < 2) 
11
-%     error('SVMCrossVal:xsvmSubjectLoop:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!');
12
-% end
13
-
14
-RANDOMIZE_DATAPOINTS = 1;
15
-
16
-decode = struct;
17
-decode.decodePerformance = [];
18
-decode.rawTimeCourse     = [];
19
-
20
-disp(sprintf('computinig additional datastructs for %u subjects',nSubjects));
21
-
22
-timeline = header.timeline;
23
-timeline.frameShiftStart = header.frameShift.frameShiftStart;
24
-timeline.frameShiftEnd   = header.frameShift.frameShiftEnd;
25
-timeline.decodeDuration  = header.frameShift.decodeDuration;
26
-
27
-% TimePointMatrix
28
-for subjectDataID = 1:nSubjects
29
-    currentSubject = subjectdata{subjectDataID};
30
-    timePointArgs.pst           = currentSubject.pst;
31
-    timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
32
-    timePointArgs.eventList     = header.classDef.eventMatrix;
33
-
34
-    timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
35
-    
36
-    decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
37
-end
38
-
39
-% timeframe x-subject validation
40
-timeLineStart   = timeline.frameShiftStart;
41
-timeLineEnd     = timeline.frameShiftEnd;
42
-
43
-display(sprintf('%u -fold cross validation for %u timeslices.\n',nSubjects,size(1:timeLineEnd-timeLineStart+1,2)));
44
-% disp(sprintf('Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.'));
45
-% pause
46
-
47
-for timeIndex = 1:timeLineEnd-timeLineStart+1
48
-    cross_value = [];
49
-    for validationSubjectID = 1:nSubjects
50
-        svm_train_label = [];
51
-        svm_train_data  = [];
52
-        svm_validation_label = [];
53
-        svm_validation_data  = [];
54
-        for subjectDataID = 1:nSubjects
55
-            svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
56
-            if subjectDataID == validationSubjectID
57
-                svm_validation_label = svmstruct.svmlabel;
58
-                svm_validation_data  = svmstruct.svmdata;
59
-            else
60
-                svm_train_label = [svm_train_label; svmstruct.svmlabel];
61
-                svm_train_data  = [svm_train_data;  svmstruct.svmdata];
62
-            end
63
-        end
64
-        
65
-        if RANDOMIZE_DATAPOINTS
66
-            rndindex  = randperm(length(svm_train_label));
67
-            svm_train_data   = svm_train_data(rndindex,:);
68
-            svm_train_label  = svm_train_label(rndindex);
69
-        end
70
-
71
-        [sD sM] = train_som(svm_train_label, svm_train_data, somOpts);
72
-
73
-        performance = som_decode(sD, sM, svm_validation_data,svm_validation_label);
74
-        
75
-        cross_value = [cross_value performance];
76
-    decode.decodePerformance = [decode.decodePerformance; cross_value];
77
-end
78
-   display('decode done'); 
79
-end
80
-
81 0
new file mode 100644
... ...
@@ -0,0 +1,76 @@
1
+%% subject loop
2
+function decode = som_subject_batch(header,subjectdata,somOpts)
3
+
4
+addpath 'somtoolbox2';
5
+
6
+RANDOMIZE_DATAPOINTS = 1;
7
+
8
+decode = struct;
9
+decode.decodePerformance = [];
10
+decode.rawTimeCourse     = [];
11
+
12
+nSubjects = numel(subjectdata);
13
+disp(sprintf('batch processing %u subjects',nSubjects));
14
+
15
+timeline = header.timeline;
16
+timeline.frameShiftStart = header.frameShift.frameShiftStart;
17
+timeline.frameShiftEnd   = header.frameShift.frameShiftEnd;
18
+timeline.decodeDuration  = header.frameShift.decodeDuration;
19
+
20
+
21
+timeLineStart   = timeline.frameShiftStart;
22
+timeLineEnd     = timeline.frameShiftEnd;
23
+
24
+% TimePointMatrix
25
+for subjectDataID = 1:nSubjects
26
+    currentSubject = subjectdata{subjectDataID};
27
+    timePointArgs.pst           = currentSubject.pst;
28
+    timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
29
+    timePointArgs.eventList     = header.classDef.eventMatrix;
30
+
31
+    timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
32
+
33
+    decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
34
+
35
+
36
+    display(sprintf('%u -fold cross validation for %u timeslices.\n',somOpts.nfold,size(1:timeLineEnd-timeLineStart+1,2)));
37
+
38
+
39
+    for timeIndex = 1:timeLineEnd-timeLineStart+1
40
+        svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
41
+        nElements = length(svmstruct.svmlabel);
42
+
43
+        if RANDOMIZE_DATAPOINTS
44
+            rndindex  = randperm(nElements);
45
+            svmstruct.svmdata   = svmstruct.svmdata(rndindex,:);
46
+            svmstruct.svmlabel  = svmstruct.svmlabel(rndindex);
47
+        end
48
+
49
+        chunksize = nElements / somOpts.nfold;
50
+        
51
+        cross_value = 0;
52
+        for iFold = 1:somOpts.nfold
53
+            chunkstart = (iFold-1)*chunksize+1;
54
+            chunkend   = min(iFold*chunksize,nElements);
55
+
56
+            svm_train_label = svmstruct.svmlabel;
57
+            svm_train_data  = svmstruct.svmdata;
58
+
59
+            svm_validation_label = svmstruct.svmlabel(chunkstart:chunkend);
60
+            svm_validation_data  = svmstruct.svmdata(chunkstart:chunkend,:);
61
+
62
+            svm_train_label(chunkstart:chunkend) = []; %del test set
63
+            svm_train_data(chunkstart:chunkend,:) = [];% del test set
64
+
65
+            [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
66
+
67
+            performance = som_decode(sM, svm_validation_data,svm_validation_label);
68
+
69
+            cross_value = [cross_value, performance];
70
+        end
71
+        decode.decodePerformance = [decode.decodePerformance; cross_value];
72
+    end
73
+end
74
+display('decode done');
75
+end
76
+
... ...
@@ -346,7 +346,7 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.4 0.5 0.4]);
346 346
     btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
347 347
         'Units','normalized',...
348 348
         'Position',[0.0 0.25 1 0.25]);
349
-    set(btnRunSOM,'Enable','off');
349
+    set(btnRunSOM,'Enable','on');
350 350
 
351 351
     btnRunXSOM = uicontrol(pSOM,'String','run SOM X-Subject validation',...
352 352
         'Units','normalized',...