|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Transfer a per-residue numeric column from a CSV into the B-factor column of a PDB. |
| 4 | +
|
| 5 | +- Prompts for CSV path, PDB path, and the CSV column to transfer. |
| 6 | +- CSV must identify residues by columns (case-insensitive): |
| 7 | + * Chain -> one of: ["chain", "Chain", "CHAIN"] (optional; blank if absent) |
| 8 | + * ResSeq -> one of: ["resseq","resseqid","resnum","res_seq","residue","Residue","resid","ResID","ResSeq"] (required) |
| 9 | + * ICode -> one of: ["icode","inscode","insertion","iCode","ICode"] (optional; blank if absent) |
| 10 | +- Column chosen must be numeric; duplicates per residue key are not allowed. |
| 11 | +- Every residue present in the PDB must be present exactly once in the CSV mapping. |
| 12 | +- Non-ATOM/HETATM (e.g., HEADER, TER, REMARK) are preserved as-is. |
| 13 | +- Output: <pdb_stem>_bfactor_from_csv_<col>.pdb |
| 14 | +""" |
| 15 | + |
| 16 | +import os |
| 17 | +import re |
| 18 | +import sys |
| 19 | +import pandas as pd |
| 20 | +from typing import Tuple, Optional |
| 21 | + |
| 22 | +# ---------------------------- Helpers ---------------------------- # |
| 23 | + |
| 24 | +def die(msg: str, code: int = 1): |
| 25 | + print(f"ERROR: {msg}", file=sys.stderr) |
| 26 | + sys.exit(code) |
| 27 | + |
| 28 | +def ask_path(prompt: str, default: Optional[str] = None) -> str: |
| 29 | + if default: |
| 30 | + s = input(f"{prompt} [{default}]: ").strip() |
| 31 | + return s or default |
| 32 | + return input(f"{prompt}: ").strip() |
| 33 | + |
| 34 | +def is_atom_record(line: str) -> bool: |
| 35 | + rec = line[:6] |
| 36 | + return rec == "ATOM " or rec == "HETATM" |
| 37 | + |
| 38 | +def residue_key_from_pdb_line(line: str) -> Tuple[str, str, str]: |
| 39 | + """Return (chain, resseq, icode) from ATOM/HETATM line.""" |
| 40 | + chain = line[21].strip() |
| 41 | + resseq = line[22:26].strip() # keep as string (handles e.g., '401') |
| 42 | + icode = line[26].strip() # insertion code |
| 43 | + return (chain, resseq, icode) |
| 44 | + |
| 45 | +def format_bfactor(line: str, b: float) -> str: |
| 46 | + """Overwrite B-factor (cols 61-66, 1-based) with {:6.3f}, preserve occupancy.""" |
| 47 | + nl = "\n" if line.endswith("\n") else "" |
| 48 | + core = line[:-1] if nl else line |
| 49 | + if len(core) < 66: |
| 50 | + core = core + " " * (66 - len(core)) |
| 51 | + return core[:60] + f"{b:6.3f}" + core[66:] + nl |
| 52 | + |
| 53 | +def find_column(df: pd.DataFrame, candidates) -> Optional[str]: |
| 54 | + for c in candidates: |
| 55 | + if c in df.columns: |
| 56 | + return c |
| 57 | + # case-insensitive |
| 58 | + lower_map = {c.lower(): c for c in df.columns} |
| 59 | + for c in candidates: |
| 60 | + if c.lower() in lower_map: |
| 61 | + return lower_map[c.lower()] |
| 62 | + return None |
| 63 | + |
| 64 | +def sanitize_for_filename(s: str) -> str: |
| 65 | + return re.sub(r"[^A-Za-z0-9._-]+", "_", s).strip("_") |
| 66 | + |
| 67 | +# ---------------------------- Main ---------------------------- # |
| 68 | + |
| 69 | +def main(): |
| 70 | + # Prompt for inputs |
| 71 | + csv_path = ask_path("CSV file", "values.csv") |
| 72 | + pdb_path = ask_path("PDB file", "structure.pdb") |
| 73 | + |
| 74 | + if not os.path.isfile(csv_path): |
| 75 | + die(f"CSV not found: {csv_path}") |
| 76 | + if not os.path.isfile(pdb_path): |
| 77 | + die(f"PDB not found: {pdb_path}") |
| 78 | + |
| 79 | + # Read CSV |
| 80 | + try: |
| 81 | + df = pd.read_csv(csv_path) |
| 82 | + except Exception as e: |
| 83 | + die(f"Failed to read CSV: {e}") |
| 84 | + |
| 85 | + # Identify residue key columns |
| 86 | + chain_col = find_column(df, ["chain", "Chain", "CHAIN"]) |
| 87 | + resseq_col = find_column(df, ["resseq","resseqid","resnum","res_seq","residue","Residue","resid","ResID","ResSeq"]) |
| 88 | + icode_col = find_column(df, ["icode","inscode","insertion","iCode","ICode"]) |
| 89 | + |
| 90 | + if resseq_col is None: |
| 91 | + die("Could not find a residue sequence column in CSV. " |
| 92 | + "Expected one of: ResSeq, resseq, resnum, residue, resid, etc.") |
| 93 | + |
| 94 | + # Normalize key columns |
| 95 | + df_keys = pd.DataFrame() |
| 96 | + df_keys["Chain"] = df[chain_col].fillna("").astype(str) if chain_col else "" |
| 97 | + df_keys["ResSeq"] = df[resseq_col].astype(str).str.strip() |
| 98 | + df_keys["ICode"] = df[icode_col].fillna("").astype(str).str.strip() if icode_col else "" |
| 99 | + |
| 100 | + # Choose value column (prompt) |
| 101 | + numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] |
| 102 | + if not numeric_cols: |
| 103 | + die("CSV has no numeric columns to transfer.") |
| 104 | + print("\nNumeric columns available to transfer:") |
| 105 | + for i, c in enumerate(numeric_cols, 1): |
| 106 | + print(f" {i}. {c}") |
| 107 | + col_name = input("Enter the exact column name to transfer (or number): ").strip() |
| 108 | + if col_name.isdigit(): |
| 109 | + idx = int(col_name) - 1 |
| 110 | + if idx < 0 or idx >= len(numeric_cols): |
| 111 | + die("Invalid column selection.") |
| 112 | + value_col = numeric_cols[idx] |
| 113 | + else: |
| 114 | + if col_name not in df.columns: |
| 115 | + die(f"Column '{col_name}' not found in CSV.") |
| 116 | + if not pd.api.types.is_numeric_dtype(df[col_name]): |
| 117 | + die(f"Column '{col_name}' is not numeric.") |
| 118 | + value_col = col_name |
| 119 | + |
| 120 | + # Build residue->value map; require uniqueness & non-NA |
| 121 | + df_map = pd.concat([df_keys, df[[value_col]]], axis=1) |
| 122 | + if df_map[value_col].isna().any(): |
| 123 | + bad = df_map[df_map[value_col].isna()][["Chain","ResSeq","ICode"]].head(10).to_dict("records") |
| 124 | + die(f"Selected column has NA values; first 10 offending residue keys: {bad}") |
| 125 | + |
| 126 | + # Enforce unique rows per residue key |
| 127 | + dup_mask = df_map.duplicated(subset=["Chain","ResSeq","ICode"], keep=False) |
| 128 | + if dup_mask.any(): |
| 129 | + dups = df_map.loc[dup_mask, ["Chain","ResSeq","ICode"]].value_counts().head(10) |
| 130 | + die("CSV contains duplicate rows for the same residue key. " |
| 131 | + f"First 10 duplicates (key -> count):\n{dups}") |
| 132 | + |
| 133 | + # Convert to dict |
| 134 | + value_by_key = { |
| 135 | + (row.Chain, row.ResSeq, row.ICode): float(row[value_col]) |
| 136 | + for row in df_map.itertuples(index=False) |
| 137 | + } |
| 138 | + |
| 139 | + # Read PDB & collect residue keys |
| 140 | + with open(pdb_path, "r") as f: |
| 141 | + pdb_lines = f.readlines() |
| 142 | + |
| 143 | + atom_keys = [] |
| 144 | + atom_idx = [] |
| 145 | + for i, ln in enumerate(pdb_lines): |
| 146 | + if is_atom_record(ln): |
| 147 | + atom_idx.append(i) |
| 148 | + atom_keys.append(residue_key_from_pdb_line(ln)) |
| 149 | + |
| 150 | + if not atom_idx: |
| 151 | + die("No ATOM/HETATM records found in PDB.") |
| 152 | + |
| 153 | + # Compute set of residues present in the PDB |
| 154 | + pdb_residues = sorted(set(atom_keys)) |
| 155 | + |
| 156 | + # Check coverage: every PDB residue must be in CSV mapping |
| 157 | + missing = [k for k in pdb_residues if k not in value_by_key] |
| 158 | + extra = [k for k in value_by_key.keys() if k not in set(pdb_residues)] |
| 159 | + |
| 160 | + if extra: |
| 161 | + print(f"NOTE: {len(extra)} CSV residue keys not present in PDB (they will be ignored). " |
| 162 | + f"Example: {extra[:5]}") |
| 163 | + if missing: |
| 164 | + preview = missing[:10] |
| 165 | + die(f"CSV does not cover all PDB residues. Missing {len(missing)} residues. " |
| 166 | + f"First 10 missing keys: {preview}") |
| 167 | + |
| 168 | + # Write output with updated B-factors |
| 169 | + out_lines = list(pdb_lines) |
| 170 | + for i, rk in zip(atom_idx, atom_keys): |
| 171 | + bval = value_by_key[rk] |
| 172 | + out_lines[i] = format_bfactor(out_lines[i], bval) |
| 173 | + |
| 174 | + stem, _ = os.path.splitext(pdb_path) |
| 175 | + out_col = sanitize_for_filename(value_col) |
| 176 | + out_path = f"{stem}_bfactor_from_csv_{out_col}.pdb" |
| 177 | + with open(out_path, "w") as f: |
| 178 | + f.writelines(out_lines) |
| 179 | + |
| 180 | + # Summary |
| 181 | + print("\nSuccess.") |
| 182 | + print(f" Wrote: {out_path}") |
| 183 | + print(f" ATOM/HETATM updated: {len(atom_idx)}") |
| 184 | + print(f" Residues in PDB: {len(pdb_residues)}") |
| 185 | + print(f" CSV residues used: {len(value_by_key)} (extras ignored: {len(extra)})") |
| 186 | + print(f" Column transferred: {value_col}") |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + main() |
0 commit comments