Christoph Budziszewski commited on 2009-01-19 18:28:59
Zeige 5 geänderte Dateien mit 100 Einfügungen und 43 Löschungen.
git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@111 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... | ... |
@@ -0,0 +1,36 @@ |
1 |
+function timePointMatrix = buildTimePointMatrix(argStruct) |
|
2 |
+ |
|
3 |
+pst = argStruct.pst; |
|
4 |
+ |
|
5 |
+timeLineStart = argStruct.timeLineStart; |
|
6 |
+timeLineEnd = argStruct.timeLineEnd; |
|
7 |
+globalStart = argStruct.globalStart; |
|
8 |
+globalEnd = argStruct.globalEnd; |
|
9 |
+decodeDuration = argStruct.decodeDuration; |
|
10 |
+eventList = argStruct.eventList; |
|
11 |
+ |
|
12 |
+labelMap = argStruct.labelMap; |
|
13 |
+ |
|
14 |
+timePointMatrix = {}; |
|
15 |
+%%% build timepoint Matrix |
|
16 |
+for timeShift = timeLineStart:1:timeLineEnd |
|
17 |
+ % center timepoint && relative shift |
|
18 |
+ frameStart = floor(-globalStart+1+timeShift - 0.5*decodeDuration); |
|
19 |
+ frameEnd = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd); |
|
20 |
+ |
|
21 |
+ %build svm inputmatrix |
|
22 |
+ index = timeShift-timeLineStart+1; %Bad 1-indexing :-( |
|
23 |
+ timePointMatrix{index} =[]; |
|
24 |
+ anyvoxel = 1; |
|
25 |
+ for pstConditionGroup = 1:size(pst{1,anyvoxel},2) |
|
26 |
+ for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point |
|
27 |
+ row = getSVMLabel(labelMap,eventList(pstConditionGroup,1)); |
|
28 |
+ for voxel = 1:size(pst,2) |
|
29 |
+ row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,[value,value,...],[value,value,...]... |
|
30 |
+ end |
|
31 |
+ timePointMatrix{index} = [timePointMatrix{index}; row]; |
|
32 |
+ end |
|
33 |
+ end |
|
34 |
+end |
|
35 |
+ |
|
36 |
+end |
|
0 | 37 |
\ No newline at end of file |
... | ... |
@@ -3,7 +3,9 @@ function outputStruct = calculateDecodePerformance(inputStruct,SubjectID) |
3 | 3 |
|
4 | 4 |
addpath 'libsvm-mat-2.88-1'; |
5 | 5 |
|
6 |
-SINGLE = 1; |
|
6 |
+METHOD = 'single subject SVM'; |
|
7 |
+% METHOD = 'cross subject SVM'; |
|
8 |
+% METHOD = 'SOM'; |
|
7 | 9 |
|
8 | 10 |
outputStruct = struct; |
9 | 11 |
|
... | ... |
@@ -23,64 +25,67 @@ globalEnd = inputStruct.psthEnd; |
23 | 25 |
baselineStart = inputStruct.baselineStart; |
24 | 26 |
baselineEnd = inputStruct.baselineEnd; |
25 | 27 |
eventList = inputStruct.eventList; |
26 |
-labelMap = inputStruct.labelMap; |
|
27 |
- |
|
28 | 28 |
|
29 | 29 |
minPerformance = inf; |
30 | 30 |
maxPerformance = -inf; |
31 | 31 |
|
32 |
- |
|
33 |
- |
|
34 |
- %Pro Voxel PSTH TIMELINE berechnen. |
|
35 |
- % timeshift mit pst-timeline durchf�hren. |
|
36 |
- % psth-timeline -25 bis +15 zu RES Onset. |
|
37 |
- |
|
38 |
-% eventList = [9,11,13;10,12,14]; |
|
39 |
-% globalStart = -25; |
|
40 |
-% globalEnd = 15; |
|
41 |
-% baselineStart = -22; |
|
42 |
-% baselineEnd = -20; |
|
43 |
- |
|
44 |
- |
|
32 |
+ %% ERSETZEN DURCH ROI-IMAGE! |
|
45 | 33 |
for voxel = 1:size(voxelList,1) % [[x;x],[y;y],[z;z]] |
46 | 34 |
extr = calculateImageData(voxelList(voxel,:),des,smoothed); |
47 | 35 |
rawdata = cell2mat({extr.mean}); % Raw Data |
48 | 36 |
pst{voxel} = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,rawdata,sessionList); |
49 | 37 |
end |
50 | 38 |
|
51 |
- decodePerformance = []; |
|
52 |
- |
|
53 |
- for timeShift = timeLineStart:1:timeLineEnd |
|
54 |
- frameStart = floor(-globalStart+1+timeShift - 0.5*decodeDuration); |
|
55 |
- frameEnd = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd); |
|
56 |
- |
|
57 |
- tmp =[]; |
|
58 |
- anyvoxel = 1; |
|
59 |
- for pstConditionGroup = 1:size(pst{1,anyvoxel},2) |
|
60 |
- for dp = 1:size(pst{1,anyvoxel}{1,pstConditionGroup},1) % data point |
|
61 |
- row = getSVMLabel(labelMap,eventList(pstConditionGroup,1)); |
|
62 |
- for voxel = 1:size(pst,2) |
|
63 |
- row = [row, pst{1,voxel}{1,pstConditionGroup}(dp,frameStart:frameEnd)]; % label,value,value |
|
64 |
- end |
|
65 |
- tmp = [tmp; row]; |
|
66 |
- end |
|
67 |
- end |
|
39 |
+ timePointArgs.pst = pst; |
|
40 |
+ timePointArgs.timeLineStart = timeLineStart; |
|
41 |
+ timePointArgs.timeLineEnd = timeLineEnd; |
|
42 |
+ timePointArgs.globalStart = globalStart; |
|
43 |
+ timePointArgs.globalEnd = globalEnd; |
|
44 |
+ timePointArgs.decodeDuration= decodeDuration; |
|
45 |
+ timePointArgs.labelMap = inputStruct.labelMap; |
|
46 |
+ timePointArgs.eventList = eventList; |
|
68 | 47 |
|
69 |
- svmdata = tmp(:,2:size(tmp,2)); |
|
70 |
- svmlabel = tmp(:,1); |
|
48 |
+ timePointMatrix = buildTimePointMatrix(timePointArgs); |
|
71 | 49 |
|
72 |
-% RANDOMIZE INPUT |
|
73 |
-% rndindex = randperm(length(svmlabel)); |
|
74 |
-% svmdata = svmdata(rndindex,:); |
|
75 |
-% svmlabel = svmlabel(rndindex); |
|
50 |
+ decodePerformance = []; |
|
51 |
+ for index = 1:timeLineEnd-timeLineStart+1 |
|
52 |
+ RANDOMIZE_DATAPOINTS = 0; |
|
53 |
+ svmdata = timePointMatrix{index}(:,2:size(timePointMatrix{index},2)); |
|
54 |
+ svmlabel = timePointMatrix{index}(:,1); |
|
55 |
+ |
|
56 |
+ if RANDOMIZE_DATAPOINTS |
|
57 |
+ rndindex = randperm(length(svmlabel)); |
|
58 |
+ svmdata = svmdata(rndindex,:); |
|
59 |
+ svmlabel = svmlabel(rndindex); |
|
60 |
+ end |
|
76 | 61 |
|
77 |
- if SINGLE |
|
62 |
+ SVM_METHOD = 2; |
|
63 |
+ switch SVM_METHOD; |
|
64 |
+ case 1 |
|
78 | 65 |
performance = svmtrain(svmlabel, svmdata, svmargs); |
79 | 66 |
|
80 | 67 |
minPerformance = min(minPerformance,performance); |
81 | 68 |
maxPerformance = max(maxPerformance,performance); |
82 | 69 |
|
83 | 70 |
decodePerformance = [decodePerformance; performance]; |
71 |
+ case 2 |
|
72 |
+ newsvmopt = killCrossvalOpt(svmargs); |
|
73 |
+ |
|
74 |
+ model = svmtrain(svmlabel,svmdata,newsvmopt); |
|
75 |
+ classperformance = []; |
|
76 |
+ for class = unique(svmlabel)'; |
|
77 |
+% assignin('base','uniquelabel',unique(svmlabel)); |
|
78 |
+% assignin('base','class',class); |
|
79 |
+% assignin('base','svmlabel',svmlabel); |
|
80 |
+ filterindex = find(class == svmlabel); |
|
81 |
+ testing_label = svmlabel(filterindex) |
|
82 |
+ testing_data = svmdata(filterindex) |
|
83 |
+ [plabel accuracy dvalue] = svmpredict(testing_label,testing_data,model,'') |
|
84 |
+% assignin('base','accuracy',accuracy); |
|
85 |
+ classperformance = [classperformance accuracy(1)]; |
|
86 |
+ end |
|
87 |
+ decodePerformance = [decodePerformance; classperformance]; |
|
88 |
+ |
|
84 | 89 |
end |
85 | 90 |
|
86 | 91 |
end |
... | ... |
@@ -93,3 +98,17 @@ maxPerformance = -inf; |
93 | 98 |
outputStruct.maxPerformance = maxPerformance; |
94 | 99 |
end |
95 | 100 |
|
101 |
+function opts = killCrossvalOpt(svmopt) |
|
102 |
+opts = ''; |
|
103 |
+idx1 = 1; |
|
104 |
+for idx2=strfind(svmopt,' -') |
|
105 |
+ if idx1 ~= strfind(svmopt,' -v') |
|
106 |
+ opts = strcat(opts,svmopt(idx1:idx2)); |
|
107 |
+ end |
|
108 |
+ idx1=idx2; |
|
109 |
+ if idx2==max(strfind(svmopt,' -')) |
|
110 |
+ opts = strcat(opts,svmopt(idx2:end)); |
|
111 |
+ end |
|
112 |
+end |
|
113 |
+end |
|
114 |
+ |
... | ... |
@@ -12,10 +10,9 @@ switch nargin |
12 | 10 |
error('spmtoolbox:SVMCrossVal:arginError','Please Specify action and parameter model'); |
13 | 11 |
end |
14 | 12 |
|
15 |
- |
|
16 | 13 |
% common params |
17 | 14 |
calculateParams = struct; |
18 |
- calculateParams.smoothed = getDouble(paramModel.txtSmoothed); |
|
15 |
+ calculateParams.smoothed = getChkValue(paramModel.chkSmoothed); |
|
19 | 16 |
|
20 | 17 |
calculateParams.frameShiftStart = getDouble(paramModel.txtFrameShiftStart); % -20; |
21 | 18 |
calculateParams.frameShiftEnd = getDouble(paramModel.txtFrameShiftEnd); %15; |
... | ... |
@@ -61,7 +61,7 @@ DEFAULT.svmoptstring = '-s 0 -t 0 -v 6 -c 1'; |
61 | 61 |
set(model.subjectSelector,'BackgroundColor','w'); |
62 | 62 |
|
63 | 63 |
createLabel(pSubject,[0.68*frameWidth firstRow*2 0.25*frameWidth controlElementHeight],'Smooth Data'); |
64 |
- model.txtSmoothed = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow 0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed); |
|
64 |
+ model.chkSmoothed = uicontrol(pSubject,'Style','checkbox','Position',[0.68*frameWidth firstRow 0.25*frameWidth controlElementHeight],'Value',DEFAULT.smoothed); |
|
65 | 65 |
|
66 | 66 |
createLabel(pSubject,[0.68*frameWidth firstRow*4 0.25*frameWidth controlElementHeight],'Crossvalidation'); |
67 | 67 |
model.txtMultisubject = createTextField(pSubject,[0.68*frameWidth firstRow*3 0.25*frameWidth controlElementHeight],DEFAULT.multisubject); |
68 | 68 |