Browse code

starting som prediction fine-tuned class-performance visualisation

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@112 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on21/01/2009 16:34:25
Showing1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,215 @@
1
+function codebook = lvq3(codebook,data,rlen,alpha,win,epsilon)
2
+
3
+%LVQ3 trains codebook with LVQ3 -algorithm
4
+%
5
+% sM = lvq3(sM,D,rlen,alpha,win,epsilon)
6
+%
7
+%   sM = lvq3(sM,sD,50*length(sM.codebook),0.05,0.2,0.3);
8
+%
9
+%  Input and output arguments: 
10
+%   sM      (struct) map struct, the class information must be 
11
+%                    present on the first column of .labels field
12
+%   D       (struct) data struct, the class information must
13
+%                    be present on the first column of .labels field
14
+%   rlen    (scalar) running length
15
+%   alpha   (scalar) learning parameter, e.g. 0.05
16
+%   win     (scalar) window width parameter, e.g. 0.25
17
+%   epsilon (scalar) relative learning parameter, e.g. 0.3
18
+%
19
+%   sM      (struct) map struct, the trained codebook
20
+%
21
+% NOTE: does not take mask into account.
22
+%
23
+% For more help, try 'type lvq3', or check out online documentation.
24
+% See also LVQ1, SOM_SUPERVISED, SOM_SEQTRAIN.
25
+
26
+%%%%%%%%%%%%% DETAILED DESCRIPTION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
27
+%
28
+% lvq3
29
+%
30
+% PURPOSE
31
+%
32
+% Trains codebook with the LVQ3 -algorithm (described below).
33
+%
34
+% SYNTAX
35
+%
36
+% sM = lvq3(sM, data, rlen, alpha, win, epsilon)
37
+%
38
+% DESCRIPTION
39
+%
40
+% Trains codebook with the LVQ3 -algorithm. Codebook contains a number
41
+% of vectors (mi, i=1,2,...,n) and so does data (vectors xj, j=1,2,...k).
42
+% Both vector sets are classified: vectors may have a class (classes are
43
+% set to data- or map -structure's 'labels' -field. For each xj the two 
44
+% closest codebookvectors mc1 and mc2 are searched (euclidean distances
45
+% d1 and d2). xj must fall into the zone of window. That happens if:
46
+%
47
+%    min(d1/d2, d2/d1) > s, where s = (1-win) / (1+win).
48
+%
49
+% If xj belongs to the same class of one of the mc1 and mc1, codebook
50
+% is updated as follows (let mc1 belong to the same class as xj):
51
+%    mc1(t+1) = mc1(t) + alpha * (xj(t) - mc1(t))
52
+%    mc2(t+1) = mc2(t) - alpha * (xj(t) - mc2(t))
53
+% If both mc1 and mc2 belong to the same class as xj, codebook is
54
+% updated as follows:
55
+%    mc1(t+1) = mc1(t) + epsilon * alpha * (xj(t) - mc1(t))
56
+%    mc2(t+1) = mc2(t) + epsilon * alpha * (xj(t) - mc2(t))
57
+% Otherwise updating is not performed.
58
+%
59
+% Argument 'rlen' tells how many times training -sequence is performed.
60
+%
61
+% Argument 'alpha' is recommended to be smaller than 0.1 and argument
62
+% 'epsilon' should be between 0.1 and 0.5.
63
+%
64
+% NOTE: does not take mask into account.
65
+%
66
+% REFERENCES
67
+%
68
+% Kohonen, T., "Self-Organizing Map", 2nd ed., Springer-Verlag, 
69
+%    Berlin, 1995, pp. 181-182.
70
+%
71
+% See also LVQ_PAK from http://www.cis.hut.fi/research/som_lvq_pak.shtml
72
+% 
73
+% REQUIRED INPUT ARGUMENTS
74
+%
75
+%  sM                The data to be trained.
76
+%          (struct)  A map struct.
77
+%
78
+%  data              The data to use in training.
79
+%          (struct)  A data struct.
80
+%
81
+%  rlen    (integer) Running length of LVQ3 -algorithm.
82
+%                    
83
+%  alpha   (float)   Learning rate used in training, e.g. 0.05
84
+%
85
+%  win     (float)   Window length, e.g. 0.25
86
+%  
87
+%  epsilon (float)   Relative learning parameter, e.g. 0.3
88
+%
89
+% OUTPUT ARGUMENTS
90
+%
91
+%  sM          Trained data.
92
+%          (struct)  A map struct.
93
+%
94
+% EXAMPLE
95
+%
96
+%   lab = unique(sD.labels(:,1));         % different classes
97
+%   mu = length(lab)*5;                   % 5 prototypes for each    
98
+%   sM = som_randinit(sD,'msize',[mu 1]); % initial prototypes
99
+%   sM.labels = [lab;lab;lab;lab;lab];    % their classes
100
+%   sM = lvq1(sM,sD,50*mu,0.05);          % use LVQ1 to adjust
101
+%                                         % the prototypes      
102
+%   sM = lvq3(sM,sD,50*mu,0.05,0.2,0.3);  % then use LVQ3 
103
+% 
104
+% SEE ALSO
105
+% 
106
+%  lvq1             Use LVQ1 algorithm for training.
107
+%  som_supervised   Train SOM using supervised training.
108
+%  som_seqtrain     Train SOM with sequential algorithm.
109
+
110
+% Contributed to SOM Toolbox vs2, February 2nd, 2000 by Juha Parhankangas
111
+% Copyright (c) by Juha Parhankangas
112
+% http://www.cis.hut.fi/projects/somtoolbox/
113
+
114
+% Juha Parhankangas 310100 juuso 020200
115
+
116
+NOTFOUND = 1;
117
+
118
+cod = codebook.codebook;
119
+dat = data.data;
120
+
121
+c_class = codebook.labels(:,1);
122
+d_class = data.labels(:,1);
123
+
124
+s = (1-win)/(1+win);
125
+
126
+x = size(dat,1);
127
+y = size(cod,2);
128
+
129
+c_class=class2num(c_class);
130
+d_class=class2num(d_class);
131
+
132
+ONES=ones(size(cod,1),1);
133
+
134
+for t=1:rlen
135
+  fprintf('\rTraining round: %d/%d',t,rlen);
136
+  tmp = NaN*ones(x,y);
137
+ 
138
+  for j=1:x
139
+    flag = 0;
140
+    mj = 0;
141
+    mi = 0;
142
+    no_NaN=find(~isnan(dat(j,:)));
143
+    di=sqrt(sum([cod(:,no_NaN) - ONES*dat(j,no_NaN)].^2,2));
144
+    [foo, ind1] = min(di);
145
+    di(ind1)=Inf;
146
+    [foo,ind2] =  min(di);    
147
+  
148
+    %ind2=ind2+1;
149
+
150
+    if d_class(j) & d_class(j)==c_class(ind1)
151
+      mj = ind1;
152
+      mi = ind2;
153
+      if d_class(j)==c_class(ind2)
154
+        flag = 1;
155
+      end
156
+    elseif d_class(j) & d_class(j)==c_class(ind2)
157
+      mj = ind2;
158
+      mi = ind1;
159
+      if d_class(j)==c_class(ind1)
160
+        flag = 1;
161
+      end
162
+    end
163
+
164
+    if mj & mi
165
+      if flag
166
+        tmp([mj mi],:) = cod([mj mi],:) + epsilon*alpha*...
167
+                       (dat([j j],:) - cod([mj mi],:));
168
+      else
169
+        tmp(mj,:) = cod(mj,:) + alpha * (dat(j,:)-cod(mj,:));
170
+        tmp(mi,:) = cod(mi,:) - alpha * (dat(j,:)-cod(mj,:));
171
+      end
172
+    end  
173
+  end    
174
+  inds = find(~isnan(sum(tmp,2)));
175
+  cod(inds,:) = tmp(inds,:);
176
+end
177
+fprintf(1,'\n');
178
+
179
+sTrain = som_set('som_train','algorithm','lvq3',...
180
+		 'data_name',data.name,...
181
+		 'neigh','',...
182
+		 'mask',ones(y,1),...
183
+		 'radius_ini',NaN,...
184
+		 'radius_fin',NaN,...
185
+		 'alpha_ini',alpha,...
186
+		 'alpha_type','constant',...
187
+		 'trainlen',rlen,...
188
+		 'time',datestr(now,0));
189
+codebook.trainhist(end+1) = sTrain;
190
+
191
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
192
+
193
+function nos = class2num(class)
194
+
195
+names = {};
196
+nos = zeros(length(class),1);
197
+
198
+for i=1:length(class)
199
+  if ~isempty(class{i}) & ~any(strcmp(class{i},names))
200
+    names=cat(1,names,class(i));
201
+  end
202
+end
203
+
204
+tmp_nos = (1:length(names))';
205
+
206
+for i=1:length(class)
207
+  if ~isempty(class{i})
208
+    nos(i,1) = find(strcmp(class{i},names));    
209
+  end
210
+end
211
+
212
+
213
+
214
+
215
+