Wildmeshing Toolkit
AutoDiffUtils.hpp
Go to the documentation of this file.
1 #pragma once
2 #include "autodiff.h"
3 
4 namespace wmtk::function::utils {
5 
6 
7 template <typename DScalarType, int Rows = Eigen::Dynamic, int Cols = Eigen::Dynamic>
8 auto make_DScalar_matrix(int rows = 0, int cols = 0)
9 {
10  if constexpr (Rows != Eigen::Dynamic) {
11  rows = Rows;
12  }
13  if constexpr (Cols != Eigen::Dynamic) {
14  cols = Cols;
15  }
16  assert(rows * cols == DiffScalarBase::getVariableCount());
17 
18  using RetType = Eigen::Matrix<DScalarType, Rows, Cols>;
19  if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) {
20  return RetType::NullaryExpr([](int row, int col) {
21  int index;
22  if constexpr (RetType::IsRowMajor) {
23  index = Rows * col + row;
24  } else {
25  index = Cols * row + col;
26  }
27  return DScalarType(index);
28  })
29  .eval();
30  } else {
31  return RetType::NullaryExpr(
32  rows,
33  cols,
34  [&](int row, int col) {
35  int index;
36  if constexpr (RetType::IsRowMajor) {
37  index = rows * col + row;
38  } else {
39  index = cols * row + col;
40  }
41  return DScalarType(index);
42  })
43  .eval();
44  }
45 }
46 
47 template <typename DScalarType, typename Derived>
48 auto as_DScalar(const Eigen::MatrixBase<Derived>& data)
49 {
50  constexpr static int Rows = Derived::RowsAtCompileTime;
51  constexpr static int Cols = Derived::ColsAtCompileTime;
52  int rows = data.rows();
53  int cols = data.cols();
54 
55  assert(rows * cols == DiffScalarBase::getVariableCount());
56 
57  using RetType = Eigen::Matrix<DScalarType, Rows, Cols>;
58  if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) {
59  return RetType::NullaryExpr([&](int row, int col) {
60  int index;
61  if constexpr (RetType::IsRowMajor) {
62  index = Rows * col + row;
63  } else {
64  index = Cols * row + col;
65  }
66  return DScalarType(index, data(row, col));
67  })
68  .eval();
69  } else {
70  return RetType::NullaryExpr(
71  rows,
72  cols,
73  [&](int row, int col) {
74  int index;
75  if constexpr (RetType::IsRowMajor) {
76  index = rows * col + row;
77  } else {
78  index = cols * row + col;
79  }
80  return DScalarType(index, data(row, col));
81  })
82  .eval();
83  }
84 }
85 
86 
87 } // namespace wmtk::function::utils
auto make_DScalar_matrix(int rows=0, int cols=0)
auto as_DScalar(const Eigen::MatrixBase< Derived > &data)
static size_t getVariableCount()
Get the variable count used by the automatic differentiation layer.
Definition: autodiff.h:63