もちもちしている

おらんなの気まぐれブログ

トップダウン型の自動微分を実装する

はじめに

だいぶ前に書いていたコードですが、C++でTD型の自動微分を実装したので公開します。 自動微分については以下の記事を参考にさせていただきました。

kivantium.hateblo.jp

C++のコード

C++書いたコードは以下の通り。 自分が書いたコードはC++14に対応していましたが、話題のChatGPTを用いて、C++17に対応するようにリファクタリングしてもらいました。

Variable::backward でグラフを幅優先探索で辿ってる以外はクラスの定義や演算子オーバーロードを行っているだけなので、スッと読めると思います。

#include <iostream>
#include <deque>
#include <unordered_set>
#include <vector>
#include <memory>
#include <tuple>
#include <functional>

template <typename T>
auto enumerate(T &container)
{
  std::vector<std::tuple<std::size_t, decltype(*std::begin(container)) &>> enumerated;
  std::size_t index = 0;
  for (auto &item : container)
  {
    enumerated.emplace_back(index, item);
    ++index;
  }
  return enumerated;
}

class Function;
class Variable;

class Variable
{
public:
  double data;
  double grad;
  std::shared_ptr<Function> creator;

  Variable(const double data)
      : data(data), grad(1.0), creator(nullptr) {}

  void set_creator(const std::shared_ptr<Function> &gen_func)
  {
    creator = gen_func;
  }

  void backward();
};

class Function : public std::enable_shared_from_this<Function>
{
public:
  std::vector<std::shared_ptr<Variable>> inputs;
  std::shared_ptr<Variable> output;

  std::shared_ptr<Variable> operator()(const std::shared_ptr<Function> &self, const std::shared_ptr<Variable> &input1, const std::shared_ptr<Variable> &input2 = nullptr)
  {
    inputs = {input1, input2};
    const double y = forward();
    output = std::make_shared<Variable>(y);
    output->set_creator(self);

    return output;
  }

  virtual double forward() const = 0;
  virtual std::vector<std::shared_ptr<Variable>> backward(const double gy) const = 0;
};

class Add : public Function, public std::enable_shared_from_this<Add>
{
public:
  double forward() const override
  {
    return inputs[0]->data + inputs[1]->data;
  }

  std::vector<std::shared_ptr<Variable>> backward(const double gy) const override
  {
    return {std::make_shared<Variable>(gy), std::make_shared<Variable>(gy)};
  }
};

class Mul : public Function, public std::enable_shared_from_this<Mul>
{
public:
  double forward() const override
  {
    return inputs[0]->data * inputs[1]->data;
  }

  std::vector<std::shared_ptr<Variable>> backward(const double gy) const override
  {
    return {std::make_shared<Variable>(gy * inputs[1]->data),
            std::make_shared<Variable>(gy * inputs[0]->data)};
  }
};

std::shared_ptr<Variable> operator+(const std::shared_ptr<Variable> &lhs, const std::shared_ptr<Variable> &rhs)
{
  auto add_func = std::make_shared<Add>();
  return add_func->operator()(add_func, lhs, rhs);
}

std::shared_ptr<Variable> operator*(const std::shared_ptr<Variable> &lhs, const std::shared_ptr<Variable> &rhs)
{
  auto mul_func = std::make_shared<Mul>();
  return mul_func->operator()(mul_func, lhs, rhs);
}

std::shared_ptr<Variable> operator+(const std::shared_ptr<Variable> &lhs, const double rhs)
{
  auto rhs_var = std::make_shared<Variable>(rhs);
  return lhs + rhs_var;
}

std::shared_ptr<Variable> operator+(const double lhs, const std::shared_ptr<Variable> &rhs)
{
  auto lhs_var = std::make_shared<Variable>(lhs);
  return lhs_var + rhs;
}

std::shared_ptr<Variable> operator*(const std::shared_ptr<Variable> &lhs, const double rhs)
{
  auto rhs_var = std::make_shared<Variable>(rhs);
  return lhs * rhs_var;
}

std::shared_ptr<Variable> operator*(const double lhs, const std::shared_ptr<Variable> &rhs)
{
  auto lhs_var = std::make_shared<Variable>(lhs);
  return lhs_var * rhs;
}

void Variable::backward()
{
  if (creator == nullptr)
  {
    return;
  }

  std::unordered_set<size_t> visited;
  std::deque<std::shared_ptr<Function>> queue{creator};

  while (!queue.empty())
  {
    auto function = queue.front();
    auto output = function->output;
    auto gy = output->grad;
    auto gxs = function->backward(gy);

    queue.pop_front();

    for (const auto &[i, gx] : enumerate(gxs))
    {
      auto x = function->inputs[i];

      if (gx == nullptr)
      {
        continue;
      }

      std::size_t id_x = std::hash<std::shared_ptr<Variable>>{}(x);
      if (x->creator != nullptr)
      {
        queue.push_back(x->creator);
      }

      if (visited.find(id_x) == visited.end())
      {
        x->grad = gx->data;
        visited.insert(id_x);
      }
      else
      {
        x->grad += gx->data;
      }
    }
  }
}

int main()
{
  auto x = std::make_shared<Variable>(1.0);
  auto y = std::make_shared<Variable>(1.0);

  auto z = x * x + 2 * x * y + y ;
  std::cout << "z = x * x + 2 * x * y + y " << std::endl;
  z->backward();

  std::cout << "x = " << x->data << ", x.grad = " << x->grad << std::endl;
  std::cout << "y = " << y->data << ", y.grad = " << y->grad << std::endl;

  return 0;
}

コンパイルと実行は以下の通りです。

$ clang++ -std=c++17 main.cpp
$ ./a.out

おわりに

C++も好きな言語ではあるんですが、今後はRustを使って何か実装できたらいいなと考えています。