• Docs >
  • Program Listing for File invoke_generic_impl_.hpp
Shortcuts

Program Listing for File invoke_generic_impl_.hpp

Return to documentation for file (include/ripple/functional/kernel/invoke_generic_impl_.hpp)

#ifndef RIPPLE_FUNCTIONAL_KERNEL_INVOKE_GENRIC__HPP
#define RIPPLE_FUNCTIONAL_KERNEL_INVOKE_GENRIC__HPP

#include "invoke_utils_.hpp"
#include <ripple/algorithm/max_element.hpp>

namespace ripple::kernel::cpu {

template <typename T, block_enable_t<T> = 0>
auto get_iterator(T&& host_block) noexcept {
  return host_block.begin();
}

template <typename T, multiblock_enable_t<T> = 0>
auto get_iterator(T&& block) noexcept {
  return block.host_iterator();
}

/*
template <typename T>
using block_shared_enable_t =
  std::enable_if_t<is_shared_wrapper_v<T> && is_any_block_v<T>, int>;

template <typename T>
using block_non_shared_enable_t =
  std::enable_if_t<!is_shared_wrapper_v<T> && is_any_block_v<T>, int>;

template <typename T>
using non_block_shared_enable_t =
  std::enable_if_t<is_shared_wrapper_v<T> && !is_any_block_v<T>, int>;
*/

template <typename T>
using non_any_block_or_shared_or_expansion_enable_t = std::enable_if_t<
  !is_any_block_v<T> && !is_shared_wrapper_v<T> && !is_expansion_wrapper_v<T>,
  int>;

template <typename T, typename... Offsets, any_block_enable_t<T> = 0>
decltype(auto) shift_if_iterator(T&& block, Offsets&&... offsets) noexcept {
  auto iter = get_iterator(block);
  auto offs = make_tuple(ripple_forward(offsets)...);
  unrolled_for<sizeof...(Offsets)>(
    [&](auto i) { iter.shift(Dimension<i>(), get<i>(offs)); });
  return iter;
}

template <typename T, typename... Offsets, any_block_enable_t<T> = 0>
decltype(auto)
shift_if_iterator(SharedWrapper<T>& wrapper, Offsets&&... offsets) noexcept {
  auto iter = get_iterator(wrapper.wrapped);
  auto offs = make_tupler(ripple_forward(offsets)...);
  unrolled_for<sizeof...(Offsets)>(
    [&](auto i) { iter.shift(Dimension<i>(), get<i>(offs)); });
  return iter;
}

template <typename T, typename... Offsets, any_block_enable_t<T> = 0>
decltype(auto)
shift_if_iterator(ExpansionWrapper<T>& wrapper, Offsets&&... offsets) noexcept {
  // TODO: Add expansion factor ...
  auto iter = get_iterator(wrapper.wrapped);
  auto offs = make_tuple(ripple_forward(offsets)...);
  unrolled_for<sizeof...(Offsets)>(
    [&](auto i) { iter.shift(Dimension<i>(), get<i>(offs)); });
  return iter;
}

template <
  typename T,
  typename... Offsets,
  non_any_block_or_shared_or_expansion_enable_t<T> = 0>
decltype(auto) shift_if_iterator(T&& non_block, Offsets&&... offsets) noexcept {
  return ripple_forward(non_block);
}

template <typename T, typename... Offsets, non_any_block_enable_t<T> = 0>
decltype(auto)
shift_if_iterator(SharedWrapper<T>& wrapper, Offsets&&... offsets) noexcept {
  return wrapper.wrapped;
}

template <typename T, typename... Offsets, non_any_block_enable_t<T> = 0>
decltype(auto)
shift_if_iterator(ExpansionWrapper<T>& wrapper, Offsets&&... offsets) noexcept {
  return wrapper.wrapped;
}

template <size_t>
struct InvokeGenericImpl {};

template <>
struct InvokeGenericImpl<2> {
  template <typename Callable, typename... Args>
  static auto
  invoke(Callable&& callable, const DimSizes& sizes, Args&&... args) -> void {
    for (size_t k = 0; k < sizes[dimz()]; ++k) {
      for (size_t j = 0; j < sizes[dimy()]; ++j) {
        for (size_t i = 0; i < sizes[dimx()]; ++i) {
          callable(shift_if_iterator(ripple_forward(args), i, j, k)...);
        }
      }
    }
  }
};

template <>
struct InvokeGenericImpl<1> {
  template <typename Callable, typename... Args>
  static auto
  invoke(Callable&& callable, const DimSizes& sizes, Args&&... args) -> void {
    for (size_t j = 0; j < sizes[dimy()]; ++j) {
      for (size_t i = 0; i < sizes[dimx()]; ++i) {
        callable(shift_if_iterator(ripple_forward(args), i, j)...);
      }
    }
  }
};

template <>
struct InvokeGenericImpl<0> {
  template <typename Callable, typename... Args>
  static auto
  invoke(Callable&& callable, const DimSizes& sizes, Args&&... args) -> void {
    for (size_t i = 0; i < sizes[dimx()]; ++i) {
      callable(shift_if_iterator(ripple_forward(args), i)...);
    }
  }
};

template <typename Invocable, typename... Args>
auto invoke_generic_impl(Invocable&& invocable, Args&&... args) noexcept
  -> void {
  constexpr size_t dims = max_element(any_block_traits_t<Args>::dimensions...);

  // Find the grid sizes:
  const auto sizes = DimSizes{
    max_element(size_t{1}, get_size_if_block(args, dimx())...),
    max_element(size_t{1}, get_size_if_block(args, dimy())...),
    max_element(size_t{1}, get_size_if_block(args, dimz())...)};

  InvokeGenericImpl<dims - 1>::invoke(
    ripple_forward(invocable), sizes, ripple_forward(args)...);
}

} // namespace ripple::kernel::cpu

#endif // RIPPLE_FUNCTIONAL_KERNEL_INVOKE_GENERIC_IMPL__HPP

Docs

Access comprehensive developer documentation for Ripple

View Docs

Tutorials

Get tutorials to help with understand all features

View Tutorials

Examples

Find examples to help get started

View Examples