Source code for borgjaxpm._simulation_bullfrog

# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Svyatoslav Trusov
# SPDX-FileContributor: Guilhem Lavaux
#
# SPDX-License-Identifier: CECILL-B

from collections import namedtuple
from enum import Enum
from functools import partial

import jax
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jaxpm.distributed import fft3d

from ._src.constants import FLOAT_DTYPE
from ._src.pm import pm_forces, pm_forces_in

BullfrogCoef = namedtuple(
    "LeapfrogCoef", ["drift_coeff", "kick_coeff", "kick_coeff_alpha", "a_s", "rescale"]
)


def _get_growth_ODE_tabs(cosmo, log10_amin=-4, steps=4096, eps=1e-4, quantity="D"):
    from jax_cosmo.background import Omega_de_a, Omega_m_a, w
    from jax_cosmo.scipy.ode import odeint

    # a for integration
    atab = np.logspace(log10_amin, 0.0, steps)

    # ODE system to solve
    def D_derivs(y, x):
        q = (
            2.0
            - 0.5
            * (Omega_m_a(cosmo, x) + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))
        ) / x
        r = 1.5 * Omega_m_a(cosmo, x) / x / x
        g1, g2 = y[0]
        f1, f2 = y[1]
        dy1da = [f1, -q * f1 + r * g1]
        dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
        return jnp.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])

    y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
    y = odeint(D_derivs, y0, atab)

    # Preparing the tables
    y1 = y[:, 0, 0]
    # gtab = y1 = D
    gtab = y1  # / y1[-1]
    y2 = y[:, 0, 1]
    # g2tab = y2 = E  (2LPT coefficient, psi = D Psi_0 + E Psi_2)
    g2tab = y2  # * (-3 / 7) / y2[-1]
    # ftab = d D / da = D'
    ftab = y[:, 1, 0]  # / y1[-1]
    # f2tab = d E / da = E'
    f2tab = y[:, 1, 1]  # * (-3 / 7) / y2[-1]

    # [ a, D, E, D', E' ]
    return [atab, gtab, g2tab, ftab, f2tab]


# make an enum class to represent KICK and DRIFT
class BullfrogStep(Enum):
    KICK = 0
    DRIFT = 1


def _make_DKD(nstep):
    DKD = []
    # Step 1 done , now we have x and v at 1
    DKD.append((BullfrogStep.DRIFT, 0.5))
    # Process
    for _ in range(nstep - 1):
        # x at n+1/2
        DKD.append((BullfrogStep.DRIFT, 0.5))
        # v at n+1
        DKD.append((BullfrogStep.KICK, 1.0))
        # x at n+1
        DKD.append((BullfrogStep.DRIFT, 0.5))
    # DKD.append((BullfrogStep.KICK, 1.0))
    # DKD.append((BullfrogStep.DRIFT, 0.5))
    # Kicks happens at midpoint
    # if nstep > 1:
    #     DKD.append((BullfrogStep.KICK, 0.5))
    #     DKD.append((BullfrogStep.DRIFT, 1.0))
    #     for _ in range(nstep - 2):
    #         DKD.append((BullfrogStep.KICK, 1.0))
    #         DKD.append((BullfrogStep.DRIFT, 1.0))
    #     DKD.append((BullfrogStep.KICK, 1.0))
    # DKD.append((BullfrogStep.DRIFT, 0.5))

    return DKD


def _EFunc(tabs, D):
    # atab,gtab,g2tab,ftab,f2tab
    # interpolate E at D
    return jnp.interp(D, tabs[1], tabs[2])


def _EPrimeFunc(tabs, D):
    # interpolate E'/D' at D
    return jnp.interp(D, tabs[1], tabs[4] / tabs[3])


# Alpha beta coefficients for Bullfrog
def _get_alpha_beta(tabs, prev_D, D, DeltaD):

    D_nh = prev_D + DeltaD / 2
    F_nh = 1 / (D_nh) * (_EFunc(tabs, D) + _EPrimeFunc(tabs, D) * DeltaD / 2) - (D_nh)
    alpha = (_EPrimeFunc(tabs, D) - F_nh) / (_EPrimeFunc(tabs, prev_D) - F_nh)

    # F_nh = 1 / (D) * (
    #     _EFunc(tabs, D - DeltaD / 2) + _EPrimeFunc(tabs, D - DeltaD / 2) * DeltaD / 2
    # ) - (D)
    # alpha = (_EPrimeFunc(tabs, D + DeltaD / 2) - F_nh) / (
    #     _EPrimeFunc(tabs, D - DeltaD / 2) - F_nh
    # )
    beta = 1 - alpha
    return alpha, beta


