Source code for borgjaxpm._simulation_leapfrog

# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B

from collections import namedtuple
from enum import Enum
from functools import partial

import aquila_borg as borg
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from ._src.constants import FLOAT_DTYPE, LOG_STEP_DEFAULT
from ._src.pm import lpt, pm_forces

LeapfrogCoef = namedtuple("LeapfrogCoef", ["drift_coeff", "kick_coeff"])


# make an enum class to represent KICK and DRIFT
class LeapFrogStep(Enum):
    KICK = 0
    DRIFT = 1


def _make_DKD(nstep):
    DKD = []
    DKD.append((LeapFrogStep.DRIFT, 0.5))
    for _ in range(nstep - 1):
        DKD.append((LeapFrogStep.KICK, 1.0))
        DKD.append((LeapFrogStep.DRIFT, 1.0))
    DKD.append((LeapFrogStep.KICK, 1.0))
    DKD.append((LeapFrogStep.DRIFT, 0.5))

    return DKD


def _precompute_leapfrog_coefs(
    cosmo, a_start, a_final, nsteps, LF_programme, log_step=LOG_STEP_DEFAULT
):
    # from scipy.integrate import quad
    from quadax import quadgk

    if log_step:
        dx = jnp.log(a_final / a_start) / nsteps
        get_a = jnp.exp

        def get_ln_a(x):
            return x

        get_x = jnp.log
    else:
        dx = (a_final - a_start) / nsteps

        def get_a(x):
            return x

        get_ln_a = jnp.log

        def get_x(x):
            return x

    def _integrand_pos(ln_a):
        a = jnp.exp(ln_a)
        # We do a change of variable to get the integrand in terms of a
        # We use the fact that dln(a)/da = 1/a, I(a) da = I(ln(a)) a dln(a)
        return jnp.array(
            1.0 / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))), dtype=FLOAT_DTYPE
        )

    def _integrand_vel(ln_a):
        a = jnp.exp(ln_a)
        # We do a change of variable to get the integrand in terms of a
        # We use the fact that dln(a)/da = 1/a, I(a) da = I(ln(a)) a dln(a)
        return jnp.array(
            1.0 / (a * jnp.sqrt(jc.background.Esqr(cosmo, a))), dtype=FLOAT_DTYPE
        )

    drift_coeff = []
    kick_coeff = []
    prev_a_time = {
        LeapFrogStep.KICK: get_x(a_start),
        LeapFrogStep.DRIFT: get_x(a_start),
    }
    for op, dt in LF_programme:
        prev_x = prev_a_time[op]
        x = prev_x + dt * dx
        ln_a = get_ln_a(x)
        ln_prev_a = get_ln_a(prev_x)
        a = get_a(x)
        borg.print_msg(borg.Level.std, f"op={op}, a={a}, dt={dt}")
        if op == LeapFrogStep.KICK:
            if dx == 0:
                last = 0
            else:
                last = quadgk(_integrand_vel, [ln_prev_a, ln_a])[0]
            borg.print_msg(
                borg.Level.std,
                "KICK ln_prev_a={ln_prev_a} ln_a={ln_a} last={last}",
                last=last,
                ln_prev_a=ln_prev_a,
                ln_a=ln_a,
            )
            kick_coeff.append(last)
            drift_coeff.append(0.0)
        elif op == LeapFrogStep.DRIFT:
            if dx == 0:
                last = 0
            else:
                last = quadgk(_integrand_pos, [ln_prev_a, ln_a])[0]
            borg.print_msg(borg.Level.std, "DRIFT last={last}", last=last)
            kick_coeff.append(0.0)
            drift_coeff.append(last)
        prev_a_time[op] = x
        last_x = x

    assert jnp.abs(get_x(a_final) - last_x) < 1e-3

    return LeapfrogCoef(
        kick_coeff=jnp.array(kick_coeff),
        drift_coeff=jnp.array(drift_coeff),
    )


