view_op
Structs
Struct: Reshape
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after reshaping.
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the reshape operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
Args
-
curr
:Array
The current array to store the result (modified in-place). -
args
:List[Array]
The array on which the reshape view is created.
Note: The information of the shape computation is stored in the ArrayShape object of the curr array.
Constraints:
- The number of elements in the input array must be equal to the number of elements in the target shape.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the reshape operation.
Args
-
primals
:List[Array]
A list containing the primal input array. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- A list containing the gradient with respect to the input.
Note: The vector-Jacobian product for reshape is computed by calling the reshape operation.
fwd(arg0: Array, shape: List[Int]) -> Array
Creates a view of the input array with the given shape.
Functions
reshape
reshape(arg0: Array, shape: List[Int]) -> Array
Creates a view of the input array with the given shape.
view
view(arg0: Array, shape: List[Int]) -> Array
Creates a view of the input array with the given shape.
flatten
flatten(arg0: Array) -> Array
Last updated on