Skip to content

Commit d6ec4ef

Browse files
author
Kane Shenton
committed
Added debye-waller core module so it can be used via CLI
Added CLI version Kept inlined versions of the functions to maintain WASM compatibility. Can remove these if this module becomes available via pypi at some point.
1 parent 4eccf8f commit d6ec4ef

5 files changed

Lines changed: 1667 additions & 413 deletions

File tree

larch_cli_wrapper/cli.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,5 +1487,256 @@ def create_config_example(
14871487
raise typer.Exit(1) from e
14881488

14891489

1490+
@app.command("debye-waller")
1491+
def debye_waller(
1492+
trajectory: Path = typer.Argument(
1493+
..., help="Path to trajectory file (any ASE-readable format)"
1494+
),
1495+
prefix: str = typer.Option(
1496+
"",
1497+
"--prefix",
1498+
"-p",
1499+
help="Output file prefix (defaults to trajectory file stem)",
1500+
),
1501+
skip_frames: int = typer.Option(
1502+
0, "--skip-frames", "-s", help="Number of frames to skip at the start"
1503+
),
1504+
no_align: bool = typer.Option(
1505+
False, "--no-align", help="Skip Kabsch alignment (use raw unwrapped positions)"
1506+
),
1507+
site: str = typer.Option(
1508+
"",
1509+
"--site",
1510+
help=(
1511+
"Absorber site for MSRD analysis. "
1512+
"Formats: 'K' (all K atoms), 'K.1' (first K), 'K.1-3' (first three K), "
1513+
"'11' (11th atom), '11-20' (atoms 11-20, 1-based)."
1514+
),
1515+
),
1516+
cutoff: float = typer.Option(3.5, "--cutoff", "-r", help="Neighbor cutoff in Å"),
1517+
cutoff_3body: float = typer.Option(
1518+
0.0,
1519+
"--cutoff-3body",
1520+
help=(
1521+
"Maximum absorber-to-neighbor distance (Å) used to select legs for "
1522+
"3-body paths. This is a neighbor distance cutoff, not an Reff cutoff — "
1523+
"it limits each individual leg length, not the total path length. "
1524+
"Set to 0 to skip 3-body paths entirely."
1525+
),
1526+
),
1527+
tol_dist: float = typer.Option(
1528+
0.1, "--tol-dist", help="Distance grouping tolerance in Å"
1529+
),
1530+
tol_angle: float = typer.Option(
1531+
5.0, "--tol-angle", help="Angle grouping tolerance in degrees"
1532+
),
1533+
include_hydrogen: bool = typer.Option(
1534+
False,
1535+
"--include-hydrogen",
1536+
help=(
1537+
"Include hydrogen atoms in the neighbor search "
1538+
"for MSRD paths (excluded by default)."
1539+
),
1540+
),
1541+
) -> None:
1542+
"""Compute Debye-Waller factors and MSRD from an MD trajectory.
1543+
1544+
Always writes:
1545+
1546+
\b
1547+
<prefix>_with_adp.cif CIF with anisotropic displacement parameters
1548+
<prefix>_bfactors.png B-factor scatter plot per atom / element
1549+
<prefix>_msrd.png σ² vs Reff plot (only when --site is given)
1550+
<prefix>_msrd_paths.csv MSRD path table (only when --site is given)
1551+
"""
1552+
from rich.table import Table
1553+
1554+
from .debye_waller_core import (
1555+
calculate_grouped_msrd,
1556+
compute_adp_results,
1557+
load_trajectory,
1558+
msrd_to_dataframe,
1559+
parse_site_specification,
1560+
plot_bfactors,
1561+
plot_sigma2_vs_reff,
1562+
process_trajectory,
1563+
save_cif_with_adp,
1564+
)
1565+
1566+
# ── Resolve output prefix ─────────────────────────────────────────────
1567+
out_prefix = prefix if prefix else trajectory.stem
1568+
1569+
# ── Load ──────────────────────────────────────────────────────────────
1570+
console.print(
1571+
f"[bold]Loading trajectory:[/bold] [cyan]{trajectory}[/cyan]"
1572+
+ (f" (skipping {skip_frames} frames)" if skip_frames else "")
1573+
)
1574+
try:
1575+
structures = load_trajectory(trajectory, skip_frames=skip_frames)
1576+
except (OSError, ValueError) as exc:
1577+
console.print(f"[red]Error loading trajectory: {exc}[/red]")
1578+
raise typer.Exit(1) from exc
1579+
1580+
n_frames = len(structures)
1581+
n_atoms = len(structures[0])
1582+
elements_str = ", ".join(sorted(set(structures[0].get_chemical_symbols())))
1583+
console.print(
1584+
f" [green]✓[/green] {n_frames} frames · {n_atoms} atoms · "
1585+
f"elements: {elements_str}"
1586+
)
1587+
1588+
# ── Unwrap & align ────────────────────────────────────────────────────
1589+
with console.status(
1590+
"[bold]Unwrapping PBC" + (" and Kabsch-aligning" if not no_align else "") + "…"
1591+
):
1592+
unwrapped = process_trajectory(structures, align=not no_align)
1593+
console.print(
1594+
f" [green]✓[/green] Positions processed "
1595+
f"({'aligned' if not no_align else 'no alignment'}) · "
1596+
f"shape: {unwrapped.shape}"
1597+
)
1598+
1599+
# ── ADP / B-factors ───────────────────────────────────────────────────
1600+
with console.status("[bold]Computing ADP tensors…"):
1601+
results = compute_adp_results(structures, unwrapped)
1602+
1603+
b_factors = results["b_factors"]
1604+
atom_names = results["atom_names"]
1605+
import numpy as _np
1606+
1607+
unique_elements = sorted(set(atom_names))
1608+
1609+
# Summary table
1610+
summary_table = Table(title="Debye-Waller factors by element", show_header=True)
1611+
summary_table.add_column("Element")
1612+
summary_table.add_column("N atoms", justify="right")
1613+
summary_table.add_column("Mean B (Ų)", justify="right")
1614+
summary_table.add_column("Std B (Ų)", justify="right")
1615+
summary_table.add_column("Min B (Ų)", justify="right")
1616+
summary_table.add_column("Max B (Ų)", justify="right")
1617+
for el in unique_elements:
1618+
mask = _np.array([n == el for n in atom_names])
1619+
bv = b_factors[mask]
1620+
summary_table.add_row(
1621+
el,
1622+
str(int(mask.sum())),
1623+
f"{_np.mean(bv):.4f}",
1624+
f"{_np.std(bv):.4f}",
1625+
f"{_np.min(bv):.4f}",
1626+
f"{_np.max(bv):.4f}",
1627+
)
1628+
console.print(summary_table)
1629+
console.print(f" Overall mean B-factor: [bold]{_np.mean(b_factors):.4f}[/bold] Ų")
1630+
1631+
# ── Save CIF ──────────────────────────────────────────────────────────
1632+
cif_path = Path(f"{out_prefix}_with_adp.cif")
1633+
cif_path.write_text(save_cif_with_adp(results))
1634+
console.print(f" [green]✓[/green] CIF saved → [cyan]{cif_path}[/cyan]")
1635+
1636+
# ── B-factor plot ─────────────────────────────────────────────────────
1637+
bfactor_plot_path = Path(f"{out_prefix}_bfactors.png")
1638+
plot_bfactors(results, output_path=bfactor_plot_path)
1639+
console.print(
1640+
f" [green]✓[/green] B-factor plot → [cyan]{bfactor_plot_path}[/cyan]"
1641+
)
1642+
1643+
# ── MSRD ──────────────────────────────────────────────────────────────
1644+
if not site:
1645+
console.print("[dim]No --site specified; skipping MSRD analysis.[/dim]")
1646+
return
1647+
1648+
symbols = structures[0].get_chemical_symbols()
1649+
try:
1650+
central_indices = parse_site_specification(site, symbols)
1651+
except ValueError as exc:
1652+
console.print(f"[red]Site specification error: {exc}[/red]")
1653+
raise typer.Exit(1) from exc
1654+
1655+
console.print(
1656+
f"\n[bold]MSRD analysis:[/bold] site=[cyan]{site}[/cyan] "
1657+
f"({len(central_indices)} absorber(s)) · cutoff={cutoff} Å"
1658+
+ (f" · 3-body cutoff={cutoff_3body} Å" if cutoff_3body > 0 else "")
1659+
)
1660+
1661+
with console.status("[bold]Computing MSRD paths…"):
1662+
res_2b, res_3b = calculate_grouped_msrd(
1663+
structures,
1664+
unwrapped,
1665+
central_indices,
1666+
site,
1667+
cutoff=cutoff,
1668+
tol_dist=tol_dist,
1669+
tol_angle=tol_angle,
1670+
cutoff_3body=cutoff_3body if cutoff_3body > 0 else 0,
1671+
exclude_hydrogen=not include_hydrogen,
1672+
)
1673+
1674+
# ── Path summary tables ───────────────────────────────────────────────
1675+
if res_2b:
1676+
t2 = Table(title="2-Body MSRD Paths", show_header=True)
1677+
for col, just in [
1678+
("Path type", "left"),
1679+
("Reff (Å)", "right"),
1680+
("σ² (Ų)", "right"),
1681+
("Count", "right"),
1682+
("Degeneracy", "right"),
1683+
]:
1684+
t2.add_column(col, justify=just)
1685+
for r in res_2b:
1686+
t2.add_row(
1687+
r["type"],
1688+
f"{r['reff']:.4f}",
1689+
f"{r['sigma2']:.6f}",
1690+
str(r["count"]),
1691+
f"{r['count'] / len(central_indices):.1f}",
1692+
)
1693+
console.print(t2)
1694+
1695+
if res_3b:
1696+
t3 = Table(title="3-Body MSRD Paths", show_header=True)
1697+
for col, just in [
1698+
("Path type", "left"),
1699+
("Reff (Å)", "right"),
1700+
("σ² (Ų)", "right"),
1701+
("Angle (°)", "right"),
1702+
("Count", "right"),
1703+
("Degeneracy", "right"),
1704+
]:
1705+
t3.add_column(col, justify=just)
1706+
for r in res_3b:
1707+
t3.add_row(
1708+
r["type"],
1709+
f"{r['reff']:.4f}",
1710+
f"{r['sigma2']:.6f}",
1711+
f"{r['angle']:.1f}",
1712+
str(r["count"]),
1713+
f"{2 * r['count'] / len(central_indices):.1f}",
1714+
)
1715+
console.print(t3)
1716+
1717+
if not res_2b and not res_3b:
1718+
console.print(
1719+
"[yellow]No MSRD paths found. Try increasing the cutoff.[/yellow]"
1720+
)
1721+
return
1722+
1723+
console.print(
1724+
f"\n [green]✓[/green] Found [bold]{len(res_2b)}[/bold] two-body and "
1725+
f"[bold]{len(res_3b)}[/bold] three-body paths."
1726+
)
1727+
1728+
# ── CSV ───────────────────────────────────────────────────────────────
1729+
msrd_df = msrd_to_dataframe(res_2b, res_3b, n_absorbers=len(central_indices))
1730+
site_label = site.replace(" ", "_")
1731+
csv_path = Path(f"{out_prefix}_msrd_paths_{site_label}.csv")
1732+
msrd_df.to_csv(csv_path, index=False)
1733+
console.print(f" [green]✓[/green] MSRD CSV → [cyan]{csv_path}[/cyan]")
1734+
1735+
# ── σ² vs Reff plot ───────────────────────────────────────────────────
1736+
msrd_plot_path = Path(f"{out_prefix}_msrd_{site_label}.png")
1737+
plot_sigma2_vs_reff(res_2b, res_3b, output_path=msrd_plot_path)
1738+
console.print(f" [green]✓[/green] σ² vs Reff plot → [cyan]{msrd_plot_path}[/cyan]")
1739+
1740+
14901741
if __name__ == "__main__":
14911742
app()

0 commit comments

Comments
 (0)