Source code for borgjaxpm._simulation_cola

# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B

"""
COLA (COmoving Lagrangian Acceleration) integration for particle mesh simulations.

COLA works in a frame co-moving with LPT trajectories. Instead of evolving the
full phase-space trajectory, it evolves only the *residual* displacement and
velocity with respect to LPT predictions. This allows far fewer time steps to
achieve the same large-scale accuracy as a standard PM simulation.

References:
    Tassev, Zaldarriaga & Eisenstein (2013), JCAP 06 036
    Howlett et al. (2015), Astronomy & Computing 12, 109

Equations of motion
-------------------
The residual displacement  Ψ_res = x - x_LPT  satisfies (in conformal time τ):

    Ψ_res'' = -∇Φ_PM(x_LPT + Ψ_res) + F_LPT_analytic

where primes denote d/dτ and F_LPT_analytic is the *exact* second conformal-time
derivative of the LPT reference trajectory, computed analytically from growth
factors without any PM solve.

Analytic LPT force
------------------
The nLPT displacement is  x_n(τ) = D_n(τ) Ψ_n(q), so its conformal acceleration is:

    x_n'' = D_n'' Ψ_n

The conformal second derivative of D_n is related to its scale-factor derivatives by

    D_n'' = a²H² [ D_n'' / (a²H²) ]
           = a²H² [ f_n(f_n - 1) (d ln H/d ln a + 1) D_n
                    + f_n' D_n  +  f_n² D_n ]    ... simplified ...
           = a²H² D_n [ f_n² + f_n(d ln H/d ln a) + f_n' ]

In practice we absorb a²H² into the kick coefficient (which integrates 1/(aH)
over conformal time) and define the dimensionless scalar:

    nD_n = D_n [ f_n(f_n - 1) - (3/2) Ω_m(a)/E²(a) · (1 - δ_{n,1}) ... ]

A cleaner derivation uses the explicit LPT equations of motion:

    D_1'' + H(τ) D_1' = (3/2) H₀² Ω_m a⁻¹ D_1     (1LPT sourced by linear δ)
    D_2'' + H(τ) D_2' = (3/2) H₀² Ω_m a⁻¹ D_2
                        - (3/2) H₀² Ω_m a⁻¹ D_1²   (2LPT sourced by D_1²)

Rearranging gives the analytic acceleration *in the same units as the PM force*
(i.e. after the kick coefficient integral has been applied):

    F_LPT_analytic = nD1 · Ψ₁_unit + nD2 · Ψ₂_unit

where the dimensionless nD scalars are precomputed once per kick step
(no per-particle work beyond a scalar multiply):

    nD1 = D_1(a) · [f_1² + f_1 · (d ln H / d ln a)
                    - (3/2) Ω_m(a) / E²(a)]
    nD2 = D_2(a) · [f_2² + f_2 · (d ln H / d ln a)
                    - (3/2) Ω_m(a) / E²(a)]
          + D_1²(a) · (3/2) Ω_m(a) / E²(a)

These are stored in COLACoef.nD1_kick and COLACoef.nD2_kick (one value per
leapfrog step), so the inner loop reduces to:

    v_res += kick_coeff * (F_PM - nD1 * Ψ₁_unit - nD2 * Ψ₂_unit) * 1.5 * Ω_m * L

which is O(N) per particle with zero FFTs for the LPT correction.
"""

from collections import namedtuple

# from enum import Enum
from functools import partial

import aquila_borg as borg
import jax
import jax.numpy as jnp
import jax_cosmo as jc

# ── Re-exported from the standard leapfrog module so callers need only import
#    this file.
from ._simulation_leapfrog import (  # noqa: F401
    LeapfrogCoef,
    LeapFrogStep,
    _make_DKD,
    _precompute_leapfrog_coefs,
    supersample_ic_field,
)
from ._src.constants import FLOAT_DTYPE, LOG_STEP_DEFAULT
from ._src.pm import lpt, pm_forces

