# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B
from collections import namedtuple
from enum import Enum
from functools import partial
import aquila_borg as borg
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from ._src.constants import FLOAT_DTYPE, LOG_STEP_DEFAULT
from ._src.pm import lpt, pm_forces
LeapfrogCoef = namedtuple("LeapfrogCoef", ["drift_coeff", "kick_coeff"])
# make an enum class to represent KICK and DRIFT
class LeapFrogStep(Enum):
KICK = 0
DRIFT = 1
def _make_DKD(nstep):
DKD = []
DKD.append((LeapFrogStep.DRIFT, 0.5))
for _ in range(nstep - 1):
DKD.append((LeapFrogStep.KICK, 1.0))
DKD.append((LeapFrogStep.DRIFT, 1.0))
DKD.append((LeapFrogStep.KICK, 1.0))
DKD.append((LeapFrogStep.DRIFT, 0.5))
return DKD
def _precompute_leapfrog_coefs(
cosmo, a_start, a_final, nsteps, LF_programme, log_step=LOG_STEP_DEFAULT
):
# from scipy.integrate import quad
from quadax import quadgk
if log_step:
dx = jnp.log(a_final / a_start) / nsteps
get_a = jnp.exp
def get_ln_a(x):
return x
get_x = jnp.log
else:
dx = (a_final - a_start) / nsteps
def get_a(x):
return x
get_ln_a = jnp.log
def get_x(x):
return x
def _integrand_pos(ln_a):
a = jnp.exp(ln_a)
# We do a change of variable to get the integrand in terms of a
# We use the fact that dln(a)/da = 1/a, I(a) da = I(ln(a)) a dln(a)
return jnp.array(
1.0 / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))), dtype=FLOAT_DTYPE
)
def _integrand_vel(ln_a):
a = jnp.exp(ln_a)
# We do a change of variable to get the integrand in terms of a
# We use the fact that dln(a)/da = 1/a, I(a) da = I(ln(a)) a dln(a)
return jnp.array(
1.0 / (a * jnp.sqrt(jc.background.Esqr(cosmo, a))), dtype=FLOAT_DTYPE
)
drift_coeff = []
kick_coeff = []
prev_a_time = {
LeapFrogStep.KICK: get_x(a_start),
LeapFrogStep.DRIFT: get_x(a_start),
}
for op, dt in LF_programme:
prev_x = prev_a_time[op]
x = prev_x + dt * dx
ln_a = get_ln_a(x)
ln_prev_a = get_ln_a(prev_x)
a = get_a(x)
borg.print_msg(borg.Level.std, f"op={op}, a={a}, dt={dt}")
if op == LeapFrogStep.KICK:
if dx == 0:
last = 0
else:
last = quadgk(_integrand_vel, [ln_prev_a, ln_a])[0]
borg.print_msg(
borg.Level.std,
"KICK ln_prev_a={ln_prev_a} ln_a={ln_a} last={last}",
last=last,
ln_prev_a=ln_prev_a,
ln_a=ln_a,
)
kick_coeff.append(last)
drift_coeff.append(0.0)
elif op == LeapFrogStep.DRIFT:
if dx == 0:
last = 0
else:
last = quadgk(_integrand_pos, [ln_prev_a, ln_a])[0]
borg.print_msg(borg.Level.std, "DRIFT last={last}", last=last)
kick_coeff.append(0.0)
drift_coeff.append(last)
prev_a_time[op] = x
last_x = x
assert jnp.abs(get_x(a_final) - last_x) < 1e-3
return LeapfrogCoef(
kick_coeff=jnp.array(kick_coeff),
drift_coeff=jnp.array(drift_coeff),
)
def supersample_ic_field(ic_field, mesh_shape, supersampling):
ic_field_hat = jnp.fft.rfftn(ic_field)
new_shape = (
supersampling * mesh_shape[0],
supersampling * mesh_shape[1],
supersampling * mesh_shape[2] // 2 + 1,
)
ic_field_hat_new = jnp.zeros(new_shape, dtype=ic_field_hat.dtype)
half0 = mesh_shape[0] // 2 + 1
half1 = mesh_shape[1] // 2 + 1
half2 = mesh_shape[2] // 2 + 1
offset0 = (0, 0, 0)
src0 = (slice(0, half0), slice(0, half1), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src0], offset0
)
offset1 = (
supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
0,
0,
)
src1 = (slice(mesh_shape[0] // 2, mesh_shape[0]), slice(0, half1), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src1], offset1
)
offset2 = (
0,
supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
0,
)
src2 = (slice(0, half0), slice(mesh_shape[1] // 2, mesh_shape[1]), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src2], offset2
)
offset3 = (
supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
0,
)
src3 = (
slice(mesh_shape[0] // 2, mesh_shape[0]),
slice(mesh_shape[1] // 2, mesh_shape[1]),
slice(0, half2),
)
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src3], offset3
)
ic_field_resampled = jnp.fft.irfftn(ic_field_hat_new) * (supersampling**3)
return ic_field_resampled
@partial(
jax.jit,
static_argnums=(
3,
4,
5,
7,
9,
),
)
def _run_leap_frog_simulation_internal(
ic_field,
omega_c,
omega_b,
want_vel,
mesh_shape,
mesh_shape_force,
lf_coefficients,
supersampling,
a_start,
sharding=None,
halo_size=0,
):
"""
ic_field: initial condition mesh grid at epoch 'a_start', under the forme of the density contrast
omega_c: cold dark matter density
omega_b: baryon density
mesh_shape: shape of the position mesh in lagrangian coordinates
mesh_shape_force: shape of the force mesh
snaphots: requested snapshots
a_start: epoch to start the PM simulation from
"""
cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b)
if supersampling > 1:
ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
mesh_shape_particles = ic_field.shape
else:
mesh_shape_particles = mesh_shape
particles = jnp.stack(
jax.device_put(
jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]),
device=sharding,
),
axis=-1,
).reshape([-1, 3])
dx, p = lpt(cosmo, ic_field, particles, a=a_start, order=1, sharding=sharding)
pos = particles + dx
vel = p
multiplier = mesh_shape_particles[0] / mesh_shape_force[0]
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def _evolver(state, istep):
step_pos, step_vel = state
step_vel = jax.lax.cond(
lf_coefficients.kick_coeff[istep] > 0,
jax.remat(
lambda: step_vel
+ lf_coefficients.kick_coeff[istep]
* pm_forces(
step_pos,
mesh_shape=mesh_shape_force,
pos_mesh_shape=mesh_shape_particles,
fd=True,
sharding=sharding,
halo_size=halo_size,
)
* 1.5
* cosmo.Omega_m
* multiplier
),
jax.remat(lambda: step_vel),
)
step_pos = step_pos + lf_coefficients.drift_coeff[istep] * step_vel
return (step_pos, step_vel), None
(pos, vel), _ = jax.lax.scan(
_evolver, (pos, vel), jnp.arange(lf_coefficients.drift_coeff.shape[0])
)
# for istep in range(lf_coefficients.drift_coeff.shape[0]):
# vel = jax.lax.cond(
# lf_coefficients.kick_coeff[istep] > 0,
# jax.remat(
# lambda: vel
# + lf_coefficients.kick_coeff[istep]
# * pm_forces(
# pos,
# mesh_shape=mesh_shape_force,
# pos_mesh_shape=mesh_shape_particles,
# fd=True,
# sharding=sharding,
# halo_size=halo_size,
# )
# * 1.5
# * cosmo.Omega_m
# * multiplier
# ),
# jax.remat(lambda: vel),
# )
# pos = pos + lf_coefficients.drift_coeff[istep] * vel
return pos
[docs]
def prepare_leap_frog_simulation(
omega_c,
omega_b,
num_steps,
a_start,
a_final,
supersampling=1,
):
lf_coefs = _precompute_leapfrog_coefs(
jc.Planck15(Omega_c=omega_c, Omega_b=omega_b),
a_start,
a_final,
num_steps,
_make_DKD(num_steps),
)
return partial(
_run_leap_frog_simulation_internal,
lf_coefficients=lf_coefs,
omega_c=omega_c,
omega_b=omega_b,
a_start=a_start,
supersampling=supersampling,
)