Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 16 additions & 92 deletions pairing/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,30 @@ def generate_direct_correlation(trj, cutoff=1.0):
return direct_corr


def _generate_indirect_connectivity(direct_corr):
def generate_indirect_connectivity(direct_corr):
"""
Genrate indirect correlation matrix from a direct correlation matrix

Parameters
----------
direct_corr : numpy.ndarray, dtype=np.int32
Direct correlation matrix from which an indirect correlation matrix
will be generated.
direct_corr: np.ndarray, dtype=np.int32
direct correlation matrix

Returns
-------
indirect_corr : numpy.ndarray, dtype=np.int32
Indirect corrlation matrix
indirect: np.ndarray, dtype=np.int32
indirect connectivity matrix
"""

c = deepcopy(direct_corr)
size = np.shape(direct_corr)
if size[0] != size[1]:
raise ValueError('Direct correlation matrix must be square')
length = size[0]

for combo in itertools.combinations([_ for _ in range(length)], 2):
for i in range(length):
vals = [c[i, combo[0]], c[i, combo[1]]]
if vals == [0, 0]:
continue
elif vals == [1, 1]:
intersect = _find_intersection(c[:, combo[0]], c[:, combo[1]])
c[:, combo[0]] = intersect
c[:, combo[1]] = intersect
continue

indirect_corr = c
return indirect_corr
for row in c:
ones = np.where(row == 1)[0]
if len(ones) == 1:
continue
else:
intersect = np.maximum.reduce(c[:, ones].T)
for ele in ones:
c[:, ele] = intersect
indirect = c

return indirect


def generate_clusters(indirect):
Expand Down Expand Up @@ -114,68 +103,3 @@ def analyze_clusters(clusters):
avg = np.mean(cluster_sizes)
stdev = np.std(cluster_sizes)
return avg, stdev


def _find_intersection(a, b):
"""
Find set intersection of two arrays

Parameters
----------
a : array-like
First array to compare
b : array-like
Second array to compare

Returns
-------
intersection : array-like
Set intersection of a and b
"""

intersection = np.maximum(a, b)
return intersection


def _check_validity(c_I):
"""
Check validity of indirect connectivity matrix

Parameters
----------
c_I : np.ndarray
indirect connectivity matrix to test

Returns
-------
Boolean 'True' or 'False'
"""

test_indirect = _generate_indirect_connectivity(c_I)
return (test_indirect == c_I).all()


def new_generate_indirect(direct_corr):
"""
Iteratively call '_generate_indirect_connectivity' and
'_check_validity' to generate valid indirect correlation
matrices

Parameters
----------
direct_corr : np.ndarray
direct correlation matrix

Returns
_______
new_indirect : np.ndarray
indirect connectivity matrix
"""

new_indirect = _generate_indirect_connectivity(
direct_corr)
while not _check_validity(new_indirect):
new_indirect = _generate_indirect_connectivity(
new_indirect)

return new_indirect
41 changes: 4 additions & 37 deletions pairing/tests/test_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,21 @@ def test_sevick1988():
[0, 0, 0, 1, 0],
[1, 1, 1, 0, 1]], dtype=np.int32)

assert (c_I == pairing.pairing._generate_indirect_connectivity(c_D)).all()


def test_check_validity_pass():
c_I = np.asarray([[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[0, 0, 0, 1, 0],
[1, 1, 1, 0, 1]], dtype=np.int32)

assert pairing.pairing._check_validity(c_I)


def test_check_validity_fail():
c_intermediate = np.asarray([[1, 0, 0, 0, 1],
[0, 1, 1, 0, 0],
[1, 1, 1, 0, 1],
[0, 0, 0, 1, 0],
[1, 0, 1, 0, 1]], dtype=np.int32)

assert not pairing.pairing._check_validity(c_intermediate)
assert (c_I == pairing.generate_indirect_connectivity(c_D)).all()


def test_40_atoms():
trj = md.load(get_fn('sevick1988.gro'))
direct = pairing.generate_direct_correlation(trj, cutoff=0.8)
indirect = pairing.pairing._generate_indirect_connectivity(direct)
indirect = pairing.generate_indirect_connectivity(direct)

assert indirect.dtype == np.int32


def test_indirect_matrix_reduction():
trj = md.load(get_fn('sevick1988.gro'))
direct = pairing.generate_direct_correlation(trj, cutoff=0.8)
indirect = pairing.pairing._generate_indirect_connectivity(direct)
indirect = pairing.generate_indirect_connectivity(direct)

c_R = np.asarray([[0, 1],
[0, 1],
Expand All @@ -83,20 +63,7 @@ def test_indirect_matrix_reduction():
def test_cluster_analysis():
trj = md.load(get_fn('sevick1988.gro'))
direct = pairing.generate_direct_correlation(trj, cutoff=0.8)
indirect = pairing.pairing._generate_indirect_connectivity(direct)
indirect = pairing.generate_indirect_connectivity(direct)
reduction = pairing.generate_clusters(indirect)

assert pairing.analyze_clusters(reduction) == (2.5, 1.5)


def test_new_indirect():
ref = np.asarray([[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[0, 0, 0, 1, 0],
[1, 1, 1, 0, 1]], dtype=np.int32)
trj = md.load(get_fn('sevick1988.gro'))
direct = pairing.generate_direct_correlation(trj, cutoff=0.8)
indirect = pairing.new_generate_indirect(direct)

assert (indirect == ref).all()