jaxincell¶
Submodules¶
Attributes¶
Functions¶
|
Applies boundary conditions (BCs) to a single particle's position and velocity. |
|
Applies boundary conditions to all particles in parallel. |
|
Applies boundary conditions to particle positions only (used for half-step updates). |
|
Applies boundary conditions to particle positions for all particles during a half-step update. |
|
|
|
Solve for the electric field E = -d(phi)/dx using FFT, |
|
Solve for the electric field E = -d(phi)/dx using FFT, |
|
Solve for the electric field at t=0 (E0) using the charge density distribution |
|
Compute the curl of the electric field, which is related to the time derivative of |
|
Compute the curl of the magnetic field, which is related to the time derivative of |
|
Update the electric and magnetic fields based on Maxwell's equations |
|
|
|
|
|
This function retrieves the electric or magnetic field values at particle positions |
|
Interpolates field to particle using Periodic BCs. |
|
This function implements the Boris algorithm to rotate the particle velocity vector |
|
This function performs one step of the Boris algorithm for particle motion. |
|
Relativistic Boris pusher for N particles. |
|
Production-ready plotting/animation for JAX-in-Cell outputs. |
|
Initialize the simulation parameters for a particle-in-cell simulation, |
|
Initialize particles and electromagnetic fields for a Particle-in-Cell simulation. |
|
Run a plasma physics simulation using a Particle-In-Cell (PIC) method in JAX. |
|
Load parameters from a TOML file. |
|
Calculates weights and indices for Quadratic Spline (S2). |
|
Compute the charge contribution to the boundary points based on particle positions and boundary conditions. |
|
Computes the charge density contribution of a single particle to the grid using a |
|
Computes the total charge density on the grid by summing contributions from all particles. |
|
Computes the current density j on the grid from particle motion. |
|
Deposits Current J using Periodic BCs. |
|
Shift x by 'shift' cells along axis=0 with boundary conditions: |
|
3-point digital filter along axis 0 with BCs: |
|
JAX-safe version of the 3-point digital filter with compensation: |
|
Apply a multi-pass 3-point binomial digital filter to a scalar field. |
|
Apply digital filter along the grid axis (axis=0) for each component. |
Package Contents¶
- jaxincell.set_BC_single_particle(x_n, v_n, q, q_m, dx, grid, box_size_x, box_size_y, box_size_z, BC_left, BC_right)¶
Applies boundary conditions (BCs) to a single particle’s position and velocity.
- Parameters:
x_n (jnp.ndarray) – Particle position as a 1D array [x, y, z].
v_n (jnp.ndarray) – Particle velocity as a 1D array [vx, vy, vz].
q (float) – Particle charge.
q_m (float) – Charge-to-mass ratio of the particle.
dx (float) – Grid spacing.
grid (jnp.ndarray) – Discretized grid positions.
box_size_x (float) – Box dimensions in x, y, and z directions.
box_size_y (float) – Box dimensions in x, y, and z directions.
box_size_z (float) – Box dimensions in x, y, and z directions.
BC_left (int) – Boundary conditions for left and right boundaries in the x-direction. 0: Periodic 1: Reflective 2: Absorbing
BC_right (int) – Boundary conditions for left and right boundaries in the x-direction. 0: Periodic 1: Reflective 2: Absorbing
- Returns:
Updated position (x_n), velocity (v_n), charge (q), and charge-to-mass ratio (q_m).
- Return type:
tuple
- jaxincell.set_BC_particles(xs_n, vs_n, qs, ms, q_ms, dx, grid, box_size_x, box_size_y, box_size_z, BC_left, BC_right)¶
Applies boundary conditions to all particles in parallel.
- Parameters:
xs_n (jnp.ndarray) – Positions of all particles, shape (N, 3).
vs_n (jnp.ndarray) – Velocities of all particles, shape (N, 3).
qs (jnp.ndarray) – Charges of all particles, shape (N,).
ms (jnp.ndarray) – Masses of all particles, shape (N,).
q_ms (jnp.ndarray) – Charge-to-mass ratios of all particles, shape (N,).
parameters (Other) – Same as set_BCs.
- Returns:
Updated positions, velocities, charges, masses, and charge-to-mass ratios for all particles.
- Return type:
tuple
- jaxincell.set_BC_single_particle_positions(x_n, dx, grid, box_size_x, box_size_y, box_size_z, BC_left, BC_right)¶
Applies boundary conditions to particle positions only (used for half-step updates).
- Parameters:
x_n (jnp.ndarray) – Particle position as a 1D array [x, y, z].
parameters (Other) – Same as set_BCs.
- Returns:
Updated particle position [x, y, z].
- Return type:
jnp.ndarray
- jaxincell.set_BC_positions(xs_n, qs, dx, grid, box_size_x, box_size_y, box_size_z, BC_left, BC_right)¶
Applies boundary conditions to particle positions for all particles during a half-step update.
- Parameters:
xs_n (jnp.ndarray) – Positions of all particles, shape (N, 3).
qs (jnp.ndarray) – Charges of all particles, shape (N,).
parameters (Other) – Same as set_BCs.
- Returns:
Updated positions of all particles, shape (N, 3).
- Return type:
jnp.ndarray
- jaxincell.epsilon_0 = 8.85418782e-12¶
- jaxincell.mu_0 = 1.25663706e-06¶
- jaxincell.speed_of_light = 299792458.0¶
- jaxincell.elementary_charge = 1.60217663e-19¶
- jaxincell.mass_electron = 9.10938371e-31¶
- jaxincell.mass_proton = 1.67262193e-27¶
- jaxincell.boltzmann_constant = 1.380649e-23¶
- jaxincell.diagnostics(output)¶
- jaxincell.E_from_Gauss_1D_FFT(charge_density, dx)¶
Solve for the electric field E = -d(phi)/dx using FFT, where phi is derived from the 1D Gauss’ law equation. Parameters: charge_density : 1D numpy array, source term (right-hand side of Poisson equation) dx : float, grid spacing in the x-direction Returns: E : 1D numpy array, electric field
- jaxincell.E_from_Poisson_1D_FFT(charge_density, dx)¶
Solve for the electric field E = -d(phi)/dx using FFT, where phi is derived from the 1D Poisson equation. Parameters: charge_density : 1D numpy array, source term (right-hand side of Poisson equation) dx : float, grid spacing in the x-direction Returns: E : 1D numpy array, electric field
- jaxincell.E_from_Gauss_1D_Cartesian(charge_density, dx)¶
Solve for the electric field at t=0 (E0) using the charge density distribution and applying Gauss’s law in a 1D system.
- Parameters:
charge_density – 1D numpy array, source term (right-hand side of Gauss equation)
dx – float, grid spacing in the x-direction
- Returns:
The electric field at each grid point due to the particles, shape (G,).
- Return type:
array
- jaxincell.curlE(E_field, B_field, dx, dt, field_BC_left, field_BC_right)¶
Compute the curl of the electric field, which is related to the time derivative of the magnetic field in Maxwell’s equations (Faraday’s law).
- Parameters:
E_field (array) – Electric field at each grid point, shape (G, 3).
B_field (array) – Magnetic field at each grid point, shape (G, 3).
dx (float) – Grid spacing.
dt (float) – Time step.
field_BC_left (int) – Left boundary condition for fields (0: periodic, 1: reflective, 2: absorbing).
field_BC_right (int) – Right boundary condition for fields.
- Returns:
The curl of the electric field, which is the source of the magnetic field.
- Return type:
array
- jaxincell.curlB(B_field, E_field, dx, dt, field_BC_left, field_BC_right)¶
Compute the curl of the magnetic field, which is related to the time derivative of the electric field in Maxwell’s equations (Ampère’s law with Maxwell correction).
- Parameters:
B_field (array) – Magnetic field at each grid point, shape (G, 3).
E_field (array) – Electric field at each grid point, shape (G, 3).
dx (float) – Grid spacing.
dt (float) – Time step.
field_BC_left (int) – Left boundary condition for fields.
field_BC_right (int) – Right boundary condition for fields.
- Returns:
The curl of the magnetic field, which is the source of the electric field.
- Return type:
array
- jaxincell.field_update(E_fields, B_fields, dx, dt, j, field_BC_left, field_BC_right)¶
Update the electric and magnetic fields based on Maxwell’s equations
- Parameters:
E_fields (array) – Electric field at each grid point, shape (G, 3).
B_fields (array) – Magnetic field at each grid point, shape (G, 3).
dx (float) – Grid spacing.
dt (float) – Time step.
j (array) – Current density at each grid point, shape (G, 3).
field_BC_left (int) – Left boundary condition for fields.
field_BC_right (int) – Right boundary condition for fields.
- Returns:
Updated electric and magnetic fields, each of shape (G, 3).
- Return type:
tuple
- jaxincell.field_update1(E_fields, B_fields, dx, dt, j, field_BC_left, field_BC_right)¶
- jaxincell.field_update2(E_fields, B_fields, dx, dt, j, field_BC_left, field_BC_right)¶
- jaxincell.fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, field_BC_right)¶
This function retrieves the electric or magnetic field values at particle positions using a field interpolation scheme. The function first adds ghost cells to the field array to handle boundary conditions, then interpolates the field based on the particle’s position in the grid.
- Parameters:
x_n (array) – The position of particles at time step n, shape (N,).
field (array) – The field values at each grid point, shape (G,).
dx (float) – The spatial grid spacing.
grid (array) – The grid positions where the field is defined, shape (G,).
grid_start (float) – The starting position of the grid (usually the left boundary).
field_BC_left (int) – Boundary condition for the left side of the particle grid.
field_BC_right (int) – Boundary condition for the right side of the particle grid.
- Returns:
The interpolated field values at the particle positions, shape (N,).
- Return type:
array
- jaxincell.fields_to_particles_periodic_CN(x_n, field, dx, grid_start)¶
Interpolates field to particle using Periodic BCs. :param field: The field array (size N). :param grid_start: Physical position of field[0].
- jaxincell.rotation(dt, B, vsub, q_m)¶
This function implements the Boris algorithm to rotate the particle velocity vector in the magnetic field for one time step. This step is part of the numerical solution of the Lorentz force equation.
- Parameters:
dt (float) – Time step for the simulation.
B (array) – Magnetic field at the particle’s position, shape (3,).
vsub (array) – The particle’s velocity before the rotation, shape (3,).
q_m (array) – The charge-to-mass ratio of the particle, shape (3,).
- Returns:
The updated velocity after the rotation, shape (3,).
- Return type:
array
- jaxincell.boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x)¶
This function performs one step of the Boris algorithm for particle motion. The particle velocity is updated using the electric and magnetic fields at its position, and the particle position is updated using the new velocity.
- Parameters:
dt (float) – Time step for the simulation.
xs_nplushalf (array) – The particle positions at the half-time step n+1/2, shape (N, 3).
vs_n (array) – The particle velocities at time step n, shape (N, 3).
q_ms (array) – The charge-to-mass ratio of each particle, shape (N, 1).
E_fields_at_x (array) – The interpolated electric field values at the particle positions, shape (N, 3).
B_fields_at_x (array) – The magnetic field values at the particle positions, shape (N, 3).
- Returns:
- A tuple containing:
xs_nplus3_2 (array): The updated particle positions at time step n+3/2, shape (N, 3).
vs_nplus1 (array): The updated particle velocities at time step n+1, shape (N, 3).
- Return type:
tuple
- jaxincell.boris_step_relativistic(dt, xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x)¶
Relativistic Boris pusher for N particles.
- Parameters:
dt – Time step
xs_nplushalf – Particle positions at t = n + 1/2, shape (N, 3)
vs_n – Velocities at time t = n, shape (N, 3)
q_s – Charges, shape (N,)
m_s – Masses, shape (N,)
E_fields_at_x – Electric fields at particle positions, shape (N, 3)
B_fields_at_x – Magnetic fields at particle positions, shape (N, 3)
c – Speed of light (default = 1.0 for normalized units)
- Returns:
Updated positions at t = n + 3/2, shape (N, 3) vs_nplus1: Updated velocities at t = n + 1, shape (N, 3)
- Return type:
xs_nplus3_2
- jaxincell.plot(output, direction: str = 'x', threshold: float = 1e-12, save_mp4: str | None = None, fps: int = 30, dpi: int = 150, show: bool = True, animation_interval: int = 1, save_stride: int = 1, save_dpi: int | None = None, save_crf: int | None = None, save_preset: str | None = None, save_codec: str | None = None)¶
Production-ready plotting/animation for JAX-in-Cell outputs.
- What you get:
Heatmaps (x vs time): E, B (nonzero components), charge density. - IMPORTANT: heatmap color limits are fixed over the whole run using a robust
percentile, so growth/decay in time is visible (no per-frame re-normalization).
Instantaneous E(x,t) overlay: - For each plotted electric-field component, we draw a line on top of the heatmap. - The line is normalized by a single GLOBAL robust scale over the whole run,
so amplitude growth is visible.
Overlay axes have no ticks (prevents clashes with colorbars).
Distribution functions f(v,t) (LAB FRAME; NO drift centering): - Shown as clean line plots in their own subplot(s) (no current-density heatmap). - Solid: current frame - Dashed: initial (t=0), labeled “(initial)”
Phase space (x vs v) for electrons and ions for each requested component: - Uses a robust velocity range per species per component, so ion dynamics remains visible. - Uses LogNorm on counts (with +1 internally) so low-density structure is visible.
- Multi-species:
Uses diagnostics() legacy split if present (velocity_electrons/velocity_ions).
Otherwise combines output[“species”] by charge sign (q<0 as electrons, q>0 as ions).
- jaxincell.initialize_simulation_parameters(user_parameters={})¶
Initialize the simulation parameters for a particle-in-cell simulation, combining user-provided values with predefined defaults. This function ensures all required parameters are set and automatically calculates derived parameters based on the inputs.
The function uses lambda functions to define derived parameters that depend on other parameters. These lambda functions are evaluated after merging user-provided parameters with the defaults, ensuring derived parameters are consistent with any overrides.
Parameters:¶
- user_parametersdict
Dictionary containing user-specified parameters. Any parameter not provided will default to predefined values.
Returns:¶
- parametersdict
Dictionary containing all simulation parameters, with user-provided values overriding defaults.
- jaxincell.initialize_particles_fields(input_parameters={}, number_grid_points=50, number_pseudoelectrons=500, number_pseudoparticles_species=None, total_steps=350, max_number_of_Picard_iterations_implicit_CN=20, number_of_particle_substeps_implicit_CN=2)¶
Initialize particles and electromagnetic fields for a Particle-in-Cell simulation.
This function generates particle positions, velocities, charges, masses, and charge-to-mass ratios, as well as the initial electric and magnetic fields. It combines user-provided parameters with default values and calculates derived quantities.
Parameters:¶
- user_parametersdict
Dictionary of user-specified simulation parameters. Can include: - Physical parameters (e.g., box size, number of particles, thermal velocities). - Numerical parameters (e.g., grid resolution, timestep). - Boundary conditions and random seed for reproducibility.
Returns:¶
- parametersdict
Updated dictionary containing: - Particle positions and velocities (electrons and ions). - Particle charges, masses, and charge-to-mass ratios. - Initial electric and magnetic fields.
- jaxincell.simulation(input_parameters={}, number_grid_points=100, number_pseudoelectrons=3000, number_pseudoparticles_species=None, total_steps=1000, field_solver=0, positions=None, velocities=None, time_evolution_algorithm=0, max_number_of_Picard_iterations_implicit_CN=20, number_of_particle_substeps_implicit_CN=2)¶
Run a plasma physics simulation using a Particle-In-Cell (PIC) method in JAX.
This function simulates the evolution of a plasma system by solving for particle motion (electrons and ions) and self-consistent electromagnetic fields on a grid. It uses the Boris algorithm for particle updates and a leapfrog scheme for field updates.
Parameters:¶
- user_parametersdict
User-defined parameters for the simulation. These can include: - Physical parameters: box size, number of particles, thermal velocities. - Numerical parameters: grid resolution, time step size. - Boundary conditions for particles and fields. - Random seed for reproducibility.
Returns:¶
output : dict
- jaxincell.load_parameters(input_file)¶
Load parameters from a TOML file.
Parameters:¶
- input_filestr
Path to the TOML file containing simulation parameters.
Returns:¶
- parametersdict
Dictionary containing simulation parameters.
- jaxincell.get_S2_weights_and_indices_periodic_CN(x, dx, grid_start, grid_size)¶
Calculates weights and indices for Quadratic Spline (S2). Applies Periodic Wrapping to indices immediately.
- jaxincell.charge_density_BCs(particle_BC_left, particle_BC_right, position, dx, grid, charge)¶
Compute the charge contribution to the boundary points based on particle positions and boundary conditions.
- Parameters:
particle_BC_left (int) – Boundary condition for the left edge (0: periodic, 1: reflective, 2: absorbing).
particle_BC_right (int) – Boundary condition for the right edge (0: periodic, 1: reflective, 2: absorbing).
position (float) – Position of the particle.
dx (float) – Grid spacing.
grid (array-like) – Grid points as a 1D array.
charge (float) – Charge of the particle.
- Returns:
Charge contributions to the left and right boundaries.
- Return type:
tuple
- jaxincell.single_particle_charge_density(x, q, dx, grid, particle_BC_left, particle_BC_right)¶
Computes the charge density contribution of a single particle to the grid using a quadratic particle shape function.
- Parameters:
x (float) – The particle position.
q (float) – The particle charge.
dx (float) – The grid spacing.
grid (array) – The grid points.
particle_BC_left (int) – Left boundary condition type (0: periodic, 1: reflective, 2: absorbing).
particle_BC_right (int) – Right boundary condition type (0: periodic, 1: reflective, 2: absorbing).
- Returns:
The charge density contribution on the grid.
- Return type:
array
- jaxincell.calculate_charge_density(xs_n, qs, dx, grid, particle_BC_left, particle_BC_right, filter_passes=5, filter_alpha=0.5, filter_strides=(1, 2, 4), field_BC_left=0, field_BC_right=0)¶
Computes the total charge density on the grid by summing contributions from all particles.
- Parameters:
xs_n (array) – Particle positions at the current timestep, shape (N, 1).
qs (array) – Particle charges, shape (N, 1).
dx (float) – The grid spacing.
grid (array) – The grid points.
particle_BC_left (int) – Left particle boundary condition type (0: periodic, 1: reflective, 2: absorbing).
particle_BC_right (int) – Right particle boundary condition type (0: periodic, 1: reflective, 2: absorbing).
filter_passes (int) – Number of digital filter passes to apply (default: 5). Internally capped at 17.
filter_alpha (float) – Filter strength parameter (default: 0.5). Controls the weight of the center point in the 3-point filter.
filter_strides (tuple) – Tuple of stride values for multi-scale filtering (default: (1, 2, 4)).
field_BC_left (int) – Left boundary condition for filtering (default: 0: periodic, 1: reflective, 2: absorbing).
field_BC_right (int) – Right boundary condition for filtering (default: 0: periodic, 1: reflective, 2: absorbing).
- Returns:
Total charge density on the grid.
- Return type:
array
- jaxincell.current_density(xs_nminushalf, xs_n, xs_nplushalf, vs_n, qs, dx, dt, grid, grid_start, particle_BC_left, particle_BC_right, filter_passes=5, filter_alpha=0.5, filter_strides=(1, 2, 4), field_BC_left=0, field_BC_right=0)¶
Computes the current density j on the grid from particle motion.
- Parameters:
xs_nminushalf (array) – Particle positions at the half timestep before the current one, shape (N, 1).
xs_n (array) – Particle positions at the current timestep, shape (N, 1).
xs_nplushalf (array) – Particle positions at the half timestep after the current one, shape (N, 1).
vs_n (array) – Particle velocities at the current timestep, shape (N, 3).
qs (array) – Particle charges, shape (N, 1).
dx (float) – The grid spacing.
dt (float) – The time step size.
grid (array) – The grid points.
grid_start (float) – The starting position of the grid.
particle_BC_left (int) – Left particle boundary condition type (0: periodic, 1: reflective, 2: absorbing).
particle_BC_right (int) – Right particle boundary condition type (0: periodic, 1: reflective, 2: absorbing).
filter_passes (int) – Number of digital filter passes to apply (default: 5). Internally capped at 17.
filter_alpha (float) – Filter strength parameter (default: 0.5). Controls the weight of the center point in the 3-point filter.
filter_strides (tuple) – Tuple of stride values for multi-scale filtering (default: (1, 2, 4)).
field_BC_left (int) – Left boundary condition for filtering (default: 0: periodic, 1: reflective, 2: absorbing).
field_BC_right (int) – Right boundary condition for filtering (default: 0: periodic, 1: reflective, 2: absorbing).
- Returns:
Current density on the grid, shape (G, 3), where G is the number of grid points.
- Return type:
array
- jaxincell.current_density_periodic_CN(xs_n, vs_n, qs, dx, grid_start, grid_size)¶
Deposits Current J using Periodic BCs. Note: We removed xs_nminushalf/plus half arguments as we use the midpoint approximation (xs_n, vs_n) consistent with CN.
- jaxincell._MAX_FILTER_PASSES = 16¶
- jaxincell._shift_with_bc_1d(x, shift, bc_left, bc_right)¶
Shift x by ‘shift’ cells along axis=0 with boundary conditions:
bc = 0: periodic bc = 1: reflective (clamp to boundary cell) bc = 2: absorbing (outside domain -> 0)
Works for arrays with shape (G, …) – only axis 0 is shifted.
- jaxincell.binomial_filter_3point(x, alpha=0.5, stride=1, bc_left=0, bc_right=0)¶
3-point digital filter along axis 0 with BCs:
x^f_j = α x_j + (1-α)/2 [ x_{j-stride} + x_{j+stride} ]
- bc_left / bc_right:
0: periodic 1: reflective (clamp) 2: absorbing (outside -> 0)
- jaxincell._repeat_filter(y, stride, passes, alpha, bc_left=0, bc_right=0)¶
JAX-safe version of the 3-point digital filter with compensation: - If passes <= 0: return y unchanged. - If passes > 0:
apply (passes - 1) regular binomial_filter_3point passes with alpha (up to a static maximum _MAX_FILTER_PASSES),
then a final compensation pass with comp_alpha = passes - alpha*(passes - 1).
This is fully jit- and grad-safe even when passes is a traced value.
Note: The number of regular filter passes is internally capped at _MAX_FILTER_PASSES (16). If passes > _MAX_FILTER_PASSES + 1, the function will only apply _MAX_FILTER_PASSES regular passes plus one compensation pass, which may not match the expected filtering behavior. For typical use cases (default is 5), this limit should not be reached.
- jaxincell.filter_scalar_field(scalar_field, passes=5, alpha=0.5, strides=(1, 2, 4), bc_left=0, bc_right=0)¶
Apply a multi-pass 3-point binomial digital filter to a scalar field.
- Parameters:
scalar_field – Input scalar field array to be filtered.
passes – Number of filter passes (default: 5). Note: internally capped at 17 (16 regular passes + 1 compensation pass).
alpha – Filter strength parameter (default: 0.5).
strides – Tuple/list of stride values for filtering (default: (1, 2, 4)).
bc_left – Boundary condition for the left side (default: 0). 0: periodic, 1: reflective, 2: absorbing.
bc_right – Boundary condition for the right side (default: 0). 0: periodic, 1: reflective, 2: absorbing.
- Returns:
Filtered scalar field array.
- jaxincell.filter_vector_field(F, passes=5, alpha=0.5, strides=(1, 2, 4), bc_left=0, bc_right=0)¶
Apply digital filter along the grid axis (axis=0) for each component. F has shape (G, C), typically (grid_points, 3) for a vector field.
- Parameters:
F – Input vector field array, shape (G, C), typically (grid_points, 3).
passes – Number of filter passes (default: 5). Note: internally capped at 17 (16 regular passes + 1 compensation pass).
alpha – Filter strength parameter (default: 0.5).
strides – Tuple/list of stride values for filtering (default: (1, 2, 4)).
bc_left – Boundary condition for the left side (default: 0). 0: periodic, 1: reflective, 2: absorbing.
bc_right – Boundary condition for the right side (default: 0). 0: periodic, 1: reflective, 2: absorbing.
- Returns:
Filtered vector field array with the same shape as input.