22# Licensed under the MIT License.
33import logging
44from pathlib import Path
5+ from tempfile import NamedTemporaryFile
56
67import mdtraj
78import numpy as np
@@ -460,13 +461,16 @@ def save_pdb_and_xtc(
460461 axis = 1 , keepdims = True
461462 ) # Center every structure at the origin
462463
463- # .pdb files contain coordinates in Angstrom
464- _write_pdb (
465- pos = pos_angstrom [0 ],
466- node_orientations = node_orientations [0 ],
467- sequence = sequence ,
468- filename = topology_path ,
469- )
464+ # save topology to tmpfile first, final topology might require filtering
465+ with NamedTemporaryFile (suffix = ".pdb" ) as tmp :
466+ # .pdb files contain coordinates in Angstrom
467+ _write_pdb (
468+ pos = pos_angstrom [0 ],
469+ node_orientations = node_orientations [0 ],
470+ sequence = sequence ,
471+ filename = tmp .name ,
472+ )
473+ topology = mdtraj .load_topology (tmp .name )
470474
471475 xyz_angstrom = []
472476 for i in range (batch_size ):
@@ -475,8 +479,6 @@ def save_pdb_and_xtc(
475479 )
476480 xyz_angstrom .append (atom_37 .view (- 1 , 3 )[atom_37_mask .flatten ()].cpu ().numpy ())
477481
478- topology = mdtraj .load_topology (topology_path )
479-
480482 traj = mdtraj .Trajectory (xyz = np .stack (xyz_angstrom ) * 0.1 , topology = topology )
481483
482484 if filter_samples :
@@ -495,6 +497,7 @@ def save_pdb_and_xtc(
495497 All unphysical samples have been saved with the suffix `_unphysical.xtc`.
496498 """
497499 )
500+
498501 else :
499502 if len (filtered_traj ) < num_samples_unfiltered :
500503 logger .info (
@@ -503,6 +506,8 @@ def save_pdb_and_xtc(
503506 )
504507 traj = filtered_traj
505508
509+ # topology is either from filtered frames or from original samples (if no filtering, or if all samples get filtered)
510+ traj [0 ].save_pdb (topology_path )
506511 traj .superpose (reference = traj , frame = 0 )
507512 traj .save_xtc (xtc_path )
508513
0 commit comments