• Docs >
  • Program Listing for File block_iterator.hpp
Shortcuts

Program Listing for File block_iterator.hpp

Return to documentation for file (include/ripple/iterator/block_iterator.hpp)

#ifndef RIPPLE_CONTAINER_BLOCK_ITERATOR_HPP
#define RIPPLE_CONTAINER_BLOCK_ITERATOR_HPP

#include "iterator_traits.hpp"
#include <ripple/container/array_traits.hpp>
#include <ripple/math/math.hpp>
#include <ripple/storage/storage_traits.hpp>
#include <ripple/execution/thread_index.hpp>
#include <cassert>

namespace ripple {

template <typename T, typename Space>
class BlockIterator {
 public:
  static constexpr size_t dims = space_traits_t<Space>::dimensions;

 private:
  /*==--- [traits] ---------------------------------------------------------==*/

  // clang-format off
  using LayoutTraits = layout_traits_t<T>;
  using Value        = typename LayoutTraits::Value;
  using Ref          = typename LayoutTraits::IterRef;
  using ConstRef     = typename LayoutTraits::IterConstRef;
  using Ptr          = typename LayoutTraits::IterPtr;
  using ConstPtr     = typename LayoutTraits::IterConstPtr;
  using CopyType     = typename LayoutTraits::IterCopy;
  using Storage      = typename LayoutTraits::IterStorage;
  using Offsetter    = typename LayoutTraits::Allocator;
  using Vec          = Vec<CopyType, dims, ContiguousOwned>;
  // clang-format on

  using UnderlyingType = typename array_traits_t<T>::Value;

  /*==--- [constants] ------------------------------------------------------==*/

  static constexpr auto is_poly_layout_overload =
    PolyLayoutOverloader<LayoutTraits::is_polymorphic_layout>{};

  /*==--- [deref impl] -----------------------------------------------------==*/

  ripple_all auto deref_impl(PolyLayoutOverload) noexcept -> Ref {
    return Ref{data_ptr_};
  }

  ripple_all auto
  deref_impl(PolyLayoutOverload) const noexcept -> ConstRef {
    return ConstRef{data_ptr_};
  }

  ripple_all auto deref_impl(NonPolyLayoutOverload) noexcept -> Ref {
    return *data_ptr_;
  }

  ripple_all auto
  deref_impl(NonPolyLayoutOverload) const noexcept -> ConstRef {
    return *data_ptr_;
  }

  /*==--- [access impl] ----------------------------------------------------==*/

  ripple_all auto access_impl(PolyLayoutOverload) noexcept -> Ptr {
    return Ptr{Value{data_ptr_}};
  }

  ripple_all auto
  access_impl(PolyLayoutOverload) const noexcept -> ConstPtr {
    return ConstPtr{Value{data_ptr_}};
  }

  ripple_all auto access_impl(NonPolyLayoutOverload) noexcept -> Ptr {
    return data_ptr_;
  }

  ripple_all auto
  access_impl(NonPolyLayoutOverload) const noexcept -> ConstPtr {
    return data_ptr_;
  }

  /*===--- [unwrap impl] ---------------------------------------------------==*/

  ripple_all auto
  unwrap_impl(PolyLayoutOverload) const noexcept -> CopyType {
    return CopyType{data_ptr_};
  }

  ripple_all auto
  unwrap_impl(NonPolyLayoutOverload) const noexcept -> CopyType {
    return CopyType{*data_ptr_};
  }

 protected:
  Storage data_ptr_;
  Space   space_;

 public:
  // clang-format off
  using RawPtr      = typename LayoutTraits::RawPtr;
  using ConstRawPtr = typename LayoutTraits::ConstRawPtr;
  // clang-format on

  ripple_all BlockIterator(Storage data_ptr, Space space) noexcept
  : data_ptr_{data_ptr}, space_{space} {}

  /*==--- [operator overloading] -------------------------------------------==*/

  ripple_all auto operator*() noexcept -> Ref {
    return deref_impl(is_poly_layout_overload);
  }

  ripple_all auto operator*() const noexcept -> ConstRef {
    return deref_impl(is_poly_layout_overload);
  }

  // clang-format off
  ripple_all auto operator->() noexcept -> Ptr {
    return access_impl(is_poly_layout_overload);
  }

  ripple_all auto operator->() const noexcept -> ConstPtr {
    return access_impl(is_poly_layout_overload);
  }
  // clang-format on

  ripple_all auto unwrap() const noexcept -> CopyType {
    return unwrap_impl(is_poly_layout_overload);
  }

  /*==--- [offsetting] -----------------------------------------------------==*/

  template <typename Dim>
  ripple_all constexpr auto
  offset(Dim&& dim, int amount) const noexcept -> BlockIterator {
    return BlockIterator{
      Offsetter::offset(data_ptr_, space_, ripple_forward(dim), amount),
      space_};
  }

