Skip to Content

utils

View the code on GitHub

Structs

Functions

reverse_bits_simd

reverse_bits_simd(x: SIMD[uint32, nelts[::DType]()]) -> SIMD[uint32, nelts[::DType]()]
Reverse the bits of a 32-bit integer.
Args
  • x: SIMD[uint32, nelts[::DType]()]
Returns
  • SIMD[uint32, nelts[::DType]()]

bit_reversal

bit_reversal(n: Int, reordered_arr_data: UnsafePointer[SIMD[uint32, 1]])
Generate a bit reversal permutation for integers from 0 to n-1. Works for any positive integer n.
Args
  • n: Int

  • reordered_arr_data: UnsafePointer[SIMD[uint32, 1]]

copy_complex_and_cast

copy_complex_and_cast[dst_type: DType, src_type: DType](dst: UnsafePointer[SIMD[dst_type, 1]], src: UnsafePointer[SIMD[src_type, 1]], size: Int, conjugate_and_divide: Bool = False, divisor: SIMD[dst_type, 1] = SIMD(1))
Copy complex data from one buffer to another and cast the data to a different type. Optionally conjugate and divide by a scalar (usefule for inverse FFT).
Args
  • dst: UnsafePointer[SIMD[dst_type, 1]]

  • src: UnsafePointer[SIMD[src_type, 1]]

  • size: Int

  • conjugate_and_divide: Bool (default: False)

  • divisor: SIMD[dst_type, 1] (default: SIMD(1))

get_workload

get_workload(n: Int, divisions: Int, num_workers: Int) -> Int
Calculate the workload size for each worker.
Args
  • n: Int

  • divisions: Int

  • num_workers: Int

Returns
  • Int

list_swap

list_swap(arg: List[Int], i: Int, j: Int) -> List[Int]
more details
Args
  • arg: List[Int]

  • i: Int

  • j: Int

Returns
  • List[Int]

determine_num_workers

determine_num_workers(size: Int) -> Int
Determine the number of workers to use for parallelization.
Args
  • size: Int
Returns
  • Int

fft_op_array

fft_op_array(arg0: Array, name: String, fwd: fn(mut Array, List[Array]) raises -> None, jvp: fn(List[Array], List[Array]) raises -> Array, vjp: fn(List[Array], Array, Array) raises -> List[Array], dims: List[Int], norm: String) -> Array
more details
Args
  • arg0: Array

  • name: String

  • fwd: fn(mut Array, List[Array]) raises -> None

  • jvp: fn(List[Array], List[Array]) raises -> Array

  • vjp: fn(List[Array], Array, Array) raises -> List[Array]

  • dims: List[Int]

  • norm: String

Returns
  • Array

encode_fft_params

encode_fft_params(dims: List[Int], norm: String) -> List[Int]
more details
Args
  • dims: List[Int]

  • norm: String

Returns
  • List[Int]

get_dims_from_encoded_params

get_dims_from_encoded_params(params: List[Int]) -> List[Int]
more details
Args
  • params: List[Int]
Returns
  • List[Int]

get_norm_from_encoded_params

get_norm_from_encoded_params(params: List[Int]) -> String
more details
Args
  • params: List[Int]
Returns
  • String
Last updated on