Browse code

begin SOM implementation

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@159 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on 16/03/2009 21:53:24
Showing 4 changed files
... ...
@@ -83,6 +83,11 @@ end
83 83
 function decode(model,task)
84 84
 preprocessedData = evalin('base','preprocessedData');
85 85
 
86
+if(~(isa(preprocessedData,'struct')))
87
+    display('you need to preprocess some data before this step');
88
+    return
89
+end
90
+
86 91
 header            = preprocessedData.header;
87 92
 header.frameShift = getFrameShiftParams(model);
88 93
 data              = preprocessedData.subjectdata;
... ...
@@ -102,8 +107,11 @@ switch task
102 107
         decode.header = header;
103 108
         assignin('base','decode',decode);
104 109
     case 'SOM'
105
-        disp('not implemented')
106
-
110
+        display('SOM');
111
+        somopts = '';
112
+        decode = som_combined_subject_batch(header,data,somopts);
113
+        decode.header = header;
114
+        assignin('base','decode',decode);
107 115
     case 'X-SOM'
108 116
         disp('not implemented')
109 117
         
... ...
@@ -18,7 +18,7 @@ nSubjects         = size(SubjectID,2);
18 18
 
19 19
 
20 20
     f = figure;
21
-    subplot(2,1,1);
21
+    subplot(2,1,1); 
22 22
         plotPSTH(psth,psthStart,psthEnd);
23 23
     
24 24
     % plot performance timeline
25 25
new file mode 100644
... ...
@@ -0,0 +1,74 @@
1
+%% subject loop
2
+function decode = som_combined_subject_batch(header,subjectdata,somOpts)
3
+
4
+addpath 'somtoolbox2';
5
+
6
+nSubjects = numel(subjectdata);
7
+if(nSubjects < 2) 
8
+    error('SVMCrossVal:xsvmSubjectLoop:tooFewSubjects','You need at least 2 Subjects in this Across-Subject analysis!');
9
+end
10
+
11
+RANDOMIZE_DATAPOINTS = 1;
12
+
13
+decode = struct;
14
+decode.decodePerformance = [];
15
+decode.rawTimeCourse     = [];
16
+
17
+disp(sprintf('computinig additional datastructs for %u subjects',nSubjects));
18
+
19
+timeline = header.timeline;
20
+timeline.frameShiftStart = header.frameShift.frameShiftStart;
21
+timeline.frameShiftEnd   = header.frameShift.frameShiftEnd;
22
+timeline.decodeDuration  = header.frameShift.decodeDuration;
23
+
24
+% TimePointMatrix
25
+for subjectDataID = 1:nSubjects
26
+    currentSubject = subjectdata{subjectDataID};
27
+    timePointArgs.pst           = currentSubject.pst;
28
+    timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
29
+    timePointArgs.eventList     = header.classDef.eventMatrix;
30
+
31
+    timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
32
+    
33
+    decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
34
+end
35
+
36
+% timeframe x-subject validation
37
+timeLineStart   = timeline.frameShiftStart;
38
+timeLineEnd     = timeline.frameShiftEnd;
39
+
40
+display(sprintf('%u -fold cross validation for %u timeslices.\n',nSubjects,size(1:timeLineEnd-timeLineStart+1,2)));
41
+% disp(sprintf('Press ANY-Key to continue.\n Use Retrun if your Keyboard lacks the ANY-Key.'));
42
+% pause
43
+
44
+for timeIndex = 1:timeLineEnd-timeLineStart+1
45
+        svm_train_label = [];
46
+        svm_train_data  = [];
47
+        svm_validation_label = [];
48
+        svm_validation_data  = [];
49
+        for subjectDataID = 1:nSubjects
50
+            svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
51
+            svm_validation_label = svmstruct.svmlabel;
52
+            svm_validation_data  = svmstruct.svmdata;
53
+            svm_train_label = [svm_train_label; svmstruct.svmlabel];
54
+            svm_train_data  = [svm_train_data;  svmstruct.svmdata];
55
+        end
56
+        
57
+        if RANDOMIZE_DATAPOINTS
58
+            rndindex  = randperm(length(svm_train_label));
59
+            svm_train_data   = svm_train_data(rndindex,:);
60
+            svm_train_label  = svm_train_label(rndindex);
61
+        end
62
+
63
+        [sD sM] = train_som(svmlabel, svmdata, somOpts)
64
+        
65
+%         [plabel accuracy dvalue] = svmpredict(svm_validation_label,svm_validation_data,svmmodel,'');
66
+        cross_value = [cross_value accuracy(1)];
67
+        
68
+    end
69
+    decode.decodePerformance = [decode.decodePerformance; cross_value];
70
+    
71
+end
72
+
73
+disp('decode done');
74
+end
... ...
@@ -319,11 +319,20 @@ pSVM = uipanel(parent,'Units','normalized','Position',[0 0.4 0.5 0.4]);
319 319
     model.txtSVMnfold = createTextField(pSVM,[0.0 0.66 0.5 0.16],DEFAULT.svmnfold);
