# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B
"""
COLA (COmoving Lagrangian Acceleration) integration for particle mesh simulations.
COLA works in a frame co-moving with LPT trajectories. Instead of evolving the
full phase-space trajectory, it evolves only the *residual* displacement and
velocity with respect to LPT predictions. This allows far fewer time steps to
achieve the same large-scale accuracy as a standard PM simulation.
References:
Tassev, Zaldarriaga & Eisenstein (2013), JCAP 06 036
Howlett et al. (2015), Astronomy & Computing 12, 109
Equations of motion
-------------------
The residual displacement Ψ_res = x - x_LPT satisfies (in conformal time τ):
Ψ_res'' = -∇Φ_PM(x_LPT + Ψ_res) + F_LPT_analytic
where primes denote d/dτ and F_LPT_analytic is the *exact* second conformal-time
derivative of the LPT reference trajectory, computed analytically from growth
factors without any PM solve.
Analytic LPT force
------------------
The nLPT displacement is x_n(τ) = D_n(τ) Ψ_n(q), so its conformal acceleration is:
x_n'' = D_n'' Ψ_n
The conformal second derivative of D_n is related to its scale-factor derivatives by
D_n'' = a²H² [ D_n'' / (a²H²) ]
= a²H² [ f_n(f_n - 1) (d ln H/d ln a + 1) D_n
+ f_n' D_n + f_n² D_n ] ... simplified ...
= a²H² D_n [ f_n² + f_n(d ln H/d ln a) + f_n' ]
In practice we absorb a²H² into the kick coefficient (which integrates 1/(aH)
over conformal time) and define the dimensionless scalar:
nD_n = D_n [ f_n(f_n - 1) - (3/2) Ω_m(a)/E²(a) · (1 - δ_{n,1}) ... ]
A cleaner derivation uses the explicit LPT equations of motion:
D_1'' + H(τ) D_1' = (3/2) H₀² Ω_m a⁻¹ D_1 (1LPT sourced by linear δ)
D_2'' + H(τ) D_2' = (3/2) H₀² Ω_m a⁻¹ D_2
- (3/2) H₀² Ω_m a⁻¹ D_1² (2LPT sourced by D_1²)
Rearranging gives the analytic acceleration *in the same units as the PM force*
(i.e. after the kick coefficient integral has been applied):
F_LPT_analytic = nD1 · Ψ₁_unit + nD2 · Ψ₂_unit
where the dimensionless nD scalars are precomputed once per kick step
(no per-particle work beyond a scalar multiply):
nD1 = D_1(a) · [f_1² + f_1 · (d ln H / d ln a)
- (3/2) Ω_m(a) / E²(a)]
nD2 = D_2(a) · [f_2² + f_2 · (d ln H / d ln a)
- (3/2) Ω_m(a) / E²(a)]
+ D_1²(a) · (3/2) Ω_m(a) / E²(a)
These are stored in COLACoef.nD1_kick and COLACoef.nD2_kick (one value per
leapfrog step), so the inner loop reduces to:
v_res += kick_coeff * (F_PM - nD1 * Ψ₁_unit - nD2 * Ψ₂_unit) * 1.5 * Ω_m * L
which is O(N) per particle with zero FFTs for the LPT correction.
"""
from collections import namedtuple
# from enum import Enum
from functools import partial
import aquila_borg as borg
import jax
import jax.numpy as jnp
import jax_cosmo as jc
# ── Re-exported from the standard leapfrog module so callers need only import
# this file.
from ._simulation_leapfrog import ( # noqa: F401
LeapfrogCoef,
LeapFrogStep,
_make_DKD,
_precompute_leapfrog_coefs,
supersample_ic_field,
)
from ._src.constants import FLOAT_DTYPE, LOG_STEP_DEFAULT
from ._src.pm import lpt, pm_forces
# ---------------------------------------------------------------------------
# Named tuple carrying the extra COLA coefficients alongside the standard ones
# ---------------------------------------------------------------------------
COLACoef = namedtuple(
"COLACoef",
[
# Standard leapfrog coefficients (same meaning as LeapfrogCoef)
"drift_coeff",
"kick_coeff",
# Growth-factor values at the *start* and *end* of each sub-step,
# needed to rescale the unit LPT displacement fields at drift time.
"D1_start", # 1LPT growth factor at step start
"D1_end", # 1LPT growth factor at step end
"D2_start", # 2LPT growth factor at step start
"D2_end", # 2LPT growth factor at step end
# Dimensionless analytic LPT force coefficients evaluated at each kick
# epoch (one scalar per leapfrog sub-step, zero on pure DRIFT steps).
#
# The analytic LPT acceleration in code units is:
# F_LPT = nD1 * psi1_unit + nD2 * psi2_unit
#
# where psi_{1,2}_unit are the growth-normalised displacement fields
# (shape: [N, 3]), so the correction is a pure O(N) scalar multiply.
#
# Derivation (see module docstring for full details):
# nD1 = D1 * [f1^2 + f1*dlnH_dlna - 3/2*Om/E^2]
# nD2 = D2 * [f2^2 + f2*dlnH_dlna - 3/2*Om/E^2]
# + D1^2 * (3/2*Om/E^2)
# Zero on DRIFT-only steps.
"nD1_kick",
"nD2_kick",
],
)
# ---------------------------------------------------------------------------
# Growth-factor helpers (jax-cosmo already exposes Omega_m_a, but we need D1)
# ---------------------------------------------------------------------------
def _growth_factor_ODE(cosmo, a):
"""
Linear growth factor D1(a) normalised to D1(a=1) = 1, computed via the
standard integral formula valid for any flat ΛCDM cosmology:
D1(a) ∝ H(a) ∫_0^a da' / (a' H(a'))^3
"""
from quadax import quadgk
def _integrand(ap):
Esqr = jc.background.Esqr(cosmo, ap)
return jnp.array(1.0 / (ap * jnp.sqrt(Esqr)) ** 3, dtype=FLOAT_DTYPE)
norm, _ = quadgk(_integrand, [1e-4, 1.0])
integral, _ = quadgk(_integrand, [1e-4, a])
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
E1 = jnp.sqrt(jc.background.Esqr(cosmo, 1.0))
return (E / E1) * integral / norm
def _growth_factor_2LPT(cosmo, a):
"""
Second-order growth factor D2(a). In the Einstein–de Sitter approximation
D2 = -3/7 * D1^2. For ΛCDM this is a good approximation at the ~1% level,
so we adopt it here for efficiency.
"""
D1 = _growth_factor_ODE(cosmo, a)
return -3.0 / 7.0 * D1**2
def _log_growth_rate(cosmo, a, D_fn, eps=1e-4):
"""
Logarithmic growth rate f = d ln D / d ln a for growth factor function
*D_fn(cosmo, a)*, computed by second-order finite differences in ln a.
"""
Dp = D_fn(cosmo, a * jnp.exp(eps))
Dm = D_fn(cosmo, a * jnp.exp(-eps))
return (jnp.log(jnp.abs(Dp)) - jnp.log(jnp.abs(Dm))) / (2 * eps)
def _dlnH_dlna(cosmo, a, eps=1e-4):
"""
d ln H / d ln a, computed by second-order finite differences in ln a.
Equals -1 - q where q is the deceleration parameter, but we prefer the
numerical form to avoid hard-coding a cosmological model.
"""
lnHp = 0.5 * jnp.log(jc.background.Esqr(cosmo, a * jnp.exp(eps)))
lnHm = 0.5 * jnp.log(jc.background.Esqr(cosmo, a * jnp.exp(-eps)))
return (lnHp - lnHm) / (2 * eps)
def _nD_coefs(cosmo, a, D1, D2):
"""
Dimensionless analytic LPT force coefficients nD1 and nD2 at scale factor *a*.
In the COLA kick step the total velocity update is
Δv_res = kick_coeff * (F_PM - F_LPT_analytic) * 1.5 * Ω_m * L
where F_LPT_analytic is the *exact* conformal acceleration of the LPT
reference trajectory, expressed in the same code units as F_PM:
F_LPT_analytic = nD1 * Ψ₁_unit + nD2 * Ψ₂_unit
with Ψ_{1,2}_unit the growth-normalised (unit-D) displacement fields
stored once for the whole run. The cost inside the time loop is therefore
two O(N) scalar multiplies — no FFT, no mesh scatter/gather.
Derivation
----------
The conformal-time EOM for the nLPT growth factor D_n is
D_n'' + H_c D_n' = S_n (H_c = aH = conformal Hubble)
Sources (in units where 3/2 H₀² Ω_m = 1 after normalising):
S_1 = (3/2 Ω_m / E²) a⁻² D_1 [restoring a² from conformal→code]
S_2 = (3/2 Ω_m / E²) a⁻² (D_2 - D_1²)
Using D_n' = H_c f_n D_n, H_c = aH:
D_n'' = S_n - H_c D_n'
= (3/2 Ω_m/E²) a⁻² D_n^{source} - a²H² f_n D_n
The kick integrand already absorbs 1/(a²H²), so the effective coefficient
multiplying kick_coeff (in the same normalisation as pm_forces output) is
nD_n ≡ D_n'' / (3/2 Ω_m * L)
Working through the algebra gives the explicit formulas:
nD1 = D1 * [ f1² + f1 * (d ln H / d ln a) - 3/2 * Ωm(a)/E²(a) ]
nD2 = D2 * [ f2² + f2 * (d ln H / d ln a) - 3/2 * Ωm(a)/E²(a) ]
+ D1² * ( 3/2 * Ωm(a)/E²(a) )
The D1²·(3/2 Ωm/E²) term arises because S_2 contains −D_1², which acts as
an additional source that partially cancels the 2LPT self-gravity term.
"""
f1 = _log_growth_rate(cosmo, a, _growth_factor_ODE)
f2 = _log_growth_rate(cosmo, a, _growth_factor_2LPT)
dlnH = _dlnH_dlna(cosmo, a)
# Ω_m(a) / E²(a) = Ω_m,0 · a⁻³ / E²(a)
Esqr = jc.background.Esqr(cosmo, a)
Om_over_Esqr = cosmo.Omega_m * a**-3 / Esqr
nD1 = D1 * (f1**2 + f1 * dlnH - 1.5 * Om_over_Esqr)
nD2 = D2 * (f2**2 + f2 * dlnH - 1.5 * Om_over_Esqr) + D1**2 * 1.5 * Om_over_Esqr
return jnp.array(nD1, dtype=FLOAT_DTYPE), jnp.array(nD2, dtype=FLOAT_DTYPE)
# ---------------------------------------------------------------------------
# Pre-compute COLA coefficients
# ---------------------------------------------------------------------------
def _precompute_cola_coefs(
cosmo,
a_start,
a_final,
nsteps,
LF_programme,
log_step=LOG_STEP_DEFAULT,
):
"""
Extend the standard leapfrog coefficient computation with COLA-specific
growth-factor values and analytic LPT force scalars (nD1, nD2) per step.
The nD coefficients are evaluated at the mid-point of each KICK sub-step
(where the velocity is updated) and set to zero for pure DRIFT sub-steps.
They are O(1) scalars so this entire function runs at Python/NumPy speed
during the preparation phase, with nothing left to compute at run time.
"""
lf_coef = _precompute_leapfrog_coefs(
cosmo, a_start, a_final, nsteps, LF_programme, log_step=log_step
)
if log_step:
dx = jnp.log(a_final / a_start) / nsteps
get_a = jnp.exp
get_x = jnp.log
else:
dx = (a_final - a_start) / nsteps
get_a = lambda x: x # noqa: E731
get_x = lambda x: x # noqa: E731
D1_start_list, D1_end_list = [], []
D2_start_list, D2_end_list = [], []
nD1_list, nD2_list = [], []
prev_x = {
LeapFrogStep.KICK: get_x(a_start),
LeapFrogStep.DRIFT: get_x(a_start),
}
zero = jnp.array(0.0, dtype=FLOAT_DTYPE)
for op, dt in LF_programme:
x_start = prev_x[op]
x_end = x_start + dt * dx
a_s = get_a(x_start)
a_e = get_a(x_end)
D1s = _growth_factor_ODE(cosmo, a_s)
D1e = _growth_factor_ODE(cosmo, a_e)
D2s = _growth_factor_2LPT(cosmo, a_s)
D2e = _growth_factor_2LPT(cosmo, a_e)
D1_start_list.append(D1s)
D1_end_list.append(D1e)
D2_start_list.append(D2s)
D2_end_list.append(D2e)
if op == LeapFrogStep.KICK:
# Evaluate analytic force coefficients at the kick mid-point epoch
a_mid = get_a(0.5 * (x_start + x_end))
D1_mid = _growth_factor_ODE(cosmo, a_mid)
D2_mid = _growth_factor_2LPT(cosmo, a_mid)
nD1, nD2 = _nD_coefs(cosmo, a_mid, D1_mid, D2_mid)
borg.print_msg(
borg.Level.std,
"COLA KICK a_mid={a} nD1={n1} nD2={n2}",
a=a_mid,
n1=nD1,
n2=nD2,
)
else:
nD1, nD2 = zero, zero
nD1_list.append(nD1)
nD2_list.append(nD2)
prev_x[op] = x_end
return COLACoef(
drift_coeff=lf_coef.drift_coeff,
kick_coeff=lf_coef.kick_coeff,
D1_start=jnp.array(D1_start_list, dtype=FLOAT_DTYPE),
D1_end=jnp.array(D1_end_list, dtype=FLOAT_DTYPE),
D2_start=jnp.array(D2_start_list, dtype=FLOAT_DTYPE),
D2_end=jnp.array(D2_end_list, dtype=FLOAT_DTYPE),
nD1_kick=jnp.array(nD1_list, dtype=FLOAT_DTYPE),
nD2_kick=jnp.array(nD2_list, dtype=FLOAT_DTYPE),
)
# ---------------------------------------------------------------------------
# Core COLA simulation (JIT-compiled)
# ---------------------------------------------------------------------------
@partial(
jax.jit,
static_argnums=(
3, # want_vel
4, # mesh_shape
5, # mesh_shape_force
7, # supersampling
9, # lpt_order
10, # sharding (treated as static because it carries device structure)
11, # halo_size
),
)
def _run_cola_simulation_internal(
ic_field,
omega_c,
omega_b,
want_vel,
mesh_shape,
mesh_shape_force,
cola_coefficients,
supersampling,
a_start,
lpt_order=2,
sharding=None,
halo_size=0,
):
"""
Run a COLA particle-mesh simulation.
The state vector is (x_res, v_res) where:
x_full = x_lpt(a) + x_res
v_full = v_lpt(a) + v_res
At initialisation both residuals are zero by construction (we start
exactly on the LPT trajectory).
Parameters
----------
ic_field : array
Initial density contrast field on the Lagrangian mesh.
omega_c, omega_b : float
Cosmological density parameters.
want_vel : bool
Whether to return velocities (currently unused; reserved for future).
mesh_shape : tuple of int
Shape of the Lagrangian particle grid.
mesh_shape_force : tuple of int
Shape of the PM force mesh.
cola_coefficients : COLACoef
Pre-computed COLA drift/kick integrals and growth-factor arrays.
supersampling : int
Supersampling factor for the initial conditions.
a_start : float
Initial scale factor.
lpt_order : int
Order of LPT used as the COLA reference trajectory (1 or 2).
sharding : optional
JAX device sharding descriptor.
halo_size : int
Halo cells for the PM force mesh (distributed runs).
Returns
-------
pos_full : array, shape (N, 3)
Final comoving particle positions.
"""
cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b)
# ── Optionally supersample the IC field ──────────────────────────────────
if supersampling > 1:
ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
mesh_shape_particles = ic_field.shape
else:
mesh_shape_particles = mesh_shape
# ── Lagrangian (unperturbed) particle positions ──────────────────────────
q = jnp.stack(
jax.device_put(
jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]),
device=sharding,
),
axis=-1,
).reshape([-1, 3])
# ── LPT initial displacement and velocity at a_start ────────────────────
dx0, p0 = lpt(cosmo, ic_field, q, a=a_start, order=lpt_order, sharding=sharding)
# Initial COLA state: residuals are zero (we sit exactly on LPT trajectory)
x_res = jnp.zeros_like(dx0) # residual displacement
v_res = jnp.zeros_like(p0) # residual velocity
# Normalisation factor for the PM force (same as standard leapfrog)
multiplier = mesh_shape_particles[0] / mesh_shape_force[0]
# Growth factor at a_start for normalising LPT displacements
D1_0 = _growth_factor_ODE(cosmo, a_start)
D2_0 = _growth_factor_2LPT(cosmo, a_start)
# Normalised (unit-growth) LPT displacement fields so we can rescale cheaply
# dx_lpt(a) = (D1(a)/D1_0)*dx1 + (D2(a)/D2_0)*dx2
# We recover dx1 and dx2 from the combined LPT call by running at two epochs.
if lpt_order >= 2:
dx1, _ = lpt(cosmo, ic_field, q, a=a_start, order=1, sharding=sharding)
dx2 = dx0 - dx1 # 2LPT contribution at a_start
dx1_unit = dx1 / D1_0
dx2_unit = dx2 / D2_0
else:
dx1_unit = dx0 / D1_0
dx2_unit = jnp.zeros_like(dx0)
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def _cola_evolver(state, istep):
x_res_i, v_res_i = state
kick_c = cola_coefficients.kick_coeff[istep]
drift_c = cola_coefficients.drift_coeff[istep]
D1s = cola_coefficients.D1_start[istep]
# D1e = cola_coefficients.D1_end[istep]
D2s = cola_coefficients.D2_start[istep]
# D2e = cola_coefficients.D2_end[istep]
# ── Current full position (needed for PM force) ──────────────────────
x_lpt_current = D1s * dx1_unit + D2s * dx2_unit
x_full = q + x_lpt_current + x_res_i
# ── KICK step (velocity update) ──────────────────────────────────────
def do_kick(_):
f_pm = pm_forces(
x_full,
mesh_shape=mesh_shape_force,
pos_mesh_shape=mesh_shape_particles,
fd=True,
sharding=sharding,
halo_size=halo_size,
)
# Analytic LPT force correction — O(N) scalar multiplies, no FFT.
# nD1_kick and nD2_kick encode the exact conformal second derivative
# of the LPT reference trajectory (see _nD_coefs for derivation).
# Units are identical to pm_forces output so the subtraction is direct.
f_lpt_analytic = (
cola_coefficients.nD1_kick[istep] * dx1_unit
+ cola_coefficients.nD2_kick[istep] * dx2_unit
)
# Residual force = full PM force – analytic LPT acceleration
f_res = f_pm - f_lpt_analytic
return v_res_i + kick_c * f_res * 1.5 * cosmo.Omega_m * multiplier
v_res_new = jax.lax.cond(
kick_c > 0,
do_kick,
lambda _: v_res_i,
operand=None,
)
# ── DRIFT step (position update) ─────────────────────────────────────
# Full position change = LPT displacement increment + residual velocity
# x_res advances by drift_c * v_res (the LPT part is tracked implicitly
# via D1_end / D2_end, so x_res stays purely residual).
x_res_new = x_res_i + drift_c * v_res_new
return (x_res_new, v_res_new), None
(x_res_final, v_res_final), _ = jax.lax.scan(
_cola_evolver,
(x_res, v_res),
jnp.arange(cola_coefficients.drift_coeff.shape[0]),
)
# ── Reconstruct full final position ─────────────────────────────────────
D1_f = cola_coefficients.D1_end[-1]
D2_f = cola_coefficients.D2_end[-1]
x_lpt_final = D1_f * dx1_unit + D2_f * dx2_unit
pos_full = q + x_lpt_final + x_res_final
return pos_full
# ---------------------------------------------------------------------------
# Public API — mirrors prepare_leap_frog_simulation
# ---------------------------------------------------------------------------
[docs]
def prepare_cola_simulation(
omega_c,
omega_b,
num_steps,
a_start,
a_final,
supersampling=1,
lpt_order=2,
log_step=LOG_STEP_DEFAULT,
):
"""
Pre-compute COLA coefficients and return a ready-to-call simulation function.
The returned callable has the same signature as the one returned by
``prepare_leap_frog_simulation``, so both can be used interchangeably.
Parameters
----------
omega_c : float
Cold dark matter density parameter Ω_c h².
omega_b : float
Baryon density parameter Ω_b h².
num_steps : int
Number of PM time steps. COLA typically needs ~10–30× fewer steps than
a standard PM simulation for the same large-scale accuracy.
a_start : float
Initial scale factor (should match the epoch of ``ic_field``).
a_final : float
Final scale factor.
supersampling : int, optional
Supersampling factor applied to ``ic_field`` before placing particles.
Default 1 (no supersampling).
lpt_order : int, optional
Order of the LPT reference trajectory (1 or 2). Second order (default)
gives better accuracy for the same number of steps.
log_step : bool, optional
If True (default) steps are uniform in ln a; otherwise uniform in a.
Returns
-------
run_fn : callable
``run_fn(ic_field, mesh_shape, mesh_shape_force, **kwargs) -> positions``
Examples
--------
>>> run = prepare_cola_simulation(0.25, 0.05, num_steps=10,
... a_start=0.1, a_final=1.0)
>>> positions = run(ic_field, mesh_shape=(128, 128, 128),
... mesh_shape_force=(128, 128, 128))
"""
cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b)
cola_coefs = _precompute_cola_coefs(
cosmo,
a_start,
a_final,
num_steps,
_make_DKD(num_steps),
log_step=log_step,
)
borg.print_msg(
borg.Level.std,
"COLA: pre-computed {n} sub-steps from a={a0} to a={a1} (lpt_order={o})",
n=len(cola_coefs.drift_coeff),
a0=a_start,
a1=a_final,
o=lpt_order,
)
return partial(
_run_cola_simulation_internal,
omega_c=omega_c,
omega_b=omega_b,
cola_coefficients=cola_coefs,
a_start=a_start,
supersampling=supersampling,
lpt_order=lpt_order,
)