# def Omega_m(a, Omega_m0=0.3, Omega_lambda=0.7):
#     numerator = Omega_m0 / a**3
#     denominator = numerator + Omega_lambda
#     return numerator / denominator


def _precompute_bullfrog_coefs(
    cosmo, a_start, a_final, nsteps, LF_programme, mult_fix=1
):
    # Solving ODEs for growth factor and related quantities.
    tabs = _get_growth_ODE_tabs(cosmo)
    # Starting growth factor override
    D_start = 0
    D_final = jnp.interp(a_final, tabs[0], tabs[1])

    # print(f"{D_start=}, {D_final=}")

    dD = (D_final - D_start) / (nsteps)

    # Initializing output arrays
    drift_coeff = []
    kick_coeff_alpha = []
    kick_coeff = []
    a_s = []

    prev_D_time = {BullfrogStep.KICK: D_start, BullfrogStep.DRIFT: D_start}

    drift_coeff.append(0.5 * dD)
    kick_coeff.append(0.0)
    kick_coeff_alpha.append(0.0)
    a_s.append(0)
    prev_D_time[BullfrogStep.DRIFT] = D_start + 0.5 * dD

    alpha1 = 1 + _EPrimeFunc(tabs, D_start + dD) / (0.5 * dD)

    drift_coeff.append(0.0)
    kick_coeff_alpha.append(alpha1)
    beta1 = 1 - alpha1
    kick_coeff.append(beta1 / (0.5 * dD))
    a_s.append(1)
    prev_D_time[BullfrogStep.KICK] = D_start + dD

    for op, dt in LF_programme:
        prev_D = prev_D_time[op]
        D = prev_D + dt * dD
        if op == BullfrogStep.KICK:
            # Computing alpha and beta
            alpha, beta = _get_alpha_beta(tabs, prev_D, D, dD)
            drift_coeff.append(0.0)
            # Beta should be divided for alpha for correct scaling
            kick_coeff.append(beta / (prev_D + dD * 0.5) * mult_fix)
            kick_coeff_alpha.append(alpha)
            a_s.append(1)
            del alpha
            del beta
        elif op == BullfrogStep.DRIFT:
            kick_coeff.append(0.0)
            kick_coeff_alpha.append(0.0)
            # Drift is just a drift
            drift_coeff.append(dt * dD)
            a_s.append(0)
        # advancing the time counter for a given step type
        prev_D_time[op] = D
        # print(f"{op=}, {D=}, {dt=}, {dD=}")

    # print(f"{prev_D_time=}")

    return BullfrogCoef(
        rescale=1 / tabs[1][-1],
        kick_coeff=jnp.array(kick_coeff),
        drift_coeff=jnp.array(drift_coeff),
        kick_coeff_alpha=jnp.array(kick_coeff_alpha),
        a_s=jnp.array(a_s),
    )