320 320
     createLabel(pSVM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
321 321
     
322
-    model.chkSVMrnd = uicontrol(pSVM,'Style','checkbox','Units','normalized','Position',[0.1 0.50 1 0.16]);
322
+    model.chkSVMrnd = uicontrol(pSVM,'Style','checkbox','Units','normalized','Position',[0.1 0.50 0.9 0.16]);
323 323
     set(model.chkSVMrnd,'String','Randomize Datapoints');
324 324
     set(model.chkSVMrnd,'BackgroundColor','w');
325 325
     set(model.chkSVMrnd,'Value',DEFAULT.svmrnd);
326 326
     
327
+    btnRunSVM = uicontrol(pSVM,'String','run batchmode SVM Crossvalidation',...
328
+        'Units','normalized',...
329
+        'Position',[0 0.25 1 0.25]);
330
+    set(btnRunSVM,'Enable','on');
331
+    
332
+    btnRunXSVM = uicontrol(pSVM,'String','run SVM X-Subject validation',...
333
+        'Units','normalized',...
334
+        'Position',[0 0.0 1 0.25]);
335
+    set(btnRunXSVM,'Enable','on');
327 336
     
328 337
 pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.4 0.5 0.4]);
329 338
     set(pSOM,'Title','SOM Classification');
... ...
@@ -331,36 +340,27 @@ pSOM = uipanel(parent,'Units','normalized','Position',[0.5 0.4 0.5 0.4]);
331 340
 
332 341
     model.txtSOMopts = createTextField(pSOM,[0 0.75 1 0.25],'4x3 rect');
333 342
     set(model.txtSOMopts,'HorizontalAlignment','left');
334
-        set(model.txtSOMopts,'Enable','off');
343
+    set(model.txtSOMopts,'Enable','off');
335 344
 
336 345
     model.txtSOMnfold = createTextField(pSOM,[0.0 0.50 0.5 0.25],DEFAULT.svmnfold);
337
-        set(model.txtSOMnfold,'Enable','off');
346
+    set(model.txtSOMnfold,'Enable','off');
338 347
     createLabel(pSOM,[0.5 0.50 0.5 0.25 ],'-Fold CrossVal');
339 348
 
340
-% buttons
341
-    btnRunSVM = uicontrol(pSVM,'String','run SVM Crossvalidation',...
342
-        'Units','normalized',...
343
-        'Position',[0 0.25 1 0.25]);
344
-    set(btnRunSVM,'Callback',{@cbRunDecode,model,'SVM'}); % set here, because of model.
345
-    set(btnRunSVM,'Enable','on');
346
-    
347
-    btnRunXSVM = uicontrol(pSVM,'String','run SVM X-Subject validation',...
348
-        'Units','normalized',...
349
-        'Position',[0 0.0 1 0.25]);
350
-    set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); % set here, because of model.
351
-    set(btnRunXSVM,'Enable','on');
352
-    
353 349
     btnRunSOM = uicontrol(pSOM,'String','run SOM Crossvalidation',...
354 350
         'Units','normalized',...
355
-    'Position',[0.0 0.25 1 0.25]);
356
-    set(btnRunSOM,'Callback',{@cbRunDecode,model,'SOM'}); % set here, because of model.
357
-    set(btnRunSOM,'Enable','off');
351
+        'Position',[0.0 0.25 1 0.25]);
352
+    set(btnRunSOM,'Enable','on');
358 353
 
359 354
     btnRunXSOM = uicontrol(pSOM,'String','run SOM X-Subject validation',...
360 355
         'Units','normalized',...
361 356
         'Position',[0.0 0.0 1 0.25]);
362
-    set(btnRunXSOM,'Callback',{@cbRunDecode,model,'XSOM'}); % set here, because of model.
363 357
     set(btnRunXSOM,'Enable','off');
358
+
359
+% button callbacks set here, because of model.
360
+    set(btnRunSVM, 'Callback',{@cbRunDecode,model,'SVM'}); 
361
+    set(btnRunXSVM,'Callback',{@cbRunDecode,model,'XSVM'}); 
362
+    set(btnRunSOM, 'Callback',{@cbRunDecode,model,'SOM'});
363
+    set(btnRunXSOM,'Callback',{@cbRunDecode,model,'XSOM'}); 
364 364
 end
365 365
 
366 366
 function model = createVisualStepPanel(model,parent,DEFAULT)