Skip to content

Commit e42bdb5

Browse files
committed
extract stabilit clusters
1 parent 82ed4c0 commit e42bdb5

File tree

2 files changed

+296
-5
lines changed

2 files changed

+296
-5
lines changed

scripts/builtin/hdbscan.dml

Lines changed: 160 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ m_hdbscan = function(Matrix[Double] X, Integer minPts = 5, Integer minClSize = -
5858

5959
coreDistances = matrix(0, rows=n, cols=1)
6060
for(i in 1:n) {
61-
kthDist = computeKthSmallest(distances[i,], minPts)
61+
kthDist = computeKthSmallest(t(distances[i,]), minPts) # Add t() here!
6262
coreDistances[i] = kthDist
6363
}
6464

@@ -68,7 +68,8 @@ m_hdbscan = function(Matrix[Double] X, Integer minPts = 5, Integer minClSize = -
6868

6969
[hierarchy, clusterSizes] = buildHierarchy(mstEdges, mstWeights, n)
7070

71-
# TODO: get stable cluster with stability score
71+
[clusterMems, stabilities] = extractStableClusters(hierarchy, mstWeights, n, minClSize)
72+
7273
# TODO: build cluster model
7374

7475
# temp dummy values
@@ -184,9 +185,9 @@ union = function(Matrix[Double] parent, Matrix[Double] rank,
184185
buildHierarchy = function(Matrix[Double] edges, Matrix[Double] weights, Integer n)
185186
return (Matrix[Double] hierarchy, Matrix[Double] sizes)
186187
{
187-
# sort edges by weight in ascending order
188-
# to build the hierarchy from dense cores outward
189-
sorted = order(target=weights, by=1, decreasing=FALSE)
188+
# create indexed weights to preserve original positions after sorting
189+
indexedWeights = cbind(seq(1, nrow(weights)), weights)
190+
sorted = order(target=indexedWeights, by=2, decreasing=FALSE)
190191

191192
# parent[i] = i, meaning each point is its own parent in the beginning
192193
parent = seq(1, n)
@@ -241,3 +242,157 @@ buildHierarchy = function(Matrix[Double] edges, Matrix[Double] weights, Integer
241242
}
242243
}
243244
}
245+
246+
getLeafDescendants = function(Matrix[Double] hierarchy, Integer n, Integer nodeId)
247+
return (Matrix[Double] leaves)
248+
{
249+
if(nodeId <= n) {
250+
leaves = matrix(nodeId, rows=1, cols=1)
251+
} else {
252+
mergeIdx = nodeId - n
253+
left = as.integer(as.scalar(hierarchy[mergeIdx,1]))
254+
right = as.integer(as.scalar(hierarchy[mergeIdx,2]))
255+
256+
leftLeaves = getLeafDescendants(hierarchy, n, left)
257+
rightLeaves = getLeafDescendants(hierarchy, n, right)
258+
259+
leaves = rbind(leftLeaves, rightLeaves)
260+
}
261+
}
262+
263+
extractStableClusters = function(Matrix[Double] hierarchy, Matrix[Double] weights,
264+
Integer n, Integer minClSize)
265+
return (Matrix[Double] labels, Matrix[Double] stabilities)
266+
{
267+
numMerges = n - 1 # hierarchical tree over n points has exactly n-1 merge events
268+
numNodes = 2*n - 1 # total nodes in the dendogram
269+
270+
# convert distances to lambda (density)
271+
lambda = matrix(0, rows=numMerges, cols=1)
272+
for(i in 1:numMerges) {
273+
dist = as.scalar(hierarchy[i,3])
274+
if(dist > 0) {
275+
lambda[i,1] = 1.0 / dist
276+
} else {
277+
lambda[i,1] = 1e15
278+
}
279+
}
280+
281+
lambda_birth = matrix(1e15, rows=numNodes, cols=1)
282+
lambda_death = matrix(0, rows=numNodes, cols=1)
283+
cluster_size = matrix(0, rows=numNodes, cols=1)
284+
285+
# initialize the leaf nodes to have cluster size 1
286+
for(i in 1:n) {
287+
cluster_size[i,1] = 1
288+
}
289+
290+
for(i in 1:numMerges) {
291+
left = as.integer(as.scalar(hierarchy[i,1]))
292+
right = as.integer(as.scalar(hierarchy[i,2]))
293+
newId = n + i
294+
merge_lambda = as.scalar(lambda[i,1])
295+
296+
# cluster newId starts existing as its own cluster at this density level
297+
# and that's why the children get their det set at the same density
298+
lambda_birth[newId,1] = merge_lambda
299+
lambda_death[left,1] = merge_lambda
300+
lambda_death[right,1] = merge_lambda
301+
cluster_size[newId,1] = as.scalar(cluster_size[left,1]) + as.scalar(cluster_size[right,1])
302+
}
303+
304+
# root cluster exists all the way
305+
rootId = 2*n - 1
306+
lambda_death[rootId,1] = 0
307+
308+
# compute own stability for each internal node
309+
# NOTE: If the cluster is big enough, we assign stability.
310+
# The more long-lived it is (birth - death) and
311+
# the larger it is, the more stable it is.
312+
stability = matrix(0, rows=numNodes, cols=1)
313+
for(nodeId in (n+1):numNodes) {
314+
size = as.scalar(cluster_size[nodeId,1])
315+
birth = as.scalar(lambda_birth[nodeId,1])
316+
death = as.scalar(lambda_death[nodeId,1])
317+
if(size >= minClSize) {
318+
stability[nodeId,1] = size * (birth - death)
319+
}
320+
}
321+
322+
# compute subtree stability (best achievable from each subtree)
323+
subtree_stability = matrix(0, rows=numNodes, cols=1)
324+
325+
# leaf nodes have 0 subtree stability
326+
for(i in 1:n) {
327+
subtree_stability[i,1] = 0
328+
}
329+
330+
# process merges in order (bottom-up)
331+
for(i in 1:numMerges) {
332+
nodeId = n + i
333+
left = as.integer(as.scalar(hierarchy[i,1]))
334+
right = as.integer(as.scalar(hierarchy[i,2]))
335+
336+
children_subtree = as.scalar(subtree_stability[left,1]) + as.scalar(subtree_stability[right,1])
337+
own_stab = as.scalar(stability[nodeId,1])
338+
339+
# Subtree stability is the best we can achieve from this subtree
340+
if(children_subtree > own_stab) {
341+
subtree_stability[nodeId,1] = children_subtree
342+
} else {
343+
subtree_stability[nodeId,1] = own_stab
344+
}
345+
}
346+
347+
# select clusters
348+
selected = matrix(0, rows=numNodes, cols=1)
349+
selected[rootId,1] = 1
350+
351+
i = numMerges
352+
while(i >= 1) {
353+
nodeId = n + i
354+
355+
if(as.scalar(selected[nodeId,1]) == 1) {
356+
left = as.integer(as.scalar(hierarchy[i,1]))
357+
right = as.integer(as.scalar(hierarchy[i,2]))
358+
359+
children_subtree = as.scalar(subtree_stability[left,1]) + as.scalar(subtree_stability[right,1])
360+
own_stab = as.scalar(stability[nodeId,1])
361+
parent_size = as.scalar(cluster_size[nodeId,1])
362+
363+
# select children if they have higher subtree stability
364+
if(parent_size < minClSize | children_subtree > own_stab) {
365+
selected[nodeId,1] = 0
366+
selected[left,1] = 1
367+
selected[right,1] = 1
368+
}
369+
}
370+
371+
i = i - 1
372+
}
373+
374+
# assign labels
375+
labels = matrix(-1, rows=n, cols=1)
376+
cluster_id = 1
377+
378+
for(nodeId in 1:numNodes) {
379+
if(as.scalar(selected[nodeId,1]) == 1) {
380+
size = as.scalar(cluster_size[nodeId,1])
381+
382+
if(size >= minClSize) {
383+
leaves = getLeafDescendants(hierarchy, n, nodeId)
384+
385+
for(j in 1:nrow(leaves)) {
386+
leafId = as.integer(as.scalar(leaves[j,1]))
387+
if(leafId >= 1 & leafId <= n) {
388+
labels[leafId,1] = cluster_id
389+
}
390+
}
391+
392+
cluster_id = cluster_id + 1
393+
}
394+
}
395+
}
396+
397+
stabilities = stability
398+
}

test_extract_stabe_clusters.dml

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
source("scripts/builtin/hdbscan.dml") as hdb
2+
3+
# 6 point example with clear cluster structure
4+
# 1,2,3 form tight cluster A, 4,5,6 form tight cluster B, A and B are far apart (10)
5+
6+
n = 6
7+
distances = matrix(10, rows=n, cols=n)
8+
9+
# points 1,2,3 (cluster A)
10+
distances[1,2] = 1
11+
distances[2,1] = 1
12+
13+
distances[1,3] = 2
14+
distances[3,1] = 2
15+
16+
distances[2,3] = 1
17+
distances[3,2] = 1
18+
19+
# points 4,5,6 (cluster B)
20+
distances[4,5] = 1
21+
distances[5,4] = 1
22+
23+
distances[4,6] = 2
24+
distances[6,4] = 2
25+
26+
distances[5,6] = 1
27+
distances[6,5] = 1
28+
29+
# zero diagonal (to self)
30+
for(i in 1:n) {
31+
distances[i,i] = 0
32+
}
33+
34+
35+
print("\nBuilding MST")
36+
expected_edges = matrix("2 1 3 2 6 3 5 6 4 5", rows=5, cols=2)
37+
expected_weights = matrix("1 1 10 1 1", rows=5, cols=1)
38+
[edges, weights] = hdb::buildMST(distances, n)
39+
edges_match = (min(edges == expected_edges) == 1)
40+
weights_match = (min(weights == expected_weights) == 1)
41+
if (edges_match) {
42+
print("Pass: edges match.")
43+
} else {
44+
print("Fail: edges don't match.")
45+
}
46+
if (weights_match) {
47+
print("Pass: weights match.")
48+
} else {
49+
print("Fail: weights don't match.")
50+
}
51+
print("MST edges:\n" + toString(edges))
52+
print("MST weights:\n " + toString(weights))
53+
54+
55+
print("\nBuilding hierarchy")
56+
[hierarchy, sizes] = hdb::buildHierarchy(edges, weights, n)
57+
expected_hierarchy = matrix("2 1 1 3 7 1 5 6 1 4 9 1 10 8 10", rows=5, cols=3)
58+
expected_sizes = matrix("2 3 2 3 6", rows=5, cols=1)
59+
hierachy_match = (min(hierarchy == expected_hierarchy) == 1)
60+
sizes_match = (min(sizes == expected_sizes) == 1)
61+
if (hierachy_match) {
62+
print("Pass: hierachy mathes.")
63+
} else {
64+
print("Fail: hierarchy doesn't match.")
65+
}
66+
if (sizes_match) {
67+
print("Pass: sizes match.")
68+
} else {
69+
print("Fail: sizes don't match.")
70+
}
71+
print("Hierarchy:\n" + toString(hierarchy))
72+
print("Sizes:\n" + toString(sizes))
73+
74+
75+
print("\nExtracting stable clusters with minClSize=2")
76+
[labels, stabilities] = hdb::extractStableClusters(hierarchy, weights, n, 2)
77+
expected_labels = matrix("1 1 1 2 2 2", rows=6, cols=1)
78+
expected_stabilites = matrix("0 0 0 0 0 0 0 2.7 0 2.7 0.6", rows=n*2-1, cols=1)
79+
labels_match = (min(labels == expected_labels) == 1)
80+
tolerance = 1e-10
81+
stabilities_match = max(abs(stabilities - expected_stabilites)) < tolerance
82+
if (labels_match) {
83+
print("Pass: labels match.")
84+
} else {
85+
print("Fail: labels don't match.")
86+
}
87+
if (stabilities_match) {
88+
print("Pass: stabilities match.")
89+
} else {
90+
print("Fail: stabilities don't match.")
91+
}
92+
print("Cluster labels:\n" + toString(labels))
93+
print("Top stabilities:\n" + toString(stabilities))
94+
95+
96+
97+
# check results (we have some duplicate logic here, but anyway)
98+
num_clusters = max(labels)
99+
num_noise = sum(labels == -1)
100+
101+
print("\nNumber of clusters found: " + num_clusters)
102+
print("Number of noise points: " + num_noise)
103+
104+
# should find 2 clusters
105+
test1 = (num_clusters == 2)
106+
print("Found 2 clusters: " + test1)
107+
108+
# no points should be noise
109+
test2 = (num_noise == 0)
110+
print("No noise points: " + test2)
111+
112+
# points 1,2,3 should be in same cluster
113+
label1 = as.scalar(labels[1])
114+
label2 = as.scalar(labels[2])
115+
label3 = as.scalar(labels[3])
116+
test3 = (label1 == label2) & (label2 == label3) & (label1 > 0)
117+
print("points 1,2,3 in same cluster: " + test3)
118+
119+
# points 4,5,6 should be in same
120+
label4 = as.scalar(labels[4])
121+
label5 = as.scalar(labels[5])
122+
label6 = as.scalar(labels[6])
123+
test4 = (label4 == label5) & (label5 == label6) & (label4 > 0)
124+
print("Points 4,5,6 in same cluster: " + test4)
125+
126+
# clusters A and B should be different
127+
test5 = (label1 != label4)
128+
print("Two clusters are different: " + test5)
129+
130+
test_pass = edges_match & weights_match & hierachy_match & sizes_match & labels_match & stabilities_match & test1 & test2 & test3 & test4 & test5
131+
132+
if(test_pass) {
133+
print("\nAll tests passed\n")
134+
} else {
135+
print("\nTests failed\n")
136+
}

0 commit comments

Comments
 (0)