Wildmeshing Toolkit
Loading...
Searching...
No Matches
AutoDiffUtils.hpp
Go to the documentation of this file.
1#pragma once
2#include "autodiff.h"
3
4namespace wmtk::function::utils {
5
6
7template <typename DScalarType, int Rows = Eigen::Dynamic, int Cols = Eigen::Dynamic>
8auto make_DScalar_matrix(int rows = 0, int cols = 0)
9{
10 if constexpr (Rows != Eigen::Dynamic) {
11 rows = Rows;
12 }
13 if constexpr (Cols != Eigen::Dynamic) {
14 cols = Cols;
15 }
16 assert(rows * cols == DiffScalarBase::getVariableCount());
17
18 using RetType = Eigen::Matrix<DScalarType, Rows, Cols>;
19 if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) {
20 return RetType::NullaryExpr([](int row, int col) {
21 int index;
22 if constexpr (RetType::IsRowMajor) {
23 index = Rows * col + row;
24 } else {
25 index = Cols * row + col;
26 }
27 return DScalarType(index);
28 })
29 .eval();
30 } else {
31 return RetType::NullaryExpr(
32 rows,
33 cols,
34 [&](int row, int col) {
35 int index;
36 if constexpr (RetType::IsRowMajor) {
37 index = rows * col + row;
38 } else {
39 index = cols * row + col;
40 }
41 return DScalarType(index);
42 })
43 .eval();
44 }
45}
46
47template <typename DScalarType, typename Derived>
48auto as_DScalar(const Eigen::MatrixBase<Derived>& data)
49{
50 constexpr static int Rows = Derived::RowsAtCompileTime;
51 constexpr static int Cols = Derived::ColsAtCompileTime;
52 int rows = data.rows();
53 int cols = data.cols();
54
55 assert(rows * cols == DiffScalarBase::getVariableCount());
56
57 using RetType = Eigen::Matrix<DScalarType, Rows, Cols>;
58 if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) {
59 return RetType::NullaryExpr([&](int row, int col) {
60 int index;
61 if constexpr (RetType::IsRowMajor) {
62 index = Rows * col + row;
63 } else {
64 index = Cols * row + col;
65 }
66 return DScalarType(index, data(row, col));
67 })
68 .eval();
69 } else {
70 return RetType::NullaryExpr(
71 rows,
72 cols,
73 [&](int row, int col) {
74 int index;
75 if constexpr (RetType::IsRowMajor) {
76 index = rows * col + row;
77 } else {
78 index = cols * row + col;
79 }
80 return DScalarType(index, data(row, col));
81 })
82 .eval();
83 }
84}
85
86
87} // namespace wmtk::function::utils
auto make_DScalar_matrix(int rows=0, int cols=0)
auto as_DScalar(const Eigen::MatrixBase< Derived > &data)
static size_t getVariableCount()
Get the variable count used by the automatic differentiation layer.
Definition autodiff.h:63