Browse code

starting som prediction fine-tuned class-performance visualisation

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@112 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on21/01/2009 16:34:25
Showing1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,251 @@
1
+function [Class,P]=knn_old(Data, Proto, proto_class, K)
2
+
3
+%KNN_OLD A K-nearest neighbor classifier using Euclidean distance 
4
+%
5
+% [Class,P]=knn_old(Data, Proto, proto_class, K)
6
+%
7
+%  [sM_class,P]=knn_old(sM, sData, [], 3);
8
+%  [sD_class,P]=knn_old(sD, sM, class);
9
+%  [class,P]=knn_old(data, proto, class);
10
+%  [class,P]=knn_old(sData, sM, class,5);
11
+%
12
+%  Input and output arguments ([]'s are optional): 
13
+%   Data   (matrix) size Nxd, vectors to be classified (=classifiees)
14
+%          (struct) map or data struct: map codebook vectors or
15
+%                   data vectors are considered as classifiees.
16
+%   Proto  (matrix) size Mxd, prototype vector matrix (=prototypes)
17
+%          (struct) map or data struct: map codebook vectors or
18
+%                   data vectors are considered as prototypes.
19
+%   [proto_class] (vector) size Nx1, integers 1,2,...,k indicating the
20
+%                   classes of corresponding protoptypes, default: see the 
21
+%                   explanation below. 
22
+%   [K]    (scalar) the K in KNN classifier, default is 1
23
+% 
24
+%   Class  (matrix) size Nx1, vector of 1,2, ..., k indicating the class 
25
+%                   desicion according to the KNN rule
26
+%   P      (matrix) size Nxk, the relative amount of prototypes of 
27
+%                   each class among the K closest prototypes for
28
+%                   each classifiee.
29
+%
30
+% If 'proto_class' is _not_ given, 'Proto' _must_ be a labeled SOM
31
+% Toolbox struct. The label of the data vector or the first label of
32
+% the map model vector is considered as class label for th prototype
33
+% vector. In this case the output 'Class' is a copy of 'Data' (map or
34
+% data struct) relabeled according to the classification.  If input
35
+% argument 'proto_class' _is_ given, the output argument 'Class' is
36
+% _always_ a vector of integers 1,2,...,k indiacating the class.
37
+%
38
+% If there is a tie between representatives of two or more classes
39
+% among the K closest neighbors to the classifiee, the class is
40
+% selected randomly among these candidates.
41
+%
42
+% IMPORTANT
43
+% 
44
+% ** Even if prototype vectors are given in a map struct the mask _is not 
45
+%    taken into account_ when calculating Euclidean distance
46
+% ** The function calculates the total distance matrix between all
47
+%    classifiees and prototype vectors. This results to an MxN matrix; 
48
+%    if N is high it is recommended to divide the matrix 'Data'
49
+%    (the classifiees) into smaller sets in order to avoid memory
50
+%    overflow or swapping. Also, if K>1 this function uses 'sort' which is
51
+%    considerably slower than 'max' which is used for K==1.
52
+%
53
+% See also KNN, SOM_LABEL, SOM_AUTOLABEL
54
+
55
+% Contributed to SOM Toolbox 2.0, February 11th, 2000 by Johan Himberg
56
+% Copyright (c) by Johan Himberg
57
+% http://www.cis.hut.fi/projects/somtoolbox/
58
+
59
+% Version 2.0beta Johan 040200
60
+
61
+%% Init %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
62
+% This must exist later
63
+classnames='';
64
+
65
+% Check K 
66
+if nargin<4 | isempty(K),
67
+  K=1;
68
+end
69
+
70
+if ~vis_valuetype(K,{'1x1'})
71
+  error('Value for K must be a scalar.');
72
+end
73
+
74
+% Take data from data or map struct
75
+
76
+if isstruct(Data);
77
+  if isfield(Data,'type') & ischar(Data.type),
78
+    ;
79
+  else
80
+    error('Invalid map/data struct?');
81
+  end
82
+  switch Data.type
83
+   case 'som_map'
84
+    data=Data.codebook;
85
+   case 'som_data'
86
+    data=Data.data;
87
+  end
88
+else
89
+  % is already a matrix
90
+  data=Data;
91
+end
92
+
93
+% Take prototype vectors from prototype struct
94
+
95
+if isstruct(Proto),
96
+  
97
+  if isfield(Proto,'type') & ischar(Proto.type),
98
+    ;
99
+  else
100
+    error('Invalid map/data struct?');
101
+  end
102
+  switch Proto.type
103
+   case 'som_map'
104
+    proto=Proto.codebook;
105
+   case 'som_data'
106
+    proto=Proto.data;
107
+  end
108
+else
109
+  % is already a matrix
110
+  proto=Proto; 
111
+end
112
+
113
+% Check that inputs are matrices
114
+if ~vis_valuetype(proto,{'nxm'}) | ~vis_valuetype(data,{'nxm'}),
115
+  error('Prototype or data input not valid.')
116
+end
117
+
118
+% Record data&proto sizes and check their dims 
119
+[N_data dim_data]=size(data); 
120
+[N_proto dim_proto]=size(proto);
121
+if dim_proto ~= dim_data,
122
+  error('Data and prototype vector dimension does not match.');
123
+end
124
+
125
+% Check if the classes are given as labels (no class input arg.)
126
+% if they are take them from prototype struct
127
+
128
+if nargin<3 | isempty(proto_class)
129
+  if ~isstruct(Proto)
130
+    error(['If prototypes are not in labeled map or data struct' ...
131
+	   'class must be given.']);  
132
+    % transform to interger (numerical) class labels
133
+  else
134
+    [proto_class,classnames]=class2num(Proto.labels); 
135
+  end
136
+end
137
+
138
+% Check class label vector: must be numerical and of integers
139
+if ~vis_valuetype(proto_class,{[N_proto 1]});
140
+  error(['Class vector is invalid: has to be a N-of-data_rows x 1' ...
141
+	 ' vector of integers']);
142
+elseif sum(fix(proto_class)-proto_class)~=0
143
+  error('Class labels in vector ''Class'' must be integers.');
144
+end
145
+
146
+% Find all class labels
147
+ClassIndex=unique(proto_class);
148
+N_class=length(ClassIndex); % number of different classes  
149
+
150
+% Calculate euclidean distances between classifiees and prototypes
151
+d=distance(proto,data);
152
+
153
+%%%% Classification %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
154
+
155
+if K==1,   % sort distances only if K>1
156
+  
157
+  % 1NN
158
+  % Select the closest prototype
159
+  [tmp,proto_index]=min(d);
160
+  class=proto_class(proto_index);
161
+
162
+else 
163
+  
164
+  % Sort the prototypes for each classifiee according to distance
165
+  [tmp,proto_index]=sort(d);
166
+  
167
+  %% Select K closest prototypes
168
+  proto_index=proto_index(1:K,:);
169
+  knn_class=proto_class(proto_index);
170
+  for i=1:N_class,
171
+    classcounter(i,:)=sum(knn_class==ClassIndex(i));
172
+  end
173
+  
174
+  %% Vote between classes of K neighbors 
175
+  [winner,vote_index]=max(classcounter);
176
+  
177
+  %% Handle ties
178
+  
179
+  % set index to clases that got as amuch votes as winner
180
+  
181
+  equal_to_winner=(repmat(winner,N_class,1)==classcounter);
182
+  
183
+  % set index to ties
184
+  tie_index=find(sum(equal_to_winner)>1); % drop the winner from counter 
185
+  
186
+  % Go through equal classes and reset vote_index randomly to one
187
+  % of them 
188
+  
189
+  for i=1:length(tie_index),
190
+    tie_class_index=find(equal_to_winner(:,tie_index(i)));
191
+    fortuna=randperm(length(tie_class_index));
192
+    vote_index(tie_index(i))=tie_class_index(fortuna(1));
193
+  end
194
+  
195
+  class=ClassIndex(vote_index);
196
+end
197
+
198
+%% Build output %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
199
+
200
+% Relative amount of classes in K neighbors for each classifiee
201
+
202
+if K==1,
203
+  P=zeros(N_data,N_class);
204
+  if nargout>1,
205
+    for i=1:N_data,
206
+      P(i,ClassIndex==class(i))=1;
207
+    end
208
+  end
209
+else
210
+  P=classcounter'./K;
211
+end
212
+
213
+% xMake class names to struct if they exist
214
+if ~isempty(classnames),
215
+  Class=Data;
216
+  for i=1:N_data,
217
+    Class.labels{i,1}=classnames{class(i)};
218
+  end
219
+else
220
+  Class=class;
221
+end
222
+
223
+
224
+%%% Subfunctions %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
225
+
226
+function [nos,names] = class2num(class)
227
+
228
+% Change string labels in map/data struct to integer numbers
229
+
230
+names = {};
231
+nos = zeros(length(class),1);
232
+for i=1:length(class)
233
+  if ~isempty(class{i}) & ~any(strcmp(class{i},names))
234
+    names=cat(1,names,class(i));
235
+  end
236
+end
237
+
238
+tmp_nos = (1:length(names))';
239
+for i=1:length(class)
240
+  if ~isempty(class{i})
241
+    nos(i,1) = find(strcmp(class{i},names));    
242
+  end
243
+end
244
+
245
+function d=distance(X,Y);
246
+
247
+% Euclidean distance matrix between row vectors in X and Y
248
+
249
+U=~isnan(Y); Y(~U)=0;
250
+V=~isnan(X); X(~V)=0;
251
+d=X.^2*U'+V*Y'.^2-2*X*Y';