Skip to content

Commit 75823f1

Browse files
authored
Merge pull request #136 from H-Dempsey/master
Fixing an issue with agreement_weighted
2 parents 61fe288 + 8495090 commit 75823f1

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

bct/algorithms/clustering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def agreement_weighted(ci, wts):
7373
7474
Parameters
7575
----------
76-
ci : MxN np.ndarray
76+
ci : NxM np.ndarray
7777
set of M (possibly degenerate) partitions of N nodes
7878
wts : Mx1 np.ndarray
7979
relative weight of each partition
@@ -84,12 +84,12 @@ def agreement_weighted(ci, wts):
8484
weighted agreement matrix
8585
'''
8686
ci = np.array(ci)
87-
m, n = ci.shape
87+
n, m = ci.shape
8888
wts = np.array(wts) / np.sum(wts)
8989

9090
D = np.zeros((n, n))
9191
for i in range(m):
92-
d = dummyvar(ci[i, :].reshape(1, n))
92+
d = dummyvar(ci[:, i].reshape(n, 1))
9393
D += np.dot(d, d.T) * wts[i]
9494
return D
9595

test/clustering_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,26 @@ def test_transitivity_bd():
105105

106106

107107
def test_agreement_weighted():
108-
# this function is very hard to use or interpret results from
109-
pass
108+
# Test whether agreement gives the same results as
109+
# agreement_weighted when the weights are all the same
110+
ci = np.array([[1, 1, 2, 2, 3],
111+
[1, 2, 2, 3, 3],
112+
[1, 1, 2, 3, 3]]).T
113+
wts = np.ones(ci.shape[1])
114+
115+
D_agreement = bct.agreement(ci)
116+
D_weighted = bct.agreement_weighted(ci, wts)
117+
118+
# Undo the normalization and fill the diagonal with zeros
119+
# in D_weighted to get the same result as D_agreement
120+
D_weighted = D_weighted * ci.shape[1]
121+
np.fill_diagonal(D_weighted, 0)
122+
123+
print('agreement matrix:')
124+
print(D_agreement)
125+
print('weighted agreement matrix:')
126+
print(D_weighted)
127+
assert (D_agreement == D_weighted).all()
110128

111129
def test_agreement():
112130
# Case 1: nodes > partitions

0 commit comments

Comments
 (0)