• Docs >
  • Program Listing for File mat.hpp
Shortcuts

Program Listing for File mat.hpp

Return to documentation for file (include/ripple/math/mat.hpp)

#ifndef RIPPLE_MATH_MAT_HPP
#define RIPPLE_MATH_MAT_HPP

#include <ripple/container/array.hpp>
#include <ripple/container/tuple.hpp>
#include <ripple/storage/polymorphic_layout.hpp>
#include <ripple/storage/storage_descriptor.hpp>
#include <ripple/storage/storage_traits.hpp>
#include <ripple/storage/struct_accessor.hpp>
#include <ripple/utility/portability.hpp>

namespace ripple {

template <typename T, typename Rows, typename Cols, typename Layout>
struct MatImpl;

template <
  typename T,
  size_t Rows,
  size_t Cols,
  typename Layout = ContiguousOwned>
using Mat = MatImpl<T, Num<Rows>, Num<Cols>, Layout>;

template <typename T, typename Rows, typename Cols, typename Layout>
struct MatImpl : public PolymorphicLayout<MatImpl<T, Rows, Cols, Layout>> {
 private:
  static constexpr auto elements = size_t{Rows::value * Cols::value};

  // clang-format off
  using Descriptor = StorageDescriptor<Layout, Vector<T, elements>>;
  using Storage    = typename Descriptor::Storage;
  using Value      = std::decay_t<T>;
  // clang-format on

  template <typename OType, typename ORows, typename OCols, typename OLayout>
  friend struct MatImpl;

  template <typename Layable, bool IsStridable>
  friend struct LayoutTraits;

 public:
  Storage storage_;

  /*==--- [construction] ---------------------------------------------------==*/

  ripple_all constexpr MatImpl() noexcept {}

  ripple_all constexpr MatImpl(T val) noexcept {
    unrolled_for<elements>(
      [&](auto i) { storage_.template get<0, i>() = val; });
  }

  template <typename... Values, variadic_size_enable_t<elements, Values...> = 0>
  ripple_all constexpr MatImpl(Values&&... values) noexcept {
    const auto v = Tuple<Values...>{values...};
    unrolled_for<elements>(
      [&](auto i) { storage_.template get<0, i>() = get<i>(v); });
  }

  ripple_all constexpr MatImpl(Storage storage) noexcept
  : storage_{storage} {}

  ripple_all constexpr MatImpl(const MatImpl& other) noexcept
  : storage_{other.storage_} {}

  ripple_all constexpr MatImpl(MatImpl&& other) noexcept
  : storage_{ripple_move(other.storage_)} {}

  template <typename OtherLayout>
  ripple_all constexpr MatImpl(
    const MatImpl<T, Rows, Cols, OtherLayout>& other) noexcept
  : storage_{other.storage_} {}

  template <typename OtherLayout>
  ripple_all constexpr MatImpl(
    MatImpl<T, Rows, Cols, OtherLayout>&& other)
  : storage_{ripple_move(other.storage_)} {}

  /*==--- [operator overloads] ---------------------------------------------==*/

  ripple_all auto operator=(const MatImpl& other) noexcept -> MatImpl& {
    storage_ = other.storage_;
    return *this;
  }

  ripple_all auto operator=(MatImpl&& other) noexcept -> MatImpl& {
    storage_ = ripple_move(other.storage_);
    return *this;
  }

  template <typename OtherLayout>
  ripple_all auto
  operator=(const MatImpl<T, Rows, Cols, OtherLayout>& other) noexcept
    -> MatImpl& {
    unrolled_for<elements>([&](auto i) {
      storage_.template get<0, i>() = other.storage_.template get<0, i>();
    });
    return *this;
  }

  template <typename OtherLayout>
  ripple_all auto
  operator=(MatImpl<T, Rows, Cols, OtherLayout>&& other) noexcept -> MatImpl& {
    storage_ = ripple_move(other.storage_);
    return *this;
  }

  ripple_all auto
  operator()(size_t row, size_t col) noexcept -> Value& {
    return storage_.template get<0>(to_index(row, col));
  }

  ripple_all auto
  operator()(size_t row, size_t col) const noexcept -> const Value& {
    return storage_.template get<0>(to_index(row, col));
  }

  /*==--- [interface] ------------------------------------------------------==*/

  ripple_all constexpr auto columns() const noexcept -> size_t {
    return Cols::value;
  }

  ripple_all constexpr auto rows() const noexcept -> size_t {
    return Rows::value;
  }

  template <size_t Row, size_t Col>
  ripple_all constexpr auto at() const noexcept -> const Value& {
    static_assert((Row < rows()), "Compile time row index out of range!");
    static_assert((Col < columns()), "Compile time col index out of range!");
    constexpr size_t i = to_index(Row, Col);
    return storage_.template get<0, i>();
  }

  template <size_t Row, size_t Col>
  ripple_all constexpr auto at() const noexcept -> Value& {
    static_assert((Row < rows()), "Compile time row index out of range!");
    static_assert((Col < columns()), "Compile time col index out of range!");
    constexpr size_t i = to_index(Row, Col);
    return storage_.template get<0, i>();
  }

  ripple_all constexpr auto size() const noexcept -> size_t {
    return elements;
  }

 private:
  ripple_all constexpr auto
  to_index(size_t r, size_t c) const noexcept -> size_t {
    return r * columns() + c;
  }
};

template <typename Vec, size_t Rows>
using mat_vec_result_t =
  typename array_traits_t<Vec>::template ImplType<Rows, ContiguousOwned>;

template <typename T, typename R, typename C, typename L, typename Impl>
ripple_all auto
operator*(const MatImpl<T, R, C, L>& m, const Array<Impl>& v) noexcept
  -> mat_vec_result_t<Impl, R::value> {
  constexpr size_t rows = R::value;
  constexpr size_t cols = C::value;
  using Value           = typename array_traits_t<Impl>::Value;
  using Result          = mat_vec_result_t<Impl, rows>;

  static_assert(
    cols == array_traits_t<Impl>::size,
    "Invalid configuration for matrix vector multiplication!");
  static_assert(
    std::is_convertible_v<T, typename array_traits_t<Impl>::Value>,
    "Matrix and vector types must be convertible!");

  Result result;
  unrolled_for<rows>([&](auto r) {
    result[r] = 0;
    unrolled_for<cols>([&](auto c) { result[r] += m(r, c) * v[c]; });
  });
  return result;
}

template <
  typename T1,
  typename T2,
  typename R1,
  typename C1R2,
  typename C2,
  typename L1,
  typename L2>
ripple_all auto operator*(
  const MatImpl<T1, R1, C1R2, L1>& a,
  const MatImpl<T2, C1R2, C2, L2>& b) noexcept
  -> MatImpl<T1, R1, C2, ContiguousOwned> {
  constexpr size_t rows  = R1::value;
  constexpr size_t cols  = C2::value;
  constexpr size_t inner = C1R2::value;

  static_assert(
    std::is_convertible_v<T1, T2>,
    "Matrix multiplication requires data types which are convertible!");

  using Result = MatImpl<T1, R1, C2, ContiguousOwned>;
  Result res{0};
  for (size_t r = 0; r < rows; ++r) {
    for (size_t c = 0; c < cols; ++c) {
      unrolled_for<inner>([&](auto i) { res(r, c) += a(r, i) * b(i, c); });
    }
  }
  return res;
}

} // namespace ripple

#endif // namespace RIPPLE_MATH_MAT_HPP

Docs

Access comprehensive developer documentation for Ripple

View Docs

Tutorials

Get tutorials to help with understand all features

View Tutorials

Examples

Find examples to help get started

View Examples