Program Listing for File splitter.hpp¶
↰ Return to documentation for file (include/ripple/graph/splitter.hpp
)
#ifndef RIPPLE_GRAPH_SPLITTER_HPP
#define RIPPLE_GRAPH_SPLITTER_HPP
#include "modifier.hpp"
#include "detail/utils_.hpp"
#include <ripple/container/block_extractor.hpp>
#include <ripple/execution/execution_traits.hpp>
#include <ripple/functional/invoke.hpp>
namespace ripple {
/*==--- [modifier application] ---------------------------------------------==*/
template <typename Mod, typename Arg, shared_mod_enable_t<Mod> = 0>
decltype(auto)
apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept {
return as_shared(
detail::deref_if_iter(ripple_forward(arg)),
detail::padding_if_iter(ripple_forward(arg)),
params);
}
template <typename Mod, typename Arg, non_shared_mod_enable_t<Mod> = 0>
decltype(auto)
apply_modifier_after_deref(Arg&& arg, ExpansionParams params) noexcept {
if constexpr (is_expander_modifier_v<Mod>) {
return as_expansion(detail::deref_if_iter(ripple_forward(arg)), params);
} else {
return detail::deref_if_iter(ripple_forward(arg));
}
}
/*==--- [fill indices implementation] --------------------------------------==*/
template <size_t Size, typename Iterator, iterator_enable_t<Iterator> = 0>
auto fill_indices(
std::array<uint32_t, Size>& indices, bool& set, Iterator&& it) noexcept
-> void {
static_assert(
iterator_traits_t<Iterator>::dimensions >= Size,
"Iterator does not have enough dimensions to fill indices!");
if (set) {
return;
}
set = true;
unrolled_for<Size>([&](auto i) { indices[i] = it->indices[i]; });
}
template <size_t Size, typename Iterator, non_iterator_enable_t<Iterator> = 0>
auto fill_indices(
std::array<uint32_t, Size>& indices, bool& set, Iterator&& it) noexcept
-> void {}
/*==--- [splitter implementation] ------------------------------------------==*/
struct Splitter {
private:
// clang-format off
template <
typename Mods, typename Graph, typename F, typename... Args, size_t... I>
static auto add_node(
Graph& graph,
ExecutionKind exe,
std::string name,
size_t id,
F&& functor,
std::index_sequence<I...>,
std::array<ExpansionParams, sizeof...(Args)>& padding_mods,
Args&&... args) noexcept -> void {
// clang-format on
graph.emplace_named(
NodeInfo(name, id, NodeKind::split, exe),
[&functor, exe](auto&&... node_args) {
invoke_generic(
exe, ripple_forward(functor), ripple_forward(node_args)...);
},
apply_modifier_after_deref<tuple_element_t<I, Mods>>(
ripple_forward(args), padding_mods[I])...);
}
public:
template <typename Graph, typename F, typename... Args>
static auto
split(Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept
-> void {
split_impl(
graph,
exe,
ripple_forward(functor),
BlockExtractor::extract_blocks_if_tensor(ripple_forward(args))...);
}
private:
template <typename Graph, typename F, typename... Args>
static auto split_impl(
Graph& graph, ExecutionKind exe, F&& functor, Args&&... args) noexcept
-> void {
constexpr size_t dimensions = max_element(detail::dims_from_block<Args>...);
using Modifiers = Tuple<std::decay_t<Args>...>;
using Indices = std::array<uint32_t, dimensions>;
using PaddingMods = std::array<ExpansionParams, sizeof...(Args)>;
/* If any argument has a modifier, then padding nodes are needed, so add the
* them for any tensor which has the modifier and multiple partitions. */
if constexpr (has_modifier_v<Args...>) {
invoke_generic(
CpuExecutor(),
[&](auto&&... unwrapped_args) {
detail::add_padding_op_nodes<Modifiers>(
graph,
exe,
TransferKind::synchronous,
ripple_forward(unwrapped_args)...);
},
unwrap_modifiers(ripple_forward(args))...);
// Start new layer in the graph.
graph.split_ids_.emplace_back(graph.nodes_.size());
}
PaddingMods padding_mods{
get_modifier_expansion_params(ripple_forward(args))...};
/* Add the nodes to perform the actual computation. */
invoke_generic(
CpuExecutor(),
[&](PaddingMods& padding_mods, auto&&... unwrapped_args) {
Indices indices;
bool set = false;
(fill_indices(indices, set, unwrapped_args), ...);
/* Emplace a node onto the graph which is the functor and the
* args, converting any of the iteraors over tensor blocks into the
* actual block that the operation will be performed on.
*/
add_node<Modifiers>(
graph,
exe,
NodeInfo::name_from_indices(indices),
NodeInfo::id_from_indices(indices),
ripple_forward(functor),
std::make_index_sequence<sizeof...(Args)>(),
padding_mods,
ripple_forward(unwrapped_args)...);
},
padding_mods,
unwrap_modifiers(ripple_forward(args))...);
}
};
} // namespace ripple
#endif // RIPPLE_GRAPH_SPLITTER_HPP