Source code for borgjaxpm.forward

# SPDX-FileCopyrightText: 2025 CNRS
# SPDX-FileContributor: Guilhem Lavaux
# SPDX-FileContributor: Svyatoslav Trusov
#
# SPDX-License-Identifier: CECILL-B

import aquila_borg as borg
import aquila_borg.wrappers as borg_wrap
import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from jax import checkpoint

from ._simulation_bullfrog import prepare_bullfrog_simulation
from ._simulation_classic import run_simulation
from ._simulation_cola import prepare_cola_simulation
from ._simulation_leapfrog import prepare_leap_frog_simulation
from ._src.painting import cic_paint


[docs] class BorgJaxPM(borg.forward.BaseForwardModel_v3): @borg.trace_function def __init__( self, sim_box: borg.forward.BoxModel, a_init: float, a_final: float = 1.0, forcesampling: int = 1, integrator: str = "bullfrog", num_steps: int = 50, supersampling: int = 1, mult_fix=1, fd: bool = None, ): super().__init__() self.sim_box = sim_box self.a_init = a_init self.a_final = a_final self.forcesampling = forcesampling self.supersampling = supersampling self.mesh_shape = tuple(sim_box.N) self.mesh_shape_force = tuple(map(lambda x: self.forcesampling * x, sim_box.N)) self.need_adjoint = True self.mult_fix = mult_fix self.mesh = jax.sharding.Mesh(np.array(jax.devices())[:, None], ("x", "y")) self.spec = jax.sharding.PartitionSpec("x", "y") self.sharding = jax.sharding.NamedSharding(self.mesh, self.spec) def _scatter_to_mesh(p): shape1 = self.mesh_shape shape2 = tuple(s * self.supersampling for s in shape1) return ( cic_paint( shape1, shape2, p, ) - 1 ) self._scatter_to_mesh = _scatter_to_mesh self.integrator = integrator self.simulator = None if integrator == "classic": borg.print_msg(borg.Level.std, "Using classic integrator") self._run_simulation = self._run_simulation_classic elif integrator == "bullfrog": borg.print_msg(borg.Level.std, "Using bullfrog integrator") self.num_steps = num_steps self._run_simulation = self._run_simulation_bullfrog if fd is not None: self.fd = fd elif integrator == "leapfrog": borg.print_msg(borg.Level.std, "Using leapfrog integrator") self.num_steps = num_steps self._run_simulation = self._run_simulation_leapfrog elif integrator == "tcola": borg.print_msg(borg.Level.std, "Using tCOLA integrator") self.num_steps = num_steps self._run_simulation = self._run_simulation_tcola else: raise ValueError(f"Unknown integrator '{integrator}'")
[docs] def getOutputDescription_AG(self) -> borg.modelio.RepresentationDescriptor: return borg.modelio.makeModelIODescriptor( self.sim_box, borg.modelio.ModelIOType.OUTPUT_ADJOINT, False )
[docs] def getInputDescription(self) -> borg.modelio.RepresentationDescriptor: # Give the closest thing to a Fourier space requirement for the input return borg.modelio.makeModelIODescriptor( self.sim_box, borg.modelio.ModelIOType.INPUT, False )
[docs] def getOutputDescription(self) -> borg.modelio.RepresentationDescriptor: return borg.modelio.makeTiledArrayDescriptor(self.sim_box.N)
[docs] def setAdjointRequired(self, on: bool): self.need_adjoint = on
# Needs to be differentiable. def _run_simulation_classic(self, modes): pars = self.getCosmoParams() omega_c = pars.omega_m - pars.omega_b assert omega_c > 0 omega_b = pars.omega_b snapshots = jnp.array([self.a_init, self.a_final]) _, snapshots, vel_snapshots = run_simulation( modes, omega_c, omega_b, True, self.mesh_shape, self.mesh_shape_force, snapshots, self.a_init, self.sharding, self.supersampling, ) return snapshots[-1] # Needs to be differentiable. def _run_simulation_leapfrog(self, modes): position = self.simulator( modes, want_vel=True, mesh_shape=self.mesh_shape, mesh_shape_force=self.mesh_shape_force, sharding=self.sharding, ) return position # Needs to be differentiable def _run_simulation_tcola(self, modes): position = self.simulator( modes, want_vel=True, mesh_shape=self.mesh_shape, mesh_shape_force=self.mesh_shape_force, sharding=self.sharding, ) return position # Needs to be differentiable. def _run_simulation_bullfrog(self, modes): position = self.simulator( modes, want_vel=True, mesh_shape=self.mesh_shape, mesh_shape_force=self.mesh_shape_force, sharding=self.sharding, ) return position
[docs] @borg.trace_function def forwardModel_v3(self, input_slot: borg.modelio.GInput): if self.simulator is None: pars = self.getCosmoParams() omega_c = pars.omega_m - pars.omega_b assert omega_c > 0 omega_b = pars.omega_b if self.integrator == "leapfrog": self.simulator = prepare_leap_frog_simulation( omega_c, omega_b, self.num_steps, self.a_init, self.a_final, supersampling=self.supersampling, ) elif self.integrator == "bullfrog": self.simulator = prepare_bullfrog_simulation( omega_c, omega_b, self.num_steps, self.a_init, self.a_final, supersampling=self.supersampling, fd=self.fd, ) elif self.integrator == "tcola": self.simulator = prepare_cola_simulation( omega_c, omega_b, self.num_steps, self.a_init, self.a_final, supersampling=self.supersampling, ) elif self.integrator == "classic": pass else: raise ValueError("Unknown integrator") input_slot.request(self.getInputDescription()) input_io = input_slot.getCurrent() assert isinstance(input_io, borg.modelio.ModelIORepresentation3) # Input is density fluctuation at epoch self.a_init self.modes = jnp.array(np.array(input_io, copy=False)) if self.modes.dtype == jnp.complex64 or self.modes.dtype == jnp.complex128: self.is_complex = True self.modes = ( jnp.fft.irfftn(self.modes, norm="forward").real / self.sim_box.volume ) simulator = checkpoint(self._run_simulation) else: self.is_complex = False simulator = checkpoint(self._run_simulation) pars = self.getCosmoParams() omega_c = pars.omega_m - pars.omega_b assert omega_c > 0 # TODO (ST): This is the variant of the computation taken from borg_pmwd # realization. It is horribly inefficient in this context and I am # pretty convinced it can be much better and faster, and I will try to # resolve it. if self.need_adjoint: self.particles, self.grad_particles = jax.vjp( jax.checkpoint( self._run_simulation, policy=jax.checkpoint_policies.nothing_saveable, ), self.modes, ) else: self.particles = simulator(self.modes) self.snapshots = [self.particles]
[docs] @borg.trace_function def getResultForward_v3(self, output_slot: borg.modelio.GOutput): output_slot.request(self.getOutputDescription()) output = output_slot.getCurrent() # This was also changed to be more similar to borg_pmwd if self.need_adjoint: scatter_func = checkpoint( self._scatter_to_mesh, policy=jax.checkpoint_policies.nothing_saveable ) mesh, self.mesh_grad = jax.vjp(scatter_func, self.particles) else: print("Particles shape:", self.particles.shape) mesh = self._scatter_to_mesh(self.particles) np.array(output, copy=False)[()] = mesh self.mesh = mesh output_slot.close_request() return output_slot
[docs] @borg.trace_function def adjointModel_v3(self, input_ag: borg.modelio.GInputAdjoint): # This function basically stayed the same assert self.need_adjoint input_ag.request( borg.modelio.makeModelIODescriptor( self.sim_box, borg.modelio.ModelIOType.INPUT_ADJOINT, fourier=False ) ) (self.part_ag,) = self.mesh_grad(jnp.array(input_ag.getCurrent(), copy=False))
[docs] @borg.trace_function def getResultAdjointGradient_v3(self, output_ag: borg.modelio.GOutputAdjoint): # The only difference in this function with respect to the original is the ugly rfftn, because output_ag expects a Fourier. assert self.need_adjoint output_ag.request( borg.modelio.makeModelIODescriptor( self.sim_box, borg.modelio.ModelIOType.OUTPUT_ADJOINT, fourier=False ) ) (grad_whitenoise,) = self.grad_particles(self.part_ag) np.array(output_ag.getCurrent(), copy=False)[:] = ( jnp.fft.rfftn(grad_whitenoise.astype("float32")) / self.sim_box.volume ) output_ag.close_request() return output_ag
[docs] def getParticlePositions(self) -> npt.ArrayLike: return self.snapshots[-1]
[docs] def getParticleVelocities(self) -> npt.ArrayLike: return self.vel_snapshots[-1]
def _construction_v3(box_in, box_out, kwargs): return (box_in,), kwargs BorgJaxPM_v2 = borg_wrap.adapt_v3_model(BorgJaxPM, constructor_helper=_construction_v3) __all__ = ["BorgJaxPM", "BorgJaxPM_v2"]