トップダウン型の自動微分を実装する
はじめに
だいぶ前に書いていたコードですが、C++でTD型の自動微分を実装したので公開します。 自動微分については以下の記事を参考にさせていただきました。
C++のコード
C++書いたコードは以下の通り。 自分が書いたコードはC++14に対応していましたが、話題のChatGPTを用いて、C++17に対応するようにリファクタリングしてもらいました。
Variable::backward
でグラフを幅優先探索で辿ってる以外はクラスの定義や演算子オーバーロードを行っているだけなので、スッと読めると思います。
#include <iostream> #include <deque> #include <unordered_set> #include <vector> #include <memory> #include <tuple> #include <functional> template <typename T> auto enumerate(T &container) { std::vector<std::tuple<std::size_t, decltype(*std::begin(container)) &>> enumerated; std::size_t index = 0; for (auto &item : container) { enumerated.emplace_back(index, item); ++index; } return enumerated; } class Function; class Variable; class Variable { public: double data; double grad; std::shared_ptr<Function> creator; Variable(const double data) : data(data), grad(1.0), creator(nullptr) {} void set_creator(const std::shared_ptr<Function> &gen_func) { creator = gen_func; } void backward(); }; class Function : public std::enable_shared_from_this<Function> { public: std::vector<std::shared_ptr<Variable>> inputs; std::shared_ptr<Variable> output; std::shared_ptr<Variable> operator()(const std::shared_ptr<Function> &self, const std::shared_ptr<Variable> &input1, const std::shared_ptr<Variable> &input2 = nullptr) { inputs = {input1, input2}; const double y = forward(); output = std::make_shared<Variable>(y); output->set_creator(self); return output; } virtual double forward() const = 0; virtual std::vector<std::shared_ptr<Variable>> backward(const double gy) const = 0; }; class Add : public Function, public std::enable_shared_from_this<Add> { public: double forward() const override { return inputs[0]->data + inputs[1]->data; } std::vector<std::shared_ptr<Variable>> backward(const double gy) const override { return {std::make_shared<Variable>(gy), std::make_shared<Variable>(gy)}; } }; class Mul : public Function, public std::enable_shared_from_this<Mul> { public: double forward() const override { return inputs[0]->data * inputs[1]->data; } std::vector<std::shared_ptr<Variable>> backward(const double gy) const override { return {std::make_shared<Variable>(gy * inputs[1]->data), std::make_shared<Variable>(gy * inputs[0]->data)}; } }; std::shared_ptr<Variable> operator+(const std::shared_ptr<Variable> &lhs, const std::shared_ptr<Variable> &rhs) { auto add_func = std::make_shared<Add>(); return add_func->operator()(add_func, lhs, rhs); } std::shared_ptr<Variable> operator*(const std::shared_ptr<Variable> &lhs, const std::shared_ptr<Variable> &rhs) { auto mul_func = std::make_shared<Mul>(); return mul_func->operator()(mul_func, lhs, rhs); } std::shared_ptr<Variable> operator+(const std::shared_ptr<Variable> &lhs, const double rhs) { auto rhs_var = std::make_shared<Variable>(rhs); return lhs + rhs_var; } std::shared_ptr<Variable> operator+(const double lhs, const std::shared_ptr<Variable> &rhs) { auto lhs_var = std::make_shared<Variable>(lhs); return lhs_var + rhs; } std::shared_ptr<Variable> operator*(const std::shared_ptr<Variable> &lhs, const double rhs) { auto rhs_var = std::make_shared<Variable>(rhs); return lhs * rhs_var; } std::shared_ptr<Variable> operator*(const double lhs, const std::shared_ptr<Variable> &rhs) { auto lhs_var = std::make_shared<Variable>(lhs); return lhs_var * rhs; } void Variable::backward() { if (creator == nullptr) { return; } std::unordered_set<size_t> visited; std::deque<std::shared_ptr<Function>> queue{creator}; while (!queue.empty()) { auto function = queue.front(); auto output = function->output; auto gy = output->grad; auto gxs = function->backward(gy); queue.pop_front(); for (const auto &[i, gx] : enumerate(gxs)) { auto x = function->inputs[i]; if (gx == nullptr) { continue; } std::size_t id_x = std::hash<std::shared_ptr<Variable>>{}(x); if (x->creator != nullptr) { queue.push_back(x->creator); } if (visited.find(id_x) == visited.end()) { x->grad = gx->data; visited.insert(id_x); } else { x->grad += gx->data; } } } } int main() { auto x = std::make_shared<Variable>(1.0); auto y = std::make_shared<Variable>(1.0); auto z = x * x + 2 * x * y + y ; std::cout << "z = x * x + 2 * x * y + y " << std::endl; z->backward(); std::cout << "x = " << x->data << ", x.grad = " << x->grad << std::endl; std::cout << "y = " << y->data << ", y.grad = " << y->grad << std::endl; return 0; }
コンパイルと実行は以下の通りです。
$ clang++ -std=c++17 main.cpp $ ./a.out
おわりに
C++も好きな言語ではあるんですが、今後はRustを使って何か実装できたらいいなと考えています。