unsqueeze_op
Structs
Struct: Unsqueeze
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after unsqueezing. This adds dimensions of size 1 along the specified axes.
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the unsqueeze operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the unsqueeze operation.
Args
-
primals
:List[Array]
A list containing the primal input array. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass (unused in this function).
Returns
List[Array]
- A list containing the gradient with respect to the input.
Note: The vector-Jacobian product for unsqueeze is computed by squeezing the gradient.
fwd(arg0: Array, axis: ArrayShape) -> Array
Unsqueezes the input array by adding axes of length 1.
Functions
unsqueeze
unsqueeze(arg0: Array, axis: ArrayShape) -> Array
Last updated on