jaxincell._filters

Attributes

_MAX_FILTER_PASSES

Functions

_shift_with_bc_1d(x, shift, bc_left, bc_right)

Shift x by 'shift' cells along axis=0 with boundary conditions:

binomial_filter_3point(x[, alpha, stride, bc_left, ...])

3-point digital filter along axis 0 with BCs:

_repeat_filter(y, stride, passes, alpha[, bc_left, ...])

JAX-safe version of the 3-point digital filter with compensation:

filter_scalar_field(scalar_field[, passes, alpha, ...])

Apply a multi-pass 3-point binomial digital filter to a scalar field.

filter_vector_field(F[, passes, alpha, strides, ...])

Apply digital filter along the grid axis (axis=0) for each component.

Module Contents

jaxincell._filters._MAX_FILTER_PASSES = 16
jaxincell._filters._shift_with_bc_1d(x, shift, bc_left, bc_right)

Shift x by ‘shift’ cells along axis=0 with boundary conditions:

bc = 0: periodic bc = 1: reflective (clamp to boundary cell) bc = 2: absorbing (outside domain -> 0)

Works for arrays with shape (G, …) – only axis 0 is shifted.

jaxincell._filters.binomial_filter_3point(x, alpha=0.5, stride=1, bc_left=0, bc_right=0)

3-point digital filter along axis 0 with BCs:

x^f_j = α x_j + (1-α)/2 [ x_{j-stride} + x_{j+stride} ]

bc_left / bc_right:

0: periodic 1: reflective (clamp) 2: absorbing (outside -> 0)

jaxincell._filters._repeat_filter(y, stride, passes, alpha, bc_left=0, bc_right=0)

JAX-safe version of the 3-point digital filter with compensation: - If passes <= 0: return y unchanged. - If passes > 0:

  • apply (passes - 1) regular binomial_filter_3point passes with alpha (up to a static maximum _MAX_FILTER_PASSES),

  • then a final compensation pass with comp_alpha = passes - alpha*(passes - 1).

This is fully jit- and grad-safe even when passes is a traced value.

Note: The number of regular filter passes is internally capped at _MAX_FILTER_PASSES (16). If passes > _MAX_FILTER_PASSES + 1, the function will only apply _MAX_FILTER_PASSES regular passes plus one compensation pass, which may not match the expected filtering behavior. For typical use cases (default is 5), this limit should not be reached.

jaxincell._filters.filter_scalar_field(scalar_field, passes=5, alpha=0.5, strides=(1, 2, 4), bc_left=0, bc_right=0)

Apply a multi-pass 3-point binomial digital filter to a scalar field.

Parameters:
  • scalar_field – Input scalar field array to be filtered.

  • passes – Number of filter passes (default: 5). Note: internally capped at 17 (16 regular passes + 1 compensation pass).

  • alpha – Filter strength parameter (default: 0.5).

  • strides – Tuple/list of stride values for filtering (default: (1, 2, 4)).

  • bc_left – Boundary condition for the left side (default: 0). 0: periodic, 1: reflective, 2: absorbing.

  • bc_right – Boundary condition for the right side (default: 0). 0: periodic, 1: reflective, 2: absorbing.

Returns:

Filtered scalar field array.

jaxincell._filters.filter_vector_field(F, passes=5, alpha=0.5, strides=(1, 2, 4), bc_left=0, bc_right=0)

Apply digital filter along the grid axis (axis=0) for each component. F has shape (G, C), typically (grid_points, 3) for a vector field.

Parameters:
  • F – Input vector field array, shape (G, C), typically (grid_points, 3).

  • passes – Number of filter passes (default: 5). Note: internally capped at 17 (16 regular passes + 1 compensation pass).

  • alpha – Filter strength parameter (default: 0.5).

  • strides – Tuple/list of stride values for filtering (default: (1, 2, 4)).

  • bc_left – Boundary condition for the left side (default: 0). 0: periodic, 1: reflective, 2: absorbing.

  • bc_right – Boundary condition for the right side (default: 0). 0: periodic, 1: reflective, 2: absorbing.

Returns:

Filtered vector field array with the same shape as input.