Skip to Content

expand_op

View the code on GitHub

Structs

Struct: Expand

Fields

Methods

compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape resulting from broadcasting one array to another.
Args
  • curr: ArrayShape The ArrayShape to store the result of the computation.

  • args: List[ArrayShape] Source ArrayShape, target ArrayShape, and axes to ignore during broadcasting.

Constraints:

  • The number of dimensions of the source ArrayShape must be less than or equal to the number of dimensions of the target ArrayShape.
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the expand operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
Args
  • curr: Array The current array to store the result (modified in-place).

  • args: List[Array] The array on which the expanded view is created.

Note: The information of the shape computation is stored in the ArrayShape object of the curr array.

jvp(primals: List[Array], tangents: List[Array]) -> Array
more details
Args
  • primals: List[Array]

  • tangents: List[Array]

Returns
  • Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the expand operation.
Args
  • primals: List[Array] A list containing the primal input array.

  • grad: Array The gradient of the output with respect to some scalar function.

  • out: Array The output of the forward pass (unused in this function).

Returns
  • List[Array] - A list containing the gradient with respect to the input.

Note: The vector-Jacobian product for expand is computed by reducing the gradient along the axes that were expanded.

fwd(arg0: Array, array_shape: ArrayShape, ignore_axes: List[Int] = List()) -> Array
Expands the input array to the given shape.
Args
  • arg0: Array The input array.

  • array_shape: ArrayShape The target shape.

  • ignore_axes: List[Int] (default: List()) The axes to ignore during expansion.

Returns
  • Array - The expanded array.

Constraints:

  • The number of dimensions of the source ArrayShape must be less than or equal to the number of dimensions of the target ArrayShape.
  • The number of axis to ignore must be less than or equal to the number of dimensions of the source ArrayShape.

Note: When performing an expand operation in eager mode, the function checks if the shape of the input array is equal to the target shape. If they are equal, the function returns the input array as is. This is done to avoid unnecessary computation.

Functions

expand

expand(arg0: Array, shape: ArrayShape, ignore_axes: List[Int] = List()) -> Array
Expands the input array to the given shape.
Args
  • arg0: Array The input array.

  • shape: ArrayShape The target shape.

  • ignore_axes: List[Int] (default: List()) The axes to ignore during expansion.

Returns
  • Array - The expanded array.

Note: This function is a wrapper around the expand function with the target shape being the shape of the target array.

expand_as

expand_as(arg0: Array, arg1: Array) -> Array
Expands the input array to the shape of the target array.
Args
  • arg0: Array The input array.

  • arg1: Array The target array.

Returns
  • Array - A view on the input array with the shape of the target array.

Note: This function is a wrapper around the expand function with the target shape being the shape of the target array.

broadcast_to

broadcast_to(arg0: Array, shape: List[Int]) -> Array
Broadcasts the input array to the given shape.
Args
  • arg0: Array The input array.

  • shape: List[Int] The target shape.

Returns
  • Array - A view on the input array with the target shape.

Note: This function is a wrapper around the expand function with the target shape being the shape of the target array.

Last updated on