fmatvec  0.0.0
checkderivative.h
1#ifndef _FMATVEC_CHECKDERIVATIVE_H_
2#define _FMATVEC_CHECKDERIVATIVE_H_
3
4#include <map>
5#include <functional>
6#include <string>
7#include <iostream>
8#include <fmatvec/linear_algebra.h>
9
10namespace fmatvec {
11
12 #define FMATVEC_CHECKDERIVATIVE_MEMBERFUNC(func, anaDiff, value, /*indep,*/ ...) \
13 fmatvec::checkDerivative(this, std::string(__FILE__)+":"+std::to_string(__LINE__)+":"+#anaDiff, func, anaDiff, value, /*indep,*/ __VA_ARGS__)
14 #define FMATVEC_CHECKDERIVATIVE_GLOBALFUNC(func, anaDiff, value, /*indep,*/ ...) \
15 fmatvec::checkDerivative(nullptr, std::string(__FILE__)+":"+std::to_string(__LINE__)+":"+#anaDiff, func, anaDiff, value, /*indep,*/ __VA_ARGS__)
16
17 // Check a analytical derivative against the finite difference.
18 // Any local variable for which a analytical derivative is given can be checked.
19 // - idPtr and idName must, together, be a unique identifier for this call. Usually idPtr is "this" and idName is the name of the variable to be checked.
20 // - func is the callers function with the independent variable as input.
21 // - anaDiff is the analytical derivative which should be checked.
22 // - value is the value for which the analytical derivative should be checked.
23 // - indep is the independent variable with respect to which the derivative is calculated.
24 // - eps is the tolerance. If anaDiff differs more than this value a error is printed.
25 // - delta is the finite difference for the finite derivative calculation.
26 template<class Value>
27 void checkDerivative(const void *idPtr, const std::string &idName, const std::function<void(double)> &func, const Value &anaDiff, const Value &value, double indep, double eps=1e-6, double delta=1e-8) {
28 #ifdef NDEBUG
29 #error "fmatvec::checkDerivative should not be active for release builds"
30 #endif
31
32 static int distributedCount = 0;
33 static std::map<std::pair<const void*, std::string>, std::pair<bool, Value>> ele;
34
35 auto &[disturbed, disturbedValue] = ele.emplace(std::make_pair(idPtr, idName), std::make_pair(false, Value())).first->second;
36
37 if(disturbed) { // this is a distributed call (for the key in "ele") ..
38 // ... save the distributed value (in "ele" map)
39 if constexpr (std::is_same_v<Value, double>)
40 disturbedValue = value;
41 else
42 disturbedValue <<= value;
43 }
44 else if(distributedCount == 0) { // this is a normal call (not distributed) ...
45 std::array<Value, 2> disturbedValueRightLeft;
46
47 // .. make a distributed call for the current key in "ele"
48 disturbed = true;
49 distributedCount++;
50
51 func(indep + delta); // right disturbed
52 // save the distributed value (scalar or fmatvec vec/mat)
53 if constexpr (std::is_same_v<Value, double>)
54 disturbedValueRightLeft[0] = disturbedValue;
55 else
56 disturbedValueRightLeft[0] <<= disturbedValue;
57
58 func(indep - delta); // left disturbed
59 // save the distributed value (scalar or fmatvec vec/mat)
60 if constexpr (std::is_same_v<Value, double>)
61 disturbedValueRightLeft[1] = disturbedValue;
62 else
63 disturbedValueRightLeft[1] <<= disturbedValue;
64
65 func(indep); // undisturbed again to reset everything
66
67 disturbed = false;
68 distributedCount--;
69
70 // calculate the right and left finite difference and its norm
71 std::array<double, 2> dist;
72 std::array<Value, 2> finiteDiffRightLeft;
73 for(int i=0; i<2; ++i) {
74 if constexpr (std::is_same_v<Value, double>)
75 finiteDiffRightLeft[i] = (i==0?1.0:-1.0) * (disturbedValueRightLeft[i] - value)/delta;
76 else
77 finiteDiffRightLeft[i] <<= (i==0?1.0:-1.0) * (disturbedValueRightLeft[i] - value)/delta;
78
79 if constexpr (std::is_same_v<Value, double>)
80 dist[i] = std::abs(finiteDiffRightLeft[i] - anaDiff);
81 else if constexpr (Value::isVector)
82 dist[i] = fmatvec::nrmInf(finiteDiffRightLeft[i] - anaDiff);
83 else {
84 dist[i] = 0;
85 for(int c=0; c<anaDiff.cols(); c++)
86 dist[i] = std::max(dist[i], fmatvec::nrm2(finiteDiffRightLeft[i].col(c) - anaDiff.col(c)));
87 }
88 }
89
90 // if both, the right and the left finite difference, differs from analytically derivative value print a error message
91 double dist0 = dist[0]; // MSVS cannot compile "dist[0] > eps" directly
92 double dist1 = dist[1];
93 if(dist0 > eps && dist1 > eps) {
94 std::stringstream str;
95 str<<"checkDerivative failed in "<<idPtr<<": ID = "<<idName<<":"<<std::endl
96 <<"- anaDiff = "<<anaDiff<<std::endl
97 <<"- finiteDiffRight = "<<finiteDiffRightLeft[0]<<std::endl
98 <<"- finiteDiffLeft = "<<finiteDiffRightLeft[1]<<std::endl
99 <<"- indep = "<<indep<<std::endl
100 <<"- nrmInf (right/left) = "<<dist[0]<<" "<<dist[1];
101 //throw std::runtime_error(str.str());
102 std::cerr<<str.str()<<std::endl;
103 assert(0);
104 throw std::runtime_error(str.str());
105 }
106 }
107 }
108
109}
110
111#endif
Namespace fmatvec.
Definition: _memory.cc:28