def supersample_ic_field(ic_field, mesh_shape, supersampling):
    ic_field_hat = jnp.fft.rfftn(ic_field)

    new_shape = (
        supersampling * mesh_shape[0],
        supersampling * mesh_shape[1],
        supersampling * mesh_shape[2] // 2 + 1,
    )

    ic_field_hat_new = jnp.zeros(new_shape, dtype=ic_field_hat.dtype)

    half0 = mesh_shape[0] // 2 + 1
    half1 = mesh_shape[1] // 2 + 1
    half2 = mesh_shape[2] // 2 + 1

    offset0 = (0, 0, 0)
    src0 = (slice(0, half0), slice(0, half1), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src0], offset0
    )

    offset1 = (
        supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
        0,
        0,
    )
    src1 = (slice(mesh_shape[0] // 2, mesh_shape[0]), slice(0, half1), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src1], offset1
    )

    offset2 = (
        0,
        supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
        0,
    )
    src2 = (slice(0, half0), slice(mesh_shape[1] // 2, mesh_shape[1]), slice(0, half2))
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src2], offset2
    )

    offset3 = (
        supersampling * mesh_shape[0] - (mesh_shape[0] - mesh_shape[0] // 2),
        supersampling * mesh_shape[1] - (mesh_shape[1] - mesh_shape[1] // 2),
        0,
    )
    src3 = (
        slice(mesh_shape[0] // 2, mesh_shape[0]),
        slice(mesh_shape[1] // 2, mesh_shape[1]),
        slice(0, half2),
    )
    ic_field_hat_new = jax.lax.dynamic_update_slice(
        ic_field_hat_new, ic_field_hat[src3], offset3
    )

    ic_field_resampled = jnp.fft.irfftn(ic_field_hat_new) * (supersampling**3)
    return ic_field_resampled


@partial(
    jax.jit,
    static_argnums=(
        3,
        4,
        5,
        7,
        9,
    ),
)
def _run_leap_frog_simulation_internal(
    ic_field,
    omega_c,
    omega_b,
    want_vel,
    mesh_shape,
    mesh_shape_force,
    lf_coefficients,
    supersampling,
    a_start,
    sharding=None,
    halo_size=0,
):
    """
    ic_field: initial condition mesh grid at epoch 'a_start', under the forme of the density contrast
    omega_c: cold dark matter density
    omega_b: baryon density
    mesh_shape: shape of the position mesh in lagrangian coordinates
    mesh_shape_force: shape of the force mesh
    snaphots: requested snapshots
    a_start: epoch to start the PM simulation from
    """

    cosmo = jc.Planck15(Omega_c=omega_c, Omega_b=omega_b)

    if supersampling > 1:
        ic_field = supersample_ic_field(ic_field, mesh_shape, supersampling)
        mesh_shape_particles = ic_field.shape
    else:
        mesh_shape_particles = mesh_shape

    particles = jnp.stack(
        jax.device_put(
            jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape_particles]),
            device=sharding,
        ),
        axis=-1,
    ).reshape([-1, 3])
    dx, p = lpt(cosmo, ic_field, particles, a=a_start, order=1, sharding=sharding)

    pos = particles + dx
    vel = p
    multiplier = mesh_shape_particles[0] / mesh_shape_force[0]

    @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
    def _evolver(state, istep):
        step_pos, step_vel = state
        step_vel = jax.lax.cond(
            lf_coefficients.kick_coeff[istep] > 0,
            jax.remat(
                lambda: step_vel
                + lf_coefficients.kick_coeff[istep]
                * pm_forces(
                    step_pos,
                    mesh_shape=mesh_shape_force,
                    pos_mesh_shape=mesh_shape_particles,
                    fd=True,
                    sharding=sharding,
                    halo_size=halo_size,
                )
                * 1.5
                * cosmo.Omega_m
                * multiplier
            ),
            jax.remat(lambda: step_vel),
        )
        step_pos = step_pos + lf_coefficients.drift_coeff[istep] * step_vel
        return (step_pos, step_vel), None

    (pos, vel), _ = jax.lax.scan(
        _evolver, (pos, vel), jnp.arange(lf_coefficients.drift_coeff.shape[0])
    )

    # for istep in range(lf_coefficients.drift_coeff.shape[0]):
    #     vel = jax.lax.cond(
    #         lf_coefficients.kick_coeff[istep] > 0,
    #         jax.remat(
    #             lambda: vel
    #             + lf_coefficients.kick_coeff[istep]
    #             * pm_forces(
    #                 pos,
    #                 mesh_shape=mesh_shape_force,
    #                 pos_mesh_shape=mesh_shape_particles,
    #                 fd=True,
    #                 sharding=sharding,
    #                 halo_size=halo_size,
    #             )
    #             * 1.5
    #             * cosmo.Omega_m
    #             * multiplier
    #         ),
    #         jax.remat(lambda: vel),
    #     )
    #     pos = pos + lf_coefficients.drift_coeff[istep] * vel

    return pos


[docs] def prepare_leap_frog_simulation( omega_c, omega_b, num_steps, a_start, a_final, supersampling=1, ): lf_coefs = _precompute_leapfrog_coefs( jc.Planck15(Omega_c=omega_c, Omega_b=omega_b), a_start, a_final, num_steps, _make_DKD(num_steps), ) return partial( _run_leap_frog_simulation_internal, lf_coefficients=lf_coefs, omega_c=omega_c, omega_b=omega_b, a_start=a_start, supersampling=supersampling, )