• Docs >
  • Program Listing for File splitter.hpp
Shortcuts

Program Listing for File splitter.hpp

Return to documentation for file (include/ripple/graph/splitter.hpp)

#ifndef RIPPLE_GRAPH_SPLITTER_HPP
#define RIPPLE_GRAPH_SPLITTER_HPP

#include "modifier.hpp"
#include "detail/utils_.hpp"
#include <ripple/container/block_extractor.hpp>
#include <ripple/execution/execution_traits.hpp>
#include <ripple/functional/invoke.hpp>

namespace ripple {

/*==--- [modifier application] ---------------------------------------------==*/

template <typename Mod, typename Arg, shared_mod_enable_t<Mod> = 0>
decltype(auto)
apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept {
  return as_shared(
    detail::deref_if_iter(ripple_forward(arg)),
    detail::padding_if_iter(ripple_forward(arg)),
    params);
}

template <typename Mod, typename Arg, non_shared_mod_enable_t<Mod> = 0>
decltype(auto)
apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept {
  if constexpr (is_expander_modifier_v<Mod>) {
    return as_expansion(detail::deref_if_iter(ripple_forward(arg)), params);
  } else {
    return detail::deref_if_iter(ripple_forward(arg));
  }
}

/*==--- [fill indices implementation] --------------------------------------==*/

template <size_t Size, typename Iterator, iterator_enable_t<Iterator> = 0>
auto fill_indices(
  std::array<uint32_t, Size>& indices, bool& set, Iterator&& it) noexcept
  -> void {
  static_assert(
    iterator_traits_t<Iterator>::dimensions >= Size,
    "Iterator does not have enough dimensions to fill indices!");
  if (set) {
    return;
  }

  set = true;
  unrolled_for<Size>([&](auto i) { indices[i] = it->indices[i]; });
}

template <size_t Size, typename Iterator, non_iterator_enable_t<Iterator> = 0>
auto fill_indices(
  std::array<uint32_t, Size>& indices, bool& set, Iterator&& it) noexcept
  -> void {}

/*==--- [splitter implementation] ------------------------------------------==*/

struct Splitter {
 private:
  // clang-format off
  template <
    typename Mods, typename Graph, typename F, typename... Args, size_t... I>
  static auto add_node(
    Graph&                    graph,
    ExecutionKind             exe,
    std::string               name,
    size_t                    id,
    F&&                       functor,
    std::index_sequence<I...>,
    std::array<ExpansionParams, sizeof...(Args)>& padding_mods,
    Args&&...                 args) noexcept -> void {
    // clang-format on
    graph.emplace_named(
      NodeInfo(name, id, NodeKind::split, exe),
      [&functor, exe](auto&&... node_args) {
        invoke_generic(
          exe, ripple_forward(functor), ripple_forward(node_args)...);
      },
      apply_modifier_after_deref<tuple_element_t<I, Mods>>(
        ripple_forward(args), padding_mods[I])...);
  }

 public:
  template <typename Graph, typename F, typename... Args>
  static auto
  split(Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept
    -> void {
    split_impl(
      graph,
      exe,
      ripple_forward(functor),
      BlockExtractor::extract_blocks_if_tensor(ripple_forward(args))...);
  }

 private:
  template <typename Graph, typename F, typename... Args>
  static auto split_impl(
    Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept
    -> void {
    constexpr size_t dimensions = max_element(detail::dims_from_block<Args>...);
    using Modifiers             = Tuple<std::decay_t<Args>...>;
    using Indices               = std::array<uint32_t, dimensions>;
    using PaddingMods           = std::array<ExpansionParams, sizeof...(Args)>;

    /* If any argument has a modifier, then padding nodes are needed, so add the
     * them for any tensor which has the modifier and multiple partitions. */
    if constexpr (has_modifier_v<Args...>) {
      invoke_generic(
        CpuExecutor(),
        [&](auto&&... unwrapped_args) {
          detail::add_padding_op_nodes<Modifiers>(
            graph,
            exe,
            TransferKind::synchronous,
            ripple_forward(unwrapped_args)...);
        },
        unwrap_modifiers(ripple_forward(args))...);

      // Start new layer in the graph.
      graph.split_ids_.emplace_back(graph.nodes_.size());
    }

    PaddingMods padding_mods{
      get_modifier_expansion_params(ripple_forward(args))...};

    /* Add the nodes to perform the actual computation. */
    invoke_generic(
      CpuExecutor(),
      [&](PaddingMods& padding_mods, auto&&... unwrapped_args) {
        Indices indices;
        bool    set = false;
        (fill_indices(indices, set, unwrapped_args), ...);

        /* Emplace a node onto the graph which is the functor and the
         * args, converting any of the iteraors over tensor blocks into the
         * actual block that the operation will be performed on.
         */
        add_node<Modifiers>(
          graph,
          exe,
          NodeInfo::name_from_indices(indices),
          NodeInfo::id_from_indices(indices),
          ripple_forward(functor),
          std::make_index_sequence<sizeof...(Args)>(),
          padding_mods,
          ripple_forward(unwrapped_args)...);
      },
      padding_mods,
      unwrap_modifiers(ripple_forward(args))...);
  }
};

} // namespace ripple

#endif // RIPPLE_GRAPH_SPLITTER_HPP

Docs

Access comprehensive developer documentation for Ripple

View Docs

Tutorials

Get tutorials to help with understand all features

View Tutorials

Examples

Find examples to help get started

View Examples