# Separate function for supersampling
def supersample_ic_field(ic_field, mesh_shape, supersampling):
    ic_field_hat = jnp.fft.rfftn(ic_field)

    new_shape = (
        supersampling * mesh_shape[0],
        supersampling * mesh_shape[1],
        supersampling * mesh_shape[2] // 2 + 1,
    )

    ic_field_hat_new = jnp.zeros(new_shape, dtype=ic_field_hat.dtype)

    half0 = mesh_shape[0] // 2 + 1
    half1 = mesh_shape[1] // 2 + 1
    half2 = mesh_shape[2] // 2 + 1

    offset0 = (0, 0, 0)
    src0 = (slice(0, half0), slice(0, half1), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src0], offset0
    )

    offset1 = (
        supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
        0,
        0,
    )
    src1 = (slice(mesh_shape[0] // 2, mesh_shape[0]), slice(0, half1), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src1], offset1
    )

    offset2 = (
        0,
        supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
        0,
    )
    src2 = (slice(0, half0), slice(mesh_shape[1] // 2, mesh_shape[1]), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src2], offset2
    )

    offset3 = (
        supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
        supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
        0,
    )
    src3 = (
        slice(mesh_shape[0] // 2, mesh_shape[0]),
        slice(mesh_shape[1] // 2, mesh_shape[1]),
        slice(0, half2),
    )
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src3], offset3
    )

    ic_field_resampled = jnp.fft.irfftn(ic_field_hat_new) * (supersampling**3)
    return ic_field_resampled


@partial(
    jax.jit,
    static_argnums=(3, 4, 5, 7, 8, 9),
)
def _run_bullfrog_simulation_internal(
    ic_field,
    omega_c,
    omega_b,
    want_vel,
    mesh_shape,
    mesh_shape_force,
    lf_coefficients,
    supersampling,
    sharding=None,
    fd=False,
):
    """
    ic_field: initial condition mesh grid at epoch 'a_start', under the forme of the density contrast
    omega_c: cold dark matter density
    omega_b: baryon density
    mesh_shape: shape of the position mesh in lagrangian coordinates
    mesh_shape_force: shape of the force mesh
    snaphots: requested snapshots
    a_start: epoch to start the PM simulation from
    lf_coefficients: bullfrog coefficients for fast integration.
    """

    if supersampling > 1:
        ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
        mesh_shape_particles = ic_field.shape

        # ic_field = jnp.fft.irfftn(ic_field_hat_new) * supersampling**3
    else:
        mesh_shape_particles = mesh_shape

    particles = jnp.stack(
        jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]), axis=-1
    ).reshape([-1, 3])

    pos = jnp.astype(particles, FLOAT_DTYPE)
    # Just taking the forces from IC.
    vel = (
        pm_forces_in(particles, delta=fft3d(ic_field), fd=fd, sharding=sharding)
        * lf_coefficients.rescale
    )
    multiplier = mesh_shape_particles[0] / mesh_shape_force[0]

    @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
    def _evolver(state, istep):
        step_pos, step_vel = state
        step_vel = jax.lax.cond(
            lf_coefficients.kick_coeff[istep] > 0,
            jax.remat(
                lambda: lf_coefficients.kick_coeff_alpha[istep] * step_vel
                + lf_coefficients.kick_coeff[istep]
                * pm_forces(
                    step_pos,
                    mesh_shape=mesh_shape_force,
                    pos_mesh_shape=mesh_shape_particles,
                    fd=fd,  # True or False ?
                    sharding=sharding,
                )
                * multiplier
            ),
            jax.remat(lambda: step_vel),
        )
        step_pos = step_pos + lf_coefficients.drift_coeff[istep] * step_vel
        return (step_pos, step_vel), None

    (pos, vel), _ = jax.lax.scan(
        _evolver, (pos, vel), jnp.arange(lf_coefficients.drift_coeff.shape[0])
    )

    # for istep in range(lf_coefficients.drift_coeff.shape[0]):
    #     # Slight differences due to D-time integration
    #     # kick
    #     vel = jax.lax.cond(
    #         lf_coefficients.a_s[istep] > 0.0,
    #         jax.remat(
    #             lambda: lf_coefficients.kick_coeff_alpha[istep] * vel
    #             + lf_coefficients.kick_coeff[istep]
    #             * pm_forces(
    #                 pos,
    #                 mesh_shape=mesh_shape_force,
    #                 pos_mesh_shape=mesh_shape_particles,
    #                 fd=True,
    #                 sharding=sharding,
    #             )
    #             * multiplier
    #         ),
    #         lambda: vel,
    #     )
    #     # drift
    #     pos += lf_coefficients.drift_coeff[istep] * vel

    return pos


# Lbox added for future adjustments and tests.
[docs] def prepare_bullfrog_simulation( omega_c, omega_b, num_steps, a_start, a_final, supersampling=1, mult_fix=1, fd=False, ): lf_coefs = _precompute_bullfrog_coefs( jc.Planck15(Omega_c=omega_c, Omega_b=omega_b), a_start, a_final, num_steps, _make_DKD(num_steps), mult_fix=mult_fix, ) return partial( _run_bullfrog_simulation_internal, lf_coefficients=lf_coefs, omega_c=omega_c, omega_b=omega_b, supersampling=supersampling, fd=fd, )