# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B
import aquila_borg as borg
import aquila_borg.wrappers as borg_wrap
import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from jax import checkpoint
from ._simulation_bullfrog import prepare_bullfrog_simulation
from ._simulation_classic import run_simulation
from ._simulation_cola import prepare_cola_simulation
from ._simulation_leapfrog import prepare_leap_frog_simulation
from ._src.painting import cic_paint
[docs]
class BorgJaxPM(borg.forward.BaseForwardModel_v3):
@borg.trace_function
def __init__(
self,
sim_box: borg.forward.BoxModel,
a_init: float,
a_final: float = 1.0,
forcesampling: int = 1,
integrator: str = "bullfrog",
num_steps: int = 50,
supersampling: int = 1,
mult_fix=1,
fd: bool = None,
):
super().__init__()
self.sim_box = sim_box
self.a_init = a_init
self.a_final = a_final
self.forcesampling = forcesampling
self.supersampling = supersampling
self.mesh_shape = tuple(sim_box.N)
self.mesh_shape_force = tuple(map(lambda x: self.forcesampling * x, sim_box.N))
self.need_adjoint = True
self.mult_fix = mult_fix
self.mesh = jax.sharding.Mesh(np.array(jax.devices())[:, None], ("x", "y"))
self.spec = jax.sharding.PartitionSpec("x", "y")
self.sharding = jax.sharding.NamedSharding(self.mesh, self.spec)
def _scatter_to_mesh(p):
shape1 = self.mesh_shape
shape2 = tuple(s * self.supersampling for s in shape1)
return (
cic_paint(
shape1,
shape2,
p,
)
- 1
)
self._scatter_to_mesh = _scatter_to_mesh
self.integrator = integrator
self.simulator = None
if integrator == "classic":
borg.print_msg(borg.Level.std, "Using classic integrator")
self._run_simulation = self._run_simulation_classic
elif integrator == "bullfrog":
borg.print_msg(borg.Level.std, "Using bullfrog integrator")
self.num_steps = num_steps
self._run_simulation = self._run_simulation_bullfrog
if fd is not None:
self.fd = fd
elif integrator == "leapfrog":
borg.print_msg(borg.Level.std, "Using leapfrog integrator")
self.num_steps = num_steps
self._run_simulation = self._run_simulation_leapfrog
elif integrator == "tcola":
borg.print_msg(borg.Level.std, "Using tCOLA integrator")
self.num_steps = num_steps
self._run_simulation = self._run_simulation_tcola
else:
raise ValueError(f"Unknown integrator '{integrator}'")
[docs]
def getOutputDescription_AG(self) -> borg.modelio.RepresentationDescriptor:
return borg.modelio.makeModelIODescriptor(
self.sim_box, borg.modelio.ModelIOType.OUTPUT_ADJOINT, False
)
[docs]
def getOutputDescription(self) -> borg.modelio.RepresentationDescriptor:
return borg.modelio.makeTiledArrayDescriptor(self.sim_box.N)
[docs]
def setAdjointRequired(self, on: bool):
self.need_adjoint = on
# Needs to be differentiable.
def _run_simulation_classic(self, modes):
pars = self.getCosmoParams()
omega_c = pars.omega_m - pars.omega_b
assert omega_c > 0
omega_b = pars.omega_b
snapshots = jnp.array([self.a_init, self.a_final])
_, snapshots, vel_snapshots = run_simulation(
modes,
omega_c,
omega_b,
True,
self.mesh_shape,
self.mesh_shape_force,
snapshots,
self.a_init,
self.sharding,
self.supersampling,
)
return snapshots[-1]
# Needs to be differentiable.
def _run_simulation_leapfrog(self, modes):
position = self.simulator(
modes,
want_vel=True,
mesh_shape=self.mesh_shape,
mesh_shape_force=self.mesh_shape_force,
sharding=self.sharding,
)
return position
# Needs to be differentiable
def _run_simulation_tcola(self, modes):
position = self.simulator(
modes,
want_vel=True,
mesh_shape=self.mesh_shape,
mesh_shape_force=self.mesh_shape_force,
sharding=self.sharding,
)
return position
# Needs to be differentiable.
def _run_simulation_bullfrog(self, modes):
position = self.simulator(
modes,
want_vel=True,
mesh_shape=self.mesh_shape,
mesh_shape_force=self.mesh_shape_force,
sharding=self.sharding,
)
return position
[docs]
@borg.trace_function
def forwardModel_v3(self, input_slot: borg.modelio.GInput):
if self.simulator is None:
pars = self.getCosmoParams()
omega_c = pars.omega_m - pars.omega_b
assert omega_c > 0
omega_b = pars.omega_b
if self.integrator == "leapfrog":
self.simulator = prepare_leap_frog_simulation(
omega_c,
omega_b,
self.num_steps,
self.a_init,
self.a_final,
supersampling=self.supersampling,
)
elif self.integrator == "bullfrog":
self.simulator = prepare_bullfrog_simulation(
omega_c,
omega_b,
self.num_steps,
self.a_init,
self.a_final,
supersampling=self.supersampling,
fd=self.fd,
)
elif self.integrator == "tcola":
self.simulator = prepare_cola_simulation(
omega_c,
omega_b,
self.num_steps,
self.a_init,
self.a_final,
supersampling=self.supersampling,
)
elif self.integrator == "classic":
pass
else:
raise ValueError("Unknown integrator")
input_slot.request(self.getInputDescription())
input_io = input_slot.getCurrent()
assert isinstance(input_io, borg.modelio.ModelIORepresentation3)
# Input is density fluctuation at epoch self.a_init
self.modes = jnp.array(np.array(input_io, copy=False))
if self.modes.dtype == jnp.complex64 or self.modes.dtype == jnp.complex128:
self.is_complex = True
self.modes = (
jnp.fft.irfftn(self.modes, norm="forward").real / self.sim_box.volume
)
simulator = checkpoint(self._run_simulation)
else:
self.is_complex = False
simulator = checkpoint(self._run_simulation)
pars = self.getCosmoParams()
omega_c = pars.omega_m - pars.omega_b
assert omega_c > 0
# TODO (ST): This is the variant of the computation taken from borg_pmwd
# realization. It is horribly inefficient in this context and I am
# pretty convinced it can be much better and faster, and I will try to
# resolve it.
if self.need_adjoint:
self.particles, self.grad_particles = jax.vjp(
jax.checkpoint(
self._run_simulation,
policy=jax.checkpoint_policies.nothing_saveable,
),
self.modes,
)
else:
self.particles = simulator(self.modes)
self.snapshots = [self.particles]
[docs]
@borg.trace_function
def getResultForward_v3(self, output_slot: borg.modelio.GOutput):
output_slot.request(self.getOutputDescription())
output = output_slot.getCurrent()
# This was also changed to be more similar to borg_pmwd
if self.need_adjoint:
scatter_func = checkpoint(
self._scatter_to_mesh, policy=jax.checkpoint_policies.nothing_saveable
)
mesh, self.mesh_grad = jax.vjp(scatter_func, self.particles)
else:
print("Particles shape:", self.particles.shape)
mesh = self._scatter_to_mesh(self.particles)
np.array(output, copy=False)[()] = mesh
self.mesh = mesh
output_slot.close_request()
return output_slot
[docs]
@borg.trace_function
def adjointModel_v3(self, input_ag: borg.modelio.GInputAdjoint):
# This function basically stayed the same
assert self.need_adjoint
input_ag.request(
borg.modelio.makeModelIODescriptor(
self.sim_box, borg.modelio.ModelIOType.INPUT_ADJOINT, fourier=False
)
)
(self.part_ag,) = self.mesh_grad(jnp.array(input_ag.getCurrent(), copy=False))
[docs]
@borg.trace_function
def getResultAdjointGradient_v3(self, output_ag: borg.modelio.GOutputAdjoint):
# The only difference in this function with respect to the original is the ugly rfftn, because output_ag expects a Fourier.
assert self.need_adjoint
output_ag.request(
borg.modelio.makeModelIODescriptor(
self.sim_box, borg.modelio.ModelIOType.OUTPUT_ADJOINT, fourier=False
)
)
(grad_whitenoise,) = self.grad_particles(self.part_ag)
np.array(output_ag.getCurrent(), copy=False)[:] = (
jnp.fft.rfftn(grad_whitenoise.astype("float32")) / self.sim_box.volume
)
output_ag.close_request()
return output_ag
[docs]
def getParticlePositions(self) -> npt.ArrayLike:
return self.snapshots[-1]
[docs]
def getParticleVelocities(self) -> npt.ArrayLike:
return self.vel_snapshots[-1]
def _construction_v3(box_in, box_out, kwargs):
return (box_in,), kwargs
BorgJaxPM_v2 = borg_wrap.adapt_v3_model(BorgJaxPM, constructor_helper=_construction_v3)
__all__ = ["BorgJaxPM", "BorgJaxPM_v2"]