  template <typename Dim>
  ripple_all constexpr auto
  shift(Dim&& dim, int amount) noexcept -> void {
    Offsetter::shift(data_ptr_, space_, ripple_forward(dim), amount);
  }

  ripple_all auto data() noexcept -> RawPtr {
    if constexpr (is_storage_accessor_v<Storage>) {
      return data_ptr_.data();
    } else {
      return data_ptr_;
    }
  }

  ripple_all auto data() const noexcept -> ConstRawPtr {
    if constexpr (is_storage_accessor_v<Storage>) {
      return data_ptr_.data();
    } else {
      return data_ptr_;
    }
  }

  ripple_all decltype(auto) storage() noexcept {
    if constexpr (is_storage_accessor_v<Storage>) {
      return data_ptr_;
    } else {
      return data_ptr_;
    }
  }

  ripple_all decltype(auto) storage() const noexcept {
    if constexpr (is_storage_accessor_v<Storage>) {
      return data_ptr_;
    } else {
      return data_ptr_;
    }
  }

  template <typename Dim>
  ripple_all auto is_valid(Dim&& dim) const noexcept -> bool {
    return ::ripple::global_idx(ripple_forward(dim)) <
           size(ripple_forward(dim));
  }

  /*==--- [dimensions] -----------------------------------------------------==*/

  ripple_all constexpr auto dimensions() const noexcept -> size_t {
    return dims;
  }

  /*==--- [gradients] ------------------------------------------------------==*/

  template <typename Dim>
  ripple_all constexpr auto
  backward_diff(Dim&& dim, int amount = 1) const noexcept -> CopyType {
    return deref_impl(is_poly_layout_overload) -
           *offset(ripple_forward(dim), -amount);
  }

  template <typename Dim>
  ripple_all constexpr auto
  forward_diff(Dim&& dim, int amount = 1) const noexcept -> CopyType {
    return *offset(ripple_forward(dim), amount) -
           deref_impl(is_poly_layout_overload);
  }

  template <typename Dim, typename Index>
  ripple_all constexpr auto
  central_diff(Dim&& dim, Index&& index, int amount) const noexcept
    -> UnderlyingType {
    // TODO: Do compile time check that underlying type has a component member.
    return offset(ripple_forward(dim), amount)->component(index) -
           offset(ripple_forward(dim), -amount)->component(index);
  }

  template <typename Dim>
  ripple_all constexpr auto
  central_diff(Dim&& dim, int amount = 1) const noexcept -> CopyType {
    return *offset(ripple_forward(dim), amount) -
           *offset(ripple_forward(dim), -amount);
  }

  template <typename Dim>
  ripple_all constexpr auto
  second_diff(Dim&& dim, int amount = 1) const noexcept -> CopyType {
    return *offset(ripple_forward(dim), amount) +
           *offset(ripple_forward(dim), -amount) - (T(2) * this->operator*());
  }

  template <typename Dim1, typename Dim2>
  ripple_all constexpr auto
  second_partial_diff(Dim1&& dim1, Dim2&& dim2, int amount = 1) const noexcept
    -> CopyType {
    const auto scale   = 0.25;
    const int  namount = -amount;
    const auto next    = offset(ripple_forward(dim1), amount);
    const auto prev    = offset(ripple_forward(dim1), namount);

    return scale * ((*next.offset(ripple_forward(dim2), amount)) -
                    (*next.offset(ripple_forward(dim2), namount)) -
                    (*prev.offset(ripple_forward(dim2), amount)) +
                    (*prev.offset(ripple_forward(dim2), namount)));
  }

  template <typename DataType>
  ripple_all constexpr auto
  grad(DataType dh = 1) const noexcept -> Vec {
    Vec result{};
    unrolled_for<dims>([&](auto d) { result[d] = grad_dim(d, dh); });
    return result;
  }

  template <typename Dim, typename DataType>
  ripple_all constexpr auto
  grad_dim(Dim&& dim, DataType dh = 1) const noexcept -> CopyType {
    // NOTE: Have to do something different depending on the data type
    // because the optimization for 0.5 / dh doesn't work if dh is integral
    // since it goes to zero.
    //
    // The lack of optimization in integer branch means that a division will
    // happen for each element if the iterated types has multiple elements,
    // where as for the non-integer branch, the pre-division means that the
    // divisions are turned into multiplications, which is a lot faster,
    // especially on the device.
    if constexpr (std::is_integral_v<DataType>) {
      return this->central_diff(ripple_forward(dim)) / (2 * dh);
    } else {
      return (DataType{0.5} / dh) * this->central_diff(ripple_forward(dim));
    }
  }

  /*==--- [normal]
   * ---------------------------------------------------------==*/