# ---------------------------------------------------------------------------
# Named tuple carrying the extra COLA coefficients alongside the standard ones
# ---------------------------------------------------------------------------

COLACoef = namedtuple(
    "COLACoef",
    [
        # Standard leapfrog coefficients (same meaning as LeapfrogCoef)
        "drift_coeff",
        "kick_coeff",
        # Growth-factor values at the *start* and *end* of each sub-step,
        # needed to rescale the unit LPT displacement fields at drift time.
        "D1_start",  # 1LPT growth factor at step start
        "D1_end",  # 1LPT growth factor at step end
        "D2_start",  # 2LPT growth factor at step start
        "D2_end",  # 2LPT growth factor at step end
        # Dimensionless analytic LPT force coefficients evaluated at each kick
        # epoch (one scalar per leapfrog sub-step, zero on pure DRIFT steps).
        #
        # The analytic LPT acceleration in code units is:
        #   F_LPT = nD1 * psi1_unit + nD2 * psi2_unit
        #
        # where psi_{1,2}_unit are the growth-normalised displacement fields
        # (shape: [N, 3]), so the correction is a pure O(N) scalar multiply.
        #
        # Derivation (see module docstring for full details):
        #   nD1 = D1 * [f1^2 + f1*dlnH_dlna - 3/2*Om/E^2]
        #   nD2 = D2 * [f2^2 + f2*dlnH_dlna - 3/2*Om/E^2]
        #         + D1^2 * (3/2*Om/E^2)
        # Zero on DRIFT-only steps.
        "nD1_kick",
        "nD2_kick",
    ],
)


# ---------------------------------------------------------------------------
# Growth-factor helpers (jax-cosmo already exposes Omega_m_a, but we need D1)
# ---------------------------------------------------------------------------


def _growth_factor_ODE(cosmo, a):
    """
    Linear growth factor D1(a) normalised to D1(a=1) = 1, computed via the
    standard integral formula valid for any flat ΛCDM cosmology:

        D1(a) ∝ H(a) ∫_0^a  da' / (a' H(a'))^3
    """
    from quadax import quadgk

    def _integrand(ap):
        Esqr = jc.background.Esqr(cosmo, ap)
        return jnp.array(1.0 / (ap * jnp.sqrt(Esqr)) ** 3, dtype=FLOAT_DTYPE)

    norm, _ = quadgk(_integrand, [1e-4, 1.0])
    integral, _ = quadgk(_integrand, [1e-4, a])
    E = jnp.sqrt(jc.background.Esqr(cosmo, a))
    E1 = jnp.sqrt(jc.background.Esqr(cosmo, 1.0))
    return (E / E1) * integral / norm


def _growth_factor_2LPT(cosmo, a):
    """
    Second-order growth factor D2(a).  In the Einstein–de Sitter approximation
    D2 = -3/7 * D1^2.  For ΛCDM this is a good approximation at the ~1% level,
    so we adopt it here for efficiency.
    """
    D1 = _growth_factor_ODE(cosmo, a)
    return -3.0 / 7.0 * D1**2


def _log_growth_rate(cosmo, a, D_fn, eps=1e-4):
    """
    Logarithmic growth rate  f = d ln D / d ln a  for growth factor function
    *D_fn(cosmo, a)*, computed by second-order finite differences in ln a.
    """
    Dp = D_fn(cosmo, a * jnp.exp(eps))
    Dm = D_fn(cosmo, a * jnp.exp(-eps))
    return (jnp.log(jnp.abs(Dp)) - jnp.log(jnp.abs(Dm))) / (2 * eps)


def _dlnH_dlna(cosmo, a, eps=1e-4):
    """
    d ln H / d ln a, computed by second-order finite differences in ln a.
    Equals  -1 - q  where q is the deceleration parameter, but we prefer the
    numerical form to avoid hard-coding a cosmological model.
    """
    lnHp = 0.5 * jnp.log(jc.background.Esqr(cosmo, a * jnp.exp(eps)))
    lnHm = 0.5 * jnp.log(jc.background.Esqr(cosmo, a * jnp.exp(-eps)))
    return (lnHp - lnHm) / (2 * eps)


