Skip to content

Commit 067c36d

Browse files
authored
Saves physics-filtered frame to topology.pdb, instead of 0th sample (#186)
closes #184
1 parent 671b8ec commit 067c36d

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/bioemu/convert_chemgraph.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33
import logging
44
from pathlib import Path
5+
from tempfile import NamedTemporaryFile
56

67
import mdtraj
78
import numpy as np
@@ -460,13 +461,16 @@ def save_pdb_and_xtc(
460461
axis=1, keepdims=True
461462
) # Center every structure at the origin
462463

463-
# .pdb files contain coordinates in Angstrom
464-
_write_pdb(
465-
pos=pos_angstrom[0],
466-
node_orientations=node_orientations[0],
467-
sequence=sequence,
468-
filename=topology_path,
469-
)
464+
# save topology to tmpfile first, final topology might require filtering
465+
with NamedTemporaryFile(suffix=".pdb") as tmp:
466+
# .pdb files contain coordinates in Angstrom
467+
_write_pdb(
468+
pos=pos_angstrom[0],
469+
node_orientations=node_orientations[0],
470+
sequence=sequence,
471+
filename=tmp.name,
472+
)
473+
topology = mdtraj.load_topology(tmp.name)
470474

471475
xyz_angstrom = []
472476
for i in range(batch_size):
@@ -475,8 +479,6 @@ def save_pdb_and_xtc(
475479
)
476480
xyz_angstrom.append(atom_37.view(-1, 3)[atom_37_mask.flatten()].cpu().numpy())
477481

478-
topology = mdtraj.load_topology(topology_path)
479-
480482
traj = mdtraj.Trajectory(xyz=np.stack(xyz_angstrom) * 0.1, topology=topology)
481483

482484
if filter_samples:
@@ -495,6 +497,7 @@ def save_pdb_and_xtc(
495497
All unphysical samples have been saved with the suffix `_unphysical.xtc`.
496498
"""
497499
)
500+
498501
else:
499502
if len(filtered_traj) < num_samples_unfiltered:
500503
logger.info(
@@ -503,6 +506,8 @@ def save_pdb_and_xtc(
503506
)
504507
traj = filtered_traj
505508

509+
# topology is either from filtered frames or from original samples (if no filtering, or if all samples get filtered)
510+
traj[0].save_pdb(topology_path)
506511
traj.superpose(reference=traj, frame=0)
507512
traj.save_xtc(xtc_path)
508513

0 commit comments

Comments
 (0)