Browse code

starting som prediction fine-tuned class-performance visualisation

git-svn-id: https://svn.discofish.de/MATLAB/spmtoolbox/SVMCrossVal@112 83ab2cfd-5345-466c-8aeb-2b2739fb922d

Christoph Budziszewski authored on21/01/2009 16:34:25
Showing1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,145 @@
1
+function [codes,clusters,err] = som_kmeans(method, D, k, epochs, verbose)
2
+
3
+% SOM_KMEANS K-means algorithm.
4
+%
5
+% [codes,clusters,err] = som_kmeans(method, D, k, [epochs], [verbose])
6
+%
7
+%  Input and output arguments ([]'s are optional):  
8
+%    method     (string) k-means algorithm type: 'batch' or 'seq'
9
+%    D          (matrix) data matrix
10
+%               (struct) data or map struct
11
+%    k          (scalar) number of centroids
12
+%    [epochs]   (scalar) number of training epochs
13
+%    [verbose]  (scalar) if <> 0 display additonal information
14
+%
15
+%    codes      (matrix) codebook vectors
16
+%    clusters   (vector) cluster number for each sample
17
+%    err        (scalar) total quantization error for the data set
18
+%
19
+% See also KMEANS_CLUSTERS, SOM_MAKE, SOM_BATCHTRAIN, SOM_SEQTRAIN.
20
+
21
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
22
+% Function has been renamed by Kimmo Raivio, because matlab65 also have 
23
+% kmeans function 1.10.02
24
+%% input arguments
25
+
26
+if isstruct(D), 
27
+    switch D.type, 
28
+    case 'som_map', data = D.codebook; 
29
+    case 'som_data', data = D.data; 
30
+    end 
31
+else 
32
+    data = D; 
33
+end
34
+[l dim]   = size(data);
35
+
36
+if nargin < 4 | isempty(epochs) | isnan(epochs), epochs = 100; end
37
+if nargin < 5, verbose = 0; end
38
+
39
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
40
+%% action
41
+
42
+rand('state', sum(100*clock)); % init rand generator
43
+
44
+lr = 0.5;                      % learning rate for sequential k-means
45
+temp      = randperm(l);
46
+centroids = data(temp(1:k),:);
47
+res       = zeros(k,l);
48
+clusters  = zeros(1, l);
49
+
50
+if dim==1, 
51
+    [codes,clusters,err] = scalar_kmeans(data,k,epochs); 
52
+    return; 
53
+end
54
+
55
+switch method
56
+ case 'seq',
57
+  len = epochs * l;
58
+  l_rate = linspace(lr,0,len);
59
+  order  = randperm(l);
60
+  for iter = 1:len
61
+    x  = D(order(rem(iter,l)+1),:);                   
62
+    dx = x(ones(k,1),:) - centroids; 
63
+    [dist nearest] = min(sum(dx.^2,2)); 
64
+    centroids(nearest,:) = centroids(nearest,:) + l_rate(iter)*dx(nearest,:);
65
+  end
66
+  [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
67
+			 ones(l, 1) * sum((centroids.^2)',1) - ...
68
+			 2.*(data*(centroids')))');
69
+
70
+ case 'batch',
71
+  iter      = 0;
72
+  old_clusters = zeros(k, 1);
73
+  while iter<epochs
74
+    
75
+    [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
76
+			   ones(l, 1) * sum((centroids.^2)',1) - ...
77
+			   2.*(data*(centroids')))');
78
+
79
+    for i = 1:k
80
+      f = find(clusters==i);
81
+      s = length(f);
82
+      if s, centroids(i,:) = sum(data(f,:)) / s; end
83
+    end
84
+
85
+    if iter
86
+      if sum(old_clusters==clusters)==0
87
+	if verbose, fprintf(1, 'Convergence in %d iterations\n', iter); end
88
+	break; 
89
+      end
90
+    end
91
+
92
+    old_clusters = clusters;
93
+    iter = iter + 1;
94
+  end
95
+  
96
+  [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
97
+			  ones(l, 1) * sum((centroids.^2)',1) - ...
98
+			  2.*(data*(centroids')))');
99
+ otherwise,
100
+  fprintf(2, 'Unknown method\n');
101
+end
102
+
103
+err = 0;
104
+for i = 1:k
105
+  f = find(clusters==i);
106
+  s = length(f);
107
+  if s, err = err + sum(sum((data(f,:)-ones(s,1)*centroids(i,:)).^2,2)); end
108
+end
109
+
110
+codes = centroids;
111
+return; 
112
+
113
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
114
+
115
+function [y,bm,qe] = scalar_kmeans(x,k,maxepochs)
116
+
117
+    nans = ~isfinite(x);
118
+    x(nans) = []; 
119
+    n = length(x); 
120
+    mi = min(x); ma = max(x)
121
+    y = linspace(mi,ma,k)'; 
122
+    bm = ones(n,1); 
123
+    bmold = zeros(n,1); 
124
+    i = 0; 
125
+    while ~all(bm==bmold) & i<maxepochs, 
126
+        bmold  = bm;  
127
+        [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]);
128
+        y      = full(sum(sparse(bm,1:n,x,k,n),2));
129
+        zh     = (c(1:end-1)==0);
130
+        y(~zh) = y(~zh)./c(~zh);
131
+        inds   = find(zh)';
132
+        for j=inds, if j==1, y(j) = mi; else y(j) = y(j-1) + eps; end, end         
133
+        i=i+1;
134
+    end
135
+    if i==maxepochs, [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); end
136
+    if nargout>2, qe = sum(abs(x-y(bm)))/n; end
137
+    if any(nans),
138
+        notnan = find(~nans); n = length(nans);
139
+        y  = full(sparse(notnan,1,y ,n,1)); y(nans)  = NaN;  
140
+        bm = full(sparse(notnan,1,bm,n,1)); bm(nans) = NaN;
141
+        if nargout>2, qe = full(sparse(notnan,1,qe,n,1)); qe(nans) = NaN; end
142
+    end 
143
+       
144
+    return; 
145
+