def _nD_coefs(cosmo, a, D1, D2):
    """
    Dimensionless analytic LPT force coefficients nD1 and nD2 at scale factor *a*.

    In the COLA kick step the total velocity update is

        Δv_res = kick_coeff * (F_PM - F_LPT_analytic) * 1.5 * Ω_m * L

    where F_LPT_analytic is the *exact* conformal acceleration of the LPT
    reference trajectory, expressed in the same code units as F_PM:

        F_LPT_analytic = nD1 * Ψ₁_unit + nD2 * Ψ₂_unit

    with Ψ_{1,2}_unit the growth-normalised (unit-D) displacement fields
    stored once for the whole run.  The cost inside the time loop is therefore
    two O(N) scalar multiplies — no FFT, no mesh scatter/gather.

    Derivation
    ----------
    The conformal-time EOM for the nLPT growth factor D_n is

        D_n'' + H_c D_n' = S_n        (H_c = aH = conformal Hubble)

    Sources (in units where 3/2 H₀² Ω_m = 1 after normalising):
        S_1 = (3/2 Ω_m / E²) a⁻² D_1        [restoring a² from conformal→code]
        S_2 = (3/2 Ω_m / E²) a⁻² (D_2 - D_1²)

    Using D_n' = H_c f_n D_n, H_c = aH:

        D_n'' = S_n - H_c D_n'
               = (3/2 Ω_m/E²) a⁻² D_n^{source} - a²H² f_n D_n

    The kick integrand already absorbs 1/(a²H²), so the effective coefficient
    multiplying kick_coeff (in the same normalisation as pm_forces output) is

        nD_n  ≡  D_n'' / (3/2 Ω_m * L)

    Working through the algebra gives the explicit formulas:

        nD1 = D1 * [ f1² + f1 * (d ln H / d ln a) - 3/2 * Ωm(a)/E²(a) ]

        nD2 = D2 * [ f2² + f2 * (d ln H / d ln a) - 3/2 * Ωm(a)/E²(a) ]
              + D1² * ( 3/2 * Ωm(a)/E²(a) )

    The D1²·(3/2 Ωm/E²) term arises because S_2 contains −D_1², which acts as
    an additional source that partially cancels the 2LPT self-gravity term.
    """
    f1 = _log_growth_rate(cosmo, a, _growth_factor_ODE)
    f2 = _log_growth_rate(cosmo, a, _growth_factor_2LPT)
    dlnH = _dlnH_dlna(cosmo, a)

    # Ω_m(a) / E²(a) = Ω_m,0 · a⁻³ / E²(a)
    Esqr = jc.background.Esqr(cosmo, a)
    Om_over_Esqr = cosmo.Omega_m * a**-3 / Esqr

    nD1 = D1 * (f1**2 + f1 * dlnH - 1.5 * Om_over_Esqr)
    nD2 = D2 * (f2**2 + f2 * dlnH - 1.5 * Om_over_Esqr) + D1**2 * 1.5 * Om_over_Esqr

    return jnp.array(nD1, dtype=FLOAT_DTYPE), jnp.array(nD2, dtype=FLOAT_DTYPE)


# ---------------------------------------------------------------------------
# Pre-compute COLA coefficients
# ---------------------------------------------------------------------------


