SVMCrossVal toolbox init
Christoph Budziszewski

Christoph Budziszewski commited on 2008-12-17 13:45:29
Zeige 43 geänderte Dateien mit 5861 Einfügungen und 0 Löschungen.


git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@90 83ab2cfd-5345-466c-8aeb-2b2739fb922d
... ...
@@ -0,0 +1,26 @@
1
+function map = LabelMap(label,value)
2
+%LabelMap(labelCellList,valueCellList) maps Label to Classvalues suitable for
3
+%SVM
4
+
5
+if nargin == 2
6
+    if ~ (iscell(label) && iscell(value))
7
+        error('LabelMap:Constructor:argsNoCell','Arguments have to be CellArrays. Vectors not yet supported. sorry.');
8
+    end   
9
+    if(any(size(label) ~= size(value)))
10
+        error('LabelMap:Constructor:sizeDontMatch','Label List and Value List must be the same size!');
11
+    end
12
+    
13
+    map.labelToValue = java.util.HashMap;
14
+    map.valueToLabel = java.util.HashMap;
15
+    
16
+    for i = 1:max(size(label)) % cell array is 1:x or x:1, indexing is same
17
+       map.labelToValue.put(label{i},value{i}); 
18
+       map.valueToLabel.put(value{i},label{i});
19
+    end
20
+    
21
+    map = class(map,'LabelMap');
22
+else
23
+    error('LabelMap:Constructor:noArgs','Sorry, default constructor not supported yet!');
24
+end
25
+
26
+end
0 27
\ No newline at end of file
... ...
@@ -0,0 +1,6 @@
1
+function label = getLabel(mapping,classValue)
2
+if mapping.valueToLabel.containsKey(classValue)
3
+    label = mapping.valueToLabel.get(classValue);   
4
+else
5
+     error('LabelMap:getLabel:noSuchValue','this Mapping does not contain a Value %d',classValue);
6
+end
0 7
\ No newline at end of file
... ...
@@ -0,0 +1,6 @@
1
+function value = getValue(mapping,classLabel)
2
+if mapping.labelToValue.containsKey(classLabel)
3
+    value = mapping.labelToValue.get(classLabel);   
4
+else
5
+     error('LabelMap:getValue:noSuchLabel','this Mapping does not contain a Label ''%s''',classLabel);
6
+end
0 7
\ No newline at end of file
... ...
@@ -0,0 +1,147 @@
1
+function m = SubjectRoiMapping(argv)
2
+%SUBJECTROIMAPPING Subject to ROI to Coordinate Mapping Class Constructor
3
+%   m = SUBJECTROIMAPPING() creates a predefined ROI Coordinate Mapping.
4
+%   normally called without any arguments
5
+
6
+if nargin == 0
7
+    m.subject{1} ='AI020';
8
+    m.subject{2} ='BD001';
9
+    m.subject{3} ='HG027';
10
+    m.subject{4} ='IK011';
11
+    m.subject{5} ='JZ006'; % Guter Proband
12
+    m.subject{6} ='LB001';
13
+    m.subject{7} ='SW007';
14
+    m.subject{8} ='VW005';
15
+
16
+    m.subjectNameMap = java.util.HashMap;
17
+    for subj = 1:size(m.subject,2)
18
+        m.subjectNameMap.put(m.subject{subj},subj);
19
+    end    
20
+    
21
+    
22
+    
23
+    m.roi_name{1}  ='SPL l'; % <-Parietalkortex links 
24
+    m.roi_name{2}  ='SPL r'; % <-Parietalkortex rechts 
25
+    m.roi_name{3}  ='PMd l'; 
26
+    m.roi_name{4}  ='PMd r'; 
27
+    m.roi_name{5}  ='IPSa l'; 
28
+    m.roi_name{6}  ='IPSa r'; 
29
+    m.roi_name{7}  ='SMA'; 
30
+    m.roi_name{8}  ='DLPFC'; 
31
+    m.roi_name{9}  ='V1 l'; 
32
+    m.roi_name{10} ='V1 r'; 
33
+    m.roi_name{11} ='M1 l'; % <-Motorischer Cortex l 
34
+    m.roi_name{12} ='M1 r'; % <-Motorischer Cortex r 
35
+    
36
+    m.roiNameMap = java.util.HashMap;
37
+    for roi = 1:size(m.roi_name,2)
38
+        m.roiNameMap.put(m.roi_name{roi},roi);
39
+    end
40
+
41
+    
42
+    % Koordinaten aller Probanden A von den ROIS B: rois{A}(B,[x y z in mm]) 
43
+    m.coordinate{1}(1,:)  = [-18, -78, 53];
44
+    m.coordinate{1}(2,:)  = [12, -69, 46];
45
+    m.coordinate{1}(3,:)  = [-21, -12, 49];
46
+    m.coordinate{1}(4,:)  = [30, -12, 53];
47
+    m.coordinate{1}(5,:)  = [-30, -51, 39];
48
+    m.coordinate{1}(6,:)  = [ 33, -60, 49];
49
+    m.coordinate{1}(7,:)  = [ -9, 6, 46];
50
+    m.coordinate{1}(8,:)  = [-27 27 48];
51
+    m.coordinate{1}(9,:)  = [-6, -90, -7];
52
+    m.coordinate{1}(10,:) = [12, -90, -4];
53
+    m.coordinate{1}(11,:) = [-57, -24, 49];
54
+    m.coordinate{1}(12,:) = [42, -24, 60];
55
+    m.coordinate{2}(1,:)  = [-9, -72, 56]; 
56
+    m.coordinate{2}(2,:)  = [15, -72, 60]; 
57
+    m.coordinate{2}(3,:)  = [-30, -9, 53]; 
58
+    m.coordinate{2}(4,:)  = [ 30, -9, 49]; 
59
+    m.coordinate{2}(5,:)  = [-42 -36 39]; 
60
+    m.coordinate{2}(6,:)  = [30 -36 42]; 
61
+    m.coordinate{2}(7,:)  = [ -3, 6, 53];
62
+    m.coordinate{2}(8,:)  = [-27 30 28];
63
+    m.coordinate{2}(9,:)  = [-6, -81, -7]; 
64
+    m.coordinate{2}(10,:) = [9, -78, -7];
65
+    m.coordinate{2}(11,:) = [-51, -24, 60];
66
+    m.coordinate{2}(12,:) = [48, -21, 63]; 
67
+    m.coordinate{3}(1,:)  = [-15, -72, 60];
68
+    m.coordinate{3}(2,:)  = [15, -66, 63];
69
+    m.coordinate{3}(3,:)  = [-27, -12, 56];
70
+    m.coordinate{3}(4,:)  = [24 -15 53];
71
+    m.coordinate{3}(5,:)  = [-36 -36 42];
72
+    m.coordinate{3}(6,:)  = [30 -39 35];
73
+    m.coordinate{3}(7,:)  = [-9, 3, 53]; 
74
+    m.coordinate{3}(8,:)  = [-30 30 28];
75
+    m.coordinate{3}(9,:)  = [-3, -90, 4];
76
+    m.coordinate{3}(10,:) = [15, -99, 14];
77
+    m.coordinate{3}(11,:) = [-27, -27, 74];
78
+    m.coordinate{3}(12,:) = [36, -27, 70]; 
79
+    m.coordinate{4}(1,:)  = [-21, -69, 63]; 
80
+    m.coordinate{4}(2,:)  = [21, -69, 63];
81
+    m.coordinate{4}(3,:)  = [-33 -12 53];
82
+    m.coordinate{4}(4,:)  = [12 -9 60];
83
+    m.coordinate{4}(5,:)  = [-33 -35 46];
84
+    m.coordinate{4}(6,:)  = [42 -36 39];
85
+    m.coordinate{4}(7,:)  = [-3 0 49];
86
+    m.coordinate{4}(8,:)  = [-33 33 28];
87
+    m.coordinate{4}(9,:)  = [-3, -90, -7];
88
+    m.coordinate{4}(10,:) = [9, -81, -7];
89
+    m.coordinate{4}(11,:) = [-39, -27, 53];
90
+    m.coordinate{4}(12,:) = [51, -24, 60];
91
+    m.coordinate{5}(1,:)  = [-12 -66 63];
92
+    m.coordinate{5}(2,:)  = [12, -75, 60];
93
+    m.coordinate{5}(3,:)  = [-24, -12, 53];
94
+    m.coordinate{5}(4,:)  = [27, -9, 60]; 
95
+    m.coordinate{5}(5,:)  = [-42 -42 35]; 
96
+    m.coordinate{5}(6,:)  = [33 -48 35];
97
+    m.coordinate{5}(7,:)  = [ -3, 0, 49];
98
+    m.coordinate{5}(8,:)  = [-36 33 28];
99
+    m.coordinate{5}(9,:)  = [-15, -93, -4];
100
+    m.coordinate{5}(10,:) = [15, -90, 4]; 
101
+    m.coordinate{5}(11,:) = [-39, -33, 67];
102
+    m.coordinate{5}(12,:) = [27, -18, 74];
103
+    m.coordinate{6}(1,:)  = [-21, -69, 60];
104
+    m.coordinate{6}(2,:)  = [9, -72, 63];
105
+    m.coordinate{6}(3,:)  = [-24 -12 53];
106
+    m.coordinate{6}(4,:)  = [32 -12 56]; 
107
+    m.coordinate{6}(5,:)  = [-36 -39 35];
108
+    m.coordinate{6}(6,:)  = [42 -33 46]; 
109
+    m.coordinate{6}(7,:)  = [-6 3 49]; 
110
+    m.coordinate{6}(8,:)  = [-36 33 28];
111
+    m.coordinate{6}(9,:)  = [-12, -99, 0];
112
+    m.coordinate{6}(10,:) = [9, -96, -7];
113
+    m.coordinate{6}(11,:) = [-48, -27, 60];
114
+    m.coordinate{6}(12,:) = [33, -33, 60];
115
+    m.coordinate{7}(1,:)  = [-21, -60, 56]; 
116
+    m.coordinate{7}(2,:)  = [12, -69, 60]; 
117
+    m.coordinate{7}(3,:)  = [-24, -12, 49];
118
+    m.coordinate{7}(4,:)  = [24, -6, 49]; 
119
+    m.coordinate{7}(5,:)  = [-33 -45 46]; 
120
+    m.coordinate{7}(6,:)  = [30, -51, 49];
121
+    m.coordinate{7}(7,:)  = [0, 9, 42]; 
122
+    m.coordinate{7}(8,:)  = [-30 36 35]; 
123
+    m.coordinate{7}(9,:)  = [-3, -84, -4];
124
+    m.coordinate{7}(10,:) = [18, -87, -7];
125
+    m.coordinate{7}(11,:) = [-36, -30, 63]; 
126
+    m.coordinate{7}(12,:) = [42, -27, 60];
127
+    m.coordinate{8}(1,:)  = [-27, -63, 53];
128
+    m.coordinate{8}(2,:)  = [18, -66, 56];
129
+    m.coordinate{8}(3,:)  = [-21, -6, 56];
130
+    m.coordinate{8}(4,:)  = [27 -6 53]; 
131
+    m.coordinate{8}(5,:)  = [-36, -51, 49];
132
+    m.coordinate{8}(6,:)  = [45, -39, 53];
133
+    m.coordinate{8}(7,:)  = [-9, 9, 53];
134
+    m.coordinate{8}(8,:)  = [-36 24 25]; 
135
+    m.coordinate{8}(9,:)  = [0, -90, 4]; 
136
+    m.coordinate{8}(10,:) = [0, -90, 4];
137
+    m.coordinate{8}(11,:) = [-42, -27, 67]; 
138
+    m.coordinate{8}(12,:) = [51, -27, 63]; 
139
+
140
+    m = class(m,'SubjectRoiMapping');
141
+    
142
+elseif isa(argv,'SubjectRoiMapping') % copy
143
+   m = argv;
144
+
145
+else
146
+error('SubjectRoiMapping:Constructor:NoSuchConstructor','There is no constructor matching your argv');
147
+end
0 148
\ No newline at end of file
... ...
@@ -0,0 +1,18 @@
1
+function coord = getCoordinate(mapping,subject,roi)
2
+% getCoordinate(SubjectRoiMapping,subjectID,roiID) returns the coordinate
3
+% for the given subject and the given roi. Both subjectID and roiID can
4
+% either be a valid Name (see get[Sunject|Roi]NameCellList(mapping) ) or
5
+% the corresponding numerical ID.
6
+
7
+if ischar(subject) && ischar(roi)
8
+    coord = getCoordinate(mapping,getSubjectID(mapping,subject),getRoiID(mapping,roi));
9
+elseif isnumeric(subject) && ischar(roi)
10
+    coord = getCoordinate(mapping,subject,getRoiID(mapping,roi));
11
+elseif ischar(subject) && isnumeric(roi)
12
+    coord = getCoordinate(mapping,getSubjectID(mapping,subject),roi);
13
+elseif isnumeric(subject) && isnumeric(roi)
14
+    coord = mapping.coordinate{subject}(roi,:);
15
+else
16
+    error('SubjectRoiMapping:getCoordinate:BadArguments','Subject has to be a valid subject identifier (either char or integer)');
17
+end
18
+
... ...
@@ -0,0 +1,5 @@
1
+function roinames = getRoiNameCellList(mapping)
2
+% getRoiNameCellList(SubjectRoiMapping) returns a cell Array of ROI
3
+% identifiers
4
+    roinames = mapping.roi_name;
5
+end
... ...
@@ -0,0 +1,5 @@
1
+function namelist = getSubjectCellList(mapping)
2
+% getSubjectCellList(SubjectRoiMapping) returns a cell Array of subject
3
+% identifiers
4
+    namelist = mapping.subject;
5
+end
0 6
\ No newline at end of file
... ...
@@ -0,0 +1,7 @@
1
+function id = getRoiID(mapping,roiName)
2
+if mapping.roiNameMap.containsKey(roiName)
3
+    id = mapping.roiNameMap.get(roiName);   
4
+else
5
+     error('SubjectRoiMapping:getRoiID:noSuchName','this Mapping does not contain a ROI ''%s''',roiName);
6
+end
7
+end
0 8
\ No newline at end of file
... ...
@@ -0,0 +1,7 @@
1
+function id = getSubjectID(mapping,subjectName)
2
+if mapping.subjectNameMap.containsKey(subjectName)
3
+    id = mapping.subjectNameMap.get(subjectName);   
4
+else
5
+     error('SubjectRoiMapping:getSubjectID:noSuchName','this Mapping does not contain a Name ''%s''',subjectName);
6
+end
7
+end
0 8
\ No newline at end of file
... ...
@@ -0,0 +1,20 @@
1
+function vValue = VoxelValueAtTimepoint (coordinate, timepoint)
2
+% single Voxel for single coordinate
3
+
4
+if(size(coordinate,2)>1)
5
+    error('VoxelValueAtTimepoint:CoordinateError','only single Coordinate permitted.');
6
+end
7
+
8
+imageNumber = timePointToImageNumber(timepoint, 's');
9
+V           = evalin('base','SPM.xY.VY'); % Memory Mapped Images
10
+center      = round(inv(V(imageNumber).mat)*[coordinate; 1]);
11
+
12
+x           = center(1,1);
13
+y           = center(2,1);
14
+z           = center(3,1);
15
+
16
+vValue      = spm_sample_vol(V(imageNumber), x, y, z, 0);
17
+
18
+end
19
+
20
+
... ...
@@ -0,0 +1,278 @@
1
+% function [decodePerformance rawTimecourse ] = calculateDecodePerformance(des,timeLineStart, timeLineEnd, decodeDuration, svmargs, conditionList, sessionList, voxelList, classList, labelMap,normalize)
2
+function outputStruct = calculateDecodePerformance(inputStruct)
3
+
4
+addpath 'libsvm-mat-2.88-1';
5
+
6
+outputStruct = struct;
7
+
8
+des             = inputStruct.des;
9
+timeLineStart   = inputStruct.frameShiftStart;
10
+timeLineEnd     = inputStruct.frameShiftEnd;
11
+decodeDuration  = inputStruct.decodeDuration;
12
+svmargs         = inputStruct.svmargs;
13
+sessionList     = inputStruct.sessionList;
14
+voxelList       = inputStruct.voxelList;
15
+% classList       = inputStruct.classList;
16
+% labelMap        = inputStruct.labelMap;
17
+% normalize       = inputStruct.normalize;
18
+globalStart     = inputStruct.psthStart;
19
+globalEnd       = inputStruct.psthEnd;
20
+baselineStart   = inputStruct.baselineStart;
21
+baselineEnd     = inputStruct.baselineEnd;
22
+eventList       = inputStruct.eventList;
23
+
24
+
25
+minPerformance = inf;
26
+maxPerformance = -inf;
27
+
28
+
29
+        
30
+        %Pro Voxel PSTH TIMELINE berechnen.
31
+        %   timeshift mit pst-timeline durchf�hren.
32
+        % psth-timeline -25 bis +15 zu RES Onset.
33
+        
34
+%         eventList       = [9,11,13;10,12,14];
35
+%         globalStart     = -25;
36
+%         globalEnd       = 15;
37
+%         baselineStart   = -22;
38
+%         baselineEnd     = -20;
39
+        
40
+        
41
+        for voxel = 1:size(voxelList,1)  % [[x;x],[y;y],[z;z]]
42
+                extr  = calculateImageData(voxelList(voxel,:),des);
43
+                rawdata=cell2mat({extr.mean}); % Raw Data
44
+                pst{voxel}  = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,rawdata,sessionList);
45
+        end
46
+
47
+        decodePerformance = [];
48
+
49
+        for timeShift   = timeLineStart:1:timeLineEnd
50
+            frameStart  = floor(-globalStart+1+timeShift - 0.5*decodeDuration);
51
+            frameEnd    = min(ceil(frameStart+decodeDuration + 0.5*decodeDuration),-globalStart+globalEnd);
52
+            
53
+            tmp =[];
54
+            anyvoxel = 1;
55
+            for label = 1:size(pst{1,anyvoxel},2) 
56
+                for dp = 1:size(pst{1,anyvoxel}{1,label},1) % data point
57
+                row = label;
58
+                    for voxel = 1:size(pst,2)
59
+                        row = [row, pst{1,voxel}{1,label}(dp,frameStart:frameEnd)]; % label,value,value
60
+                    end
61
+                tmp  = [tmp; row];
62
+                end
63
+            end 
64
+        
65
+            svmdata      = tmp(:,2:size(tmp,2));
66
+            svmlabel     = tmp(:,1);
67
+            performance  = svmtrain(svmlabel, svmdata, svmargs);
68
+
69
+            minPerformance = min(minPerformance,performance);
70
+            maxPerformance = max(maxPerformance,performance);
71
+
72
+            decodePerformance = [decodePerformance; performance];
73
+        end
74
+        
75
+        outputStruct.decodePerformance  = decodePerformance;
76
+        outputStruct.svmdata            = svmdata;
77
+        outputStruct.svmlabel           = svmlabel;
78
+        outputStruct.rawTimeCourse      = pst;
79
+        outputStruct.minPerformance     = minPerformance;
80
+        outputStruct.maxPerformance     = maxPerformance;
81
+
82
+% display(sprintf('Min CrossVal Accuracy: %g%% \t Max CrossVal Accuracy: %g%%',minPerformance,maxPerformance));
83
+end
84
+
85
+
86
+function extr = calculateImageData(voxelList,des)
87
+
88
+dtype='PSTH';
89
+
90
+switch dtype 
91
+    case 'PSTH'
92
+        V=des.xY.VY;
93
+    case 'betas'
94
+        V=des.Vbeta;
95
+end;
96
+%   for z=1:length(V) % Change Drive Letter!
97
+%       V(z).fname(1)='E';
98
+%   end;
99
+
100
+% rad = 0; % one voxel
101
+% opt = 1; % xyz coordinates [mm]
102
+
103
+
104
+vox = voxelList;
105
+nRoi = size(vox,1);
106
+
107
+nImg = numel(V);
108
+
109
+for k=1:nImg
110
+	extr(k) = struct(...
111
+        'val',   repmat(NaN, [1 nRoi]),...
112
+		'mean',  repmat(NaN, [1 nRoi]),...
113
+		'sum',   repmat(NaN, [1 nRoi]),...
114
+		'nvx',   repmat(NaN, [1 nRoi]),...
115
+		'posmm', repmat(NaN, [3 nRoi]),...
116
+		'posvx', repmat(NaN, [3 nRoi]));
117
+
118
+    roicenter = round(inv(V(k).mat)*[vox, ones(nRoi,1)]');
119
+
120
+	for l = 1:nRoi
121
+
122
+%         if rad==0
123
+            x = roicenter(1,l);
124
+            y = roicenter(2,l);
125
+            z = roicenter(3,l);
126
+%         else
127
+%             tmp = spm_imatrix(V(k).mat);
128
+%             vdim = tmp(7:9);
129
+%             vxrad = ceil((rad*ones(1,3))./(ones(nRoi,1)*vdim))';
130
+%             [x y z] = ndgrid(-vxrad(1,l):sign(vdim(1)):vxrad(1,l), ...
131
+%                       -vxrad(2,l):sign(vdim(2)):vxrad(2,l), ...
132
+%                       -vxrad(3,l):sign(vdim(3)):vxrad(3,l));
133
+%             sel = (x./vxrad(1,l)).^2 + (y./vxrad(2,l)).^2 + ...
134
+%                   (z./vxrad(3,l)).^2 <= 1;
135
+%             x = roicenter(1,l)+x(sel(:));
136
+%             y = roicenter(2,l)+y(sel(:));
137
+%             z = roicenter(3,l)+z(sel(:));
138
+%         end;
139
+		dat                 = spm_sample_vol(V(k), x, y, z,0);
140
+		[maxv maxi]         = max(dat);
141
+		tmp                 = V(k).mat*[x(maxi); y(maxi); z(maxi);1]; % Max Pos
142
+		extr(k).val(l)      = maxv;
143
+		extr(k).sum(l)      = sum(dat);
144
+		extr(k).mean(l)     = nanmean(dat);
145
+        extr(k).nvx(l)      = numel(dat);
146
+		extr(k).posmm(:,l)  = tmp(1:3);
147
+		extr(k).posvx(:,l)  = [x(maxi); y(maxi); z(maxi)]; % Max Pos
148
+	end;
149
+
150
+end;
151
+end
152
+
153
+% disp(sprintf('Extracted at %.1f %.1f %.1f [xyz(mm)], average of %i voxel(s) [%.1fmm radius Sphere]',vox,length(x),rad));
154
+
155
+function pst = calculatePST(des,globalStart,baselineStart,baselineEnd,globalEnd,eventList,data,sessionList)
156
+    bstart          = baselineStart;
157
+    bend            = baselineEnd;
158
+    edur            = 12;
159
+    pre             =  globalStart;
160
+    post            =  globalEnd;
161
+    res             = 1;
162
+
163
+    normz           = 'file';
164
+    pm              = 0;
165
+
166
+    lsess           = getNumberOfScans(des);
167
+    nSessions       = getNumberOfSessions(des);
168
+    tr              = 2;
169
+
170
+    [evntrow evntcol]=size(eventList);
171
+    
172
+
173
+    hsec=str2num(des.xsDes.High_pass_Filter(8:end-3)); % Highpass filter [sec] from SPM.mat
174
+
175
+    if strcmp(des.xBF.UNITS,'secs')
176
+        unitsecs=1;
177
+    end;
178
+
179
+    nScansPerSession=getNumberOfScans(des);
180
+    %stime=[0:tr:max(nScansPerSession)*tr+post-tr]; % Stimulus time for raw data plot
181
+    stime=0:tr:max(nScansPerSession)*tr+round(post/tr)*tr-tr; % Stimulus time for raw data plot
182
+
183
+
184
+
185
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
186
+    % RUN
187
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
188
+
189
+
190
+    % Digital Highpass
191
+    Rp=0.5;
192
+    Rs=20;
193
+    NO=1;
194
+    Wp=1/((1/2/tr)/(1/hsec));
195
+    [B, A] = ellip(NO,Rp,Rs,Wp,'high');
196
+
197
+    sdata(1:max(nScansPerSession)+round(post/tr),1:nSessions)=nan; % Open Data Matrix
198
+    for z=1:nSessions % Fill Data Matrix sessionwise
199
+        sdata(1:nScansPerSession(z),z)=data(sum(nScansPerSession(1:z))-nScansPerSession(z)+1:sum(nScansPerSession(1:z)))';
200
+    end;
201
+%         usdata=sdata; % Keep unfiltered data
202
+
203
+    sdatamean=nanmean(nanmean(sdata(:,:)));
204
+    for z=1:nSessions
205
+%             X(:,z)=[1:1:max(nScansPerSession)]'; % #Volume
206
+        sdata(1:nScansPerSession(z),z)=filtfilt(B,A,sdata(1:nScansPerSession(z),z)); %Filter Data (Highpass)
207
+    end;
208
+    sdata=sdata+sdatamean;
209
+
210
+
211
+    %%%%Parametric Modulation Modus%%%%
212
+    if pm %Find Parameters for Event of Interest
213
+        [imods modss mods erow evntrow eventList] = getParametricMappingEvents(eventList,evntrow,des,pmf);
214
+    end;
215
+    %%%%PM%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
216
+
217
+
218
+    for zr=1:evntrow
219
+        n{zr}=0;
220
+        nn{zr}=0; 
221
+        nnn{zr}=0;
222
+        sstart{zr}=1;
223
+    end;
224
+
225
+
226
+    sesst0=0; 
227
+    for sessionID=sessionList
228
+        if sessionID>1
229
+            sesst0(sessionID)=sum(lsess(1:sessionID-1))*tr;  
230
+        end;
231
+        for zr=1:evntrow  %LABEL NUMBER, EVENT GROUP
232
+            sstart{zr}=n{zr}+1;
233
+            for ze=1:evntcol % EVENT INDEX in EventList
234
+                if ze==1 || (ze>1 && eventList(zr,ze)~=eventList(zr,ze-1))
235
+                    for zz=1:length(des.Sess(sessionID).U(eventList(zr,ze)).ons) % EVENT REPETITION NUMBER
236
+                        if ~unitsecs
237
+                            des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)=(des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)-1)*tr;
238
+                            des.Sess(sessionID).U(eventList(zr,ze)).dur(zz)=(des.Sess(sessionID).U(eventList(zr,ze)).dur(zz)-1)*tr;
239
+                        end;
240
+
241
+                        nnn{zr}=nnn{zr}+1; % INFO for rawdataplot start
242
+                        if des.Sess(sessionID).U(eventList(zr,ze)).dur(zz)<edur
243
+                            mev{zr}(nnn{zr},1:2)=[des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)+sesst0(sessionID) edur]; % modeled event [onset length]
244
+                        else
245
+                            mev{zr}(nnn{zr},1:2)=[des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)+sesst0(sessionID) des.Sess(sessionID).U(eventList(zr,ze)).dur(zz)];
246
+                        end; % INFO for rawdataplot end
247
+
248
+                        n{zr}=n{zr}+1;
249
+                        pst{zr}(n{zr},:)=interp1(stime,sdata(:,sessionID),[des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)+pre:res:des.Sess(sessionID).U(eventList(zr,ze)).ons(zz)+post],'linear');
250
+                        if strcmp(normz,'epoc')
251
+                            bline=nanmean(pst{zr}(n{zr},round(-pre/res+(bstart)/res+1):round(-pre/res+(bend)/res+1)));
252
+                            if isnan(bline)
253
+                                pst{zr}(n{zr},1:-pre/res+post/res+1)=nan;
254
+                            else
255
+%                                     nn{zr}=nn{zr}+1;
256
+                                pst{zr}(n{zr},:)=(pst{zr}(n{zr},:)-bline)/bline*100; % 'epoch-based' normalization
257
+                            end;
258
+                        end;
259
+                    end;
260
+                end;
261
+            end;
262
+            if ~strcmp(normz,'epoc')
263
+                bline(zr)=nanmean(nanmean(pst{zr}(sstart{zr}:n{zr},-pre/res+(bstart)/res+1:-pre/res+(bend)/res+1)));
264
+                bstd(zr)=nanmean(nanstd(pst{zr}(sstart{zr}:n{zr},-pre/res+(bstart)/res+1:-pre/res+(bend)/res+1)));
265
+                nn{zr}=n{zr};
266
+            end;
267
+        end;
268
+        if strcmp(normz,'filz')
269
+            for zr=1:evntrow
270
+                pst{zr}(sstart{zr}:n{zr},:)=(pst{zr}(sstart{zr}:n{zr},:)-mean(bline))/mean(bstd); % session-based z-score normalization
271
+            end;
272
+        elseif strcmp(normz,'file')
273
+            for zr=1:evntrow
274
+                pst{zr}(sstart{zr}:n{zr},:)=(pst{zr}(sstart{zr}:n{zr},:)-mean(bline))/mean(bline)*100; % session-based normalization
275
+            end;
276
+        end;
277
+    end;
278
+end
... ...
@@ -0,0 +1,138 @@
1
+function classify(action)
2
+
3
+if ~exist('action','var')
4
+    action='no action';
5
+end
6
+
7
+    switch(action)
8
+    case 'clear'
9
+        evalin('base','clear map lm SPM classList dataTimeLine decodeTable labelTimeLine svmopts trialProtocol voxelList xTimeEnd xTimeStart xTimeWindow');
10
+      
11
+    case 'decode'
12
+    
13
+
14
+
15
+        display('loading SPM.mat');
16
+        SubjectID = 'JZ006';
17
+%         SubjectID = 'AI020';
18
+%         SubjectID = 'HG027';
19
+        spm = load(fullfile('D:\Analyze\Choice\24pilot',SubjectID,'results\SPM.mat'));
20
+
21
+        display('done.');
22
+
23
+
24
+
25
+
26
+        map = SubjectRoiMapping;
27
+
28
+        voxelList  = [...
29
+                      getCoordinate(map,SubjectID,'SPL l')+[0,0,0];...
30
+                          getCoordinate(map,SubjectID,'SPL l')+[1,0,0];...
31
+                          getCoordinate(map,SubjectID,'SPL l')+[-1,0,0];...
32
+                          getCoordinate(map,SubjectID,'SPL l')+[0,1,0];...
33
+                          getCoordinate(map,SubjectID,'SPL l')+[0,-1,0];...
34
+                          getCoordinate(map,SubjectID,'SPL l')+[0,0,1];...
35
+                          getCoordinate(map,SubjectID,'SPL l')+[0,0,-1];...
36
+                      getCoordinate(map,SubjectID,'SPL r')+[0,0,0];...
37
+                          getCoordinate(map,SubjectID,'SPL r')+[1,0,0];...
38
+                          getCoordinate(map,SubjectID,'SPL r')+[-1,0,0];...
39
+                          getCoordinate(map,SubjectID,'SPL r')+[0,1,0];...
40
+                          getCoordinate(map,SubjectID,'SPL r')+[0,-1,0];...
41
+                          getCoordinate(map,SubjectID,'SPL r')+[0,0,1];...
42
+                          getCoordinate(map,SubjectID,'SPL r')+[0,0,-1];...
43
+                      getCoordinate(map,SubjectID,'M1 r')+[0,0,0];...
44
+                      getCoordinate(map,SubjectID,'M1 l')+[0,0,0];...
45
+                      ];
46
+        
47
+        
48
+        params = struct;
49
+        params.nClasses = 2;
50
+
51
+        assignin('base','params',params);
52
+        %% calculate
53
+        display('calculating cross-validation performance time-shift');
54
+        calculateParams  = struct;
55
+        
56
+        calculateParams.des             = spm.SPM;
57
+        calculateParams.frameShiftStart = -20;
58
+        calculateParams.frameShiftEnd   = 15;
59
+        calculateParams.decodeDuration  = 1;
60
+        calculateParams.svmargs         = '-t 0 -s 0 -v 6';
61
+        calculateParams.sessionList     = 1:3;
62
+        calculateParams.voxelList       = voxelList;
63
+        calculateParams.classList       = {'<','>'};
64
+        calculateParams.labelMap        = LabelMap({'<','>','<+<','>+>','<+>','>+<'},{-2,-1,1,2,3,4});
65
+        calculateParams.psthStart       = -25;
66
+        calculateParams.psthEnd         = 20;
67
+        calculateParams.baselineStart   = -22;
68
+        calculateParams.baselineEnd     = -20;
69
+        calculateParams.eventList       = [9,11,13; 10,12,14];
70
+        
71
+        assignin('base','calculateParams',calculateParams);
72
+        
73
+%         [decodeTable rawTimeCourse] = calculateDecodePerformance(spm,params.frameShiftStart,params.frameShiftEnd,params.xTimeWindow,params.svmopts,1:4,params.sessionList,params.voxelList,params.classList,params.labelMap,params.normalize);
74
+        decode = calculateDecodePerformance(calculateParams);
75
+        display(sprintf('Min CrossVal Accuracy: %g%% \t Max CrossVal Accuracy: %g%%',decode.minPerformance,decode.maxPerformance));
76
+        
77
+        assignin('base','decode',decode);
78
+
79
+        display('Finished calculations.');
80
+        display('Plotting.');
81
+
82
+        plotParams = struct;
83
+        plotParams.psthStart = calculateParams.psthStart;
84
+        plotParams.psthEnd   = calculateParams.psthEnd;
85
+        plotParams.nClasses  = length(calculateParams.classList);
86
+        plotParams.frameShiftStart   = calculateParams.frameShiftStart;
87
+        plotParams.frameShiftEnd     = calculateParams.frameShiftEnd;
88
+        plotParams.decodePerformance = decode.decodePerformance;
89
+        plotParams.rawTimeCourse     = decode.rawTimeCourse;
90
+        plotParams.SubjectID         = SubjectID;
91
+        
92
+        assignin('base','plotParams',plotParams);
93
+%         plotDecodePerformance(params.psthStart,params.psthEnd,params.nClasses,decode.decodeTable,params.frameShiftStart,params.frameShiftEnd,decode.rawTimeCourse);
94
+        plotDecodePerformance(plotParams);
95
+
96
+        case 'gen'
97
+            center = '[-39 -33 67]';
98
+            sessionList = '1:3';
99
+            conditionList = '1:2';
100
+            radius = 3;
101
+            normalize = 1;
102
+
103
+            cmd=sprintf(...
104
+                '[label data] = generateDataMatrix(generateVoxelList(%s,%d), generateTrialProtocol(%s,%s),%d);',...
105
+                center,radius,sessionList,conditionList,normalize);
106
+
107
+            %     assignin('base','label',label);
108
+            %     assignin('base','data',data);
109
+
110
+    case 'norm'
111
+        cmd = ['for i=1:size(data,2)'...
112
+            'data(:,i)=data(:,i)/std(data(:,i));'...
113
+        'end;'];
114
+        
115
+    case 'xtrain'
116
+        svmargs = '-t 0'; %linear kernel
117
+        svmargs = [svmargs '-v 4'];
118
+        
119
+        cmd=sprintf('model = svmtrain(label,data,''%s'')',svmargs);
120
+        
121
+    case 'train'
122
+        svmargs = '-t 0'; %linear kernel
123
+%         svmargs = [svmargs '-v 4'];
124
+        
125
+        cmd=sprintf('model = svmtrain(label,data,%s)',svmargs);
126
+       
127
+    case 'pred'
128
+        cmd = '[predicted_label, accuracy, decision_values] = svmpredict(label, data, model);';
129
+        
130
+    otherwise
131
+        display('give action command: clear load gen (norm) xtrain train pred');
132
+    end
133
+    
134
+    if exist('cmd','var') 
135
+        evalin('base',cmd);
136
+    end
137
+
138
+end
0 139
\ No newline at end of file
... ...
@@ -0,0 +1,20 @@
1
+function voxellist = generateVoxelList(center, radius)
2
+%     cx = ROICenter(1);
3
+%     cy = ROICenter(2);
4
+%     cz = ROICenter(3);
5
+    
6
+    voxellist = [];
7
+    
8
+    cx = center(1);
9
+    cy = center(2);
10
+    cz = center(3);
11
+
12
+    for z=cz-radius+1:cz+radius-1
13
+        for y=(cy-radius+1):(cy+radius-1)
14
+            for x=(cx-radius+1):(cx+radius-1)
15
+                voxellist = [voxellist [x;y;z]];
16
+            end
17
+        end
18
+    end
19
+
20
+end
0 21
\ No newline at end of file
... ...
@@ -0,0 +1,10 @@
1
+function duration = getDuration(session, condition, repetition)
2
+%session    : the session number. 
3
+%condition  : the condition intrested in (CUE1,CUE2,...)
4
+%repetition : the repetition number; same event, different time ;)
5
+
6
+cmd = sprintf('SPM.Sess(%d,%d).U(%d,%d).dur(%d)',1,session,1,condition,repetition);
7
+
8
+duration = evalin('base',cmd);
9
+
10
+end
0 11
\ No newline at end of file
... ...
@@ -0,0 +1,4 @@
1
+function rep = getNumberOfRepetitions(session,condition)
2
+   cmd = sprintf('length(SPM.Sess(%d,%d).U(%d,%d).ons)',1,session,1,condition);
3
+   rep = evalin('base',cmd);
4
+end
0 5
\ No newline at end of file
... ...
@@ -0,0 +1,4 @@
1
+% Number of Scans per session
2
+function nScan = getNumberOfScans(des)
3
+    nScan = des.nscan;
4
+end
0 5
\ No newline at end of file
... ...
@@ -0,0 +1,4 @@
1
+function nSessions = getNumberOfSessions(des)
2
+    nSessions = length(des.Sess);
3
+%     nSessions = model.nSessions;
4
+end
0 5
\ No newline at end of file
... ...
@@ -0,0 +1,11 @@
1
+function onset = getOnset(session,condition,repetition)
2
+
3
+%session    : the session number. 
4
+%condition  : the condition intrested in (CUE1,CUE2,...)
5
+%repetition : the repetition number; same event, different time ;)
6
+
7
+cmd = sprintf('SPM.Sess(%d,%d).U(%d,%d).ons(%d)',1,session,1,condition,repetition);
8
+
9
+onset = evalin('base',cmd);
10
+
11
+end
0 12
\ No newline at end of file
... ...
@@ -0,0 +1,38 @@
1
+import java.util.HashMap; 
2
+
3
+class LabelMap {
4
+
5
+private HashMap<String,Double> labelToValue;
6
+private HashMap<Double,String> valueToLabel;
7
+
8
+public LabelMap(){
9
+    this(2);
10
+}
11
+
12
+public LabelMap(int numberOfLabels){
13
+    labelToValue = new HashMap<String,Double>(numberOfLabels+1,1);
14
+    valueToLabel = new HashMap<Double,String>(numberOfLabels+1,1);
15
+}
16
+
17
+public void add(String label, double value){
18
+    labelToValue.put(label,value);
19
+    valueToLabel.put(value,label);
20
+}
21
+
22
+public String getLabel(double value){
23
+    return valueToLabel.get(value);
24
+}
25
+
26
+public Double getValue(String label){
27
+    return labelToValue.get(label);
28
+}
29
+
30
+public String toString(){
31
+    StringBuffer s = new StringBuffer("LabelMap: \n");
32
+    for( String key : labelToValue.keySet()){
33
+        s.append(key+'\t'+labelToValue.get(key)+"\n");
34
+    }
35
+    return s.toString();
36
+}
37
+    
38
+}
0 39
\ No newline at end of file
... ...
@@ -0,0 +1,14 @@
1
+function generateClassLabelValueMaps(filename)
2
+
3
+if exist(filename,'file')
4
+    vars = load(filename);
5
+    clMap = vars.classLabelMap;
6
+    nItems = size(clMap,1);
7
+    lm = LabelMap(nItems);
8
+    for item = 1:nItems
9
+        label = clMap(item,1);
10
+        value = cell2mat(clMap(item,2));
11
+        lm.add(label,value)
12
+    end
13
+    assignin('base','lm',lm);
14
+end
0 15
\ No newline at end of file
... ...
@@ -0,0 +1,5 @@
1
+function label = getClassLabel(key)
2
+    cmd = sprintf('lm.getLabel(%d)',key);
3
+    label = evalin('base',cmd);
4
+end
5
+ 
0 6
\ No newline at end of file
... ...
@@ -0,0 +1,4 @@
1
+function clazz = getClassValue(key)
2
+    cmd = sprintf('lm.getValue(''%s'')',key);
3
+    clazz = evalin('base',cmd);
4
+end
... ...
@@ -0,0 +1,31 @@
1
+
2
+Copyright (c) 2000-2008 Chih-Chung Chang and Chih-Jen Lin
3
+All rights reserved.
4
+
5
+Redistribution and use in source and binary forms, with or without
6
+modification, are permitted provided that the following conditions
7
+are met:
8
+
9
+1. Redistributions of source code must retain the above copyright
10
+notice, this list of conditions and the following disclaimer.
11
+
12
+2. Redistributions in binary form must reproduce the above copyright
13
+notice, this list of conditions and the following disclaimer in the
14
+documentation and/or other materials provided with the distribution.
15
+
16
+3. Neither name of copyright holders nor the names of its contributors
17
+may be used to endorse or promote products derived from this software
18
+without specific prior written permission.
19
+
20
+
21
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
25
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
... ...
@@ -0,0 +1,47 @@
1
+# This Makefile is used under Linux
2
+
3
+MATLABDIR ?= /usr/local/matlab
4
+CXX ?= g++
5
+#CXX = g++-4.1
6
+CFLAGS = -Wall -O3 -fPIC -I$(MATLABDIR)/extern/include
7
+
8
+MEX = $(MATLABDIR)/bin/mex
9
+MEX_OPTION = CC\#$(CXX) CXX\#$(CXX) CFLAGS\#"$(CFLAGS)" CXXFLAGS\#"$(CFLAGS)"
10
+# comment the following line if you use MATLAB on 32-bit computer
11
+MEX_OPTION += -largeArrayDims
12
+MEX_EXT = $(shell $(MATLABDIR)/bin/mexext)
13
+
14
+OCTAVEDIR ?= /usr/include/octave
15
+OCTAVE_MEX = env CC=$(CXX) mkoctfile
16
+OCTAVE_MEX_OPTION = --mex
17
+OCTAVE_MEX_EXT = mex
18
+OCTAVE_CFLAGS = -Wall -O3 -fPIC -I$(OCTAVEDIR)
19
+
20
+all:	matlab
21
+
22
+matlab:	binary
23
+
24
+octave:
25
+	@make MEX="$(OCTAVE_MEX)" MEX_OPTION="$(OCTAVE_MEX_OPTION)" \
26
+	MEX_EXT="$(OCTAVE_MEX_EXT)" CFLAGS="$(OCTAVE_CFLAGS)" \
27
+	binary
28
+
29
+binary: svmpredict.$(MEX_EXT) svmtrain.$(MEX_EXT) read_sparse.$(MEX_EXT)
30
+
31
+svmpredict.$(MEX_EXT):     svmpredict.c svm.h svm.o svm_model_matlab.o
32
+	$(MEX) $(MEX_OPTION) svmpredict.c svm.o svm_model_matlab.o
33
+
34
+svmtrain.$(MEX_EXT):       svmtrain.c svm.h svm.o svm_model_matlab.o
35
+	$(MEX) $(MEX_OPTION) svmtrain.c svm.o svm_model_matlab.o
36
+
37
+read_sparse.$(MEX_EXT):	read_sparse.c
38
+	$(MEX) $(MEX_OPTION) read_sparse.c
39
+
40
+svm_model_matlab.o:     svm_model_matlab.c svm.h
41
+	$(CXX) $(CFLAGS) -c svm_model_matlab.c
42
+
43
+svm.o:  svm.cpp svm.h
44
+	$(CXX) $(CFLAGS) -c svm.cpp
45
+
46
+clean:
47
+	rm -f *~ *.o *.mex* *.obj
... ...
@@ -0,0 +1,210 @@
1
+-----------------------------------------
2
+--- MATLAB/OCTAVE interface of LIBSVM ---
3
+-----------------------------------------
4
+
5
+Table of Contents
6
+=================
7
+
8
+- Introduction
9
+- Installation
10
+- Usage
11
+- Returned Model Structure
12
+- Examples
13
+- Other Utilities
14
+- Additional Information
15
+
16
+
17
+Introduction
18
+============
19
+
20
+This tool provides a simple interface to LIBSVM, a library for support vector
21
+machines (http://www.csie.ntu.edu.tw/~cjlin/libsvm). It is very easy to use as
22
+the usage and the way of specifying parameters are the same as that of LIBSVM.
23
+
24
+Installation
25
+============
26
+
27
+On Unix systems, we recommend using GNU g++ as your
28
+compiler and type 'make' to build 'svmtrain.mexglx' and 'svmpredict.mexglx'.
29
+Note that we assume your MATLAB is installed in '/usr/local/matlab',
30
+if not, please change MATLABDIR in Makefile.
31
+
32
+Example:
33
+        linux> make
34
+
35
+To use Octave, type 'make octave':
36
+
37
+Example:
38
+	linux> make octave
39
+
40
+On Windows systems, pre-built 'svmtrain.mexw32' and 'svmpredict.mexw32' are
41
+included in this package, so no need to conduct installation. If you
42
+have modified the sources and would like to re-build the package, type
43
+'mex -setup' in MATLAB to choose a compiler for mex first. Then type
44
+'make' to start the installation.
45
+
46
+Starting from MATLAB 7.1 (R14SP3), the default MEX file extension is changed
47
+from .dll to .mexw32 or .mexw64 (depends on 32-bit or 64-bit Windows). If your
48
+MATLAB is older than 7.1, you have to build these files yourself.
49
+
50
+Example:
51
+        matlab> mex -setup
52
+        (ps: MATLAB will show the following messages to setup default compiler.)
53
+        Please choose your compiler for building external interface (MEX) files: 
54
+        Would you like mex to locate installed compilers [y]/n? y
55
+        Select a compiler: 
56
+        [1] Microsoft Visual C/C++ version 7.1 in C:\Program Files\Microsoft Visual Studio 
57
+        [0] None 
58
+        Compiler: 1
59
+        Please verify your choices: 
60
+        Compiler: Microsoft Visual C/C++ 7.1 
61
+        Location: C:\Program Files\Microsoft Visual Studio 
62
+        Are these correct?([y]/n): y
63
+
64
+        matlab> make
65
+
66
+
67
+Under 64-bit Windows, Visual Studio 2005 user will need "X64 Compiler and Tools".
68
+The package won't be installed by default, but you can find it in customized
69
+installation options.
70
+
71
+For list of supported/compatible compilers for MATLAB, please check the
72
+following page:
73
+
74
+http://www.mathworks.com/support/compilers/current_release/
75
+
76
+Usage
77
+=====
78
+
79
+matlab> model = svmtrain(training_label_vector, training_instance_matrix [, 'libsvm_options']);
80
+
81
+        -training_label_vector:
82
+            An m by 1 vector of training labels (type must be double).
83
+        -training_instance_matrix:
84
+            An m by n matrix of m training instances with n features.
85
+            It can be dense or sparse (type must be double).
86
+        -libsvm_options:
87
+            A string of training options in the same format as that of LIBSVM.
88
+
89
+matlab> [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model [, 'libsvm_options']);
90
+
91
+        -testing_label_vector:
92
+            An m by 1 vector of prediction labels. If labels of test
93
+            data are unknown, simply use any random values. (type must be double)
94
+        -testing_instance_matrix:
95
+            An m by n matrix of m testing instances with n features.
96
+            It can be dense or sparse. (type must be double)
97
+        -model:
98
+            The output of svmtrain.
99
+        -libsvm_options:
100
+            A string of testing options in the same format as that of LIBSVM.
101
+
102
+Returned Model Structure
103
+========================
104
+
105
+The 'svmtrain' function returns a model which can be used for future
106
+prediction.  It is a structure and is organized as [Parameters, nr_class,
107
+totalSV, rho, Label, ProbA, ProbB, nSV, sv_coef, SVs]:
108
+
109
+        -Parameters: parameters
110
+        -nr_class: number of classes; = 2 for regression/one-class svm
111
+        -totalSV: total #SV
112
+        -rho: -b of the decision function(s) wx+b
113
+        -Label: label of each class; empty for regression/one-class SVM
114
+        -ProbA: pairwise probability information; empty if -b 0 or in one-class SVM
115
+        -ProbB: pairwise probability information; empty if -b 0 or in one-class SVM
116
+        -nSV: number of SVs for each class; empty for regression/one-class SVM
117
+        -sv_coef: coefficients for SVs in decision functions
118
+        -SVs: support vectors
119
+
120
+If you do not use the option '-b 1', ProbA and ProbB are empty
121
+matrices. If the '-v' option is specified, cross validation is
122
+conducted and the returned model is just a scalar: cross-validation
123
+accuracy for classification and mean-squared error for regression.
124
+
125
+More details about this model can be found in LIBSVM FAQ
126
+(http://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html) and LIBSVM
127
+implementation document
128
+(http://www.csie.ntu.edu.tw/~cjlin/papers/libsvm.pdf).
129
+
130
+Result of Prediction
131
+====================
132
+
133
+The function 'svmpredict' has three outputs. The first one,
134
+predictd_label, is a vector of predicted labels. The second output,
135
+accuracy, is a vector including accuracy (for classification), mean
136
+squared error, and squared correlation coefficient (for regression).
137
+The third is a matrix containing decision values or probability
138
+estimates (if '-b 1' is specified). If k is the number of classes,
139
+for decision values, each row includes results of predicting
140
+k(k-1/2) binary-class SVMs. For probabilities, each row contains k values
141
+indicating the probability that the testing instance is in each class.
142
+Note that the order of classes here is the same as 'Label' field
143
+in the model structure.
144
+
145
+Examples
146
+========
147
+
148
+Train and test on the provided data heart_scale:
149
+
150
+matlab> load heart_scale.mat
151
+matlab> model = svmtrain(heart_scale_label, heart_scale_inst, '-c 1 -g 0.07');
152
+matlab> [predict_label, accuracy, dec_values] = svmpredict(heart_scale_label, heart_scale_inst, model); % test the training data
153
+
154
+For probability estimates, you need '-b 1' for training and testing:
155
+
156
+matlab> load heart_scale.mat
157
+matlab> model = svmtrain(heart_scale_label, heart_scale_inst, '-c 1 -g 0.07 -b 1');
158
+matlab> load heart_scale.mat
159
+matlab> [predict_label, accuracy, prob_estimates] = svmpredict(heart_scale_label, heart_scale_inst, model, '-b 1');
160
+
161
+To use precomputed kernel, you must include sample serial number as
162
+the first column of the training and testing data (assume your kernel
163
+matrix is K, # of instances is n):
164
+
165
+matlab> K1 = [(1:n)', K]; % include sample serial number as first column
166
+matlab> model = svmtrain(label_vector, K1, '-t 4');
167
+matlab> [predict_label, accuracy, dec_values] = svmpredict(label_vector, K1, model); % test the training data
168
+
169
+Take linear kernel for example, the following precomputed kernel example 
170
+gives exactly same training error as LIBSVM built-in linear kernel
171
+
172
+matlab> load heart_scale.mat
173
+matlab> n = size(heart_scale_inst,1);
174
+matlab> K = heart_scale_inst*heart_scale_inst';
175
+matlab> K1 = [(1:n)', K];
176
+matlab> model = svmtrain(heart_scale_label, K1, '-t 4');
177
+matlab> [predict_label, accuracy, dec_values] = svmpredict(heart_scale_label, K1, model);
178
+       
179
+Note that for testing, you can put anything in the testing_label_vector.  For
180
+details of precomputed kernels, please read the section ``Precomputed
181
+Kernels'' in the README of the LIBSVM package.
182
+
183
+Other Utilities
184
+===============
185
+
186
+A matlab function read_sparse reads files in LIBSVM format: 
187
+
188
+[label_vector, instance_matrix] = read_sparse('data.txt'); 
189
+
190
+Two outputs are labels and instances, which can then be used as inputs
191
+of svmtrain or svmpredict. This code is derived from svm-train.c in
192
+LIBSVM by Rong-En Fan from National Taiwan University.
193
+
194
+Additional Information
195
+======================
196
+
197
+This interface was initially written by Jun-Cheng Chen, Kuan-Jen Peng,
198
+Chih-Yuan Yang and Chih-Huai Cheng from Department of Computer
199
+Science, National Taiwan University. The current version was prepared
200
+by Rong-En Fan and Ting-Fan Wu. If you find this tool useful, please
201
+cite LIBSVM as follows
202
+
203
+Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for
204
+support vector machines, 2001. Software available at
205
+http://www.csie.ntu.edu.tw/~cjlin/libsvm
206
+
207
+For any question, please contact Chih-Jen Lin <cjlin@csie.ntu.edu.tw>,
208
+or check the FAQ page:
209
+
210
+http://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q9:_MATLAB_interface
... ...
@@ -0,0 +1,7 @@
1
+% This make.m is used under Windows
2
+
3
+mex -O -c svm.cpp
4
+mex -O -c svm_model_matlab.c
5
+mex -O svmtrain.c svm.obj svm_model_matlab.obj
6
+mex -O svmpredict.c svm.obj svm_model_matlab.obj
7
+mex -O read_sparse.c
... ...
@@ -0,0 +1,200 @@
1
+#include <stdio.h>
2
+#include <string.h>
3
+#include <stdlib.h>
4
+#include <ctype.h>
5
+#include <errno.h>
6
+
7
+#include "mex.h"
8
+
9
+#if MX_API_VER < 0x07030000
10
+typedef int mwIndex;
11
+#endif 
12
+#define max(x,y) (((x)>(y))?(x):(y))
13
+#define min(x,y) (((x)<(y))?(x):(y))
14
+
15
+void exit_with_help()
16
+{
17
+	mexPrintf(
18
+	"Usage: [label_vector, instance_matrix] = read_sparse(fname);\n"
19
+	);
20
+}
21
+
22
+static void fake_answer(mxArray *plhs[])
23
+{
24
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
25
+	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
26
+}
27
+
28
+static char *line;
29
+static int max_line_len;
30
+
31
+static char* readline(FILE *input)
32
+{
33
+	int len;
34
+	
35
+	if(fgets(line,max_line_len,input) == NULL)
36
+		return NULL;
37
+
38
+	while(strrchr(line,'\n') == NULL)
39
+	{
40
+		max_line_len *= 2;
41
+		line = (char *) realloc(line, max_line_len);
42
+		len = (int) strlen(line);
43
+		if(fgets(line+len,max_line_len-len,input) == NULL)
44
+			break;
45
+	}
46
+	return line;
47
+}
48
+
49
+// read in a problem (in svmlight format)
50
+void read_problem(const char *filename, mxArray *plhs[])
51
+{
52
+	int max_index, min_index, inst_max_index, i;
53
+	long elements, k;
54
+	FILE *fp = fopen(filename,"r");
55
+	int l = 0;
56
+	char *endptr;
57
+	mwIndex *ir, *jc;
58
+	double *labels, *samples;
59
+	
60
+	if(fp == NULL)
61
+	{
62
+		mexPrintf("can't open input file %s\n",filename);
63
+		fake_answer(plhs);
64
+		return;
65
+	}
66
+
67
+	max_line_len = 1024;
68
+	line = (char *) malloc(max_line_len*sizeof(char));
69
+
70
+	max_index = 0;
71
+	min_index = 1; // our index starts from 1
72
+	elements = 0;
73
+	while(readline(fp) != NULL)
74
+	{
75
+		char *idx, *val;
76
+		// features
77
+		int index = 0;
78
+
79
+		inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
80
+		strtok(line," \t"); // label
81
+		while (1)
82
+		{
83
+			idx = strtok(NULL,":"); // index:value
84
+			val = strtok(NULL," \t");
85
+			if(val == NULL)
86
+				break;
87
+
88
+			errno = 0;
89
+			index = (int) strtol(idx,&endptr,10);
90
+			if(endptr == idx || errno != 0 || *endptr != '\0' || index <= inst_max_index)
91
+			{
92
+				mexPrintf("Wrong input format at line %d\n",l+1);
93
+				fake_answer(plhs);
94
+				return;
95
+			}
96
+			else
97
+				inst_max_index = index;
98
+
99
+			min_index = min(min_index, index);
100
+			elements++;
101
+		}
102
+		max_index = max(max_index, inst_max_index);
103
+		l++;
104
+	}
105
+	rewind(fp);
106
+
107
+	// y
108
+	plhs[0] = mxCreateDoubleMatrix(l, 1, mxREAL);
109
+	// x^T
110
+	if (min_index <= 0)
111
+		plhs[1] = mxCreateSparse(max_index-min_index+1, l, elements, mxREAL);
112
+	else
113
+		plhs[1] = mxCreateSparse(max_index, l, elements, mxREAL);
114
+
115
+	labels = mxGetPr(plhs[0]);
116
+	samples = mxGetPr(plhs[1]);
117
+	ir = mxGetIr(plhs[1]);
118
+	jc = mxGetJc(plhs[1]);
119
+
120
+	k=0;
121
+	for(i=0;i<l;i++)
122
+	{
123
+		char *idx, *val, *label;
124
+		jc[i] = k;
125
+
126
+		readline(fp);
127
+
128
+		label = strtok(line," \t");
129
+		labels[i] = (int)strtol(label,&endptr,10);
130
+		if(endptr == label)
131
+		{
132
+			mexPrintf("Wrong input format at line %d\n",i+1);
133
+			fake_answer(plhs);
134
+			return;
135
+		}
136
+
137
+		// features
138
+		while(1)
139
+		{
140
+			idx = strtok(NULL,":");
141
+			val = strtok(NULL," \t");
142
+			if(val == NULL)
143
+				break;
144
+
145
+			ir[k] = (mwIndex) (strtol(idx,&endptr,10) - min_index); // precomputed kernel has <index> start from 0
146
+
147
+			errno = 0;
148
+			samples[k] = strtod(val,&endptr);
149
+			if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
150
+			{
151
+				mexPrintf("Wrong input format at line %d\n",i+1);
152
+				fake_answer(plhs);
153
+				return;
154
+			}
155
+			++k;
156
+		}
157
+	}
158
+	jc[l] = k;
159
+
160
+	fclose(fp);
161
+	free(line);
162
+
163
+	{
164
+		mxArray *rhs[1], *lhs[1];
165
+		rhs[0] = plhs[1];
166
+		if(mexCallMATLAB(1, lhs, 1, rhs, "transpose"))
167
+		{
168
+			mexPrintf("Error: cannot transpose problem\n");
169
+			fake_answer(plhs);
170
+			return;
171
+		}
172
+		plhs[1] = lhs[0];
173
+	}
174
+}
175
+
176
+void mexFunction( int nlhs, mxArray *plhs[],
177
+		int nrhs, const mxArray *prhs[] )
178
+{
179
+	if(nrhs == 1)
180
+	{
181
+		char filename[256];
182
+
183
+		mxGetString(prhs[0], filename, mxGetN(prhs[0]) + 1);
184
+
185
+		if(filename == NULL)
186
+		{
187
+			mexPrintf("Error: filename is NULL\n");
188
+			return;
189
+		}
190
+
191
+		read_problem(filename, plhs);
192
+	}
193
+	else
194
+	{
195
+		exit_with_help();
196
+		fake_answer(plhs);
197
+		return;
198
+	}
199
+}
200
+
... ...
@@ -0,0 +1,3030 @@
1
+#include <math.h>
2
+#include <stdio.h>
3
+#include <stdlib.h>
4
+#include <ctype.h>
5
+#include <float.h>
6
+#include <string.h>
7
+#include <stdarg.h>
8
+#include "svm.h"
9
+typedef float Qfloat;
10
+typedef signed char schar;
11
+#ifndef min
12
+template <class T> inline T min(T x,T y) { return (x<y)?x:y; }
13
+#endif
14
+#ifndef max
15
+template <class T> inline T max(T x,T y) { return (x>y)?x:y; }
16
+#endif
17
+template <class T> inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
18
+template <class S, class T> inline void clone(T*& dst, S* src, int n)
19
+{
20
+	dst = new T[n];
21
+	memcpy((void *)dst,(void *)src,sizeof(T)*n);
22
+}
23
+inline double powi(double base, int times)
24
+{
25
+        double tmp = base, ret = 1.0;
26
+
27
+        for(int t=times; t>0; t/=2)
28
+	{
29
+                if(t%2==1) ret*=tmp;
30
+                tmp = tmp * tmp;
31
+        }
32
+        return ret;
33
+}
34
+#define INF HUGE_VAL
35
+#define TAU 1e-12
36
+#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
37
+#if 1
38
+static void info(const char *fmt,...)
39
+{
40
+	va_list ap;
41
+	va_start(ap,fmt);
42
+	vprintf(fmt,ap);
43
+	va_end(ap);
44
+}
45
+static void info_flush()
46
+{
47
+	fflush(stdout);
48
+}
49
+#else
50
+static void info(char *fmt,...) {}
51
+static void info_flush() {}
52
+#endif
53
+
54
+//
55
+// Kernel Cache
56
+//
57
+// l is the number of total data items
58
+// size is the cache size limit in bytes
59
+//
60
+class Cache
61
+{
62
+public:
63
+	Cache(int l,long int size);
64
+	~Cache();
65
+
66
+	// request data [0,len)
67
+	// return some position p where [p,len) need to be filled
68
+	// (p >= len if nothing needs to be filled)
69
+	int get_data(const int index, Qfloat **data, int len);
70
+	void swap_index(int i, int j);	
71
+private:
72
+	int l;
73
+	long int size;
74
+	struct head_t
75
+	{
76
+		head_t *prev, *next;	// a circular list
77
+		Qfloat *data;
78
+		int len;		// data[0,len) is cached in this entry
79
+	};
80
+
81
+	head_t *head;
82
+	head_t lru_head;
83
+	void lru_delete(head_t *h);
84
+	void lru_insert(head_t *h);
85
+};
86
+
87
+Cache::Cache(int l_,long int size_):l(l_),size(size_)
88
+{
89
+	head = (head_t *)calloc(l,sizeof(head_t));	// initialized to 0
90
+	size /= sizeof(Qfloat);
91
+	size -= l * sizeof(head_t) / sizeof(Qfloat);
92
+	size = max(size, 2 * (long int) l);	// cache must be large enough for two columns
93
+	lru_head.next = lru_head.prev = &lru_head;
94
+}
95
+
96
+Cache::~Cache()
97
+{
98
+	for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
99
+		free(h->data);
100
+	free(head);
101
+}
102
+
103
+void Cache::lru_delete(head_t *h)
104
+{
105
+	// delete from current location
106
+	h->prev->next = h->next;
107
+	h->next->prev = h->prev;
108
+}
109
+
110
+void Cache::lru_insert(head_t *h)
111
+{
112
+	// insert to last position
113
+	h->next = &lru_head;
114
+	h->prev = lru_head.prev;
115
+	h->prev->next = h;
116
+	h->next->prev = h;
117
+}
118
+
119
+int Cache::get_data(const int index, Qfloat **data, int len)
120
+{
121
+	head_t *h = &head[index];
122
+	if(h->len) lru_delete(h);
123
+	int more = len - h->len;
124
+
125
+	if(more > 0)
126
+	{
127
+		// free old space
128
+		while(size < more)
129
+		{
130
+			head_t *old = lru_head.next;
131
+			lru_delete(old);
132
+			free(old->data);
133
+			size += old->len;
134
+			old->data = 0;
135
+			old->len = 0;
136
+		}
137
+
138
+		// allocate new space
139
+		h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
140
+		size -= more;
141
+		swap(h->len,len);
142
+	}
143
+
144
+	lru_insert(h);
145
+	*data = h->data;
146
+	return len;
147
+}
148
+
149
+void Cache::swap_index(int i, int j)
150
+{
151
+	if(i==j) return;
152
+
153
+	if(head[i].len) lru_delete(&head[i]);
154
+	if(head[j].len) lru_delete(&head[j]);
155
+	swap(head[i].data,head[j].data);
156
+	swap(head[i].len,head[j].len);
157
+	if(head[i].len) lru_insert(&head[i]);
158
+	if(head[j].len) lru_insert(&head[j]);
159
+
160
+	if(i>j) swap(i,j);
161
+	for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
162
+	{
163
+		if(h->len > i)
164
+		{
165
+			if(h->len > j)
166
+				swap(h->data[i],h->data[j]);
167
+			else
168
+			{
169
+				// give up
170
+				lru_delete(h);
171
+				free(h->data);
172
+				size += h->len;
173
+				h->data = 0;
174
+				h->len = 0;
175
+			}
176
+		}
177
+	}
178
+}
179
+
180
+//
181
+// Kernel evaluation
182
+//
183
+// the static method k_function is for doing single kernel evaluation
184
+// the constructor of Kernel prepares to calculate the l*l kernel matrix
185
+// the member function get_Q is for getting one column from the Q Matrix
186
+//
187
+class QMatrix {
188
+public:
189
+	virtual Qfloat *get_Q(int column, int len) const = 0;
190
+	virtual Qfloat *get_QD() const = 0;
191
+	virtual void swap_index(int i, int j) const = 0;
192
+	virtual ~QMatrix() {}
193
+};
194
+
195
+class Kernel: public QMatrix {
196
+public:
197
+	Kernel(int l, svm_node * const * x, const svm_parameter& param);
198
+	virtual ~Kernel();
199
+
200
+	static double k_function(const svm_node *x, const svm_node *y,
201
+				 const svm_parameter& param);
202
+	virtual Qfloat *get_Q(int column, int len) const = 0;
203
+	virtual Qfloat *get_QD() const = 0;
204
+	virtual void swap_index(int i, int j) const	// no so const...
205
+	{
206
+		swap(x[i],x[j]);
207
+		if(x_square) swap(x_square[i],x_square[j]);
208
+	}
209
+protected:
210
+
211
+	double (Kernel::*kernel_function)(int i, int j) const;
212
+
213
+private:
214
+	const svm_node **x;
215
+	double *x_square;
216
+
217
+	// svm_parameter
218
+	const int kernel_type;
219
+	const int degree;
220
+	const double gamma;
221
+	const double coef0;
222
+
223
+	static double dot(const svm_node *px, const svm_node *py);
224
+	double kernel_linear(int i, int j) const
225
+	{
226
+		return dot(x[i],x[j]);
227
+	}
228
+	double kernel_poly(int i, int j) const
229
+	{
230
+		return powi(gamma*dot(x[i],x[j])+coef0,degree);
231
+	}
232
+	double kernel_rbf(int i, int j) const
233
+	{
234
+		return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
235
+	}
236
+	double kernel_sigmoid(int i, int j) const
237
+	{
238
+		return tanh(gamma*dot(x[i],x[j])+coef0);
239
+	}
240
+	double kernel_precomputed(int i, int j) const
241
+	{
242
+		return x[i][(int)(x[j][0].value)].value;
243
+	}
244
+};
245
+
246
+Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
247
+:kernel_type(param.kernel_type), degree(param.degree),
248
+ gamma(param.gamma), coef0(param.coef0)
249
+{
250
+	switch(kernel_type)
251
+	{
252
+		case LINEAR:
253
+			kernel_function = &Kernel::kernel_linear;
254
+			break;
255
+		case POLY:
256
+			kernel_function = &Kernel::kernel_poly;
257
+			break;
258
+		case RBF:
259
+			kernel_function = &Kernel::kernel_rbf;
260
+			break;
261
+		case SIGMOID:
262
+			kernel_function = &Kernel::kernel_sigmoid;
263
+			break;
264
+		case PRECOMPUTED:
265
+			kernel_function = &Kernel::kernel_precomputed;
266
+			break;
267
+	}
268
+
269
+	clone(x,x_,l);
270
+
271
+	if(kernel_type == RBF)
272
+	{
273
+		x_square = new double[l];
274
+		for(int i=0;i<l;i++)
275
+			x_square[i] = dot(x[i],x[i]);
276
+	}
277
+	else
278
+		x_square = 0;
279
+}
280
+
281
+Kernel::~Kernel()
282
+{
283
+	delete[] x;
284
+	delete[] x_square;
285
+}
286
+
287
+double Kernel::dot(const svm_node *px, const svm_node *py)
288
+{
289
+	double sum = 0;
290
+	while(px->index != -1 && py->index != -1)
291
+	{
292
+		if(px->index == py->index)
293
+		{
294
+			sum += px->value * py->value;
295
+			++px;
296
+			++py;
297
+		}
298
+		else
299
+		{
300
+			if(px->index > py->index)
301
+				++py;
302
+			else
303
+				++px;
304
+		}			
305
+	}
306
+	return sum;
307
+}
308
+
309
+double Kernel::k_function(const svm_node *x, const svm_node *y,
310
+			  const svm_parameter& param)
311
+{
312
+	switch(param.kernel_type)
313
+	{
314
+		case LINEAR:
315
+			return dot(x,y);
316
+		case POLY:
317
+			return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
318
+		case RBF:
319
+		{
320
+			double sum = 0;
321
+			while(x->index != -1 && y->index !=-1)
322
+			{
323
+				if(x->index == y->index)
324
+				{
325
+					double d = x->value - y->value;
326
+					sum += d*d;
327
+					++x;
328
+					++y;
329
+				}
330
+				else
331
+				{
332
+					if(x->index > y->index)
333
+					{	
334
+						sum += y->value * y->value;
335
+						++y;
336
+					}
337
+					else
338
+					{
339
+						sum += x->value * x->value;
340
+						++x;
341
+					}
342
+				}
343
+			}
344
+
345
+			while(x->index != -1)
346
+			{
347
+				sum += x->value * x->value;
348
+				++x;
349
+			}
350
+
351
+			while(y->index != -1)
352
+			{
353
+				sum += y->value * y->value;
354
+				++y;
355
+			}
356
+			
357
+			return exp(-param.gamma*sum);
358
+		}
359
+		case SIGMOID:
360
+			return tanh(param.gamma*dot(x,y)+param.coef0);
361
+		case PRECOMPUTED:  //x: test (validation), y: SV
362
+			return x[(int)(y->value)].value;
363
+		default:
364
+			return 0;  // Unreachable 
365
+	}
366
+}
367
+
368
+// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
369
+// Solves:
370
+//
371
+//	min 0.5(\alpha^T Q \alpha) + p^T \alpha
372
+//
373
+//		y^T \alpha = \delta
374
+//		y_i = +1 or -1
375
+//		0 <= alpha_i <= Cp for y_i = 1
376
+//		0 <= alpha_i <= Cn for y_i = -1
377
+//
378
+// Given:
379
+//
380
+//	Q, p, y, Cp, Cn, and an initial feasible point \alpha
381
+//	l is the size of vectors and matrices
382
+//	eps is the stopping tolerance
383
+//
384
+// solution will be put in \alpha, objective value will be put in obj
385
+//
386
+class Solver {
387
+public:
388
+	Solver() {};
389
+	virtual ~Solver() {};
390
+
391
+	struct SolutionInfo {
392
+		double obj;
393
+		double rho;
394
+		double upper_bound_p;
395
+		double upper_bound_n;
396
+		double r;	// for Solver_NU
397
+	};
398
+
399
+	void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
400
+		   double *alpha_, double Cp, double Cn, double eps,
401
+		   SolutionInfo* si, int shrinking);
402
+protected:
403
+	int active_size;
404
+	schar *y;
405
+	double *G;		// gradient of objective function
406
+	enum { LOWER_BOUND, UPPER_BOUND, FREE };
407
+	char *alpha_status;	// LOWER_BOUND, UPPER_BOUND, FREE
408
+	double *alpha;
409
+	const QMatrix *Q;
410
+	const Qfloat *QD;
411
+	double eps;
412
+	double Cp,Cn;
413
+	double *p;
414
+	int *active_set;
415
+	double *G_bar;		// gradient, if we treat free variables as 0
416
+	int l;
417
+	bool unshrink;	// XXX
418
+
419
+	double get_C(int i)
420
+	{
421
+		return (y[i] > 0)? Cp : Cn;
422
+	}
423
+	void update_alpha_status(int i)
424
+	{
425
+		if(alpha[i] >= get_C(i))
426
+			alpha_status[i] = UPPER_BOUND;
427
+		else if(alpha[i] <= 0)
428
+			alpha_status[i] = LOWER_BOUND;
429
+		else alpha_status[i] = FREE;
430
+	}
431
+	bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
432
+	bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
433
+	bool is_free(int i) { return alpha_status[i] == FREE; }
434
+	void swap_index(int i, int j);
435
+	void reconstruct_gradient();
436
+	virtual int select_working_set(int &i, int &j);
437
+	virtual double calculate_rho();
438
+	virtual void do_shrinking();
439
+private:
440
+	bool be_shrunk(int i, double Gmax1, double Gmax2);	
441
+};
442
+
443
+void Solver::swap_index(int i, int j)
444
+{
445
+	Q->swap_index(i,j);
446
+	swap(y[i],y[j]);
447
+	swap(G[i],G[j]);
448
+	swap(alpha_status[i],alpha_status[j]);
449
+	swap(alpha[i],alpha[j]);
450
+	swap(p[i],p[j]);
451
+	swap(active_set[i],active_set[j]);
452
+	swap(G_bar[i],G_bar[j]);
453
+}
454
+
455
+void Solver::reconstruct_gradient()
456
+{
457
+	// reconstruct inactive elements of G from G_bar and free variables
458
+
459
+	if(active_size == l) return;
460
+
461
+	int i,j;
462
+	int nr_free = 0;
463
+
464
+	for(j=active_size;j<l;j++)
465
+		G[j] = G_bar[j] + p[j];
466
+
467
+	for(j=0;j<active_size;j++)
468
+		if(is_free(j))
469
+			nr_free++;
470
+
471
+	if(2*nr_free < active_size)
472
+		info("\nWarning: using -h 0 may be faster\n");
473
+
474
+	if (nr_free*l > 2*active_size*(l-active_size))
475
+	{
476
+		for(i=active_size;i<l;i++)
477
+		{
478
+			const Qfloat *Q_i = Q->get_Q(i,active_size);
479
+			for(j=0;j<active_size;j++)
480
+				if(is_free(j))
481
+					G[i] += alpha[j] * Q_i[j];
482
+		}
483
+	}
484
+	else
485
+	{
486
+		for(i=0;i<active_size;i++)
487
+			if(is_free(i))
488
+			{
489
+				const Qfloat *Q_i = Q->get_Q(i,l);
490
+				double alpha_i = alpha[i];
491
+				for(j=active_size;j<l;j++)
492
+					G[j] += alpha_i * Q_i[j];
493
+			}
494
+	}
495
+}
496
+
497
+void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
498
+		   double *alpha_, double Cp, double Cn, double eps,
499
+		   SolutionInfo* si, int shrinking)
500
+{
501
+	this->l = l;
502
+	this->Q = &Q;
503
+	QD=Q.get_QD();
504
+	clone(p, p_,l);
505
+	clone(y, y_,l);
506
+	clone(alpha,alpha_,l);
507
+	this->Cp = Cp;
508
+	this->Cn = Cn;
509
+	this->eps = eps;
510
+	unshrink = false;
511
+
512
+	// initialize alpha_status
513
+	{
514
+		alpha_status = new char[l];
515
+		for(int i=0;i<l;i++)
516
+			update_alpha_status(i);
517
+	}
518
+
519
+	// initialize active set (for shrinking)
520
+	{
521
+		active_set = new int[l];
522
+		for(int i=0;i<l;i++)
523
+			active_set[i] = i;
524
+		active_size = l;
525
+	}
526
+
527
+	// initialize gradient
528
+	{
529
+		G = new double[l];
530
+		G_bar = new double[l];
531
+		int i;
532
+		for(i=0;i<l;i++)
533
+		{
534
+			G[i] = p[i];
535
+			G_bar[i] = 0;
536
+		}
537
+		for(i=0;i<l;i++)
538
+			if(!is_lower_bound(i))
539
+			{
540
+				const Qfloat *Q_i = Q.get_Q(i,l);
541
+				double alpha_i = alpha[i];
542
+				int j;
543
+				for(j=0;j<l;j++)
544
+					G[j] += alpha_i*Q_i[j];
545
+				if(is_upper_bound(i))
546
+					for(j=0;j<l;j++)
547
+						G_bar[j] += get_C(i) * Q_i[j];
548
+			}
549
+	}
550
+
551
+	// optimization step
552
+
553
+	int iter = 0;
554
+	int counter = min(l,1000)+1;
555
+
556
+	while(1)
557
+	{
558
+		// show progress and do shrinking
559
+
560
+		if(--counter == 0)
561
+		{
562
+			counter = min(l,1000);
563
+			if(shrinking) do_shrinking();
564
+			info("."); info_flush();
565
+		}
566
+
567
+		int i,j;
568
+		if(select_working_set(i,j)!=0)
569
+		{
570
+			// reconstruct the whole gradient
571
+			reconstruct_gradient();
572
+			// reset active set size and check
573
+			active_size = l;
574
+			info("*"); info_flush();
575
+			if(select_working_set(i,j)!=0)
576
+				break;
577
+			else
578
+				counter = 1;	// do shrinking next iteration
579
+		}
580
+		
581
+		++iter;
582
+
583
+		// update alpha[i] and alpha[j], handle bounds carefully
584
+		
585
+		const Qfloat *Q_i = Q.get_Q(i,active_size);
586
+		const Qfloat *Q_j = Q.get_Q(j,active_size);
587
+
588
+		double C_i = get_C(i);
589
+		double C_j = get_C(j);
590
+
591
+		double old_alpha_i = alpha[i];
592
+		double old_alpha_j = alpha[j];
593
+
594
+		if(y[i]!=y[j])
595
+		{
596
+			double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j];
597
+			if (quad_coef <= 0)
598
+				quad_coef = TAU;
599
+			double delta = (-G[i]-G[j])/quad_coef;
600
+			double diff = alpha[i] - alpha[j];
601
+			alpha[i] += delta;
602
+			alpha[j] += delta;
603
+			
604
+			if(diff > 0)
605
+			{
606
+				if(alpha[j] < 0)
607
+				{
608
+					alpha[j] = 0;
609
+					alpha[i] = diff;
610
+				}
611
+			}
612
+			else
613
+			{
614
+				if(alpha[i] < 0)
615
+				{
616
+					alpha[i] = 0;
617
+					alpha[j] = -diff;
618
+				}
619
+			}
620
+			if(diff > C_i - C_j)
621
+			{
622
+				if(alpha[i] > C_i)
623
+				{
624
+					alpha[i] = C_i;
625
+					alpha[j] = C_i - diff;
626
+				}
627
+			}
628
+			else
629
+			{
630
+				if(alpha[j] > C_j)
631
+				{
632
+					alpha[j] = C_j;
633
+					alpha[i] = C_j + diff;
634
+				}
635
+			}
636
+		}
637
+		else
638
+		{
639
+			double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j];
640
+			if (quad_coef <= 0)
641
+				quad_coef = TAU;
642
+			double delta = (G[i]-G[j])/quad_coef;
643
+			double sum = alpha[i] + alpha[j];
644
+			alpha[i] -= delta;
645
+			alpha[j] += delta;
646
+
647
+			if(sum > C_i)
648
+			{
649
+				if(alpha[i] > C_i)
650
+				{
651
+					alpha[i] = C_i;
652
+					alpha[j] = sum - C_i;
653
+				}
654
+			}
655
+			else
656
+			{
657
+				if(alpha[j] < 0)
658
+				{
659
+					alpha[j] = 0;
660
+					alpha[i] = sum;
661
+				}
662
+			}
663
+			if(sum > C_j)
664
+			{
665
+				if(alpha[j] > C_j)
666
+				{
667
+					alpha[j] = C_j;
668
+					alpha[i] = sum - C_j;
669
+				}
670
+			}
671
+			else
672
+			{
673
+				if(alpha[i] < 0)
674
+				{
675
+					alpha[i] = 0;
676
+					alpha[j] = sum;
677
+				}
678
+			}
679
+		}
680
+
681
+		// update G
682
+
683
+		double delta_alpha_i = alpha[i] - old_alpha_i;
684
+		double delta_alpha_j = alpha[j] - old_alpha_j;
685
+		
686
+		for(int k=0;k<active_size;k++)
687
+		{
688
+			G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
689
+		}
690
+
691
+		// update alpha_status and G_bar
692
+
693
+		{
694
+			bool ui = is_upper_bound(i);
695
+			bool uj = is_upper_bound(j);
696
+			update_alpha_status(i);
697
+			update_alpha_status(j);
698
+			int k;
699
+			if(ui != is_upper_bound(i))
700
+			{
701
+				Q_i = Q.get_Q(i,l);
702
+				if(ui)
703
+					for(k=0;k<l;k++)
704
+						G_bar[k] -= C_i * Q_i[k];
705
+				else
706
+					for(k=0;k<l;k++)
707
+						G_bar[k] += C_i * Q_i[k];
708
+			}
709
+
710
+			if(uj != is_upper_bound(j))
711
+			{
712
+				Q_j = Q.get_Q(j,l);
713
+				if(uj)
714
+					for(k=0;k<l;k++)
715
+						G_bar[k] -= C_j * Q_j[k];
716
+				else
717
+					for(k=0;k<l;k++)
718
+						G_bar[k] += C_j * Q_j[k];
719
+			}
720
+		}
721
+	}
722
+
723
+	// calculate rho
724
+
725
+	si->rho = calculate_rho();
726
+
727
+	// calculate objective value
728
+	{
729
+		double v = 0;
730
+		int i;
731
+		for(i=0;i<l;i++)
732
+			v += alpha[i] * (G[i] + p[i]);
733
+
734
+		si->obj = v/2;
735
+	}
736
+
737
+	// put back the solution
738
+	{
739
+		for(int i=0;i<l;i++)
740
+			alpha_[active_set[i]] = alpha[i];
741
+	}
742
+
743
+	// juggle everything back
744
+	/*{
745
+		for(int i=0;i<l;i++)
746
+			while(active_set[i] != i)
747
+				swap_index(i,active_set[i]);
748
+				// or Q.swap_index(i,active_set[i]);
749
+	}*/
750
+
751
+	si->upper_bound_p = Cp;
752
+	si->upper_bound_n = Cn;
753
+
754
+	info("\noptimization finished, #iter = %d\n",iter);
755
+
756
+	delete[] p;
757
+	delete[] y;
758
+	delete[] alpha;
759
+	delete[] alpha_status;
760
+	delete[] active_set;
761
+	delete[] G;
762
+	delete[] G_bar;
763
+}
764
+
765
+// return 1 if already optimal, return 0 otherwise
766
+int Solver::select_working_set(int &out_i, int &out_j)
767
+{
768
+	// return i,j such that
769
+	// i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
770
+	// j: minimizes the decrease of obj value
771
+	//    (if quadratic coefficeint <= 0, replace it with tau)
772
+	//    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
773
+	
774
+	double Gmax = -INF;
775
+	double Gmax2 = -INF;
776
+	int Gmax_idx = -1;
777
+	int Gmin_idx = -1;
778
+	double obj_diff_min = INF;
779
+
780
+	for(int t=0;t<active_size;t++)
781
+		if(y[t]==+1)	
782
+		{
783
+			if(!is_upper_bound(t))
784
+				if(-G[t] >= Gmax)
785
+				{
786
+					Gmax = -G[t];
787
+					Gmax_idx = t;
788
+				}
789
+		}
790
+		else
791
+		{
792
+			if(!is_lower_bound(t))
793
+				if(G[t] >= Gmax)
794
+				{
795
+					Gmax = G[t];
796
+					Gmax_idx = t;
797
+				}
798
+		}
799
+
800
+	int i = Gmax_idx;
801
+	const Qfloat *Q_i = NULL;
802
+	if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
803
+		Q_i = Q->get_Q(i,active_size);
804
+
805
+	for(int j=0;j<active_size;j++)
806
+	{
807
+		if(y[j]==+1)
808
+		{
809
+			if (!is_lower_bound(j))
810
+			{
811
+				double grad_diff=Gmax+G[j];
812
+				if (G[j] >= Gmax2)
813
+					Gmax2 = G[j];
814
+				if (grad_diff > 0)
815
+				{
816
+					double obj_diff; 
817
+					double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j];
818
+					if (quad_coef > 0)
819
+						obj_diff = -(grad_diff*grad_diff)/quad_coef;
820
+					else
821
+						obj_diff = -(grad_diff*grad_diff)/TAU;
822
+
823
+					if (obj_diff <= obj_diff_min)
824
+					{
825
+						Gmin_idx=j;
826
+						obj_diff_min = obj_diff;
827
+					}
828
+				}
829
+			}
830
+		}
831
+		else
832
+		{
833
+			if (!is_upper_bound(j))
834
+			{
835
+				double grad_diff= Gmax-G[j];
836
+				if (-G[j] >= Gmax2)
837
+					Gmax2 = -G[j];
838
+				if (grad_diff > 0)
839
+				{
840
+					double obj_diff; 
841
+					double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j];
842
+					if (quad_coef > 0)
843
+						obj_diff = -(grad_diff*grad_diff)/quad_coef;
844
+					else
845
+						obj_diff = -(grad_diff*grad_diff)/TAU;
846
+
847
+					if (obj_diff <= obj_diff_min)
848
+					{
849
+						Gmin_idx=j;
850
+						obj_diff_min = obj_diff;
851
+					}
852
+				}
853
+			}
854
+		}
855
+	}
856
+
857
+	if(Gmax+Gmax2 < eps)
858
+		return 1;
859
+
860
+	out_i = Gmax_idx;
861
+	out_j = Gmin_idx;
862
+	return 0;
863
+}
864
+
865
+bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
866
+{
867
+	if(is_upper_bound(i))
868
+	{
869
+		if(y[i]==+1)
870
+			return(-G[i] > Gmax1);
871
+		else
872
+			return(-G[i] > Gmax2);
873
+	}
874
+	else if(is_lower_bound(i))
875
+	{
876
+		if(y[i]==+1)
877
+			return(G[i] > Gmax2);
878
+		else	
879
+			return(G[i] > Gmax1);
880
+	}
881
+	else
882
+		return(false);
883
+}
884
+
885
+void Solver::do_shrinking()
886
+{
887
+	int i;
888
+	double Gmax1 = -INF;		// max { -y_i * grad(f)_i | i in I_up(\alpha) }
889
+	double Gmax2 = -INF;		// max { y_i * grad(f)_i | i in I_low(\alpha) }
890
+
891
+	// find maximal violating pair first
892
+	for(i=0;i<active_size;i++)
893
+	{
894
+		if(y[i]==+1)	
895
+		{
896
+			if(!is_upper_bound(i))	
897
+			{
898
+				if(-G[i] >= Gmax1)
899
+					Gmax1 = -G[i];
900
+			}
901
+			if(!is_lower_bound(i))	
902
+			{
903
+				if(G[i] >= Gmax2)
904
+					Gmax2 = G[i];
905
+			}
906
+		}
907
+		else	
908
+		{
909
+			if(!is_upper_bound(i))	
910
+			{
911
+				if(-G[i] >= Gmax2)
912
+					Gmax2 = -G[i];
913
+			}
914
+			if(!is_lower_bound(i))	
915
+			{
916
+				if(G[i] >= Gmax1)
917
+					Gmax1 = G[i];
918
+			}
919
+		}
920
+	}
921
+
922
+	if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
923
+	{
924
+		unshrink = true;
925
+		reconstruct_gradient();
926
+		active_size = l;
927
+		info("*"); info_flush();
928
+	}
929
+
930
+	for(i=0;i<active_size;i++)
931
+		if (be_shrunk(i, Gmax1, Gmax2))
932
+		{
933
+			active_size--;
934
+			while (active_size > i)
935
+			{
936
+				if (!be_shrunk(active_size, Gmax1, Gmax2))
937
+				{
938
+					swap_index(i,active_size);
939
+					break;
940
+				}
941
+				active_size--;
942
+			}
943
+		}
944
+}
945
+
946
+double Solver::calculate_rho()
947
+{
948
+	double r;
949
+	int nr_free = 0;
950
+	double ub = INF, lb = -INF, sum_free = 0;
951
+	for(int i=0;i<active_size;i++)
952
+	{
953
+		double yG = y[i]*G[i];
954
+
955
+		if(is_upper_bound(i))
956
+		{
957
+			if(y[i]==-1)
958
+				ub = min(ub,yG);
959
+			else
960
+				lb = max(lb,yG);
961
+		}
962
+		else if(is_lower_bound(i))
963
+		{
964
+			if(y[i]==+1)
965
+				ub = min(ub,yG);
966
+			else
967
+				lb = max(lb,yG);
968
+		}
969
+		else
970
+		{
971
+			++nr_free;
972
+			sum_free += yG;
973
+		}
974
+	}
975
+
976
+	if(nr_free>0)
977
+		r = sum_free/nr_free;
978
+	else
979
+		r = (ub+lb)/2;
980
+
981
+	return r;
982
+}
983
+
984
+//
985
+// Solver for nu-svm classification and regression
986
+//
987
+// additional constraint: e^T \alpha = constant
988
+//
989
+class Solver_NU : public Solver
990
+{
991
+public:
992
+	Solver_NU() {}
993
+	void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
994
+		   double *alpha, double Cp, double Cn, double eps,
995
+		   SolutionInfo* si, int shrinking)
996
+	{
997
+		this->si = si;
998
+		Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
999
+	}
1000
+private:
1001
+	SolutionInfo *si;
1002
+	int select_working_set(int &i, int &j);
1003
+	double calculate_rho();
1004
+	bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1005
+	void do_shrinking();
1006
+};
1007
+
1008
+// return 1 if already optimal, return 0 otherwise
1009
+int Solver_NU::select_working_set(int &out_i, int &out_j)
1010
+{
1011
+	// return i,j such that y_i = y_j and
1012
+	// i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1013
+	// j: minimizes the decrease of obj value
1014
+	//    (if quadratic coefficeint <= 0, replace it with tau)
1015
+	//    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1016
+
1017
+	double Gmaxp = -INF;
1018
+	double Gmaxp2 = -INF;
1019
+	int Gmaxp_idx = -1;
1020
+
1021
+	double Gmaxn = -INF;
1022
+	double Gmaxn2 = -INF;
1023
+	int Gmaxn_idx = -1;
1024
+
1025
+	int Gmin_idx = -1;
1026
+	double obj_diff_min = INF;
1027
+
1028
+	for(int t=0;t<active_size;t++)
1029
+		if(y[t]==+1)
1030
+		{
1031
+			if(!is_upper_bound(t))
1032
+				if(-G[t] >= Gmaxp)
1033
+				{
1034
+					Gmaxp = -G[t];
1035
+					Gmaxp_idx = t;
1036
+				}
1037
+		}
1038
+		else
1039
+		{
1040
+			if(!is_lower_bound(t))
1041
+				if(G[t] >= Gmaxn)
1042
+				{
1043
+					Gmaxn = G[t];
1044
+					Gmaxn_idx = t;
1045
+				}
1046
+		}
1047
+
1048
+	int ip = Gmaxp_idx;
1049
+	int in = Gmaxn_idx;
1050
+	const Qfloat *Q_ip = NULL;
1051
+	const Qfloat *Q_in = NULL;
1052
+	if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1053
+		Q_ip = Q->get_Q(ip,active_size);
1054
+	if(in != -1)
1055
+		Q_in = Q->get_Q(in,active_size);
1056
+
1057
+	for(int j=0;j<active_size;j++)
1058
+	{
1059
+		if(y[j]==+1)
1060
+		{
1061
+			if (!is_lower_bound(j))	
1062
+			{
1063
+				double grad_diff=Gmaxp+G[j];
1064
+				if (G[j] >= Gmaxp2)
1065
+					Gmaxp2 = G[j];
1066
+				if (grad_diff > 0)
1067
+				{
1068
+					double obj_diff; 
1069
+					double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j];
1070
+					if (quad_coef > 0)
1071
+						obj_diff = -(grad_diff*grad_diff)/quad_coef;
1072
+					else
1073
+						obj_diff = -(grad_diff*grad_diff)/TAU;
1074
+
1075
+					if (obj_diff <= obj_diff_min)
1076
+					{
1077
+						Gmin_idx=j;
1078
+						obj_diff_min = obj_diff;
1079
+					}
1080
+				}
1081
+			}
1082
+		}
1083
+		else
1084
+		{
1085
+			if (!is_upper_bound(j))
1086
+			{
1087
+				double grad_diff=Gmaxn-G[j];
1088
+				if (-G[j] >= Gmaxn2)
1089
+					Gmaxn2 = -G[j];
1090
+				if (grad_diff > 0)
1091
+				{
1092
+					double obj_diff; 
1093
+					double quad_coef = Q_in[in]+QD[j]-2*Q_in[j];
1094
+					if (quad_coef > 0)
1095
+						obj_diff = -(grad_diff*grad_diff)/quad_coef;
1096
+					else
1097
+						obj_diff = -(grad_diff*grad_diff)/TAU;
1098
+
1099
+					if (obj_diff <= obj_diff_min)
1100
+					{
1101
+						Gmin_idx=j;
1102
+						obj_diff_min = obj_diff;
1103
+					}
1104
+				}
1105
+			}
1106
+		}
1107
+	}
1108
+
1109
+	if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1110
+ 		return 1;
1111
+
1112
+	if (y[Gmin_idx] == +1)
1113
+		out_i = Gmaxp_idx;
1114
+	else
1115
+		out_i = Gmaxn_idx;
1116
+	out_j = Gmin_idx;
1117
+
1118
+	return 0;
1119
+}
1120
+
1121
+bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1122
+{
1123
+	if(is_upper_bound(i))
1124
+	{
1125
+		if(y[i]==+1)
1126
+			return(-G[i] > Gmax1);
1127
+		else	
1128
+			return(-G[i] > Gmax4);
1129
+	}
1130
+	else if(is_lower_bound(i))
1131
+	{
1132
+		if(y[i]==+1)
1133
+			return(G[i] > Gmax2);
1134
+		else	
1135
+			return(G[i] > Gmax3);
1136
+	}
1137
+	else
1138
+		return(false);
1139
+}
1140
+
1141
+void Solver_NU::do_shrinking()
1142
+{
1143
+	double Gmax1 = -INF;	// max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1144
+	double Gmax2 = -INF;	// max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1145
+	double Gmax3 = -INF;	// max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1146
+	double Gmax4 = -INF;	// max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1147
+
1148
+	// find maximal violating pair first
1149
+	int i;
1150
+	for(i=0;i<active_size;i++)
1151
+	{
1152
+		if(!is_upper_bound(i))
1153
+		{
1154
+			if(y[i]==+1)
1155
+			{
1156
+				if(-G[i] > Gmax1) Gmax1 = -G[i];
1157
+			}
1158
+			else	if(-G[i] > Gmax4) Gmax4 = -G[i];
1159
+		}
1160
+		if(!is_lower_bound(i))
1161
+		{
1162
+			if(y[i]==+1)
1163
+			{	
1164
+				if(G[i] > Gmax2) Gmax2 = G[i];
1165
+			}
1166
+			else	if(G[i] > Gmax3) Gmax3 = G[i];
1167
+		}
1168
+	}
1169
+
1170
+	if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
1171
+	{
1172
+		unshrink = true;
1173
+		reconstruct_gradient();
1174
+		active_size = l;
1175
+	}
1176
+
1177
+	for(i=0;i<active_size;i++)
1178
+		if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1179
+		{
1180
+			active_size--;
1181
+			while (active_size > i)
1182
+			{
1183
+				if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1184
+				{
1185
+					swap_index(i,active_size);
1186
+					break;
1187
+				}
1188
+				active_size--;
1189
+			}
1190
+		}
1191
+}
1192
+
1193
+double Solver_NU::calculate_rho()
1194
+{
1195
+	int nr_free1 = 0,nr_free2 = 0;
1196
+	double ub1 = INF, ub2 = INF;
1197
+	double lb1 = -INF, lb2 = -INF;
1198
+	double sum_free1 = 0, sum_free2 = 0;
1199
+
1200
+	for(int i=0;i<active_size;i++)
1201
+	{
1202
+		if(y[i]==+1)
1203
+		{
1204
+			if(is_upper_bound(i))
1205
+				lb1 = max(lb1,G[i]);
1206
+			else if(is_lower_bound(i))
1207
+				ub1 = min(ub1,G[i]);
1208
+			else
1209
+			{
1210
+				++nr_free1;
1211
+				sum_free1 += G[i];
1212
+			}
1213
+		}
1214
+		else
1215
+		{
1216
+			if(is_upper_bound(i))
1217
+				lb2 = max(lb2,G[i]);
1218
+			else if(is_lower_bound(i))
1219
+				ub2 = min(ub2,G[i]);
1220
+			else
1221
+			{
1222
+				++nr_free2;
1223
+				sum_free2 += G[i];
1224
+			}
1225
+		}
1226
+	}
1227
+
1228
+	double r1,r2;
1229
+	if(nr_free1 > 0)
1230
+		r1 = sum_free1/nr_free1;
1231
+	else
1232
+		r1 = (ub1+lb1)/2;
1233
+	
1234
+	if(nr_free2 > 0)
1235
+		r2 = sum_free2/nr_free2;
1236
+	else
1237
+		r2 = (ub2+lb2)/2;
1238
+	
1239
+	si->r = (r1+r2)/2;
1240
+	return (r1-r2)/2;
1241
+}
1242
+
1243
+//
1244
+// Q matrices for various formulations
1245
+//
1246
+class SVC_Q: public Kernel
1247
+{ 
1248
+public:
1249
+	SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1250
+	:Kernel(prob.l, prob.x, param)
1251
+	{
1252
+		clone(y,y_,prob.l);
1253
+		cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1254
+		QD = new Qfloat[prob.l];
1255
+		for(int i=0;i<prob.l;i++)
1256
+			QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1257
+	}
1258
+	
1259
+	Qfloat *get_Q(int i, int len) const
1260
+	{
1261
+		Qfloat *data;
1262
+		int start, j;
1263
+		if((start = cache->get_data(i,&data,len)) < len)
1264
+		{
1265
+			for(j=start;j<len;j++)
1266
+				data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1267
+		}
1268
+		return data;
1269
+	}
1270
+
1271
+	Qfloat *get_QD() const
1272
+	{
1273
+		return QD;
1274
+	}
1275
+
1276
+	void swap_index(int i, int j) const
1277
+	{
1278
+		cache->swap_index(i,j);
1279
+		Kernel::swap_index(i,j);
1280
+		swap(y[i],y[j]);
1281
+		swap(QD[i],QD[j]);
1282
+	}
1283
+
1284
+	~SVC_Q()
1285
+	{
1286
+		delete[] y;
1287
+		delete cache;
1288
+		delete[] QD;
1289
+	}
1290
+private:
1291
+	schar *y;
1292
+	Cache *cache;
1293
+	Qfloat *QD;
1294
+};
1295
+
1296
+class ONE_CLASS_Q: public Kernel
1297
+{
1298
+public:
1299
+	ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1300
+	:Kernel(prob.l, prob.x, param)
1301
+	{
1302
+		cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1303
+		QD = new Qfloat[prob.l];
1304
+		for(int i=0;i<prob.l;i++)
1305
+			QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1306
+	}
1307
+	
1308
+	Qfloat *get_Q(int i, int len) const
1309
+	{
1310
+		Qfloat *data;
1311
+		int start, j;
1312
+		if((start = cache->get_data(i,&data,len)) < len)
1313
+		{
1314
+			for(j=start;j<len;j++)
1315
+				data[j] = (Qfloat)(this->*kernel_function)(i,j);
1316
+		}
1317
+		return data;
1318
+	}
1319
+
1320
+	Qfloat *get_QD() const
1321
+	{
1322
+		return QD;
1323
+	}
1324
+
1325
+	void swap_index(int i, int j) const
1326
+	{
1327
+		cache->swap_index(i,j);
1328
+		Kernel::swap_index(i,j);
1329
+		swap(QD[i],QD[j]);
1330
+	}
1331
+
1332
+	~ONE_CLASS_Q()
1333
+	{
1334
+		delete cache;
1335
+		delete[] QD;
1336
+	}
1337
+private:
1338
+	Cache *cache;
1339
+	Qfloat *QD;
1340
+};
1341
+
1342
+class SVR_Q: public Kernel
1343
+{ 
1344
+public:
1345
+	SVR_Q(const svm_problem& prob, const svm_parameter& param)
1346
+	:Kernel(prob.l, prob.x, param)
1347
+	{
1348
+		l = prob.l;
1349
+		cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1350
+		QD = new Qfloat[2*l];
1351
+		sign = new schar[2*l];
1352
+		index = new int[2*l];
1353
+		for(int k=0;k<l;k++)
1354
+		{
1355
+			sign[k] = 1;
1356
+			sign[k+l] = -1;
1357
+			index[k] = k;
1358
+			index[k+l] = k;
1359
+			QD[k]= (Qfloat)(this->*kernel_function)(k,k);
1360
+			QD[k+l]=QD[k];
1361
+		}
1362
+		buffer[0] = new Qfloat[2*l];
1363
+		buffer[1] = new Qfloat[2*l];
1364
+		next_buffer = 0;
1365
+	}
1366
+
1367
+	void swap_index(int i, int j) const
1368
+	{
1369
+		swap(sign[i],sign[j]);
1370
+		swap(index[i],index[j]);
1371
+		swap(QD[i],QD[j]);
1372
+	}
1373
+	
1374
+	Qfloat *get_Q(int i, int len) const
1375
+	{
1376
+		Qfloat *data;
1377
+		int j, real_i = index[i];
1378
+		if(cache->get_data(real_i,&data,l) < l)
1379
+		{
1380
+			for(j=0;j<l;j++)
1381
+				data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1382
+		}
1383
+
1384
+		// reorder and copy
1385
+		Qfloat *buf = buffer[next_buffer];
1386
+		next_buffer = 1 - next_buffer;
1387
+		schar si = sign[i];
1388
+		for(int j=0;j<len;j++)
1389
+			buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1390
+		return buf;
1391
+	}
1392
+
1393
+	Qfloat *get_QD() const
1394
+	{
1395
+		return QD;
1396
+	}
1397
+
1398
+	~SVR_Q()
1399
+	{
1400
+		delete cache;
1401
+		delete[] sign;
1402
+		delete[] index;
1403
+		delete[] buffer[0];
1404
+		delete[] buffer[1];
1405
+		delete[] QD;
1406
+	}
1407
+private:
1408
+	int l;
1409
+	Cache *cache;
1410
+	schar *sign;
1411
+	int *index;
1412
+	mutable int next_buffer;
1413
+	Qfloat *buffer[2];
1414
+	Qfloat *QD;
1415
+};
1416
+
1417
+//
1418
+// construct and solve various formulations
1419
+//
1420
+static void solve_c_svc(
1421
+	const svm_problem *prob, const svm_parameter* param,
1422
+	double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1423
+{
1424
+	int l = prob->l;
1425
+	double *minus_ones = new double[l];
1426
+	schar *y = new schar[l];
1427
+
1428
+	int i;
1429
+
1430
+	for(i=0;i<l;i++)
1431
+	{
1432
+		alpha[i] = 0;
1433
+		minus_ones[i] = -1;
1434
+		if(prob->y[i] > 0) y[i] = +1; else y[i]=-1;
1435
+	}
1436
+
1437
+	Solver s;
1438
+	s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1439
+		alpha, Cp, Cn, param->eps, si, param->shrinking);
1440
+
1441
+	double sum_alpha=0;
1442
+	for(i=0;i<l;i++)
1443
+		sum_alpha += alpha[i];
1444
+
1445
+	if (Cp==Cn)
1446
+		info("nu = %f\n", sum_alpha/(Cp*prob->l));
1447
+
1448
+	for(i=0;i<l;i++)
1449
+		alpha[i] *= y[i];
1450
+
1451
+	delete[] minus_ones;
1452
+	delete[] y;
1453
+}
1454
+
1455
+static void solve_nu_svc(
1456
+	const svm_problem *prob, const svm_parameter *param,
1457
+	double *alpha, Solver::SolutionInfo* si)
1458
+{
1459
+	int i;
1460
+	int l = prob->l;
1461
+	double nu = param->nu;
1462
+
1463
+	schar *y = new schar[l];
1464
+
1465
+	for(i=0;i<l;i++)
1466
+		if(prob->y[i]>0)
1467
+			y[i] = +1;
1468
+		else
1469
+			y[i] = -1;
1470
+
1471
+	double sum_pos = nu*l/2;
1472
+	double sum_neg = nu*l/2;
1473
+
1474
+	for(i=0;i<l;i++)
1475
+		if(y[i] == +1)
1476
+		{
1477
+			alpha[i] = min(1.0,sum_pos);
1478
+			sum_pos -= alpha[i];
1479
+		}
1480
+		else
1481
+		{
1482
+			alpha[i] = min(1.0,sum_neg);
1483
+			sum_neg -= alpha[i];
1484
+		}
1485
+
1486
+	double *zeros = new double[l];
1487
+
1488
+	for(i=0;i<l;i++)
1489
+		zeros[i] = 0;
1490
+
1491
+	Solver_NU s;
1492
+	s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1493
+		alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
1494
+	double r = si->r;
1495
+
1496
+	info("C = %f\n",1/r);
1497
+
1498
+	for(i=0;i<l;i++)
1499
+		alpha[i] *= y[i]/r;
1500
+
1501
+	si->rho /= r;
1502
+	si->obj /= (r*r);
1503
+	si->upper_bound_p = 1/r;
1504
+	si->upper_bound_n = 1/r;
1505
+
1506
+	delete[] y;
1507
+	delete[] zeros;
1508
+}
1509
+
1510
+static void solve_one_class(
1511
+	const svm_problem *prob, const svm_parameter *param,
1512
+	double *alpha, Solver::SolutionInfo* si)
1513
+{
1514
+	int l = prob->l;
1515
+	double *zeros = new double[l];
1516
+	schar *ones = new schar[l];
1517
+	int i;
1518
+
1519
+	int n = (int)(param->nu*prob->l);	// # of alpha's at upper bound
1520
+
1521
+	for(i=0;i<n;i++)
1522
+		alpha[i] = 1;
1523
+	if(n<prob->l)
1524
+		alpha[n] = param->nu * prob->l - n;
1525
+	for(i=n+1;i<l;i++)
1526
+		alpha[i] = 0;
1527
+
1528
+	for(i=0;i<l;i++)
1529
+	{
1530
+		zeros[i] = 0;
1531
+		ones[i] = 1;
1532
+	}
1533
+
1534
+	Solver s;
1535
+	s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1536
+		alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1537
+
1538
+	delete[] zeros;
1539
+	delete[] ones;
1540
+}
1541
+
1542
+static void solve_epsilon_svr(
1543
+	const svm_problem *prob, const svm_parameter *param,
1544
+	double *alpha, Solver::SolutionInfo* si)
1545
+{
1546
+	int l = prob->l;
1547
+	double *alpha2 = new double[2*l];
1548
+	double *linear_term = new double[2*l];
1549
+	schar *y = new schar[2*l];
1550
+	int i;
1551
+
1552
+	for(i=0;i<l;i++)
1553
+	{
1554
+		alpha2[i] = 0;
1555
+		linear_term[i] = param->p - prob->y[i];
1556
+		y[i] = 1;
1557
+
1558
+		alpha2[i+l] = 0;
1559
+		linear_term[i+l] = param->p + prob->y[i];
1560
+		y[i+l] = -1;
1561
+	}
1562
+
1563
+	Solver s;
1564
+	s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1565
+		alpha2, param->C, param->C, param->eps, si, param->shrinking);
1566
+
1567
+	double sum_alpha = 0;
1568
+	for(i=0;i<l;i++)
1569
+	{
1570
+		alpha[i] = alpha2[i] - alpha2[i+l];
1571
+		sum_alpha += fabs(alpha[i]);
1572
+	}
1573
+	info("nu = %f\n",sum_alpha/(param->C*l));
1574
+
1575
+	delete[] alpha2;
1576
+	delete[] linear_term;
1577
+	delete[] y;
1578
+}
1579
+
1580
+static void solve_nu_svr(
1581
+	const svm_problem *prob, const svm_parameter *param,
1582
+	double *alpha, Solver::SolutionInfo* si)
1583
+{
1584
+	int l = prob->l;
1585
+	double C = param->C;
1586
+	double *alpha2 = new double[2*l];
1587
+	double *linear_term = new double[2*l];
1588
+	schar *y = new schar[2*l];
1589
+	int i;
1590
+
1591
+	double sum = C * param->nu * l / 2;
1592
+	for(i=0;i<l;i++)
1593
+	{
1594
+		alpha2[i] = alpha2[i+l] = min(sum,C);
1595
+		sum -= alpha2[i];
1596
+
1597
+		linear_term[i] = - prob->y[i];
1598
+		y[i] = 1;
1599
+
1600
+		linear_term[i+l] = prob->y[i];
1601
+		y[i+l] = -1;
1602
+	}
1603
+
1604
+	Solver_NU s;
1605
+	s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1606
+		alpha2, C, C, param->eps, si, param->shrinking);
1607
+
1608
+	info("epsilon = %f\n",-si->r);
1609
+
1610
+	for(i=0;i<l;i++)
1611
+		alpha[i] = alpha2[i] - alpha2[i+l];
1612
+
1613
+	delete[] alpha2;
1614
+	delete[] linear_term;
1615
+	delete[] y;
1616
+}
1617
+
1618
+//
1619
+// decision_function
1620
+//
1621
+struct decision_function
1622
+{
1623
+	double *alpha;
1624
+	double rho;	
1625
+};
1626
+
1627
+decision_function svm_train_one(
1628
+	const svm_problem *prob, const svm_parameter *param,
1629
+	double Cp, double Cn)
1630
+{
1631
+	double *alpha = Malloc(double,prob->l);
1632
+	Solver::SolutionInfo si;
1633
+	switch(param->svm_type)
1634
+	{
1635
+		case C_SVC:
1636
+			solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1637
+			break;
1638
+		case NU_SVC:
1639
+			solve_nu_svc(prob,param,alpha,&si);
1640
+			break;
1641
+		case ONE_CLASS:
1642
+			solve_one_class(prob,param,alpha,&si);
1643
+			break;
1644
+		case EPSILON_SVR:
1645
+			solve_epsilon_svr(prob,param,alpha,&si);
1646
+			break;
1647
+		case NU_SVR:
1648
+			solve_nu_svr(prob,param,alpha,&si);
1649
+			break;
1650
+	}
1651
+
1652
+	info("obj = %f, rho = %f\n",si.obj,si.rho);
1653
+
1654
+	// output SVs
1655
+
1656
+	int nSV = 0;
1657
+	int nBSV = 0;
1658
+	for(int i=0;i<prob->l;i++)
1659
+	{
1660
+		if(fabs(alpha[i]) > 0)
1661
+		{
1662
+			++nSV;
1663
+			if(prob->y[i] > 0)
1664
+			{
1665
+				if(fabs(alpha[i]) >= si.upper_bound_p)
1666
+					++nBSV;
1667
+			}
1668
+			else
1669
+			{
1670
+				if(fabs(alpha[i]) >= si.upper_bound_n)
1671
+					++nBSV;
1672
+			}
1673
+		}
1674
+	}
1675
+
1676
+	info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1677
+
1678
+	decision_function f;
1679
+	f.alpha = alpha;
1680
+	f.rho = si.rho;
1681
+	return f;
1682
+}
1683
+
1684
+/*
1685
+//
1686
+// svm_model
1687
+//
1688
+struct svm_model
1689
+{
1690
+	svm_parameter param;	// parameter
1691
+	int nr_class;		// number of classes, = 2 in regression/one class svm
1692
+	int l;			// total #SV
1693
+	svm_node **SV;		// SVs (SV[l])
1694
+	double **sv_coef;	// coefficients for SVs in decision functions (sv_coef[k-1][l])
1695
+	double *rho;		// constants in decision functions (rho[k*(k-1)/2])
1696
+	double *probA;          // pariwise probability information
1697
+	double *probB;
1698
+
1699
+	// for classification only
1700
+
1701
+	int *label;		// label of each class (label[k])
1702
+	int *nSV;		// number of SVs for each class (nSV[k])
1703
+				// nSV[0] + nSV[1] + ... + nSV[k-1] = l
1704
+	// XXX
1705
+	int free_sv;		// 1 if svm_model is created by svm_load_model
1706
+				// 0 if svm_model is created by svm_train
1707
+};
1708
+*/
1709
+// Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1710
+void sigmoid_train(
1711
+	int l, const double *dec_values, const double *labels, 
1712
+	double& A, double& B)
1713
+{
1714
+	double prior1=0, prior0 = 0;
1715
+	int i;
1716
+
1717
+	for (i=0;i<l;i++)
1718
+		if (labels[i] > 0) prior1+=1;
1719
+		else prior0+=1;
1720
+	
1721
+	int max_iter=100; 	// Maximal number of iterations
1722
+	double min_step=1e-10;	// Minimal step taken in line search
1723
+	double sigma=1e-12;	// For numerically strict PD of Hessian
1724
+	double eps=1e-5;
1725
+	double hiTarget=(prior1+1.0)/(prior1+2.0);
1726
+	double loTarget=1/(prior0+2.0);
1727
+	double *t=Malloc(double,l);
1728
+	double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1729
+	double newA,newB,newf,d1,d2;
1730
+	int iter; 
1731
+	
1732
+	// Initial Point and Initial Fun Value
1733
+	A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1734
+	double fval = 0.0;
1735
+
1736
+	for (i=0;i<l;i++)
1737
+	{
1738
+		if (labels[i]>0) t[i]=hiTarget;
1739
+		else t[i]=loTarget;
1740
+		fApB = dec_values[i]*A+B;
1741
+		if (fApB>=0)
1742
+			fval += t[i]*fApB + log(1+exp(-fApB));
1743
+		else
1744
+			fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1745
+	}
1746
+	for (iter=0;iter<max_iter;iter++)
1747
+	{
1748
+		// Update Gradient and Hessian (use H' = H + sigma I)
1749
+		h11=sigma; // numerically ensures strict PD
1750
+		h22=sigma;
1751
+		h21=0.0;g1=0.0;g2=0.0;
1752
+		for (i=0;i<l;i++)
1753
+		{
1754
+			fApB = dec_values[i]*A+B;
1755
+			if (fApB >= 0)
1756
+			{
1757
+				p=exp(-fApB)/(1.0+exp(-fApB));
1758
+				q=1.0/(1.0+exp(-fApB));
1759
+			}
1760
+			else
1761
+			{
1762
+				p=1.0/(1.0+exp(fApB));
1763
+				q=exp(fApB)/(1.0+exp(fApB));
1764
+			}
1765
+			d2=p*q;
1766
+			h11+=dec_values[i]*dec_values[i]*d2;
1767
+			h22+=d2;
1768
+			h21+=dec_values[i]*d2;
1769
+			d1=t[i]-p;
1770
+			g1+=dec_values[i]*d1;
1771
+			g2+=d1;
1772
+		}
1773
+
1774
+		// Stopping Criteria
1775
+		if (fabs(g1)<eps && fabs(g2)<eps)
1776
+			break;
1777
+
1778
+		// Finding Newton direction: -inv(H') * g
1779
+		det=h11*h22-h21*h21;
1780
+		dA=-(h22*g1 - h21 * g2) / det;
1781
+		dB=-(-h21*g1+ h11 * g2) / det;
1782
+		gd=g1*dA+g2*dB;
1783
+
1784
+
1785
+		stepsize = 1; 		// Line Search
1786
+		while (stepsize >= min_step)
1787
+		{
1788
+			newA = A + stepsize * dA;
1789
+			newB = B + stepsize * dB;
1790
+
1791
+			// New function value
1792
+			newf = 0.0;
1793
+			for (i=0;i<l;i++)
1794
+			{
1795
+				fApB = dec_values[i]*newA+newB;
1796
+				if (fApB >= 0)
1797
+					newf += t[i]*fApB + log(1+exp(-fApB));
1798
+				else
1799
+					newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1800
+			}
1801
+			// Check sufficient decrease
1802
+			if (newf<fval+0.0001*stepsize*gd)
1803
+			{
1804
+				A=newA;B=newB;fval=newf;
1805
+				break;
1806
+			}
1807
+			else
1808
+				stepsize = stepsize / 2.0;
1809
+		}
1810
+
1811
+		if (stepsize < min_step)
1812
+		{
1813
+			info("Line search fails in two-class probability estimates\n");
1814
+			break;
1815
+		}
1816
+	}
1817
+
1818
+	if (iter>=max_iter)
1819
+		info("Reaching maximal iterations in two-class probability estimates\n");
1820
+	free(t);
1821
+}
1822
+
1823
+double sigmoid_predict(double decision_value, double A, double B)
1824
+{
1825
+	double fApB = decision_value*A+B;
1826
+	if (fApB >= 0)
1827
+		return exp(-fApB)/(1.0+exp(-fApB));
1828
+	else
1829
+		return 1.0/(1+exp(fApB)) ;
1830
+}
1831
+
1832
+// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1833
+void multiclass_probability(int k, double **r, double *p)
1834
+{
1835
+	int t,j;
1836
+	int iter = 0, max_iter=max(100,k);
1837
+	double **Q=Malloc(double *,k);
1838
+	double *Qp=Malloc(double,k);
1839
+	double pQp, eps=0.005/k;
1840
+	
1841
+	for (t=0;t<k;t++)
1842
+	{
1843
+		p[t]=1.0/k;  // Valid if k = 1
1844
+		Q[t]=Malloc(double,k);
1845
+		Q[t][t]=0;
1846
+		for (j=0;j<t;j++)
1847
+		{
1848
+			Q[t][t]+=r[j][t]*r[j][t];
1849
+			Q[t][j]=Q[j][t];
1850
+		}
1851
+		for (j=t+1;j<k;j++)
1852
+		{
1853
+			Q[t][t]+=r[j][t]*r[j][t];
1854
+			Q[t][j]=-r[j][t]*r[t][j];
1855
+		}
1856
+	}
1857
+	for (iter=0;iter<max_iter;iter++)
1858
+	{
1859
+		// stopping condition, recalculate QP,pQP for numerical accuracy
1860
+		pQp=0;
1861
+		for (t=0;t<k;t++)
1862
+		{
1863
+			Qp[t]=0;
1864
+			for (j=0;j<k;j++)
1865
+				Qp[t]+=Q[t][j]*p[j];
1866
+			pQp+=p[t]*Qp[t];
1867
+		}
1868
+		double max_error=0;
1869
+		for (t=0;t<k;t++)
1870
+		{
1871
+			double error=fabs(Qp[t]-pQp);
1872
+			if (error>max_error)
1873
+				max_error=error;
1874
+		}
1875
+		if (max_error<eps) break;
1876
+		
1877
+		for (t=0;t<k;t++)
1878
+		{
1879
+			double diff=(-Qp[t]+pQp)/Q[t][t];
1880
+			p[t]+=diff;
1881
+			pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1882
+			for (j=0;j<k;j++)
1883
+			{
1884
+				Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1885
+				p[j]/=(1+diff);
1886
+			}
1887
+		}
1888
+	}
1889
+	if (iter>=max_iter)
1890
+		info("Exceeds max_iter in multiclass_prob\n");
1891
+	for(t=0;t<k;t++) free(Q[t]);
1892
+	free(Q);
1893
+	free(Qp);
1894
+}
1895
+
1896
+// Cross-validation decision values for probability estimates
1897
+void svm_binary_svc_probability(
1898
+	const svm_problem *prob, const svm_parameter *param,
1899
+	double Cp, double Cn, double& probA, double& probB)
1900
+{
1901
+	int i;
1902
+	int nr_fold = 5;
1903
+	int *perm = Malloc(int,prob->l);
1904
+	double *dec_values = Malloc(double,prob->l);
1905
+
1906
+	// random shuffle
1907
+	for(i=0;i<prob->l;i++) perm[i]=i;
1908
+	for(i=0;i<prob->l;i++)
1909
+	{
1910
+		int j = i+rand()%(prob->l-i);
1911
+		swap(perm[i],perm[j]);
1912
+	}
1913
+	for(i=0;i<nr_fold;i++)
1914
+	{
1915
+		int begin = i*prob->l/nr_fold;
1916
+		int end = (i+1)*prob->l/nr_fold;
1917
+		int j,k;
1918
+		struct svm_problem subprob;
1919
+
1920
+		subprob.l = prob->l-(end-begin);
1921
+		subprob.x = Malloc(struct svm_node*,subprob.l);
1922
+		subprob.y = Malloc(double,subprob.l);
1923
+			
1924
+		k=0;
1925
+		for(j=0;j<begin;j++)
1926
+		{
1927
+			subprob.x[k] = prob->x[perm[j]];
1928
+			subprob.y[k] = prob->y[perm[j]];
1929
+			++k;
1930
+		}
1931
+		for(j=end;j<prob->l;j++)
1932
+		{
1933
+			subprob.x[k] = prob->x[perm[j]];
1934
+			subprob.y[k] = prob->y[perm[j]];
1935
+			++k;
1936
+		}
1937
+		int p_count=0,n_count=0;
1938
+		for(j=0;j<k;j++)
1939
+			if(subprob.y[j]>0)
1940
+				p_count++;
1941
+			else
1942
+				n_count++;
1943
+
1944
+		if(p_count==0 && n_count==0)
1945
+			for(j=begin;j<end;j++)
1946
+				dec_values[perm[j]] = 0;
1947
+		else if(p_count > 0 && n_count == 0)
1948
+			for(j=begin;j<end;j++)
1949
+				dec_values[perm[j]] = 1;
1950
+		else if(p_count == 0 && n_count > 0)
1951
+			for(j=begin;j<end;j++)
1952
+				dec_values[perm[j]] = -1;
1953
+		else
1954
+		{
1955
+			svm_parameter subparam = *param;
1956
+			subparam.probability=0;
1957
+			subparam.C=1.0;
1958
+			subparam.nr_weight=2;
1959
+			subparam.weight_label = Malloc(int,2);
1960
+			subparam.weight = Malloc(double,2);
1961
+			subparam.weight_label[0]=+1;
1962
+			subparam.weight_label[1]=-1;
1963
+			subparam.weight[0]=Cp;
1964
+			subparam.weight[1]=Cn;
1965
+			struct svm_model *submodel = svm_train(&subprob,&subparam);
1966
+			for(j=begin;j<end;j++)
1967
+			{
1968
+				svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 
1969
+				// ensure +1 -1 order; reason not using CV subroutine
1970
+				dec_values[perm[j]] *= submodel->label[0];
1971
+			}		
1972
+			svm_destroy_model(submodel);
1973
+			svm_destroy_param(&subparam);
1974
+		}
1975
+		free(subprob.x);
1976
+		free(subprob.y);
1977
+	}		
1978
+	sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1979
+	free(dec_values);
1980
+	free(perm);
1981
+}
1982
+
1983
+// Return parameter of a Laplace distribution 
1984
+double svm_svr_probability(
1985
+	const svm_problem *prob, const svm_parameter *param)
1986
+{
1987
+	int i;
1988
+	int nr_fold = 5;
1989
+	double *ymv = Malloc(double,prob->l);
1990
+	double mae = 0;
1991
+
1992
+	svm_parameter newparam = *param;
1993
+	newparam.probability = 0;
1994
+	svm_cross_validation(prob,&newparam,nr_fold,ymv);
1995
+	for(i=0;i<prob->l;i++)
1996
+	{
1997
+		ymv[i]=prob->y[i]-ymv[i];
1998
+		mae += fabs(ymv[i]);
1999
+	}		
2000
+	mae /= prob->l;
2001
+	double std=sqrt(2*mae*mae);
2002
+	int count=0;
2003
+	mae=0;
2004
+	for(i=0;i<prob->l;i++)
2005
+	        if (fabs(ymv[i]) > 5*std) 
2006
+                        count=count+1;
2007
+		else 
2008
+		        mae+=fabs(ymv[i]);
2009
+	mae /= (prob->l-count);
2010
+	info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
2011
+	free(ymv);
2012
+	return mae;
2013
+}
2014
+
2015
+
2016
+// label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
2017
+// perm, length l, must be allocated before calling this subroutine
2018
+void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
2019
+{
2020
+	int l = prob->l;
2021
+	int max_nr_class = 16;
2022
+	int nr_class = 0;
2023
+	int *label = Malloc(int,max_nr_class);
2024
+	int *count = Malloc(int,max_nr_class);
2025
+	int *data_label = Malloc(int,l);	
2026
+	int i;
2027
+
2028
+	for(i=0;i<l;i++)
2029
+	{
2030
+		int this_label = (int)prob->y[i];
2031
+		int j;
2032
+		for(j=0;j<nr_class;j++)
2033
+		{
2034
+			if(this_label == label[j])
2035
+			{
2036
+				++count[j];
2037
+				break;
2038
+			}
2039
+		}
2040
+		data_label[i] = j;
2041
+		if(j == nr_class)
2042
+		{
2043
+			if(nr_class == max_nr_class)
2044
+			{
2045
+				max_nr_class *= 2;
2046
+				label = (int *)realloc(label,max_nr_class*sizeof(int));
2047
+				count = (int *)realloc(count,max_nr_class*sizeof(int));
2048
+			}
2049
+			label[nr_class] = this_label;
2050
+			count[nr_class] = 1;
2051
+			++nr_class;
2052
+		}
2053
+	}
2054
+
2055
+	int *start = Malloc(int,nr_class);
2056
+	start[0] = 0;
2057
+	for(i=1;i<nr_class;i++)
2058
+		start[i] = start[i-1]+count[i-1];
2059
+	for(i=0;i<l;i++)
2060
+	{
2061
+		perm[start[data_label[i]]] = i;
2062
+		++start[data_label[i]];
2063
+	}
2064
+	start[0] = 0;
2065
+	for(i=1;i<nr_class;i++)
2066
+		start[i] = start[i-1]+count[i-1];
2067
+
2068
+	*nr_class_ret = nr_class;
2069
+	*label_ret = label;
2070
+	*start_ret = start;
2071
+	*count_ret = count;
2072
+	free(data_label);
2073
+}
2074
+
2075
+//
2076
+// Interface functions
2077
+//
2078
+svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2079
+{
2080
+	svm_model *model = Malloc(svm_model,1);
2081
+	model->param = *param;
2082
+	model->free_sv = 0;	// XXX
2083
+
2084
+	if(param->svm_type == ONE_CLASS ||
2085
+	   param->svm_type == EPSILON_SVR ||
2086
+	   param->svm_type == NU_SVR)
2087
+	{
2088
+		// regression or one-class-svm
2089
+		model->nr_class = 2;
2090
+		model->label = NULL;
2091
+		model->nSV = NULL;
2092
+		model->probA = NULL; model->probB = NULL;
2093
+		model->sv_coef = Malloc(double *,1);
2094
+
2095
+		if(param->probability && 
2096
+		   (param->svm_type == EPSILON_SVR ||
2097
+		    param->svm_type == NU_SVR))
2098
+		{
2099
+			model->probA = Malloc(double,1);
2100
+			model->probA[0] = svm_svr_probability(prob,param);
2101
+		}
2102
+
2103
+		decision_function f = svm_train_one(prob,param,0,0);
2104
+		model->rho = Malloc(double,1);
2105
+		model->rho[0] = f.rho;
2106
+
2107
+		int nSV = 0;
2108
+		int i;
2109
+		for(i=0;i<prob->l;i++)
2110
+			if(fabs(f.alpha[i]) > 0) ++nSV;
2111
+		model->l = nSV;
2112
+		model->SV = Malloc(svm_node *,nSV);
2113
+		model->sv_coef[0] = Malloc(double,nSV);
2114
+		int j = 0;
2115
+		for(i=0;i<prob->l;i++)
2116
+			if(fabs(f.alpha[i]) > 0)
2117
+			{
2118
+				model->SV[j] = prob->x[i];
2119
+				model->sv_coef[0][j] = f.alpha[i];
2120
+				++j;
2121
+			}		
2122
+
2123
+		free(f.alpha);
2124
+	}
2125
+	else
2126
+	{
2127
+		// classification
2128
+		int l = prob->l;
2129
+		int nr_class;
2130
+		int *label = NULL;
2131
+		int *start = NULL;
2132
+		int *count = NULL;
2133
+		int *perm = Malloc(int,l);
2134
+
2135
+		// group training data of the same class
2136
+		svm_group_classes(prob,&nr_class,&label,&start,&count,perm);		
2137
+		svm_node **x = Malloc(svm_node *,l);
2138
+		int i;
2139
+		for(i=0;i<l;i++)
2140
+			x[i] = prob->x[perm[i]];
2141
+
2142
+		// calculate weighted C
2143
+
2144
+		double *weighted_C = Malloc(double, nr_class);
2145
+		for(i=0;i<nr_class;i++)
2146
+			weighted_C[i] = param->C;
2147
+		for(i=0;i<param->nr_weight;i++)
2148
+		{	
2149
+			int j;
2150
+			for(j=0;j<nr_class;j++)
2151
+				if(param->weight_label[i] == label[j])
2152
+					break;
2153
+			if(j == nr_class)
2154
+				fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2155
+			else
2156
+				weighted_C[j] *= param->weight[i];
2157
+		}
2158
+
2159
+		// train k*(k-1)/2 models
2160
+		
2161
+		bool *nonzero = Malloc(bool,l);
2162
+		for(i=0;i<l;i++)
2163
+			nonzero[i] = false;
2164
+		decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2165
+
2166
+		double *probA=NULL,*probB=NULL;
2167
+		if (param->probability)
2168
+		{
2169
+			probA=Malloc(double,nr_class*(nr_class-1)/2);
2170
+			probB=Malloc(double,nr_class*(nr_class-1)/2);
2171
+		}
2172
+
2173
+		int p = 0;
2174
+		for(i=0;i<nr_class;i++)
2175
+			for(int j=i+1;j<nr_class;j++)
2176
+			{
2177
+				svm_problem sub_prob;
2178
+				int si = start[i], sj = start[j];
2179
+				int ci = count[i], cj = count[j];
2180
+				sub_prob.l = ci+cj;
2181
+				sub_prob.x = Malloc(svm_node *,sub_prob.l);
2182
+				sub_prob.y = Malloc(double,sub_prob.l);
2183
+				int k;
2184
+				for(k=0;k<ci;k++)
2185
+				{
2186
+					sub_prob.x[k] = x[si+k];
2187
+					sub_prob.y[k] = +1;
2188
+				}
2189
+				for(k=0;k<cj;k++)
2190
+				{
2191
+					sub_prob.x[ci+k] = x[sj+k];
2192
+					sub_prob.y[ci+k] = -1;
2193
+				}
2194
+
2195
+				if(param->probability)
2196
+					svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2197
+
2198
+				f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2199
+				for(k=0;k<ci;k++)
2200
+					if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2201
+						nonzero[si+k] = true;
2202
+				for(k=0;k<cj;k++)
2203
+					if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2204
+						nonzero[sj+k] = true;
2205
+				free(sub_prob.x);
2206
+				free(sub_prob.y);
2207
+				++p;
2208
+			}
2209
+
2210
+		// build output
2211
+
2212
+		model->nr_class = nr_class;
2213
+		
2214
+		model->label = Malloc(int,nr_class);
2215
+		for(i=0;i<nr_class;i++)
2216
+			model->label[i] = label[i];
2217
+		
2218
+		model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2219
+		for(i=0;i<nr_class*(nr_class-1)/2;i++)
2220
+			model->rho[i] = f[i].rho;
2221
+
2222
+		if(param->probability)
2223
+		{
2224
+			model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2225
+			model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2226
+			for(i=0;i<nr_class*(nr_class-1)/2;i++)
2227
+			{
2228
+				model->probA[i] = probA[i];
2229
+				model->probB[i] = probB[i];
2230
+			}
2231
+		}
2232
+		else
2233
+		{
2234
+			model->probA=NULL;
2235
+			model->probB=NULL;
2236
+		}
2237
+
2238
+		int total_sv = 0;
2239
+		int *nz_count = Malloc(int,nr_class);
2240
+		model->nSV = Malloc(int,nr_class);
2241
+		for(i=0;i<nr_class;i++)
2242
+		{
2243
+			int nSV = 0;
2244
+			for(int j=0;j<count[i];j++)
2245
+				if(nonzero[start[i]+j])
2246
+				{	
2247
+					++nSV;
2248
+					++total_sv;
2249
+				}
2250
+			model->nSV[i] = nSV;
2251
+			nz_count[i] = nSV;
2252
+		}
2253
+		
2254
+		info("Total nSV = %d\n",total_sv);
2255
+
2256
+		model->l = total_sv;
2257
+		model->SV = Malloc(svm_node *,total_sv);
2258
+		p = 0;
2259
+		for(i=0;i<l;i++)
2260
+			if(nonzero[i]) model->SV[p++] = x[i];
2261
+
2262
+		int *nz_start = Malloc(int,nr_class);
2263
+		nz_start[0] = 0;
2264
+		for(i=1;i<nr_class;i++)
2265
+			nz_start[i] = nz_start[i-1]+nz_count[i-1];
2266
+
2267
+		model->sv_coef = Malloc(double *,nr_class-1);
2268
+		for(i=0;i<nr_class-1;i++)
2269
+			model->sv_coef[i] = Malloc(double,total_sv);
2270
+
2271
+		p = 0;
2272
+		for(i=0;i<nr_class;i++)
2273
+			for(int j=i+1;j<nr_class;j++)
2274
+			{
2275
+				// classifier (i,j): coefficients with
2276
+				// i are in sv_coef[j-1][nz_start[i]...],
2277
+				// j are in sv_coef[i][nz_start[j]...]
2278
+
2279
+				int si = start[i];
2280
+				int sj = start[j];
2281
+				int ci = count[i];
2282
+				int cj = count[j];
2283
+				
2284
+				int q = nz_start[i];
2285
+				int k;
2286
+				for(k=0;k<ci;k++)
2287
+					if(nonzero[si+k])
2288
+						model->sv_coef[j-1][q++] = f[p].alpha[k];
2289
+				q = nz_start[j];
2290
+				for(k=0;k<cj;k++)
2291
+					if(nonzero[sj+k])
2292
+						model->sv_coef[i][q++] = f[p].alpha[ci+k];
2293
+				++p;
2294
+			}
2295
+		
2296
+		free(label);
2297
+		free(probA);
2298
+		free(probB);
2299
+		free(count);
2300
+		free(perm);
2301
+		free(start);
2302
+		free(x);
2303
+		free(weighted_C);
2304
+		free(nonzero);
2305
+		for(i=0;i<nr_class*(nr_class-1)/2;i++)
2306
+			free(f[i].alpha);
2307
+		free(f);
2308
+		free(nz_count);
2309
+		free(nz_start);
2310
+	}
2311
+	return model;
2312
+}
2313
+
2314
+// Stratified cross validation
2315
+void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2316
+{
2317
+	int i;
2318
+	int *fold_start = Malloc(int,nr_fold+1);
2319
+	int l = prob->l;
2320
+	int *perm = Malloc(int,l);
2321
+	int nr_class;
2322
+
2323
+	// stratified cv may not give leave-one-out rate
2324
+	// Each class to l folds -> some folds may have zero elements
2325
+	if((param->svm_type == C_SVC ||
2326
+	    param->svm_type == NU_SVC) && nr_fold < l)
2327
+	{
2328
+		int *start = NULL;
2329
+		int *label = NULL;
2330
+		int *count = NULL;
2331
+		svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2332
+
2333
+		// random shuffle and then data grouped by fold using the array perm
2334
+		int *fold_count = Malloc(int,nr_fold);
2335
+		int c;
2336
+		int *index = Malloc(int,l);
2337
+		for(i=0;i<l;i++)
2338
+			index[i]=perm[i];
2339
+		for (c=0; c<nr_class; c++) 
2340
+			for(i=0;i<count[c];i++)
2341
+			{
2342
+				int j = i+rand()%(count[c]-i);
2343
+				swap(index[start[c]+j],index[start[c]+i]);
2344
+			}
2345
+		for(i=0;i<nr_fold;i++)
2346
+		{
2347
+			fold_count[i] = 0;
2348
+			for (c=0; c<nr_class;c++)
2349
+				fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2350
+		}
2351
+		fold_start[0]=0;
2352
+		for (i=1;i<=nr_fold;i++)
2353
+			fold_start[i] = fold_start[i-1]+fold_count[i-1];
2354
+		for (c=0; c<nr_class;c++)
2355
+			for(i=0;i<nr_fold;i++)
2356
+			{
2357
+				int begin = start[c]+i*count[c]/nr_fold;
2358
+				int end = start[c]+(i+1)*count[c]/nr_fold;
2359
+				for(int j=begin;j<end;j++)
2360
+				{
2361
+					perm[fold_start[i]] = index[j];
2362
+					fold_start[i]++;
2363
+				}
2364
+			}
2365
+		fold_start[0]=0;
2366
+		for (i=1;i<=nr_fold;i++)
2367
+			fold_start[i] = fold_start[i-1]+fold_count[i-1];
2368
+		free(start);	
2369
+		free(label);
2370
+		free(count);	
2371
+		free(index);
2372
+		free(fold_count);
2373
+	}
2374
+	else
2375
+	{
2376
+		for(i=0;i<l;i++) perm[i]=i;
2377
+		for(i=0;i<l;i++)
2378
+		{
2379
+			int j = i+rand()%(l-i);
2380
+			swap(perm[i],perm[j]);
2381
+		}
2382
+		for(i=0;i<=nr_fold;i++)
2383
+			fold_start[i]=i*l/nr_fold;
2384
+	}
2385
+
2386
+	for(i=0;i<nr_fold;i++)
2387
+	{
2388
+		int begin = fold_start[i];
2389
+		int end = fold_start[i+1];
2390
+		int j,k;
2391
+		struct svm_problem subprob;
2392
+
2393
+		subprob.l = l-(end-begin);
2394
+		subprob.x = Malloc(struct svm_node*,subprob.l);
2395
+		subprob.y = Malloc(double,subprob.l);
2396
+			
2397
+		k=0;
2398
+		for(j=0;j<begin;j++)
2399
+		{
2400
+			subprob.x[k] = prob->x[perm[j]];
2401
+			subprob.y[k] = prob->y[perm[j]];
2402
+			++k;
2403
+		}
2404
+		for(j=end;j<l;j++)
2405
+		{
2406
+			subprob.x[k] = prob->x[perm[j]];
2407
+			subprob.y[k] = prob->y[perm[j]];
2408
+			++k;
2409
+		}
2410
+		struct svm_model *submodel = svm_train(&subprob,param);
2411
+		if(param->probability && 
2412
+		   (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2413
+		{
2414
+			double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2415
+			for(j=begin;j<end;j++)
2416
+				target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2417
+			free(prob_estimates);			
2418
+		}
2419
+		else
2420
+			for(j=begin;j<end;j++)
2421
+				target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2422
+		svm_destroy_model(submodel);
2423
+		free(subprob.x);
2424
+		free(subprob.y);
2425
+	}		
2426
+	free(fold_start);
2427
+	free(perm);	
2428
+}
2429
+
2430
+
2431
+int svm_get_svm_type(const svm_model *model)
2432
+{
2433
+	return model->param.svm_type;
2434
+}
2435
+
2436
+int svm_get_nr_class(const svm_model *model)
2437
+{
2438
+	return model->nr_class;
2439
+}
2440
+
2441
+void svm_get_labels(const svm_model *model, int* label)
2442
+{
2443
+	if (model->label != NULL)
2444
+		for(int i=0;i<model->nr_class;i++)
2445
+			label[i] = model->label[i];
2446
+}
2447
+
2448
+double svm_get_svr_probability(const svm_model *model)
2449
+{
2450
+	if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2451
+	    model->probA!=NULL)
2452
+		return model->probA[0];
2453
+	else
2454
+	{
2455
+		info("Model doesn't contain information for SVR probability inference\n");
2456
+		return 0;
2457
+	}
2458
+}
2459
+
2460
+void svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2461
+{
2462
+	if(model->param.svm_type == ONE_CLASS ||
2463
+	   model->param.svm_type == EPSILON_SVR ||
2464
+	   model->param.svm_type == NU_SVR)
2465
+	{
2466
+		double *sv_coef = model->sv_coef[0];
2467
+		double sum = 0;
2468
+		for(int i=0;i<model->l;i++)
2469
+			sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2470
+		sum -= model->rho[0];
2471
+		*dec_values = sum;
2472
+	}
2473
+	else
2474
+	{
2475
+		int i;
2476
+		int nr_class = model->nr_class;
2477
+		int l = model->l;
2478
+		
2479
+		double *kvalue = Malloc(double,l);
2480
+		for(i=0;i<l;i++)
2481
+			kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2482
+
2483
+		int *start = Malloc(int,nr_class);
2484
+		start[0] = 0;
2485
+		for(i=1;i<nr_class;i++)
2486
+			start[i] = start[i-1]+model->nSV[i-1];
2487
+
2488
+		int p=0;
2489
+		for(i=0;i<nr_class;i++)
2490
+			for(int j=i+1;j<nr_class;j++)
2491
+			{
2492
+				double sum = 0;
2493
+				int si = start[i];
2494
+				int sj = start[j];
2495
+				int ci = model->nSV[i];
2496
+				int cj = model->nSV[j];
2497
+				
2498
+				int k;
2499
+				double *coef1 = model->sv_coef[j-1];
2500
+				double *coef2 = model->sv_coef[i];
2501
+				for(k=0;k<ci;k++)
2502
+					sum += coef1[si+k] * kvalue[si+k];
2503
+				for(k=0;k<cj;k++)
2504
+					sum += coef2[sj+k] * kvalue[sj+k];
2505
+				sum -= model->rho[p];
2506
+				dec_values[p] = sum;
2507
+				p++;
2508
+			}
2509
+
2510
+		free(kvalue);
2511
+		free(start);
2512
+	}
2513
+}
2514
+
2515
+double svm_predict(const svm_model *model, const svm_node *x)
2516
+{
2517
+	if(model->param.svm_type == ONE_CLASS ||
2518
+	   model->param.svm_type == EPSILON_SVR ||
2519
+	   model->param.svm_type == NU_SVR)
2520
+	{
2521
+		double res;
2522
+		svm_predict_values(model, x, &res);
2523
+		
2524
+		if(model->param.svm_type == ONE_CLASS)
2525
+			return (res>0)?1:-1;
2526
+		else
2527
+			return res;
2528
+	}
2529
+	else
2530
+	{
2531
+		int i;
2532
+		int nr_class = model->nr_class;
2533
+		double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2534
+		svm_predict_values(model, x, dec_values);
2535
+
2536
+		int *vote = Malloc(int,nr_class);
2537
+		for(i=0;i<nr_class;i++)
2538
+			vote[i] = 0;
2539
+		int pos=0;
2540
+		for(i=0;i<nr_class;i++)
2541
+			for(int j=i+1;j<nr_class;j++)
2542
+			{
2543
+				if(dec_values[pos++] > 0)
2544
+					++vote[i];
2545
+				else
2546
+					++vote[j];
2547
+			}
2548
+
2549
+		int vote_max_idx = 0;
2550
+		for(i=1;i<nr_class;i++)
2551
+			if(vote[i] > vote[vote_max_idx])
2552
+				vote_max_idx = i;
2553
+		free(vote);
2554
+		free(dec_values);
2555
+		return model->label[vote_max_idx];
2556
+	}
2557
+}
2558
+
2559
+double svm_predict_probability(
2560
+	const svm_model *model, const svm_node *x, double *prob_estimates)
2561
+{
2562
+	if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2563
+	    model->probA!=NULL && model->probB!=NULL)
2564
+	{
2565
+		int i;
2566
+		int nr_class = model->nr_class;
2567
+		double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2568
+		svm_predict_values(model, x, dec_values);
2569
+
2570
+		double min_prob=1e-7;
2571
+		double **pairwise_prob=Malloc(double *,nr_class);
2572
+		for(i=0;i<nr_class;i++)
2573
+			pairwise_prob[i]=Malloc(double,nr_class);
2574
+		int k=0;
2575
+		for(i=0;i<nr_class;i++)
2576
+			for(int j=i+1;j<nr_class;j++)
2577
+			{
2578
+				pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2579
+				pairwise_prob[j][i]=1-pairwise_prob[i][j];
2580
+				k++;
2581
+			}
2582
+		multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2583
+
2584
+		int prob_max_idx = 0;
2585
+		for(i=1;i<nr_class;i++)
2586
+			if(prob_estimates[i] > prob_estimates[prob_max_idx])
2587
+				prob_max_idx = i;
2588
+		for(i=0;i<nr_class;i++)
2589
+			free(pairwise_prob[i]);
2590
+		free(dec_values);
2591
+                free(pairwise_prob);	     
2592
+		return model->label[prob_max_idx];
2593
+	}
2594
+	else 
2595
+		return svm_predict(model, x);
2596
+}
2597
+
2598
+const char *svm_type_table[] =
2599
+{
2600
+	"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2601
+};
2602
+
2603
+const char *kernel_type_table[]=
2604
+{
2605
+	"linear","polynomial","rbf","sigmoid","precomputed",NULL
2606
+};
2607
+
2608
+int svm_save_model(const char *model_file_name, const svm_model *model)
2609
+{
2610
+	FILE *fp = fopen(model_file_name,"w");
2611
+	if(fp==NULL) return -1;
2612
+
2613
+	const svm_parameter& param = model->param;
2614
+
2615
+	fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2616
+	fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2617
+
2618
+	if(param.kernel_type == POLY)
2619
+		fprintf(fp,"degree %d\n", param.degree);
2620
+
2621
+	if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2622
+		fprintf(fp,"gamma %g\n", param.gamma);
2623
+
2624
+	if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2625
+		fprintf(fp,"coef0 %g\n", param.coef0);
2626
+
2627
+	int nr_class = model->nr_class;
2628
+	int l = model->l;
2629
+	fprintf(fp, "nr_class %d\n", nr_class);
2630
+	fprintf(fp, "total_sv %d\n",l);
2631
+	
2632
+	{
2633
+		fprintf(fp, "rho");
2634
+		for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2635
+			fprintf(fp," %g",model->rho[i]);
2636
+		fprintf(fp, "\n");
2637
+	}
2638
+	
2639
+	if(model->label)
2640
+	{
2641
+		fprintf(fp, "label");
2642
+		for(int i=0;i<nr_class;i++)
2643
+			fprintf(fp," %d",model->label[i]);
2644
+		fprintf(fp, "\n");
2645
+	}
2646
+
2647
+	if(model->probA) // regression has probA only
2648
+	{
2649
+		fprintf(fp, "probA");
2650
+		for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2651
+			fprintf(fp," %g",model->probA[i]);
2652
+		fprintf(fp, "\n");
2653
+	}
2654
+	if(model->probB)
2655
+	{
2656
+		fprintf(fp, "probB");
2657
+		for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2658
+			fprintf(fp," %g",model->probB[i]);
2659
+		fprintf(fp, "\n");
2660
+	}
2661
+
2662
+	if(model->nSV)
2663
+	{
2664
+		fprintf(fp, "nr_sv");
2665
+		for(int i=0;i<nr_class;i++)
2666
+			fprintf(fp," %d",model->nSV[i]);
2667
+		fprintf(fp, "\n");
2668
+	}
2669
+
2670
+	fprintf(fp, "SV\n");
2671
+	const double * const *sv_coef = model->sv_coef;
2672
+	const svm_node * const *SV = model->SV;
2673
+
2674
+	for(int i=0;i<l;i++)
2675
+	{
2676
+		for(int j=0;j<nr_class-1;j++)
2677
+			fprintf(fp, "%.16g ",sv_coef[j][i]);
2678
+
2679
+		const svm_node *p = SV[i];
2680
+
2681
+		if(param.kernel_type == PRECOMPUTED)
2682
+			fprintf(fp,"0:%d ",(int)(p->value));
2683
+		else
2684
+			while(p->index != -1)
2685
+			{
2686
+				fprintf(fp,"%d:%.8g ",p->index,p->value);
2687
+				p++;
2688
+			}
2689
+		fprintf(fp, "\n");
2690
+	}
2691
+	if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2692
+	else return 0;
2693
+}
2694
+
2695
+svm_model *svm_load_model(const char *model_file_name)
2696
+{
2697
+	FILE *fp = fopen(model_file_name,"rb");
2698
+	if(fp==NULL) return NULL;
2699
+	
2700
+	// read parameters
2701
+
2702
+	svm_model *model = Malloc(svm_model,1);
2703
+	svm_parameter& param = model->param;
2704
+	model->rho = NULL;
2705
+	model->probA = NULL;
2706
+	model->probB = NULL;
2707
+	model->label = NULL;
2708
+	model->nSV = NULL;
2709
+
2710
+	char cmd[81];
2711
+	while(1)
2712
+	{
2713
+		fscanf(fp,"%80s",cmd);
2714
+
2715
+		if(strcmp(cmd,"svm_type")==0)
2716
+		{
2717
+			fscanf(fp,"%80s",cmd);
2718
+			int i;
2719
+			for(i=0;svm_type_table[i];i++)
2720
+			{
2721
+				if(strcmp(svm_type_table[i],cmd)==0)
2722
+				{
2723
+					param.svm_type=i;
2724
+					break;
2725
+				}
2726
+			}
2727
+			if(svm_type_table[i] == NULL)
2728
+			{
2729
+				fprintf(stderr,"unknown svm type.\n");
2730
+				free(model->rho);
2731
+				free(model->label);
2732
+				free(model->nSV);
2733
+				free(model);
2734
+				return NULL;
2735
+			}
2736
+		}
2737
+		else if(strcmp(cmd,"kernel_type")==0)
2738
+		{		
2739
+			fscanf(fp,"%80s",cmd);
2740
+			int i;
2741
+			for(i=0;kernel_type_table[i];i++)
2742
+			{
2743
+				if(strcmp(kernel_type_table[i],cmd)==0)
2744
+				{
2745
+					param.kernel_type=i;
2746
+					break;
2747
+				}
2748
+			}
2749
+			if(kernel_type_table[i] == NULL)
2750
+			{
2751
+				fprintf(stderr,"unknown kernel function.\n");
2752
+				free(model->rho);
2753
+				free(model->label);
2754
+				free(model->nSV);
2755
+				free(model);
2756
+				return NULL;
2757
+			}
2758
+		}
2759
+		else if(strcmp(cmd,"degree")==0)
2760
+			fscanf(fp,"%d",&param.degree);
2761
+		else if(strcmp(cmd,"gamma")==0)
2762
+			fscanf(fp,"%lf",&param.gamma);
2763
+		else if(strcmp(cmd,"coef0")==0)
2764
+			fscanf(fp,"%lf",&param.coef0);
2765
+		else if(strcmp(cmd,"nr_class")==0)
2766
+			fscanf(fp,"%d",&model->nr_class);
2767
+		else if(strcmp(cmd,"total_sv")==0)
2768
+			fscanf(fp,"%d",&model->l);
2769
+		else if(strcmp(cmd,"rho")==0)
2770
+		{
2771
+			int n = model->nr_class * (model->nr_class-1)/2;
2772
+			model->rho = Malloc(double,n);
2773
+			for(int i=0;i<n;i++)
2774
+				fscanf(fp,"%lf",&model->rho[i]);
2775
+		}
2776
+		else if(strcmp(cmd,"label")==0)
2777
+		{
2778
+			int n = model->nr_class;
2779
+			model->label = Malloc(int,n);
2780
+			for(int i=0;i<n;i++)
2781
+				fscanf(fp,"%d",&model->label[i]);
2782
+		}
2783
+		else if(strcmp(cmd,"probA")==0)
2784
+		{
2785
+			int n = model->nr_class * (model->nr_class-1)/2;
2786
+			model->probA = Malloc(double,n);
2787
+			for(int i=0;i<n;i++)
2788
+				fscanf(fp,"%lf",&model->probA[i]);
2789
+		}
2790
+		else if(strcmp(cmd,"probB")==0)
2791
+		{
2792
+			int n = model->nr_class * (model->nr_class-1)/2;
2793
+			model->probB = Malloc(double,n);
2794
+			for(int i=0;i<n;i++)
2795
+				fscanf(fp,"%lf",&model->probB[i]);
2796
+		}
2797
+		else if(strcmp(cmd,"nr_sv")==0)
2798
+		{
2799
+			int n = model->nr_class;
2800
+			model->nSV = Malloc(int,n);
2801
+			for(int i=0;i<n;i++)
2802
+				fscanf(fp,"%d",&model->nSV[i]);
2803
+		}
2804
+		else if(strcmp(cmd,"SV")==0)
2805
+		{
2806
+			while(1)
2807
+			{
2808
+				int c = getc(fp);
2809
+				if(c==EOF || c=='\n') break;	
2810
+			}
2811
+			break;
2812
+		}
2813
+		else
2814
+		{
2815
+			fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2816
+			free(model->rho);
2817
+			free(model->label);
2818
+			free(model->nSV);
2819
+			free(model);
2820
+			return NULL;
2821
+		}
2822
+	}
2823
+
2824
+	// read sv_coef and SV
2825
+
2826
+	int elements = 0;
2827
+	long pos = ftell(fp);
2828
+
2829
+	while(1)
2830
+	{
2831
+		int c = fgetc(fp);
2832
+		switch(c)
2833
+		{
2834
+			case '\n':
2835
+				// count the '-1' element
2836
+			case ':':
2837
+				++elements;
2838
+				break;
2839
+			case EOF:
2840
+				goto out;
2841
+			default:
2842
+				;
2843
+		}
2844
+	}
2845
+out:
2846
+	fseek(fp,pos,SEEK_SET);
2847
+
2848
+	int m = model->nr_class - 1;
2849
+	int l = model->l;
2850
+	model->sv_coef = Malloc(double *,m);
2851
+	int i;
2852
+	for(i=0;i<m;i++)
2853
+		model->sv_coef[i] = Malloc(double,l);
2854
+	model->SV = Malloc(svm_node*,l);
2855
+	svm_node *x_space=NULL;
2856
+	if(l>0) x_space = Malloc(svm_node,elements);
2857
+
2858
+	int j=0;
2859
+	for(i=0;i<l;i++)
2860
+	{
2861
+		model->SV[i] = &x_space[j];
2862
+		for(int k=0;k<m;k++)
2863
+			fscanf(fp,"%lf",&model->sv_coef[k][i]);
2864
+		while(1)
2865
+		{
2866
+			int c;
2867
+			do {
2868
+				c = getc(fp);
2869
+				if(c=='\n') goto out2;
2870
+			} while(isspace(c));
2871
+			ungetc(c,fp);
2872
+			fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
2873
+			++j;
2874
+		}	
2875
+out2:
2876
+		x_space[j++].index = -1;
2877
+	}
2878
+	if (ferror(fp) != 0 || fclose(fp) != 0) return NULL;
2879
+
2880
+	model->free_sv = 1;	// XXX
2881
+	return model;
2882
+}
2883
+
2884
+void svm_destroy_model(svm_model* model)
2885
+{
2886
+	if(model->free_sv && model->l > 0)
2887
+		free((void *)(model->SV[0]));
2888
+	for(int i=0;i<model->nr_class-1;i++)
2889
+		free(model->sv_coef[i]);
2890
+	free(model->SV);
2891
+	free(model->sv_coef);
2892
+	free(model->rho);
2893
+	free(model->label);
2894
+	free(model->probA);
2895
+	free(model->probB);
2896
+	free(model->nSV);
2897
+	free(model);
2898
+}
2899
+
2900
+void svm_destroy_param(svm_parameter* param)
2901
+{
2902
+	free(param->weight_label);
2903
+	free(param->weight);
2904
+}
2905
+
2906
+const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
2907
+{
2908
+	// svm_type
2909
+
2910
+	int svm_type = param->svm_type;
2911
+	if(svm_type != C_SVC &&
2912
+	   svm_type != NU_SVC &&
2913
+	   svm_type != ONE_CLASS &&
2914
+	   svm_type != EPSILON_SVR &&
2915
+	   svm_type != NU_SVR)
2916
+		return "unknown svm type";
2917
+	
2918
+	// kernel_type, degree
2919
+	
2920
+	int kernel_type = param->kernel_type;
2921
+	if(kernel_type != LINEAR &&
2922
+	   kernel_type != POLY &&
2923
+	   kernel_type != RBF &&
2924
+	   kernel_type != SIGMOID &&
2925
+	   kernel_type != PRECOMPUTED)
2926
+		return "unknown kernel type";
2927
+
2928
+	if(param->degree < 0)
2929
+		return "degree of polynomial kernel < 0";
2930
+
2931
+	// cache_size,eps,C,nu,p,shrinking
2932
+
2933
+	if(param->cache_size <= 0)
2934
+		return "cache_size <= 0";
2935
+
2936
+	if(param->eps <= 0)
2937
+		return "eps <= 0";
2938
+
2939
+	if(svm_type == C_SVC ||
2940
+	   svm_type == EPSILON_SVR ||
2941
+	   svm_type == NU_SVR)
2942
+		if(param->C <= 0)
2943
+			return "C <= 0";
2944
+
2945
+	if(svm_type == NU_SVC ||
2946
+	   svm_type == ONE_CLASS ||
2947
+	   svm_type == NU_SVR)
2948
+		if(param->nu <= 0 || param->nu > 1)
2949
+			return "nu <= 0 or nu > 1";
2950
+
2951
+	if(svm_type == EPSILON_SVR)
2952
+		if(param->p < 0)
2953
+			return "p < 0";
2954
+
2955
+	if(param->shrinking != 0 &&
2956
+	   param->shrinking != 1)
2957
+		return "shrinking != 0 and shrinking != 1";
2958
+
2959
+	if(param->probability != 0 &&
2960
+	   param->probability != 1)
2961
+		return "probability != 0 and probability != 1";
2962
+
2963
+	if(param->probability == 1 &&
2964
+	   svm_type == ONE_CLASS)
2965
+		return "one-class SVM probability output not supported yet";
2966
+
2967
+
2968
+	// check whether nu-svc is feasible
2969
+	
2970
+	if(svm_type == NU_SVC)
2971
+	{
2972
+		int l = prob->l;
2973
+		int max_nr_class = 16;
2974
+		int nr_class = 0;
2975
+		int *label = Malloc(int,max_nr_class);
2976
+		int *count = Malloc(int,max_nr_class);
2977
+
2978
+		int i;
2979
+		for(i=0;i<l;i++)
2980
+		{
2981
+			int this_label = (int)prob->y[i];
2982
+			int j;
2983
+			for(j=0;j<nr_class;j++)
2984
+				if(this_label == label[j])
2985
+				{
2986
+					++count[j];
2987
+					break;
2988
+				}
2989
+			if(j == nr_class)
2990
+			{
2991
+				if(nr_class == max_nr_class)
2992
+				{
2993
+					max_nr_class *= 2;
2994
+					label = (int *)realloc(label,max_nr_class*sizeof(int));
2995
+					count = (int *)realloc(count,max_nr_class*sizeof(int));
2996
+				}
2997
+				label[nr_class] = this_label;
2998
+				count[nr_class] = 1;
2999
+				++nr_class;
3000
+			}
3001
+		}
3002
+	
3003
+		for(i=0;i<nr_class;i++)
3004
+		{
3005
+			int n1 = count[i];
3006
+			for(int j=i+1;j<nr_class;j++)
3007
+			{
3008
+				int n2 = count[j];
3009
+				if(param->nu*(n1+n2)/2 > min(n1,n2))
3010
+				{
3011
+					free(label);
3012
+					free(count);
3013
+					return "specified nu is infeasible";
3014
+				}
3015
+			}
3016
+		}
3017
+		free(label);
3018
+		free(count);
3019
+	}
3020
+
3021
+	return NULL;
3022
+}
3023
+
3024
+int svm_check_probability_model(const svm_model *model)
3025
+{
3026
+	return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3027
+		model->probA!=NULL && model->probB!=NULL) ||
3028
+		((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3029
+		 model->probA!=NULL);
3030
+}
... ...
@@ -0,0 +1,93 @@
1
+#ifndef _LIBSVM_H
2
+#define _LIBSVM_H
3
+
4
+#define LIBSVM_VERSION 288
5
+
6
+#ifdef __cplusplus
7
+extern "C" {
8
+#endif
9
+
10
+struct svm_node
11
+{
12
+	int index;
13
+	double value;
14
+};
15
+
16
+struct svm_problem
17
+{
18
+	int l;
19
+	double *y;
20
+	struct svm_node **x;
21
+};
22
+
23
+enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };	/* svm_type */
24
+enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
25
+
26
+struct svm_parameter
27
+{
28
+	int svm_type;
29
+	int kernel_type;
30
+	int degree;	/* for poly */
31
+	double gamma;	/* for poly/rbf/sigmoid */
32
+	double coef0;	/* for poly/sigmoid */
33
+
34
+	/* these are for training only */
35
+	double cache_size; /* in MB */
36
+	double eps;	/* stopping criteria */
37
+	double C;	/* for C_SVC, EPSILON_SVR and NU_SVR */
38
+	int nr_weight;		/* for C_SVC */
39
+	int *weight_label;	/* for C_SVC */
40
+	double* weight;		/* for C_SVC */
41
+	double nu;	/* for NU_SVC, ONE_CLASS, and NU_SVR */
42
+	double p;	/* for EPSILON_SVR */
43
+	int shrinking;	/* use the shrinking heuristics */
44
+	int probability; /* do probability estimates */
45
+};
46
+
47
+struct svm_model
48
+{
49
+	struct svm_parameter param;	// parameter
50
+	int nr_class;		// number of classes, = 2 in regression/one class svm
51
+	int l;			// total #SV
52
+	struct svm_node **SV;		// SVs (SV[l])
53
+	double **sv_coef;	// coefficients for SVs in decision functions (sv_coef[k-1][l])
54
+	double *rho;		// constants in decision functions (rho[k*(k-1)/2])
55
+	double *probA;          // pariwise probability information
56
+	double *probB;
57
+
58
+	// for classification only
59
+
60
+	int *label;		// label of each class (label[k])
61
+	int *nSV;		// number of SVs for each class (nSV[k])
62
+				// nSV[0] + nSV[1] + ... + nSV[k-1] = l
63
+	// XXX
64
+	int free_sv;		// 1 if svm_model is created by svm_load_model
65
+				// 0 if svm_model is created by svm_train
66
+};
67
+
68
+struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param);
69
+void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target);
70
+
71
+int svm_save_model(const char *model_file_name, const struct svm_model *model);
72
+struct svm_model *svm_load_model(const char *model_file_name);
73
+
74
+int svm_get_svm_type(const struct svm_model *model);
75
+int svm_get_nr_class(const struct svm_model *model);
76
+void svm_get_labels(const struct svm_model *model, int *label);
77
+double svm_get_svr_probability(const struct svm_model *model);
78
+
79
+void svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values);
80
+double svm_predict(const struct svm_model *model, const struct svm_node *x);
81
+double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates);
82
+
83
+void svm_destroy_model(struct svm_model *model);
84
+void svm_destroy_param(struct svm_parameter *param);
85
+
86
+const char *svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param);
87
+int svm_check_probability_model(const struct svm_model *model);
88
+
89
+#ifdef __cplusplus
90
+}
91
+#endif
92
+
93
+#endif /* _LIBSVM_H */
... ...
@@ -0,0 +1,349 @@
1
+#include <stdlib.h>
2
+#include <string.h>
3
+#include "svm.h"
4
+
5
+#include "mex.h"
6
+
7
+#if MX_API_VER < 0x07030000
8
+typedef int mwIndex;
9
+#endif 
10
+
11
+#define NUM_OF_RETURN_FIELD 10
12
+
13
+#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
14
+
15
+static const char *field_names[] = {
16
+	"Parameters",
17
+	"nr_class",
18
+	"totalSV",
19
+	"rho",
20
+	"Label",
21
+	"ProbA",
22
+	"ProbB",
23
+	"nSV",
24
+	"sv_coef",
25
+	"SVs"
26
+};
27
+
28
+const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
29
+{
30
+	int i, j, n;
31
+	double *ptr;
32
+	mxArray *return_model, **rhs;
33
+	int out_id = 0;
34
+
35
+	rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
36
+
37
+	// Parameters
38
+	rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
39
+	ptr = mxGetPr(rhs[out_id]);
40
+	ptr[0] = model->param.svm_type;
41
+	ptr[1] = model->param.kernel_type;
42
+	ptr[2] = model->param.degree;
43
+	ptr[3] = model->param.gamma;
44
+	ptr[4] = model->param.coef0;
45
+	out_id++;
46
+
47
+	// nr_class
48
+	rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
49
+	ptr = mxGetPr(rhs[out_id]);
50
+	ptr[0] = model->nr_class;
51
+	out_id++;
52
+
53
+	// total SV
54
+	rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
55
+	ptr = mxGetPr(rhs[out_id]);
56
+	ptr[0] = model->l;
57
+	out_id++;
58
+
59
+	// rho
60
+	n = model->nr_class*(model->nr_class-1)/2;
61
+	rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
62
+	ptr = mxGetPr(rhs[out_id]);
63
+	for(i = 0; i < n; i++)
64
+		ptr[i] = model->rho[i];
65
+	out_id++;
66
+
67
+	// Label
68
+	if(model->label)
69
+	{
70
+		rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
71
+		ptr = mxGetPr(rhs[out_id]);
72
+		for(i = 0; i < model->nr_class; i++)
73
+			ptr[i] = model->label[i];
74
+	}
75
+	else
76
+		rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
77
+	out_id++;
78
+
79
+	// probA
80
+	if(model->probA != NULL)
81
+	{
82
+		rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
83
+		ptr = mxGetPr(rhs[out_id]);
84
+		for(i = 0; i < n; i++)
85
+			ptr[i] = model->probA[i];
86
+	}
87
+	else
88
+		rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
89
+	out_id ++;
90
+
91
+	// probB
92
+	if(model->probB != NULL)
93
+	{
94
+		rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
95
+		ptr = mxGetPr(rhs[out_id]);
96
+		for(i = 0; i < n; i++)
97
+			ptr[i] = model->probB[i];
98
+	}
99
+	else
100
+		rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
101
+	out_id++;
102
+
103
+	// nSV
104
+	if(model->nSV)
105
+	{
106
+		rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
107
+		ptr = mxGetPr(rhs[out_id]);
108
+		for(i = 0; i < model->nr_class; i++)
109
+			ptr[i] = model->nSV[i];
110
+	}
111
+	else
112
+		rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
113
+	out_id++;
114
+
115
+	// sv_coef
116
+	rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
117
+	ptr = mxGetPr(rhs[out_id]);
118
+	for(i = 0; i < model->nr_class-1; i++)
119
+		for(j = 0; j < model->l; j++)
120
+			ptr[(i*(model->l))+j] = model->sv_coef[i][j];
121
+	out_id++;
122
+
123
+	// SVs
124
+	{
125
+		int ir_index, nonzero_element;
126
+		mwIndex *ir, *jc;
127
+		mxArray *pprhs[1], *pplhs[1];	
128
+
129
+		if(model->param.kernel_type == PRECOMPUTED)
130
+		{
131
+			nonzero_element = model->l;
132
+			num_of_feature = 1;
133
+		}
134
+		else
135
+		{
136
+			nonzero_element = 0;
137
+			for(i = 0; i < model->l; i++) {
138
+				j = 0;
139
+				while(model->SV[i][j].index != -1) 
140
+				{
141
+					nonzero_element++;
142
+					j++;
143
+				}
144
+			}
145
+		}
146
+
147
+		// SV in column, easier accessing
148
+		rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
149
+		ir = mxGetIr(rhs[out_id]);
150
+		jc = mxGetJc(rhs[out_id]);
151
+		ptr = mxGetPr(rhs[out_id]);
152
+		ir_index = jc[0] = 0;
153
+		for(i = 0;i < model->l; i++)
154
+		{
155
+			if(model->param.kernel_type == PRECOMPUTED)
156
+			{
157
+				// make a (1 x model->l) matrix
158
+				ir[ir_index] = 0; 
159
+				ptr[ir_index] = model->SV[i][0].value;
160
+				ir_index++;
161
+				jc[i+1] = jc[i] + 1;
162
+			}
163
+			else
164
+			{
165
+				int x_index = 0;
166
+				while (model->SV[i][x_index].index != -1)
167
+				{
168
+					ir[ir_index] = model->SV[i][x_index].index - 1; 
169
+					ptr[ir_index] = model->SV[i][x_index].value;
170
+					ir_index++, x_index++;
171
+				}
172
+				jc[i+1] = jc[i] + x_index;
173
+			}
174
+		}
175
+		// transpose back to SV in row
176
+		pprhs[0] = rhs[out_id];
177
+		if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
178
+			return "cannot transpose SV matrix";
179
+		rhs[out_id] = pplhs[0];
180
+		out_id++;
181
+	}
182
+
183
+	/* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
184
+	return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
185
+
186
+	/* Fill struct matrix with input arguments */
187
+	for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
188
+		mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
189
+	/* return */
190
+	plhs[0] = return_model;
191
+	mxFree(rhs);
192
+
193
+	return NULL;
194
+}
195
+
196
+struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
197
+{
198
+	int i, j, n, num_of_fields;
199
+	double *ptr;
200
+	int id = 0;
201
+	struct svm_node *x_space;
202
+	struct svm_model *model;
203
+	mxArray **rhs;
204
+
205
+	num_of_fields = mxGetNumberOfFields(matlab_struct);
206
+	if(num_of_fields != NUM_OF_RETURN_FIELD) 
207
+	{
208
+		*msg = "number of return field is not correct";
209
+		return NULL;
210
+	}
211
+	rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
212
+
213
+	for(i=0;i<num_of_fields;i++)
214
+		rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
215
+
216
+	model = Malloc(struct svm_model, 1);
217
+	model->rho = NULL;
218
+	model->probA = NULL;
219
+	model->probB = NULL;
220
+	model->label = NULL;
221
+	model->nSV = NULL;
222
+	model->free_sv = 1; // XXX
223
+
224
+	ptr = mxGetPr(rhs[id]);
225
+	model->param.svm_type	  = (int)ptr[0];
226
+	model->param.kernel_type  = (int)ptr[1];
227
+	model->param.degree	  = (int)ptr[2];
228
+	model->param.gamma	  = ptr[3];
229
+	model->param.coef0	  = ptr[4];
230
+	id++;
231
+
232
+	ptr = mxGetPr(rhs[id]);
233
+	model->nr_class = (int)ptr[0];
234
+	id++;
235
+
236
+	ptr = mxGetPr(rhs[id]);
237
+	model->l = (int)ptr[0];
238
+	id++;
239
+
240
+	// rho
241
+	n = model->nr_class * (model->nr_class-1)/2;
242
+	model->rho = (double*) malloc(n*sizeof(double));
243
+	ptr = mxGetPr(rhs[id]);
244
+	for(i=0;i<n;i++)
245
+		model->rho[i] = ptr[i];
246
+	id++;
247
+
248
+	// label
249
+	if(mxIsEmpty(rhs[id]) == 0)
250
+	{
251
+		model->label = (int*) malloc(model->nr_class*sizeof(int));
252
+		ptr = mxGetPr(rhs[id]);
253
+		for(i=0;i<model->nr_class;i++)
254
+			model->label[i] = (int)ptr[i];
255
+	}
256
+	id++;
257
+
258
+	// probA
259
+	if(mxIsEmpty(rhs[id]) == 0)
260
+	{
261
+		model->probA = (double*) malloc(n*sizeof(double));
262
+		ptr = mxGetPr(rhs[id]);
263
+		for(i=0;i<n;i++)
264
+			model->probA[i] = ptr[i];
265
+	}
266
+	id++;
267
+
268
+	// probB
269
+	if(mxIsEmpty(rhs[id]) == 0)
270
+	{
271
+		model->probB = (double*) malloc(n*sizeof(double));
272
+		ptr = mxGetPr(rhs[id]);
273
+		for(i=0;i<n;i++)
274
+			model->probB[i] = ptr[i];
275
+	}
276
+	id++;
277
+
278
+	// nSV
279
+	if(mxIsEmpty(rhs[id]) == 0)
280
+	{
281
+		model->nSV = (int*) malloc(model->nr_class*sizeof(int));
282
+		ptr = mxGetPr(rhs[id]);
283
+		for(i=0;i<model->nr_class;i++)
284
+			model->nSV[i] = (int)ptr[i];
285
+	}
286
+	id++;
287
+
288
+	// sv_coef
289
+	ptr = mxGetPr(rhs[id]);
290
+	model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
291
+	for( i=0 ; i< model->nr_class -1 ; i++ )
292
+		model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
293
+	for(i = 0; i < model->nr_class - 1; i++)
294
+		for(j = 0; j < model->l; j++)
295
+			model->sv_coef[i][j] = ptr[i*(model->l)+j];
296
+	id++;
297
+
298
+	// SV
299
+	{
300
+		int sr, sc, elements;
301
+		int num_samples;
302
+		mwIndex *ir, *jc;
303
+		mxArray *pprhs[1], *pplhs[1];
304
+
305
+		// transpose SV
306
+		pprhs[0] = rhs[id];
307
+		if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 
308
+		{
309
+			svm_destroy_model(model);
310
+			*msg = "cannot transpose SV matrix";
311
+			return NULL;
312
+		}
313
+		rhs[id] = pplhs[0];
314
+
315
+		sr = mxGetN(rhs[id]);
316
+		sc = mxGetM(rhs[id]);
317
+
318
+		ptr = mxGetPr(rhs[id]);
319
+		ir = mxGetIr(rhs[id]);
320
+		jc = mxGetJc(rhs[id]);
321
+
322
+		num_samples = mxGetNzmax(rhs[id]);
323
+
324
+		elements = num_samples + sr;
325
+
326
+		model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
327
+		x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
328
+
329
+		// SV is in column
330
+		for(i=0;i<sr;i++)
331
+		{
332
+			int low = jc[i], high = jc[i+1];
333
+			int x_index = 0;
334
+			model->SV[i] = &x_space[low+i];
335
+			for(j=low;j<high;j++)
336
+			{
337
+				model->SV[i][x_index].index = ir[j] + 1; 
338
+				model->SV[i][x_index].value = ptr[j];
339
+				x_index++;
340
+			}
341
+			model->SV[i][x_index].index = -1;
342
+		}
343
+
344
+		id++;
345
+	}
346
+	mxFree(rhs);
347
+
348
+	return model;
349
+}
... ...
@@ -0,0 +1,2 @@
1
+const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model);
2
+struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **error_message);
... ...
@@ -0,0 +1,339 @@
1
+#include <stdio.h>
2
+#include <stdlib.h>
3
+#include <string.h>
4
+#include "svm.h"
5
+
6
+#include "mex.h"
7
+#include "svm_model_matlab.h"
8
+
9
+#if MX_API_VER < 0x07030000
10
+typedef int mwIndex;
11
+#endif 
12
+
13
+#define CMD_LEN 2048
14
+
15
+void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
16
+{
17
+	int i, j, low, high;
18
+	mwIndex *ir, *jc;
19
+	double *samples;
20
+
21
+	ir = mxGetIr(prhs);
22
+	jc = mxGetJc(prhs);
23
+	samples = mxGetPr(prhs);
24
+
25
+	// each column is one instance
26
+	j = 0;
27
+	low = jc[index], high = jc[index+1];
28
+	for(i=low;i<high;i++)
29
+	{
30
+		x[j].index = ir[i] + 1;
31
+		x[j].value = samples[i];
32
+		j++;
33
+ 	}
34
+	x[j].index = -1;
35
+}
36
+
37
+static void fake_answer(mxArray *plhs[])
38
+{
39
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
40
+	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
41
+	plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
42
+}
43
+
44
+void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
45
+{
46
+	int label_vector_row_num, label_vector_col_num;
47
+	int feature_number, testing_instance_number;
48
+	int instance_index;
49
+	double *ptr_instance, *ptr_label, *ptr_predict_label; 
50
+	double *ptr_prob_estimates, *ptr_dec_values, *ptr;
51
+	struct svm_node *x;
52
+	mxArray *pplhs[1]; // transposed instance sparse matrix
53
+
54
+	int correct = 0;
55
+	int total = 0;
56
+	double error = 0;
57
+	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
58
+
59
+	int svm_type=svm_get_svm_type(model);
60
+	int nr_class=svm_get_nr_class(model);
61
+	double *prob_estimates=NULL;
62
+
63
+	// prhs[1] = testing instance matrix
64
+	feature_number = mxGetN(prhs[1]);
65
+	testing_instance_number = mxGetM(prhs[1]);
66
+	label_vector_row_num = mxGetM(prhs[0]);
67
+	label_vector_col_num = mxGetN(prhs[0]);
68
+
69
+	if(label_vector_row_num!=testing_instance_number)
70
+	{
71
+		mexPrintf("Length of label vector does not match # of instances.\n");
72
+		fake_answer(plhs);
73
+		return;
74
+	}
75
+	if(label_vector_col_num!=1)
76
+	{
77
+		mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
78
+		fake_answer(plhs);
79
+		return;
80
+	}
81
+
82
+	ptr_instance = mxGetPr(prhs[1]);
83
+	ptr_label    = mxGetPr(prhs[0]);
84
+	
85
+	// transpose instance matrix
86
+	if(mxIsSparse(prhs[1]))
87
+	{
88
+		if(model->param.kernel_type == PRECOMPUTED)
89
+		{
90
+			// precomputed kernel requires dense matrix, so we make one
91
+			mxArray *rhs[1], *lhs[1];
92
+			rhs[0] = mxDuplicateArray(prhs[1]);
93
+			if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
94
+			{
95
+				mexPrintf("Error: cannot full testing instance matrix\n");
96
+				fake_answer(plhs);
97
+				return;
98
+			}
99
+			ptr_instance = mxGetPr(lhs[0]);
100
+			mxDestroyArray(rhs[0]);
101
+		}
102
+		else
103
+		{
104
+			mxArray *pprhs[1];
105
+			pprhs[0] = mxDuplicateArray(prhs[1]);
106
+			if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
107
+			{
108
+				mexPrintf("Error: cannot transpose testing instance matrix\n");
109
+				fake_answer(plhs);
110
+				return;
111
+			}
112
+		}
113
+	}
114
+
115
+	if(predict_probability)
116
+	{
117
+		if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
118
+			mexPrintf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
119
+		else
120
+			prob_estimates = (double *) malloc(nr_class*sizeof(double));
121
+	}
122
+
123
+	plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
124
+	if(predict_probability)
125
+	{
126
+		// prob estimates are in plhs[2]
127
+		if(svm_type==C_SVC || svm_type==NU_SVC)
128
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
129
+		else
130
+			plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
131
+	}
132
+	else
133
+	{
134
+		// decision values are in plhs[2]
135
+		if(svm_type == ONE_CLASS ||
136
+		   svm_type == EPSILON_SVR ||
137
+		   svm_type == NU_SVR)
138
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
139
+		else
140
+			plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
141
+	}
142
+
143
+	ptr_predict_label = mxGetPr(plhs[0]);
144
+	ptr_prob_estimates = mxGetPr(plhs[2]);
145
+	ptr_dec_values = mxGetPr(plhs[2]);
146
+	x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
147
+	for(instance_index=0;instance_index<testing_instance_number;instance_index++)
148
+	{
149
+		int i;
150
+		double target,v;
151
+
152
+		target = ptr_label[instance_index];
153
+
154
+		if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
155
+			read_sparse_instance(pplhs[0], instance_index, x);
156
+		else
157
+		{
158
+			for(i=0;i<feature_number;i++)
159
+			{
160
+				x[i].index = i+1;
161
+				x[i].value = ptr_instance[testing_instance_number*i+instance_index];
162
+			}
163
+			x[feature_number].index = -1;
164
+		}
165
+
166
+		if(predict_probability) 
167
+		{
168
+			if(svm_type==C_SVC || svm_type==NU_SVC)
169
+			{
170
+				v = svm_predict_probability(model, x, prob_estimates);
171
+				ptr_predict_label[instance_index] = v;
172
+				for(i=0;i<nr_class;i++)
173
+					ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
174
+			} else {
175
+				v = svm_predict(model,x);
176
+				ptr_predict_label[instance_index] = v;
177
+			}
178
+		}
179
+		else
180
+		{
181
+			v = svm_predict(model,x);
182
+			ptr_predict_label[instance_index] = v;
183
+
184
+			if(svm_type == ONE_CLASS ||
185
+			   svm_type == EPSILON_SVR ||
186
+			   svm_type == NU_SVR)
187
+			{
188
+				double res;
189
+				svm_predict_values(model, x, &res);
190
+				ptr_dec_values[instance_index] = res;
191
+			}
192
+			else
193
+			{
194
+				double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
195
+				svm_predict_values(model, x, dec_values);
196
+				for(i=0;i<(nr_class*(nr_class-1))/2;i++)
197
+					ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
198
+				free(dec_values);
199
+			}
200
+		}
201
+
202
+		if(v == target)
203
+			++correct;
204
+		error += (v-target)*(v-target);
205
+		sumv += v;
206
+		sumy += target;
207
+		sumvv += v*v;
208
+		sumyy += target*target;
209
+		sumvy += v*target;
210
+		++total;
211
+	}
212
+	if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
213
+	{
214
+		mexPrintf("Mean squared error = %g (regression)\n",error/total);
215
+		mexPrintf("Squared correlation coefficient = %g (regression)\n",
216
+			((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
217
+			((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))
218
+			);
219
+	}
220
+	else
221
+		mexPrintf("Accuracy = %g%% (%d/%d) (classification)\n",
222
+			(double)correct/total*100,correct,total);
223
+
224
+	// return accuracy, mean squared error, squared correlation coefficient
225
+	plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
226
+	ptr = mxGetPr(plhs[1]);
227
+	ptr[0] = (double)correct/total*100;
228
+	ptr[1] = error/total;
229
+	ptr[2] = ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
230
+				((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy));
231
+
232
+	free(x);
233
+	if(prob_estimates != NULL)
234
+		free(prob_estimates);
235
+}
236
+
237
+void exit_with_help()
238
+{
239
+	mexPrintf(
240
+	"Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
241
+	"libsvm_options:\n"
242
+	"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
243
+	);
244
+}
245
+
246
+void mexFunction( int nlhs, mxArray *plhs[],
247
+		 int nrhs, const mxArray *prhs[] )
248
+{
249
+	int prob_estimate_flag = 0;
250
+	struct svm_model *model;
251
+
252
+	if(nrhs > 4 || nrhs < 3)
253
+	{
254
+		exit_with_help();
255
+		fake_answer(plhs);
256
+		return;
257
+	}
258
+
259
+	if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
260
+		mexPrintf("Error: label vector and instance matrix must be double\n");
261
+		fake_answer(plhs);
262
+		return;
263
+	}
264
+
265
+	if(mxIsStruct(prhs[2]))
266
+	{
267
+		const char *error_msg;
268
+
269
+		// parse options
270
+		if(nrhs==4)
271
+		{
272
+			int i, argc = 1;
273
+			char cmd[CMD_LEN], *argv[CMD_LEN/2];
274
+
275
+			// put options in argv[]
276
+			mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
277
+			if((argv[argc] = strtok(cmd, " ")) != NULL)
278
+				while((argv[++argc] = strtok(NULL, " ")) != NULL)
279
+					;
280
+
281
+			for(i=1;i<argc;i++)
282
+			{
283
+				if(argv[i][0] != '-') break;
284
+				if(++i>=argc)
285
+				{
286
+					exit_with_help();
287
+					fake_answer(plhs);
288
+					return;
289
+				}
290
+				switch(argv[i-1][1])
291
+				{
292
+					case 'b':
293
+						prob_estimate_flag = atoi(argv[i]);
294
+						break;
295
+					default:
296
+						mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
297
+						exit_with_help();
298
+						fake_answer(plhs);
299
+						return;
300
+				}
301
+			}
302
+		}
303
+
304
+		model = matlab_matrix_to_model(prhs[2], &error_msg);
305
+		if (model == NULL)
306
+		{
307
+			mexPrintf("Error: can't read model: %s\n", error_msg);
308
+			fake_answer(plhs);
309
+			return;
310
+		}
311
+
312
+		if(prob_estimate_flag)
313
+		{
314
+			if(svm_check_probability_model(model)==0)
315
+			{
316
+				mexPrintf("Model does not support probabiliy estimates\n");
317
+				fake_answer(plhs);
318
+				svm_destroy_model(model);
319
+				return;
320
+			}
321
+		}
322
+		else
323
+		{
324
+			if(svm_check_probability_model(model)!=0)
325
+				printf("Model supports probability estimates, but disabled in predicton.\n");
326
+		}
327
+
328
+		predict(plhs, prhs, model, prob_estimate_flag);
329
+		// destroy model
330
+		svm_destroy_model(model);
331
+	}
332
+	else
333
+	{
334
+		mexPrintf("model file should be a struct array\n");
335
+		fake_answer(plhs);
336
+	}
337
+
338
+	return;
339
+}
... ...
@@ -0,0 +1,458 @@
1
+#include <stdio.h>
2
+#include <stdlib.h>
3
+#include <string.h>
4
+#include <ctype.h>
5
+#include "svm.h"
6
+
7
+#include "mex.h"
8
+#include "svm_model_matlab.h"
9
+
10
+#if MX_API_VER < 0x07030000
11
+typedef int mwIndex;
12
+#endif 
13
+
14
+#define CMD_LEN 2048
15
+#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
16
+
17
+void exit_with_help()
18
+{
19
+	mexPrintf(
20
+	"Usage: model = svmtrain(training_label_vector, training_instance_matrix, 'libsvm_options');\n"
21
+	"libsvm_options:\n"
22
+	"-s svm_type : set type of SVM (default 0)\n"
23
+	"	0 -- C-SVC\n"
24
+	"	1 -- nu-SVC\n"
25
+	"	2 -- one-class SVM\n"
26
+	"	3 -- epsilon-SVR\n"
27
+	"	4 -- nu-SVR\n"
28
+	"-t kernel_type : set type of kernel function (default 2)\n"
29
+	"	0 -- linear: u'*v\n"
30
+	"	1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
31
+	"	2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
32
+	"	3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
33
+	"	4 -- precomputed kernel (kernel values in training_instance_matrix)\n"
34
+	"-d degree : set degree in kernel function (default 3)\n"
35
+	"-g gamma : set gamma in kernel function (default 1/k)\n"
36
+	"-r coef0 : set coef0 in kernel function (default 0)\n"
37
+	"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
38
+	"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
39
+	"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
40
+	"-m cachesize : set cache memory size in MB (default 100)\n"
41
+	"-e epsilon : set tolerance of termination criterion (default 0.001)\n"
42
+	"-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
43
+	"-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
44
+	"-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
45
+	"-v n: n-fold cross validation mode\n"
46
+	);
47
+}
48
+
49
+// svm arguments
50
+struct svm_parameter param;		// set by parse_command_line
51
+struct svm_problem prob;		// set by read_problem
52
+struct svm_model *model;
53
+struct svm_node *x_space;
54
+int cross_validation;
55
+int nr_fold;
56
+
57
+double do_cross_validation()
58
+{
59
+	int i;
60
+	int total_correct = 0;
61
+	double total_error = 0;
62
+	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
63
+	double *target = Malloc(double,prob.l);
64
+	double retval = 0.0;
65
+
66
+	svm_cross_validation(&prob,&param,nr_fold,target);
67
+	if(param.svm_type == EPSILON_SVR ||
68
+	   param.svm_type == NU_SVR)
69
+	{
70
+		for(i=0;i<prob.l;i++)
71
+		{
72
+			double y = prob.y[i];
73
+			double v = target[i];
74
+			total_error += (v-y)*(v-y);
75
+			sumv += v;
76
+			sumy += y;
77
+			sumvv += v*v;
78
+			sumyy += y*y;
79
+			sumvy += v*y;
80
+		}
81
+		mexPrintf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
82
+		mexPrintf("Cross Validation Squared correlation coefficient = %g\n",
83
+			((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
84
+			((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
85
+			);
86
+		retval = total_error/prob.l;
87
+	}
88
+	else
89
+	{
90
+		for(i=0;i<prob.l;i++)
91
+			if(target[i] == prob.y[i])
92
+				++total_correct;
93
+		mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
94
+		retval = 100.0*total_correct/prob.l;
95
+	}
96
+	free(target);
97
+	return retval;
98
+}
99
+
100
+// nrhs should be 3
101
+int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
102
+{
103
+	int i, argc = 1;
104
+	char cmd[CMD_LEN];
105
+	char *argv[CMD_LEN/2];
106
+
107
+	// default values
108
+	param.svm_type = C_SVC;
109
+	param.kernel_type = RBF;
110
+	param.degree = 3;
111
+	param.gamma = 0;	// 1/k
112
+	param.coef0 = 0;
113
+	param.nu = 0.5;
114
+	param.cache_size = 100;
115
+	param.C = 1;
116
+	param.eps = 1e-3;
117
+	param.p = 0.1;
118
+	param.shrinking = 1;
119
+	param.probability = 0;
120
+	param.nr_weight = 0;
121
+	param.weight_label = NULL;
122
+	param.weight = NULL;
123
+	cross_validation = 0;
124
+
125
+	if(nrhs <= 1)
126
+		return 1;
127
+
128
+	if(nrhs > 2)
129
+	{
130
+		// put options in argv[]
131
+		mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
132
+		if((argv[argc] = strtok(cmd, " ")) != NULL)
133
+			while((argv[++argc] = strtok(NULL, " ")) != NULL)
134
+				;
135
+	}
136
+
137
+	// parse options
138
+	for(i=1;i<argc;i++)
139
+	{
140
+		if(argv[i][0] != '-') break;
141
+		if(++i>=argc)
142
+			return 1;
143
+		switch(argv[i-1][1])
144
+		{
145
+			case 's':
146
+				param.svm_type = atoi(argv[i]);
147
+				break;
148
+			case 't':
149
+				param.kernel_type = atoi(argv[i]);
150
+				break;
151
+			case 'd':
152
+				param.degree = atoi(argv[i]);
153
+				break;
154
+			case 'g':
155
+				param.gamma = atof(argv[i]);
156
+				break;
157
+			case 'r':
158
+				param.coef0 = atof(argv[i]);
159
+				break;
160
+			case 'n':
161
+				param.nu = atof(argv[i]);
162
+				break;
163
+			case 'm':
164
+				param.cache_size = atof(argv[i]);
165
+				break;
166
+			case 'c':
167
+				param.C = atof(argv[i]);
168
+				break;
169
+			case 'e':
170
+				param.eps = atof(argv[i]);
171
+				break;
172
+			case 'p':
173
+				param.p = atof(argv[i]);
174
+				break;
175
+			case 'h':
176
+				param.shrinking = atoi(argv[i]);
177
+				break;
178
+			case 'b':
179
+				param.probability = atoi(argv[i]);
180
+				break;
181
+			case 'v':
182
+				cross_validation = 1;
183
+				nr_fold = atoi(argv[i]);
184
+				if(nr_fold < 2)
185
+				{
186
+					mexPrintf("n-fold cross validation: n must >= 2\n");
187
+					return 1;
188
+				}
189
+				break;
190
+			case 'w':
191
+				++param.nr_weight;
192
+				param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
193
+				param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
194
+				param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
195
+				param.weight[param.nr_weight-1] = atof(argv[i]);
196
+				break;
197
+			default:
198
+				mexPrintf("Unknown option -%c\n", argv[i-1][1]);
199
+				return 1;
200
+		}
201
+	}
202
+	return 0;
203
+}
204
+
205
+// read in a problem (in svmlight format)
206
+int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat)
207
+{
208
+	int i, j, k;
209
+	int elements, max_index, sc, label_vector_row_num;
210
+	double *samples, *labels;
211
+
212
+	prob.x = NULL;
213
+	prob.y = NULL;
214
+	x_space = NULL;
215
+
216
+	labels = mxGetPr(label_vec);
217
+	samples = mxGetPr(instance_mat);
218
+	sc = mxGetN(instance_mat);
219
+
220
+	elements = 0;
221
+	// the number of instance
222
+	prob.l = mxGetM(instance_mat);
223
+	label_vector_row_num = mxGetM(label_vec);
224
+
225
+	if(label_vector_row_num!=prob.l)
226
+	{
227
+		mexPrintf("Length of label vector does not match # of instances.\n");
228
+		return -1;
229
+	}
230
+
231
+	if(param.kernel_type == PRECOMPUTED)
232
+		elements = prob.l * (sc + 1);
233
+	else
234
+	{
235
+		for(i = 0; i < prob.l; i++)
236
+		{
237
+			for(k = 0; k < sc; k++)
238
+				if(samples[k * prob.l + i] != 0)
239
+					elements++;
240
+			// count the '-1' element
241
+			elements++;
242
+		}
243
+	}
244
+
245
+	prob.y = Malloc(double,prob.l);
246
+	prob.x = Malloc(struct svm_node *,prob.l);
247
+	x_space = Malloc(struct svm_node, elements);
248
+
249
+	max_index = sc;
250
+	j = 0;
251
+	for(i = 0; i < prob.l; i++)
252
+	{
253
+		prob.x[i] = &x_space[j];
254
+		prob.y[i] = labels[i];
255
+
256
+		for(k = 0; k < sc; k++)
257
+		{
258
+			if(param.kernel_type == PRECOMPUTED || samples[k * prob.l + i] != 0)
259
+			{
260
+				x_space[j].index = k + 1;
261
+				x_space[j].value = samples[k * prob.l + i];
262
+				j++;
263
+			}
264
+		}
265
+		x_space[j++].index = -1;
266
+	}
267
+
268
+	if(param.gamma == 0)
269
+		param.gamma = 1.0/max_index;
270
+
271
+	if(param.kernel_type == PRECOMPUTED)
272
+		for(i=0;i<prob.l;i++)
273
+		{
274
+			if((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
275
+			{
276
+				mexPrintf("Wrong input format: sample_serial_number out of range\n");
277
+				return -1;
278
+			}
279
+		}
280
+
281
+	return 0;
282
+}
283
+
284
+int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
285
+{
286
+	int i, j, k, low, high;
287
+	mwIndex *ir, *jc;
288
+	int elements, max_index, num_samples, label_vector_row_num;
289
+	double *samples, *labels;
290
+	mxArray *instance_mat_col; // transposed instance sparse matrix
291
+
292
+	prob.x = NULL;
293
+	prob.y = NULL;
294
+	x_space = NULL;
295
+
296
+	// transpose instance matrix
297
+	{
298
+		mxArray *prhs[1], *plhs[1];
299
+		prhs[0] = mxDuplicateArray(instance_mat);
300
+		if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
301
+		{
302
+			mexPrintf("Error: cannot transpose training instance matrix\n");
303
+			return -1;
304
+		}
305
+		instance_mat_col = plhs[0];
306
+		mxDestroyArray(prhs[0]);
307
+	}
308
+
309
+	// each column is one instance
310
+	labels = mxGetPr(label_vec);
311
+	samples = mxGetPr(instance_mat_col);
312
+	ir = mxGetIr(instance_mat_col);
313
+	jc = mxGetJc(instance_mat_col);
314
+
315
+	num_samples = mxGetNzmax(instance_mat_col);
316
+
317
+	// the number of instance
318
+	prob.l = mxGetN(instance_mat_col);
319
+	label_vector_row_num = mxGetM(label_vec);
320
+
321
+	if(label_vector_row_num!=prob.l)
322
+	{
323
+		mexPrintf("Length of label vector does not match # of instances.\n");
324
+		return -1;
325
+	}
326
+
327
+	elements = num_samples + prob.l;
328
+	max_index = mxGetM(instance_mat_col);
329
+
330
+	prob.y = Malloc(double,prob.l);
331
+	prob.x = Malloc(struct svm_node *,prob.l);
332
+	x_space = Malloc(struct svm_node, elements);
333
+
334
+	j = 0;
335
+	for(i=0;i<prob.l;i++)
336
+	{
337
+		prob.x[i] = &x_space[j];
338
+		prob.y[i] = labels[i];
339
+		low = jc[i], high = jc[i+1];
340
+		for(k=low;k<high;k++)
341
+		{
342
+			x_space[j].index = ir[k] + 1;
343
+			x_space[j].value = samples[k];
344
+			j++;
345
+	 	}
346
+		x_space[j++].index = -1;
347
+	}
348
+
349
+	if(param.gamma == 0)
350
+		param.gamma = 1.0/max_index;
351
+
352
+	return 0;
353
+}
354
+
355
+static void fake_answer(mxArray *plhs[])
356
+{
357
+	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
358
+}
359
+
360
+// Interface function of matlab
361
+// now assume prhs[0]: label prhs[1]: features
362
+void mexFunction( int nlhs, mxArray *plhs[],
363
+		int nrhs, const mxArray *prhs[] )
364
+{
365
+	const char *error_msg;
366
+
367
+	// fix random seed to have same results for each run
368
+	// (for cross validation and probability estimation)
369
+	srand(1);
370
+
371
+	// Transform the input Matrix to libsvm format
372
+	if(nrhs > 0 && nrhs < 4)
373
+	{
374
+		int err;
375
+
376
+		if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
377
+			mexPrintf("Error: label vector and instance matrix must be double\n");
378
+			fake_answer(plhs);
379
+			return;
380
+		}
381
+
382
+		if(parse_command_line(nrhs, prhs, NULL))
383
+		{
384
+			exit_with_help();
385
+			svm_destroy_param(&param);
386
+			fake_answer(plhs);
387
+			return;
388
+		}
389
+
390
+		if(mxIsSparse(prhs[1]))
391
+		{
392
+			if(param.kernel_type == PRECOMPUTED)
393
+			{
394
+				// precomputed kernel requires dense matrix, so we make one
395
+				mxArray *rhs[1], *lhs[1];
396
+
397
+				rhs[0] = mxDuplicateArray(prhs[1]);
398
+				if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
399
+				{
400
+					mexPrintf("Error: cannot generate a full training instance matrix\n");
401
+					svm_destroy_param(&param);
402
+					fake_answer(plhs);
403
+					return;
404
+				}
405
+				err = read_problem_dense(prhs[0], lhs[0]);
406
+				mxDestroyArray(lhs[0]);
407
+				mxDestroyArray(rhs[0]);
408
+			}
409
+			else
410
+				err = read_problem_sparse(prhs[0], prhs[1]);
411
+		}
412
+		else
413
+			err = read_problem_dense(prhs[0], prhs[1]);
414
+
415
+		// svmtrain's original code
416
+		error_msg = svm_check_parameter(&prob, &param);
417
+
418
+		if(err || error_msg)
419
+		{
420
+			if (error_msg != NULL)
421
+				mexPrintf("Error: %s\n", error_msg);
422
+			svm_destroy_param(&param);
423
+			free(prob.y);
424
+			free(prob.x);
425
+			free(x_space);
426
+			fake_answer(plhs);
427
+			return;
428
+		}
429
+
430
+		if(cross_validation)
431
+		{
432
+			double *ptr;
433
+			plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
434
+			ptr = mxGetPr(plhs[0]);
435
+			ptr[0] = do_cross_validation();
436
+		}
437
+		else
438
+		{
439
+			int nr_feat = mxGetN(prhs[1]);
440
+			const char *error_msg;
441
+			model = svm_train(&prob, &param);
442
+			error_msg = model_to_matlab_structure(plhs, nr_feat, model);
443
+			if(error_msg)
444
+				mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
445
+			svm_destroy_model(model);
446
+		}
447
+		svm_destroy_param(&param);
448
+		free(prob.y);
449
+		free(prob.x);
450
+		free(x_space);
451
+	}
452
+	else
453
+	{
454
+		exit_with_help();
455
+		fake_answer(plhs);
456
+		return;
457
+	}
458
+}
... ...
@@ -0,0 +1,91 @@
1
+function plotDecodePerformance(varargin)
2
+% plotDecodePerformance(timeline,decodePerformance,nClasses,rawData)
3
+
4
+if(nargin==1)
5
+    inputStruct       = cell2mat(varargin(1));
6
+    
7
+    psthStart         = inputStruct.psthStart;
8
+    psthEnd           = inputStruct.psthEnd;
9
+    nClasses          = inputStruct.nClasses;
10
+    decodePerformance = inputStruct.decodePerformance;
11
+    frameStart        = inputStruct.frameShiftStart;
12
+    frameEnd          = inputStruct.frameShiftEnd;
13
+    psth              = inputStruct.rawTimeCourse;
14
+    SubjectID         = inputStruct.SubjectID;
15
+    
16
+    
17
+elseif( nargin == 7)
18
+    
19
+    psthStart   = cell2mat(varargin(1));
20
+    psthEnd     = cell2mat(varargin(2));
21
+    nClasses    = cell2mat(varargin(3));
22
+    decodePerformance = cell2mat(varargin(4));
23
+    frameStart  = cell2mat(varargin(5));
24
+    frameEnd    = cell2mat(varargin(6));
25
+    psth        = varargin(7);
26
+    psth        = psth{1};
27
+    SubjectID   = '';
28
+end
29
+
30
+    f = figure;
31
+    subplot(2,1,1);
32
+    hold on;
33
+      for voxel = 1:size(psth,2)
34
+          for label = 1:size(psth{voxel},2)
35
+              psthData = [];
36
+              for timepoint = 1:size(psth{voxel}{label},2)
37
+                  psthData = nanmean(psth{voxel}{label})+voxel/100;
38
+              end
39
+              plot(psthStart:psthEnd,psthData,[colorChooser(voxel), lineStyleChooser(label)]);
40
+          end
41
+      end
42
+%     axis([psthStart psthEnd 0 0])
43
+    hold off
44
+    
45
+    subplot(2,1,2)    
46
+    hold on;
47
+    plot(frameStart:frameEnd, decodePerformance ,'b');
48
+    chanceLevel = 100/nClasses;
49
+    goodPredictionLevel = chanceLevel*1.5;
50
+    plot([psthStart psthEnd],[chanceLevel chanceLevel],'r');
51
+    plot([psthStart psthEnd],[goodPredictionLevel goodPredictionLevel],'g');
52
+    axis([psthStart psthEnd 0 100])
53
+
54
+    hold off;
55
+
56
+    title = sprintf('Subject %s, over %g voxel',SubjectID,size(psth,2));
57
+    set(f,'Name',title);
58
+    display(sprintf('%s',title));
59
+
60
+
61
+
62
+end
63
+
64
+function color = colorChooser(n)
65
+    switch (mod(n,8))
66
+    case 0
67
+        color = 'y';
68
+    case 1
69
+        color = 'r';
70
+    case 2
71
+        color = 'b';
72
+    case 3
73
+        color = 'g';
74
+    otherwise
75
+        color = 'k';
76
+    end
77
+end
78
+
79
+function style = lineStyleChooser(n)
80
+switch(mod(n,4))
81
+    case 0
82
+      style = '--';
83
+    case 1
84
+        style = '-';
85
+    case 2 
86
+        style = ':';
87
+    case 3
88
+        style = ':-';
89
+end
90
+end
91
+
... ...
@@ -0,0 +1,12 @@
1
+function sortedList = sortedAdd(element, list)
2
+    if(isempty(list))
3
+        sortedList = element;
4
+        return;
5
+    end
6
+    head = list(1);
7
+    if element.id < head.id
8
+        sortedList = [element list];
9
+    else 
10
+        sortedList = [head sortedAdd(element, list(2:length(list)))];
11
+    end
12
+end
0 13
\ No newline at end of file
... ...
@@ -0,0 +1,197 @@
1
+function spm_SVMCrossVal
2
+
3
+
4
+%  Initialize and hide the GUI as it is being constructed.
5
+    frameWidth=450;
6
+    frameHeight=450;
7
+    frame = figure('Visible','off','Position',[0,0,frameWidth,frameHeight]);
8
+    movegui(frame,'west'); % get this thing visible on smaller displays.
9
+    
10
+    set(frame,'Name','SVMCrossVal Decode Performance 4 SPM');
11
+    set(frame,'NumberTitle','off');
12
+    set(frame,'MenuBar','none');
13
+    set(frame,'Color',get(0,'defaultUicontrolBackgroundColor'));
14
+    set(frame,'Resize','off');
15
+    set(frame,'Units','normalize');
16
+
17
+
18
+    optionLineHeight = 1.0/16.0;
19
+    controlElementHeight=optionLineHeight*(1.0/1.5)*frameHeight;
20
+    pMain        = uipanel(frame,'Title','Main Panel',             'Position',[0 optionLineHeight*10 frameWidth optionLineHeight*6]); 
21
+    pAdvanced    = uipanel(frame,'Title','Advanced Options',       'Position',[0 optionLineHeight*5  frameWidth optionLineHeight*5]); 
22
+    pDisplay     = uipanel(frame,'Title','Display Options',        'Position',[0 optionLineHeight*1  frameWidth optionLineHeight*4]); 
23
+    btnRunButton = uicontrol(frame,'Tag','run','String','Run PSTH','Position',[0 optionLineHeight*0  frameWidth frameHeight/16]);
24
+    
25
+    %Main
26
+     firstColumn  = 0.00*frameWidth;
27
+     secondColumn = 0.33*frameWidth;
28
+     thirdColumn  = 0.66*frameWidth;
29
+     
30
+     firstRow  = 6.3*controlElementHeight;
31
+     secondRow = 5.3*controlElementHeight;
32
+     thirdRow  = 4.3*controlElementHeight;
33
+     fourthRow = 3.3*controlElementHeight;
34
+     fifthRow  = 2.3*controlElementHeight;
35
+     sixthRow  = 1.0*controlElementHeight;
36
+     
37
+     createLabel(pMain, [firstColumn firstRow  0.33*frameWidth controlElementHeight],'Position'); %     lPosition 
38
+     createLabel(pMain, [firstColumn secondRow 0.33*frameWidth controlElementHeight],'Voxel Sphere Radius' );%lRadius
39
+     lEvents            = createLabel(pMain, [firstColumn thirdRow  0.33*frameWidth controlElementHeight],'Event List' );
40
+     lSessions          = createLabel(pMain, [firstColumn fourthRow 0.33*frameWidth controlElementHeight],'Session List'  );
41
+     lNormalize         = createLabel(pMain, [firstColumn fifthRow  0.33*frameWidth controlElementHeight],'Normalization Method' );
42
+     lParametric        = createLabel(pMain, [firstColumn sixthRow  0.25*frameWidth controlElementHeight],'Parametric Modulation');
43
+     lParametricFactor  = createLabel(pMain, [(secondColumn+0.33*frameWidth*0.2) sixthRow 0.33*frameWidth*0.8 controlElementHeight],'Modulation Factor');
44
+
45
+     model.txtPosition  = createTextField(pMain, [secondColumn firstRow  0.33*frameWidth controlElementHeight],'0 0 0');
46
+     btnParseHReg       = createButton(pMain,    [thirdColumn  firstRow  0.33*frameWidth controlElementHeight],'hReg', 'parse hReg',model.txtPosition);
47
+
48
+     model.txtRadius    = createTextField(pMain, [secondColumn secondRow 0.33*frameWidth controlElementHeight],'3');
49
+
50
+     model.txtEvents    = createTextField(pMain, [secondColumn thirdRow  0.33*frameWidth controlElementHeight],'');
51
+     btnEvents          = createButton(pMain,    [thirdColumn  thirdRow  0.33*frameWidth controlElementHeight],'events', 'show Event List',model.txtEvents);
52
+        set(btnEvents,'Enable','off');
53
+
54
+     model.txtSessions  = createTextField(pMain, [secondColumn fourthRow 0.33*frameWidth controlElementHeight],'');
55
+
56
+     model.normalization = createDropDown(pMain, [secondColumn fifthRow  0.33*frameWidth controlElementHeight],...
57
+                           defaults.tools.psth4spm.normalizeSelectionModel);
58
+
59
+     model.chkParametric = uicontrol(pMain,'Position',[secondColumn sixthRow 0.33*frameWidth*0.2 controlElementHeight],'Style','checkbox');
60
+     model.txtParametricMappingFactor = createTextField(pMain,     [thirdColumn  sixthRow 0.33*frameWidth controlElementHeight],'1.0');
61
+        set(model.txtParametricMappingFactor,'Enable','off');
62
+        set(model.chkParametric,'Callback',{@cbToggleEnableTarget,model.txtParametricMappingFactor});
63
+
64
+     %Advanced
65
+    firstColumn  = 0.00*frameWidth;
66
+    secondColumn = 0.33*frameWidth;
67
+    thirdColumn  = 0.66*frameWidth;
68
+    fourthColumn = 0.84*frameWidth;
69
+    
70
+    firstRow    = 5.5*controlElementHeight;
71
+    secondRow   = 4.5*controlElementHeight;
72
+    thirdRow    = 3.5*controlElementHeight;
73
+    fourthRow   = 2*controlElementHeight;
74
+    
75
+    lStart  = createLabel(pAdvanced, [secondColumn firstRow 0.33*frameWidth controlElementHeight],'Start [sec]');
76
+    lEnd    = createLabel(pAdvanced, [thirdColumn  firstRow 0.33*frameWidth controlElementHeight],'End [sec]');
77
+    lBaseline = createLabel(pAdvanced,[firstColumn secondRow 0.33*frameWidth controlElementHeight],'Baseline');
78
+    lTimeRange = createLabel(pAdvanced,[firstColumn thirdRow 0.33*frameWidth controlElementHeight],'Time Range (X-Axis)');
79
+    lTemporalResolutionMultiplyer = createLabel(pAdvanced, [firstColumn fourthRow 0.33*frameWidth controlElementHeight],'TR Factor');
80
+    
81
+    
82
+    model.txtBaselineStart = createTextField(pAdvanced,[secondColumn secondRow 0.25*frameWidth controlElementHeight],'-3.0');
83
+    model.txtBaselineEnd = createTextField(pAdvanced,[thirdColumn secondRow 0.25*frameWidth controlElementHeight],'-1.0');
84
+    model.txtTimeRangeStart = createTextField(pAdvanced,[secondColumn thirdRow 0.25*frameWidth controlElementHeight],'-5.0');
85
+    model.txtTimeRangeEnd = createTextField(pAdvanced,[thirdColumn thirdRow 0.25*frameWidth controlElementHeight],'45.0');
86
+    
87
+
88
+    model.txtTemporalResolution = createTextField(pAdvanced,[thirdColumn fourthRow 0.18*frameWidth controlElementHeight],'');
89
+        set(model.txtTemporalResolution,'Enable','inactive');
90
+    try
91
+        tr = evalin('base','SPM.xsDes.Interscan_interval(1:end-3)');
92
+        set(model.txtTemporalResolution,'String',tr);
93
+    catch
94
+        btnParseTemporalResolution = createButton(pAdvanced,[fourthColumn fourthRow 0.15*frameWidth controlElementHeight],'TR','parse TR',model.txtTemporalResolution);
95
+    end
96
+    model.txtTemporalResolutionFactor = createTextField(pAdvanced,[secondColumn fourthRow 0.25*frameWidth controlElementHeight],'0.5');
97
+   
98
+    %Display
99
+    firstColumn  = 0.00*frameWidth;
100
+    secondColumn = 0.33*frameWidth;
101
+    thirdColumn  = 0.66*frameWidth;
102
+    
103
+    firstRow    = 4*controlElementHeight;
104
+    secondRow   = 3*controlElementHeight;
105
+    thirdRow    = 2*controlElementHeight;
106
+    fourthRow   = 0.5*controlElementHeight;
107
+    
108
+    lAxisUpper      = createLabel(pDisplay, [firstColumn firstRow 0.33*frameWidth controlElementHeight],'Y-Axis Upper Bound');
109
+    lAxisLower      = createLabel(pDisplay, [firstColumn secondRow 0.33*frameWidth controlElementHeight],'Y-Axis Lower Bound');
110
+    lColorScheme    = createLabel(pDisplay, [firstColumn thirdRow 0.33*frameWidth controlElementHeight],'Color Scheme');
111
+    lShowLegend     = createLabel(pDisplay, [secondColumn+0.33*frameWidth*0.2 fourthRow 0.33*frameWidth controlElementHeight],'Show Legend');
112
+    lShowFiltered   = createLabel(pDisplay, [thirdColumn+0.33*frameWidth*0.2 fourthRow 0.33*frameWidth controlElementHeight],'Show Filtered');
113
+    
114
+    model.txtYAxisUpper = createTextField(pDisplay,[secondColumn firstRow 0.33*frameWidth controlElementHeight],'0');
115
+    model.txtYAxisLower = createTextField(pDisplay,[secondColumn secondRow 0.33*frameWidth controlElementHeight],'0');
116
+    
117
+    model.colorScheme = createDropDown(pDisplay,[secondColumn thirdRow 0.33*frameWidth controlElementHeight],defaults.tools.psth4spm.colorschemeSelectionModel);
118
+    
119
+    model.chkShowLegend = uicontrol(pDisplay,'Position',[secondColumn fourthRow 0.33*frameWidth*0.1 controlElementHeight],'Style','checkbox','Value',1);
120
+    model.chkShowUnfiltered = uicontrol(pDisplay,'Position',[thirdColumn fourthRow 0.33*frameWidth*0.1 controlElementHeight],'Style','checkbox','Value',1);
121
+
122
+    set(btnRunButton,'Callback',{@cbRunPSTH,model});
123
+    set(frame,'Visible','on');
124
+end
125
+
126
+% this is a function callback
127
+function cbToggleEnableTarget(src,eventData,target)
128
+    if(strcmp(get(target,'Enable'),'off'))
129
+%         display('is off. set on');
130
+        set(target,'Enable','on');
131
+    else
132
+%         display('is on, set off');
133
+        set(target,'Enable','off');
134
+    end
135
+end
136
+
137
+function cbParseVariable(src,evnt,target)
138
+% display('button pressed');
139
+    switch(get(src,'Tag'))
140
+        case 'hReg'
141
+            pos = num2str(evalin('base','spm_XYZreg(''GetCoords'',hReg)')');
142
+            set(target,'String',pos);
143
+        case 'TR'
144
+            tr = evalin('base','SPM.xsDes.Interscan_interval(1:end-3)');
145
+            set(target,'String',tr);
146
+%             set(src,'Enable','off');
147
+            set(target,'Visible','on');
148
+        otherwise 
149
+            display(['no parse Rule for Button Tagged' get(src,'Tag')]);
150
+    end
151
+end
152
+
153
+function label = createLabel(parent,  pos, labelText)
154
+    label = uicontrol(parent,'Style','text','String',labelText,'Position',pos);
155
+    set(label,'HorizontalAlignment','left');
156
+    set(label,'Units','characters');
157
+%     set(label,'BackgroundColor','r');
158
+end
159
+
160
+function btn = createButton(parent,pos,tag,labelText,cbArgs)
161
+    btn = uicontrol(parent,'Position',pos,'String',labelText,'tag',tag);
162
+     set(btn,'Callback',{@cbParseVariable,cbArgs});
163
+%     set(btn,'BackgroundColor','b');
164
+end
165
+
166
+function txt = createTextField(parent,pos,model)
167
+    txt = uicontrol(parent,'Style','edit','String',model,'Position',pos);
168
+    set(txt,'BackgroundColor','w');
169
+end
170
+
171
+function drpField = createDropDown(parent,pos,selectionModel)
172
+ drpField = uicontrol(parent,'Style','popupmenu','Position',pos);
173
+  set(drpField,'String',selectionModel.Strings);
174
+  set(drpField,'BackgroundColor','w');
175
+end
176
+
177
+
178
+function cbRunPSTH(src,evnt,model)
179
+
180
+    % TODO test parameter values
181
+    
182
+    if isSane(model)
183
+        set(0,'userdata',model);
184
+%         set(src,'Enable','off');
185
+        evalin('base','runPSTH4SPM(SPM)');
186
+%         set(src,'Enable','on');
187
+    else
188
+        %todo error beep!
189
+        error('spmtoolbox:SVMCrossVal:paramcheck','please verify all parameters');
190
+    end
191
+                                        
192
+end
193
+
194
+
195
+
196
+
197
+
... ...
@@ -0,0 +1,18 @@
1
+%timePointToImageNumber type is optional
2
+function imgNumber = timePointToImageNumber(timepoint, type)% timepoint in ms
3
+    switch type
4
+        case 's'
5
+            imgNumber = timePointToImageNumber(timepoint*1000,'ms');
6
+            return;
7
+        case 'ms'
8
+            imageTimeResolution   = 2000; %ms
9
+            imgNumber = round(timepoint/imageTimeResolution);
10
+            return;
11
+        case 'image'
12
+            imgNumber = timepoint;
13
+            return;
14
+        otherwise
15
+            imgNumber = timePointToImageNumber(timepoint,'ms');
16
+            return;
17
+    end
18
+end
0 19
\ No newline at end of file
1 20