.. _program_listing_file_include_ripple_graph_splitter.hpp: Program Listing for File splitter.hpp ===================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/ripple/graph/splitter.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef RIPPLE_GRAPH_SPLITTER_HPP #define RIPPLE_GRAPH_SPLITTER_HPP #include "modifier.hpp" #include "detail/utils_.hpp" #include #include #include namespace ripple { /*==--- [modifier application] ---------------------------------------------==*/ template = 0> decltype(auto) apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept { return as_shared( detail::deref_if_iter(ripple_forward(arg)), detail::padding_if_iter(ripple_forward(arg)), params); } template = 0> decltype(auto) apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept { if constexpr (is_expander_modifier_v) { return as_expansion(detail::deref_if_iter(ripple_forward(arg)), params); } else { return detail::deref_if_iter(ripple_forward(arg)); } } /*==--- [fill indices implementation] --------------------------------------==*/ template = 0> auto fill_indices( std::array& indices, bool& set, Iterator&& it) noexcept -> void { static_assert( iterator_traits_t::dimensions >= Size, "Iterator does not have enough dimensions to fill indices!"); if (set) { return; } set = true; unrolled_for([&](auto i) { indices[i] = it->indices[i]; }); } template = 0> auto fill_indices( std::array& indices, bool& set, Iterator&& it) noexcept -> void {} /*==--- [splitter implementation] ------------------------------------------==*/ struct Splitter { private: // clang-format off template < typename Mods, typename Graph, typename F, typename... Args, size_t... I> static auto add_node( Graph& graph, ExecutionKind exe, std::string name, size_t id, F&& functor, std::index_sequence, std::array& padding_mods, Args&&... args) noexcept -> void { // clang-format on graph.emplace_named( NodeInfo(name, id, NodeKind::split, exe), [&functor, exe](auto&&... node_args) { invoke_generic( exe, ripple_forward(functor), ripple_forward(node_args)...); }, apply_modifier_after_deref>( ripple_forward(args), padding_mods[I])...); } public: template static auto split(Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept -> void { split_impl( graph, exe, ripple_forward(functor), BlockExtractor::extract_blocks_if_tensor(ripple_forward(args))...); } private: template static auto split_impl( Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept -> void { constexpr size_t dimensions = max_element(detail::dims_from_block...); using Modifiers = Tuple...>; using Indices = std::array; using PaddingMods = std::array; /* If any argument has a modifier, then padding nodes are needed, so add the * them for any tensor which has the modifier and multiple partitions. */ if constexpr (has_modifier_v) { invoke_generic( CpuExecutor(), [&](auto&&... unwrapped_args) { detail::add_padding_op_nodes( graph, exe, TransferKind::synchronous, ripple_forward(unwrapped_args)...); }, unwrap_modifiers(ripple_forward(args))...); // Start new layer in the graph. graph.split_ids_.emplace_back(graph.nodes_.size()); } PaddingMods padding_mods{ get_modifier_expansion_params(ripple_forward(args))...}; /* Add the nodes to perform the actual computation. */ invoke_generic( CpuExecutor(), [&](PaddingMods& padding_mods, auto&&... unwrapped_args) { Indices indices; bool set = false; (fill_indices(indices, set, unwrapped_args), ...); /* Emplace a node onto the graph which is the functor and the * args, converting any of the iteraors over tensor blocks into the * actual block that the operation will be performed on. */ add_node( graph, exe, NodeInfo::name_from_indices(indices), NodeInfo::id_from_indices(indices), ripple_forward(functor), std::make_index_sequence(), padding_mods, ripple_forward(unwrapped_args)...); }, padding_mods, unwrap_modifiers(ripple_forward(args))...); } }; } // namespace ripple #endif // RIPPLE_GRAPH_SPLITTER_HPP