def _precompute_cola_coefs(
    cosmo,
    a_start,
    a_final,
    nsteps,
    LF_programme,
    log_step=LOG_STEP_DEFAULT,
):
    """
    Extend the standard leapfrog coefficient computation with COLA-specific
    growth-factor values and analytic LPT force scalars (nD1, nD2) per step.

    The nD coefficients are evaluated at the mid-point of each KICK sub-step
    (where the velocity is updated) and set to zero for pure DRIFT sub-steps.
    They are O(1) scalars so this entire function runs at Python/NumPy speed
    during the preparation phase, with nothing left to compute at run time.
    """
    lf_coef = _precompute_leapfrog_coefs(
        cosmo, a_start, a_final, nsteps, LF_programme, log_step=log_step
    )

    if log_step:
        dx = jnp.log(a_final / a_start) / nsteps
        get_a = jnp.exp
        get_x = jnp.log
    else:
        dx = (a_final - a_start) / nsteps
        get_a = lambda x: x  # noqa: E731
        get_x = lambda x: x  # noqa: E731

    D1_start_list, D1_end_list = [], []
    D2_start_list, D2_end_list = [], []
    nD1_list, nD2_list = [], []

    prev_x = {
        LeapFrogStep.KICK: get_x(a_start),
        LeapFrogStep.DRIFT: get_x(a_start),
    }

    zero = jnp.array(0.0, dtype=FLOAT_DTYPE)

    for op, dt in LF_programme:
        x_start = prev_x[op]
        x_end = x_start + dt * dx
        a_s = get_a(x_start)
        a_e = get_a(x_end)

        D1s = _growth_factor_ODE(cosmo, a_s)
        D1e = _growth_factor_ODE(cosmo, a_e)
        D2s = _growth_factor_2LPT(cosmo, a_s)
        D2e = _growth_factor_2LPT(cosmo, a_e)

        D1_start_list.append(D1s)
        D1_end_list.append(D1e)
        D2_start_list.append(D2s)
        D2_end_list.append(D2e)

        if op == LeapFrogStep.KICK:
            # Evaluate analytic force coefficients at the kick mid-point epoch
            a_mid = get_a(0.5 * (x_start + x_end))
            D1_mid = _growth_factor_ODE(cosmo, a_mid)
            D2_mid = _growth_factor_2LPT(cosmo, a_mid)
            nD1, nD2 = _nD_coefs(cosmo, a_mid, D1_mid, D2_mid)
            borg.print_msg(
                borg.Level.std,
                "COLA KICK a_mid={a} nD1={n1} nD2={n2}",
                a=a_mid,
                n1=nD1,
                n2=nD2,
            )
        else:
            nD1, nD2 = zero, zero

        nD1_list.append(nD1)
        nD2_list.append(nD2)

        prev_x[op] = x_end

    return COLACoef(
        drift_coeff=lf_coef.drift_coeff,
        kick_coeff=lf_coef.kick_coeff,
        D1_start=jnp.array(D1_start_list, dtype=FLOAT_DTYPE),
        D1_end=jnp.array(D1_end_list, dtype=FLOAT_DTYPE),
        D2_start=jnp.array(D2_start_list, dtype=FLOAT_DTYPE),
        D2_end=jnp.array(D2_end_list, dtype=FLOAT_DTYPE),
        nD1_kick=jnp.array(nD1_list, dtype=FLOAT_DTYPE),
        nD2_kick=jnp.array(nD2_list, dtype=FLOAT_DTYPE),
    )


# ---------------------------------------------------------------------------
# Core COLA simulation (JIT-compiled)
# ---------------------------------------------------------------------------