  template <typename DataType>
  ripple_all constexpr auto
  norm(DataType dh = DataType(1)) const noexcept -> Vec {
    Vec      result{};
    DataType mag{1e-15};

    // NOTE: Here we do not use the grad() function to save some loops.
    // If grad was used, then an extra multiplication would be required per
    // element for vector types.
    unrolled_for<dims>([&](auto d) {
      // Add the negative sign in now, to avoid an op later ...
      if constexpr (std::is_integral_v<DataType>) {
        result[d] = this->central_diff(d) / (-2 * dh);
      } else {
        result[d] = (DataType{-0.5} / dh) * this->central_diff(d);
      }
      mag += result[d] * result[d];
    });
    // assert(std::abs(std::sqrt(mag)) < 1e-22 && "Division by zero in
    // norm!");
    result /= std::sqrt(mag);
    return result;
  }

  template <typename DataType>
  ripple_all constexpr auto
  norm_sd(DataType dh = 1) const noexcept -> Vec {
    // NOTE: Here we do not use the grad() function to save some loops.
    // If grad was used, then an extra multiplication would be required per
    // element for vector types.
    Vec result{};
    unrolled_for<dims>([&](auto d) {
      // Add the negative sign in now, to avoid an op later ...
      if constexpr (std::is_integral_v<DataType>) {
        result[d] = this->central_diff(d) / (-2 * dh);
      } else {
        result[d] = (DataType{-0.5} / dh) * this->central_diff(d);
      }
    });
    return result;
  }

  /*==--- [curvature]
   * ------------------------------------------------------==*/

  template <typename DataType>
  ripple_all constexpr auto
  curvature(DataType dh = 1) const noexcept -> DataType {
    if constexpr (dims == 2) {
      return curvature_2d(dh);
    } else if (dims == 3) {
      return curvature_3d(dh);
    } else {
      return DataType{1};
    }
  }

  template <typename DataType>
  ripple_all constexpr auto
  curvature_2d(DataType dh = 1) const noexcept -> DataType {
    const auto px      = grad_dim(dimx(), dh);
    const auto py      = grad_dim(dimy(), dh);
    const auto px2     = px * px;
    const auto py2     = py * py;
    const auto px2_py2 = px2 + py2;
    const auto dh2     = DataType(1) / (dh * dh);
    const auto pxx     = second_diff(dimx());
    const auto pyy     = second_diff(dimy());
    const auto pxy     = second_partial_diff(dimx(), dimy());

    return dh2 * (pxx * py2 - DataType{2} * py * px * pxy + pyy * px2) /
           math::sqrt(px2_py2 * px2_py2 * px2_py2);
  }

  template <typename DataType>
  ripple_all constexpr auto
  curvature_3d(DataType dh = 1) const noexcept -> DataType {
    const auto px    = grad_dim(dimx(), dh);
    const auto py    = grad_dim(dimy(), dh);
    const auto pz    = grad_dim(dimz(), dh);
    const auto px2   = px * px;
    const auto py2   = py * py;
    const auto pz2   = pz * pz;
    const auto pxyz2 = px2 + py2 + pz2;

    const auto dh2 = DataType(1) / (dh * dh);
    const auto pxx = second_diff(dimx());
    const auto pyy = second_diff(dimy());
    const auto pzz = second_diff(dimz());
    const auto pxy = second_partial_diff(dimx(), dimy());
    const auto pxz = second_partial_diff(dimx(), dimz());
    const auto pyz = second_partial_diff(dimy(), dimz());

    // clang-format off
    return dh2 * (
      (pyy + pzz) * px2 +
      (pxx + pzz) * py2 +
      (pxx + pyy) * pz2 -
      DataType{2} * (px * py * pxy + px * pz * pxz + py * pz * pyz)) /
      math::sqrt(pxyz2 * pxyz2 * pxyz2);
    // clang-format on
  }

  /*==--- [size]
   * -----------------------------------------------------------==*/

  ripple_all constexpr auto size() const noexcept -> size_t {
    return space_.internal_size();
  }

  template <typename Dim>
  ripple_all constexpr auto size(Dim&& dim) const noexcept -> size_t {
    return space_.internal_size(ripple_forward(dim));
  }

  template <typename Dim>
  ripple_all constexpr auto
  resize(Dim&& dim, size_t size) noexcept -> void {
    space_.resize_dim(ripple_forward(dim), size);
  }

  ripple_all constexpr auto padding() const noexcept -> size_t {
    return space_.padding();
  }
};

} // namespace ripple

#endif // RIPPLE_CONTAINER_BOX_ITERATOR_HPP

Docs

Access comprehensive developer documentation for Ripple

View Docs

Tutorials

Get tutorials to help with understand all features

View Tutorials

Examples

Find examples to help get started

View Examples