# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Svyatoslav Trusov
# SPDX-FileContributor: Guilhem Lavaux
#
# SPDX-License-Identifier: CECILL-B
from collections import namedtuple
from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jaxpm.distributed import fft3d
from ._src.constants import FLOAT_DTYPE
from ._src.pm import pm_forces, pm_forces_in
BullfrogCoef = namedtuple(
"LeapfrogCoef", ["drift_coeff", "kick_coeff", "kick_coeff_alpha", "a_s", "rescale"]
)
def _get_growth_ODE_tabs(cosmo, log10_amin=-4, steps=4096, eps=1e-4, quantity="D"):
from jax_cosmo.background import Omega_de_a, Omega_m_a, w
from jax_cosmo.scipy.ode import odeint
# a for integration
atab = np.logspace(log10_amin, 0.0, steps)
# ODE system to solve
def D_derivs(y, x):
q = (
2.0
- 0.5
* (Omega_m_a(cosmo, x) + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))
) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return jnp.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# Preparing the tables
y1 = y[:, 0, 0]
# gtab = y1 = D
gtab = y1 # / y1[-1]
y2 = y[:, 0, 1]
# g2tab = y2 = E (2LPT coefficient, psi = D Psi_0 + E Psi_2)
g2tab = y2 # * (-3 / 7) / y2[-1]
# ftab = d D / da = D'
ftab = y[:, 1, 0] # / y1[-1]
# f2tab = d E / da = E'
f2tab = y[:, 1, 1] # * (-3 / 7) / y2[-1]
# [ a, D, E, D', E' ]
return [atab, gtab, g2tab, ftab, f2tab]
# make an enum class to represent KICK and DRIFT
class BullfrogStep(Enum):
KICK = 0
DRIFT = 1
def _make_DKD(nstep):
DKD = []
# Step 1 done , now we have x and v at 1
DKD.append((BullfrogStep.DRIFT, 0.5))
# Process
for _ in range(nstep - 1):
# x at n+1/2
DKD.append((BullfrogStep.DRIFT, 0.5))
# v at n+1
DKD.append((BullfrogStep.KICK, 1.0))
# x at n+1
DKD.append((BullfrogStep.DRIFT, 0.5))
# DKD.append((BullfrogStep.KICK, 1.0))
# DKD.append((BullfrogStep.DRIFT, 0.5))
# Kicks happens at midpoint
# if nstep > 1:
# DKD.append((BullfrogStep.KICK, 0.5))
# DKD.append((BullfrogStep.DRIFT, 1.0))
# for _ in range(nstep - 2):
# DKD.append((BullfrogStep.KICK, 1.0))
# DKD.append((BullfrogStep.DRIFT, 1.0))
# DKD.append((BullfrogStep.KICK, 1.0))
# DKD.append((BullfrogStep.DRIFT, 0.5))
return DKD
def _EFunc(tabs, D):
# atab,gtab,g2tab,ftab,f2tab
# interpolate E at D
return jnp.interp(D, tabs[1], tabs[2])
def _EPrimeFunc(tabs, D):
# interpolate E'/D' at D
return jnp.interp(D, tabs[1], tabs[4] / tabs[3])
# Alpha beta coefficients for Bullfrog
def _get_alpha_beta(tabs, prev_D, D, DeltaD):
D_nh = prev_D + DeltaD / 2
F_nh = 1 / (D_nh) * (_EFunc(tabs, D) + _EPrimeFunc(tabs, D) * DeltaD / 2) - (D_nh)
alpha = (_EPrimeFunc(tabs, D) - F_nh) / (_EPrimeFunc(tabs, prev_D) - F_nh)
# F_nh = 1 / (D) * (
# _EFunc(tabs, D - DeltaD / 2) + _EPrimeFunc(tabs, D - DeltaD / 2) * DeltaD / 2
# ) - (D)
# alpha = (_EPrimeFunc(tabs, D + DeltaD / 2) - F_nh) / (
# _EPrimeFunc(tabs, D - DeltaD / 2) - F_nh
# )
beta = 1 - alpha
return alpha, beta
# def Omega_m(a, Omega_m0=0.3, Omega_lambda=0.7):
# numerator = Omega_m0 / a**3
# denominator = numerator + Omega_lambda
# return numerator / denominator
def _precompute_bullfrog_coefs(
cosmo, a_start, a_final, nsteps, LF_programme, mult_fix=1
):
# Solving ODEs for growth factor and related quantities.
tabs = _get_growth_ODE_tabs(cosmo)
# Starting growth factor override
D_start = 0
D_final = jnp.interp(a_final, tabs[0], tabs[1])
# print(f"{D_start=}, {D_final=}")
dD = (D_final - D_start) / (nsteps)
# Initializing output arrays
drift_coeff = []
kick_coeff_alpha = []
kick_coeff = []
a_s = []
prev_D_time = {BullfrogStep.KICK: D_start, BullfrogStep.DRIFT: D_start}
drift_coeff.append(0.5 * dD)
kick_coeff.append(0.0)
kick_coeff_alpha.append(0.0)
a_s.append(0)
prev_D_time[BullfrogStep.DRIFT] = D_start + 0.5 * dD
alpha1 = 1 + _EPrimeFunc(tabs, D_start + dD) / (0.5 * dD)
drift_coeff.append(0.0)
kick_coeff_alpha.append(alpha1)
beta1 = 1 - alpha1
kick_coeff.append(beta1 / (0.5 * dD))
a_s.append(1)
prev_D_time[BullfrogStep.KICK] = D_start + dD
for op, dt in LF_programme:
prev_D = prev_D_time[op]
D = prev_D + dt * dD
if op == BullfrogStep.KICK:
# Computing alpha and beta
alpha, beta = _get_alpha_beta(tabs, prev_D, D, dD)
drift_coeff.append(0.0)
# Beta should be divided for alpha for correct scaling
kick_coeff.append(beta / (prev_D + dD * 0.5) * mult_fix)
kick_coeff_alpha.append(alpha)
a_s.append(1)
del alpha
del beta
elif op == BullfrogStep.DRIFT:
kick_coeff.append(0.0)
kick_coeff_alpha.append(0.0)
# Drift is just a drift
drift_coeff.append(dt * dD)
a_s.append(0)
# advancing the time counter for a given step type
prev_D_time[op] = D
# print(f"{op=}, {D=}, {dt=}, {dD=}")
# print(f"{prev_D_time=}")
return BullfrogCoef(
rescale=1 / tabs[1][-1],
kick_coeff=jnp.array(kick_coeff),
drift_coeff=jnp.array(drift_coeff),
kick_coeff_alpha=jnp.array(kick_coeff_alpha),
a_s=jnp.array(a_s),
)
# Separate function for supersampling
def supersample_ic_field(ic_field, mesh_shape, supersampling):
ic_field_hat = jnp.fft.rfftn(ic_field)
new_shape = (
supersampling * mesh_shape[0],
supersampling * mesh_shape[1],
supersampling * mesh_shape[2] // 2 + 1,
)
ic_field_hat_new = jnp.zeros(new_shape, dtype=ic_field_hat.dtype)
half0 = mesh_shape[0] // 2 + 1
half1 = mesh_shape[1] // 2 + 1
half2 = mesh_shape[2] // 2 + 1
offset0 = (0, 0, 0)
src0 = (slice(0, half0), slice(0, half1), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src0], offset0
)
offset1 = (
supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
0,
0,
)
src1 = (slice(mesh_shape[0] // 2, mesh_shape[0]), slice(0, half1), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src1], offset1
)
offset2 = (
0,
supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
0,
)
src2 = (slice(0, half0), slice(mesh_shape[1] // 2, mesh_shape[1]), slice(0, half2))
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src2], offset2
)
offset3 = (
supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
0,
)
src3 = (
slice(mesh_shape[0] // 2, mesh_shape[0]),
slice(mesh_shape[1] // 2, mesh_shape[1]),
slice(0, half2),
)
ic_field_hat_new = jax.lax.dynamic_update_slice(
ic_field_hat_new, ic_field_hat[src3], offset3
)
ic_field_resampled = jnp.fft.irfftn(ic_field_hat_new) * (supersampling**3)
return ic_field_resampled
@partial(
jax.jit,
static_argnums=(3, 4, 5, 7, 8, 9),
)
def _run_bullfrog_simulation_internal(
ic_field,
omega_c,
omega_b,
want_vel,
mesh_shape,
mesh_shape_force,
lf_coefficients,
supersampling,
sharding=None,
fd=False,
):
"""
ic_field: initial condition mesh grid at epoch 'a_start', under the forme of the density contrast
omega_c: cold dark matter density
omega_b: baryon density
mesh_shape: shape of the position mesh in lagrangian coordinates
mesh_shape_force: shape of the force mesh
snaphots: requested snapshots
a_start: epoch to start the PM simulation from
lf_coefficients: bullfrog coefficients for fast integration.
"""
if supersampling > 1:
ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
mesh_shape_particles = ic_field.shape
# ic_field = jnp.fft.irfftn(ic_field_hat_new) * supersampling**3
else:
mesh_shape_particles = mesh_shape
particles = jnp.stack(
jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]), axis=-1
).reshape([-1, 3])
pos = jnp.astype(particles, FLOAT_DTYPE)
# Just taking the forces from IC.
vel = (
pm_forces_in(particles, delta=fft3d(ic_field), fd=fd, sharding=sharding)
* lf_coefficients.rescale
)
multiplier = mesh_shape_particles[0] / mesh_shape_force[0]
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def _evolver(state, istep):
step_pos, step_vel = state
step_vel = jax.lax.cond(
lf_coefficients.kick_coeff[istep] > 0,
jax.remat(
lambda: lf_coefficients.kick_coeff_alpha[istep] * step_vel
+ lf_coefficients.kick_coeff[istep]
* pm_forces(
step_pos,
mesh_shape=mesh_shape_force,
pos_mesh_shape=mesh_shape_particles,
fd=fd, # True or False ?
sharding=sharding,
)
* multiplier
),
jax.remat(lambda: step_vel),
)
step_pos = step_pos + lf_coefficients.drift_coeff[istep] * step_vel
return (step_pos, step_vel), None
(pos, vel), _ = jax.lax.scan(
_evolver, (pos, vel), jnp.arange(lf_coefficients.drift_coeff.shape[0])
)
# for istep in range(lf_coefficients.drift_coeff.shape[0]):
# # Slight differences due to D-time integration
# # kick
# vel = jax.lax.cond(
# lf_coefficients.a_s[istep] > 0.0,
# jax.remat(
# lambda: lf_coefficients.kick_coeff_alpha[istep] * vel
# + lf_coefficients.kick_coeff[istep]
# * pm_forces(
# pos,
# mesh_shape=mesh_shape_force,
# pos_mesh_shape=mesh_shape_particles,
# fd=True,
# sharding=sharding,
# )
# * multiplier
# ),
# lambda: vel,
# )
# # drift
# pos += lf_coefficients.drift_coeff[istep] * vel
return pos
# Lbox added for future adjustments and tests.
[docs]
def prepare_bullfrog_simulation(
omega_c,
omega_b,
num_steps,
a_start,
a_final,
supersampling=1,
mult_fix=1,
fd=False,
):
lf_coefs = _precompute_bullfrog_coefs(
jc.Planck15(Omega_c=omega_c, Omega_b=omega_b),
a_start,
a_final,
num_steps,
_make_DKD(num_steps),
mult_fix=mult_fix,
)
return partial(
_run_bullfrog_simulation_internal,
lf_coefficients=lf_coefs,
omega_c=omega_c,
omega_b=omega_b,
supersampling=supersampling,
fd=fd,
)