@partial(
    jax.jit,
    static_argnums=(
        3,  # want_vel
        4,  # mesh_shape
        5,  # mesh_shape_force
        7,  # supersampling
        9,  # lpt_order
        10,  # sharding (treated as static because it carries device structure)
        11,  # halo_size
    ),
)
def _run_cola_simulation_internal(
    ic_field,
    omega_c,
    omega_b,
    want_vel,
    mesh_shape,
    mesh_shape_force,
    cola_coefficients,
    supersampling,
    a_start,
    lpt_order=2,
    sharding=None,
    halo_size=0,
):
    """
    Run a COLA particle-mesh simulation.

    The state vector is (x_res, v_res) where:
        x_full = x_lpt(a) + x_res
        v_full = v_lpt(a) + v_res

    At initialisation both residuals are zero by construction (we start
    exactly on the LPT trajectory).

    Parameters
    ----------
    ic_field : array
        Initial density contrast field on the Lagrangian mesh.
    omega_c, omega_b : float
        Cosmological density parameters.
    want_vel : bool
        Whether to return velocities (currently unused; reserved for future).
    mesh_shape : tuple of int
        Shape of the Lagrangian particle grid.
    mesh_shape_force : tuple of int
        Shape of the PM force mesh.
    cola_coefficients : COLACoef
        Pre-computed COLA drift/kick integrals and growth-factor arrays.
    supersampling : int
        Supersampling factor for the initial conditions.
    a_start : float
        Initial scale factor.
    lpt_order : int
        Order of LPT used as the COLA reference trajectory (1 or 2).
    sharding : optional
        JAX device sharding descriptor.
    halo_size : int
        Halo cells for the PM force mesh (distributed runs).

    Returns
    -------
    pos_full : array, shape (N, 3)
        Final comoving particle positions.
    """
    cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b)

    # ── Optionally supersample the IC field ──────────────────────────────────
    if supersampling > 1:
        ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
        mesh_shape_particles = ic_field.shape
    else:
        mesh_shape_particles = mesh_shape

    # ── Lagrangian (unperturbed) particle positions ──────────────────────────
    q = jnp.stack(
        jax.device_put(
            jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]),
            device=sharding,
        ),
        axis=-1,
    ).reshape([-1, 3])

    # ── LPT initial displacement and velocity at a_start ────────────────────
    dx0, p0 = lpt(cosmo, ic_field, q, a=a_start, order=lpt_order, sharding=sharding)

    # Initial COLA state: residuals are zero (we sit exactly on LPT trajectory)
    x_res = jnp.zeros_like(dx0)  # residual displacement
    v_res = jnp.zeros_like(p0)  # residual velocity

    # Normalisation factor for the PM force (same as standard leapfrog)
    multiplier = mesh_shape_particles[0] / mesh_shape_force[0]

    # Growth factor at a_start for normalising LPT displacements
    D1_0 = _growth_factor_ODE(cosmo, a_start)
    D2_0 = _growth_factor_2LPT(cosmo, a_start)

    # Normalised (unit-growth) LPT displacement fields so we can rescale cheaply
    # dx_lpt(a) = (D1(a)/D1_0)*dx1 + (D2(a)/D2_0)*dx2
    # We recover dx1 and dx2 from the combined LPT call by running at two epochs.
    if lpt_order >= 2:
        dx1, _ = lpt(cosmo, ic_field, q, a=a_start, order=1, sharding=sharding)
        dx2 = dx0 - dx1  # 2LPT contribution at a_start
        dx1_unit = dx1 / D1_0
        dx2_unit = dx2 / D2_0
    else:
        dx1_unit = dx0 / D1_0
        dx2_unit = jnp.zeros_like(dx0)

    @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
    def _cola_evolver(state, istep):
        x_res_i, v_res_i = state

        kick_c = cola_coefficients.kick_coeff[istep]
        drift_c = cola_coefficients.drift_coeff[istep]
        D1s = cola_coefficients.D1_start[istep]
        # D1e = cola_coefficients.D1_end[istep]
        D2s = cola_coefficients.D2_start[istep]
        # D2e = cola_coefficients.D2_end[istep]

        # ── Current full position (needed for PM force) ──────────────────────
        x_lpt_current = D1s * dx1_unit + D2s * dx2_unit
        x_full = q + x_lpt_current + x_res_i

        # ── KICK step (velocity update) ──────────────────────────────────────
        def do_kick(_):
            f_pm = pm_forces(
                x_full,
                mesh_shape=mesh_shape_force,
                pos_mesh_shape=mesh_shape_particles,
                fd=True,
                sharding=sharding,
                halo_size=halo_size,
            )

            # Analytic LPT force correction — O(N) scalar multiplies, no FFT.
            # nD1_kick and nD2_kick encode the exact conformal second derivative
            # of the LPT reference trajectory (see _nD_coefs for derivation).
            # Units are identical to pm_forces output so the subtraction is direct.
            f_lpt_analytic = (
                cola_coefficients.nD1_kick[istep] * dx1_unit
                + cola_coefficients.nD2_kick[istep] * dx2_unit
            )

            # Residual force = full PM force – analytic LPT acceleration
            f_res = f_pm - f_lpt_analytic

            return v_res_i + kick_c * f_res * 1.5 * cosmo.Omega_m * multiplier

        v_res_new = jax.lax.cond(
            kick_c > 0,
            do_kick,
            lambda _: v_res_i,
            operand=None,
        )

        # ── DRIFT step (position update) ─────────────────────────────────────
        # Full position change = LPT displacement increment + residual velocity
        # x_res advances by drift_c * v_res  (the LPT part is tracked implicitly
        # via D1_end / D2_end, so x_res stays purely residual).
        x_res_new = x_res_i + drift_c * v_res_new

        return (x_res_new, v_res_new), None

    (x_res_final, v_res_final), _ = jax.lax.scan(
        _cola_evolver,
        (x_res, v_res),
        jnp.arange(cola_coefficients.drift_coeff.shape[0]),
    )

    # ── Reconstruct full final position ─────────────────────────────────────
    D1_f = cola_coefficients.D1_end[-1]
    D2_f = cola_coefficients.D2_end[-1]
    x_lpt_final = D1_f * dx1_unit + D2_f * dx2_unit
    pos_full = q + x_lpt_final + x_res_final

    return pos_full


