Commit c2553e00 by Praetorius, Simon

### symbolic differentiation

parent 88f6fb85
This diff is collapsed.
 ... @@ -57,6 +57,7 @@ class: tud-light, typo ... @@ -57,6 +57,7 @@ class: tud-light, typo {% include concepts.md %} {% include concepts.md %} {% include metaprogramming.md %} {% include metaprogramming.md %} {% include further-topics.md %} {% include further-topics.md %} {% include symbolic-differentiation.md %}
... ...
 ... @@ -24,7 +24,7 @@ template<> struct Prime_print<1> { ... @@ -24,7 +24,7 @@ template<> struct Prime_print<1> { #define LAST 18 #define LAST 18 #endif #endif main() { int main() { Prime_print a; Prime_print a; a.f(); a.f(); } }
 #include #include #define SIMPLIFY 1 struct Plus { template auto operator() (A const& a, B const& b) const { return a + b; } }; struct Mult { template auto operator() (A const& a, B const& b) const { return a * b; } }; struct Exp { template auto operator() (A const& a) const { return std::exp(a); } }; // some basic expressions: template struct Variable { double operator() (double const* x) const { return x[varID]; } }; // The integer class is both, an expression and a functor template struct Integer { template double operator() (Ts const&...) const { return double(value); } }; struct Number { double value_; Number (double value) : value_(value) {} double operator() (double const* x) const { return value_; } }; // Composition of F(A) template struct UnaryExpr { F f_; A a_; UnaryExpr (F f, A const& a) : f_(f), a_(a) {} double operator() (double const* x) const { return f_(a_(x)); } }; // Exp(F) -> UnaryExpr template auto exp (F const& f) { return UnaryExpr{Exp{},f}; } // Composition of F(A1,A2) template struct BinaryExpr { F f_; A1 a1_; A2 a2_; BinaryExpr (F f, A1 const& a1, A2 const& a2) : f_(f), a1_(a1), a2_(a2) {} double operator() (double const* x) const { return f_(a1_(x), a2_(x)); } }; // F + G -> BinaryExpr template auto operator+ (F const& f, G const& g) { return BinaryExpr{Plus{},f,g}; } // F * G -> BinaryExpr template auto operator* (F const& f, G const& g) { return BinaryExpr{Mult{},f,g}; } // Simplification of expressions template struct Simplify { using type = Expr; static Expr const& generate (Expr const& expr) { return expr; } }; template using Simplify_t = typename Simplify::type; template auto simplify (Expr const& expr) { return Simplify::generate(expr); } #if SIMPLIFY // basic simplification rules: recursively template struct Simplify> { using type = UnaryExpr>; static type generate (UnaryExpr const& expr) { return {expr.f_, simplify(expr.a_)}; } }; // basic simplification rules: recursively template struct Simplify> { using type = BinaryExpr,Simplify_t>; static type generate (BinaryExpr const& expr) { return {expr.f_, simplify(expr.a1_), simplify(expr.a2_)}; } }; #endif // Generator for derivative expressions template struct Derivative; template // alias for the generated derivative type using Derivative_t = typename Derivative::type; template // note the different order of the arguments auto derivative (F const& f) { return Derivative::generate(f); } // Derivative of a general unary expression template struct Derivative, varID> { using dF = Derivative_t; // functor using dA = Derivative_t; // expression using type = Simplify_t, dA> >>; static type generate(UnaryExpr const& expr) { return simplify(simplify( UnaryExpr(derivative<0>(expr.f_), expr.a_) * derivative(expr.a_) )); // f1'(A) * dA/dx_i } }; // Derivative of a general binary expression template struct Derivative, varID > { using dF0 = Derivative_t; // functors using dF1 = Derivative_t; using dA1 = Derivative_t; // expressions using dA2 = Derivative_t; using type = Simplify_t, dA1> >, Simplify_t, dA2> > > >>; static type generate (BinaryExpr const& expr) { return simplify(simplify( simplify(BinaryExpr(derivative<0>(expr.f_), expr.a1_, expr.a2_) * derivative(expr.a1_) ) + simplify(BinaryExpr(derivative<1>(expr.f_), expr.a1_, expr.a2_) * derivative(expr.a2_) ) )); } }; // ----------------------------------- // Derivative of a plus expression template struct Derivative, varID > { using dA1 = Derivative_t; // expressions using dA2 = Derivative_t; using type = Simplify_t >>; static type generate (BinaryExpr const& expr) { return simplify(simplify( derivative(expr.a1_) + derivative(expr.a2_) )); } }; // derivatives of some basic expressions template struct Derivative, varID> { using type = Integer<0>; static type generate (Integer) { return {}; } }; template struct Derivative { using type = Integer<0>; static type generate (Number) { return {}; } }; template struct Derivative, j> { using type = Integer<0>; static type generate (Variable) { return {}; } }; template struct Derivative, i> // specialization for i == j { using type = Integer<1>; static type generate (Variable) { return {}; } }; template <> struct Derivative { using type = Exp; static type generate (Exp) { return {}; } }; #if SIMPLIFY template <> struct Simplify>> { using type = Integer<1>; static type generate (UnaryExpr> const&) { return type{}; } }; template <> struct Simplify>> { using type = Number; static type generate (UnaryExpr> const&) { return Number{M_E}; } }; #endif template <> struct Derivative { using type = Integer<1>; static type generate (Plus) { return {}; } }; template <> struct Derivative { using type = Integer<1>; static type generate (Plus) { return {}; } }; #if SIMPLIFY template struct Simplify>> // E + 0 { using type = E; static type generate (BinaryExpr> const& expr) { return expr.a1_; } }; template struct Simplify,E>> // 0 + E { using type = E; static type generate (BinaryExpr,E> const& expr) { return expr.a2_; } }; template <> struct Simplify,Integer<0>>> // 0 + 0 { using type = Integer<0>; static type generate (BinaryExpr,Integer<0>> const&) { return {}; } }; template struct Simplify,Integer>> // n + m { using type = Integer; static type generate (BinaryExpr,Integer> const&) { return {}; } }; template struct Simplify,Integer<0>>> // n + 0 { using type = Integer; static type generate (BinaryExpr,Integer<0>> const&) { return {}; } }; template struct Simplify,Integer>> // n + m { using type = Integer; static type generate (BinaryExpr,Integer> const&) { return {}; } }; template struct Simplify,A1,A2>> { using type = Integer; static type generate (BinaryExpr,A1,A2> const&) { return {}; } }; template struct Simplify,A>> { using type = Integer; static type generate (UnaryExpr,A> const&) { return {}; } }; #endif // functor representing the i'th argument template struct Arg { template auto operator() (A const& a, B const& b) const { if constexpr (i == 0) return a; else return b; } }; template struct Derivative,j> { using type = Integer<0>; static type generate (Arg) { return {}; } }; template struct Derivative,i> { using type = Integer<1>; static type generate (Arg) { return {}; } }; #if SIMPLIFY template struct Simplify,A1,A2>> { using type = A1; static type generate (BinaryExpr,A1,A2> const& expr) { return expr.a1_; } }; template struct Simplify,A1,A2>> { using type = A2; static type generate (BinaryExpr,A1,A2> const& expr) { return expr.a2_; } }; #endif template <> struct Derivative { using type = Arg<1>; static type generate (Mult) { return {}; } }; template <> struct Derivative { using type = Arg<0>; static type generate (Mult) { return {}; } }; #if SIMPLIFY template struct Simplify>> // E * 0 { using type = Integer<0>; static type generate (BinaryExpr> const&) { return {}; } }; template struct Simplify,E>> // 0 * E { using type = Integer<0>; static type generate (BinaryExpr,E> const&) { return {}; } }; template <> struct Simplify,Integer<0>>> // 0 * 0 { using type = Integer<0>; static type generate (BinaryExpr,Integer<0>> const&) { return {}; } }; template struct Simplify>> // E * 1 { using type = E; static type generate (BinaryExpr> const& expr) { return expr.a1_; } }; template struct Simplify,E>> // 1 * E { using type = E; static type generate (BinaryExpr,E> const& expr) { return expr.a2_; } }; template <> struct Simplify,Integer<1>>> // 1 * 1 { using type = Integer<1>; static type generate (BinaryExpr,Integer<1>> const&) { return {}; } }; template <> struct Simplify,Integer<1>>> // 0 * 1 { using type = Integer<0>; static type generate (BinaryExpr,Integer<1>> const&) { return {}; } }; template <> struct Simplify,Integer<0>>> // 1 * 0 { using type = Integer<0>; static type generate (BinaryExpr,Integer<0>> const&) { return {}; } }; template struct Simplify,Integer>> // n * m { using type = Integer; static type generate (BinaryExpr,Integer> const&) { return {}; } }; template struct Simplify,Integer<0>>> // n * 0 { using type = Integer<0>; static type generate (BinaryExpr,Integer<0>> const&) { return {}; } }; template struct Simplify,Integer>> // 0 * m { using type = Integer<0>; static type generate (BinaryExpr,Integer> const&) { return {}; } }; template struct Simplify,Integer<1>>> // n * 1 { using type = Integer; static type generate (BinaryExpr,Integer<1>> const&) { return {}; } }; template struct Simplify,Integer>> // 1 * m { using type = Integer; static type generate (BinaryExpr,Integer> const&) { return {}; } }; template struct Simplify,BinaryExpr,E>>> // n * (m * E) { using type = BinaryExpr,E>; static type generate (BinaryExpr,BinaryExpr,E>> const& expr) { return {Mult{}, Integer{}, expr.a2_.a2_}; } }; template struct Simplify,BinaryExpr>>> // n * (E * m) { using type = BinaryExpr,E>; static type generate (BinaryExpr,BinaryExpr>> const& expr) { return {Mult{}, Integer{}, expr.a2_.a1_}; } }; template struct Simplify,BinaryExpr,E>>> // 0 * (m * E) { using type = Integer<0>; static type generate (BinaryExpr,BinaryExpr,E>> const& expr) { return {}; } }; template struct Simplify,BinaryExpr>>> // 0 * (E * m) { using type = Integer<0>; static type generate (BinaryExpr,BinaryExpr>> const& expr) { return {}; } }; template struct Simplify,BinaryExpr>>> // 1 * (E * m) { using type = BinaryExpr,E>; static type generate (BinaryExpr,BinaryExpr>> const& expr) { return {Mult{}, Integer{}, expr.a2_.a1_}; } }; template struct Simplify,BinaryExpr,E>>> // 1 * (m * E) { using type = BinaryExpr,E>; static type generate (BinaryExpr,BinaryExpr,E>> const& expr) { return {Mult{}, Integer{}, expr.a2_.a2_}; } }; #endif template void print (Expr const&) { std::cout << __PRETTY_FUNCTION__ << std::endl; } int main() { Variable<0> x0; Variable<1> x1; Variable<2> x2; Integer<1> _1; Integer<2> _2;