# ---------------------------------------------------------------------------
# Public API — mirrors prepare_leap_frog_simulation
# ---------------------------------------------------------------------------


[docs] def prepare_cola_simulation( omega_c, omega_b, num_steps, a_start, a_final, supersampling=1, lpt_order=2, log_step=LOG_STEP_DEFAULT, ): """ Pre-compute COLA coefficients and return a ready-to-call simulation function. The returned callable has the same signature as the one returned by ``prepare_leap_frog_simulation``, so both can be used interchangeably. Parameters ---------- omega_c : float Cold dark matter density parameter Ω_c h². omega_b : float Baryon density parameter Ω_b h². num_steps : int Number of PM time steps. COLA typically needs ~10–30× fewer steps than a standard PM simulation for the same large-scale accuracy. a_start : float Initial scale factor (should match the epoch of ``ic_field``). a_final : float Final scale factor. supersampling : int, optional Supersampling factor applied to ``ic_field`` before placing particles. Default 1 (no supersampling). lpt_order : int, optional Order of the LPT reference trajectory (1 or 2). Second order (default) gives better accuracy for the same number of steps. log_step : bool, optional If True (default) steps are uniform in ln a; otherwise uniform in a. Returns ------- run_fn : callable ``run_fn(ic_field, mesh_shape, mesh_shape_force, **kwargs) -> positions`` Examples -------- >>> run = prepare_cola_simulation(0.25, 0.05, num_steps=10, ... a_start=0.1, a_final=1.0) >>> positions = run(ic_field, mesh_shape=(128, 128, 128), ... mesh_shape_force=(128, 128, 128)) """ cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b) cola_coefs = _precompute_cola_coefs( cosmo, a_start, a_final, num_steps, _make_DKD(num_steps), log_step=log_step, ) borg.print_msg( borg.Level.std, "COLA: pre-computed {n} sub-steps from a={a0} to a={a1} (lpt_order={o})", n=len(cola_coefs.drift_coeff), a0=a_start, a1=a_final, o=lpt_order, ) return partial( _run_cola_simulation_internal, omega_c=omega_c, omega_b=omega_b, cola_coefficients=cola_coefs, a_start=a_start, supersampling=supersampling, lpt_order=lpt_order, )