初期値問題用ソルバーの速度比較

初期値問題用ソルバーの速度比較 #

1. まとめ #

Non-stiff な方程式 Stiffな方程式
$N\leq 10^5$ scipy.integrate.odeint scipy.integrate.odeint
$N\geq 10^5$ torchdiffeq.odeint (No data)

(絶対誤差、相対誤差が$10^{-4}\sim 10^{-9}$の範囲)

2. 比較対象 #

本稿では、下記3種のpythonソルバーを使用して常微分方程式の初期値問題を解いた場合の速度比較を行います。対象は、刻み幅を自動的に決める適応型ソルバーに限ります。

  • scipy.integrate.solve_ivp1
  • scipy.integrate.odeint2
  • torchdiffeq.odeint3

条件は次の通りです。

  • 微分方程式の右辺以外の情報は与えない(scipyのodeintでJacobianは与えない)
  • 関数のJit化は行わない (付録参照)
  • 並列化は指定しない

実行環境は次の通りです。

  • OS: Windows 11 Pro 64-bit
  • RAM: 16GB
  • CPU: Intel Core i7-10710U
  • GPU: GeForce RTX 3060 Ti VENTUS 2X 8G OCV1 (eGPU)

2.1. 各ライブラリの特徴 #

項目 scipy.integrate.solve_ivp1 scipy.integrate.odeint2 torchdiffeq.odeint3
実装 Python FORTRAN (ODEPACK の LSODA) Python
対応する微分方程式 stiff/non-stiff(手動選択) stiff/non-stiff (自動選択) non-stiff
計算環境 CPU のみ CPU のみ GPU および CPU
主な利用シーン 柔軟な手法選択やイベント検出が求められる数値解析 従来型の数値解析、硬さに応じた自動アルゴリズムによるODE解法 微分可能プログラミング、ニューラルODEの実装・研究
特徴 複数のアルゴリズム(RK45、Radau、BDF、LSODA等)から選択可能。イベント処理が容易。 問題の硬さに応じ、non-stiffでは Adams 法、stiffでは BDF を自動切替。高速・信頼性が高い。 自動微分によりニューラルODE等の最適化に適している

2.2. IVPを解くコード例 #

各ライブラリで、ODEを解くコードを示します。 ベンチマークに用いるコードとは異なりますが、使用方法を理解するために示します。

クラス名の異なる3つのラッパー IvpSolverScipy, IvpSolverScipyOdeint, IvpSolvePytorchを用意しました。
それぞれ、SciPyのsolve_ivp, SciPyのodeint, PyTorchのodeintを使用して解くソルバーです。 要素の型はfloat64としています。

ivp solver
import torch
import torchdiffeq
import time
import scipy as sp
import numpy as np
import sys
import matplotlib.pyplot as plt


# Ivp Solver with GPU
class IvpSolverPytorch:
    def __init__(self):
        pass

    def solve(
        self,
        t: torch.Tensor,
        psi0: torch.Tensor,
        rtol: float = 1e-9,
        atol: float = 1e-9,
        method: str = "dopri5",
    ):
        res = torchdiffeq.odeint(self.func, psi0, t, rtol=rtol, atol=atol, method=method)
        return res

    def func(self, t, y: torch.Tensor):
        rhs = -y * torch.sin(t)
        return rhs


# Ivp Solver with odeint
class IvpSolverScipyOdeint:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(self.func, y0, t, atol=atol, rtol=rtol)
        return res

    def func(self, y: np.array, t: float):
        rhs = -y * np.sin(t)
        return rhs


# Ivp Solver with CPU
class IvpSolverScipy:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
        method: str = "RK45",
    ):
        res = sp.integrate.solve_ivp(self.func, [t[0], t[-1]], y0, t_eval=t, rtol=rtol, atol=atol, method=method)
        return res

    def func(self, t: float, y: np.array):
        rhs = -y * np.sin(t)
        return rhs


# Main関数
if __name__ == "__main__":
    # Number of 1st-order differential equations
    N = 10
    rtol = 1e-4
    atol = 1e-4

    # 初期状態の設定 (各ルーチンで同一のy0を用いる)
    y0 = np.full(N, 2.0, dtype=np.float64)
    y0_gpu = torch.full((N,), 2.0, device="cuda", dtype=torch.float64)

    # 時間範囲の設定
    t_eval = np.linspace(0, 10.0, 101, endpoint=True)
    t_eval_gpu = torch.tensor(t_eval, device="cuda", dtype=torch.float64)

    # ----- scipy.integrate.odeint -----
    solver = IvpSolverScipyOdeint()
    t_start = time.time()
    solution = solver.solve(t_eval, y0, rtol=rtol, atol=atol)
    elapsed_time = time.time() - t_start
    print(f"  Elapsed time Scipy Odeint              : {elapsed_time:10.6f} (s)")
    sol_sp_odeint = solution[:, 0]

    # ----- scipy.integrate.solve_ivp -----
    method = "RK45"  # "RK23", "RK45", "DOP853", "Radau", "BDF", "LSODA"
    solver = IvpSolverScipy()
    t_start = time.time()
    solution = solver.solve(t_eval, y0, rtol=rtol, atol=atol, method=method)
    elapsed_time = time.time() - t_start
    print(f"  Elapsed time Scipy solve_ivp {method:10s}: {elapsed_time:10.6f} (s)")
    sol_sp_solve_ivp = solution.y[0, :]

    # ----- pytorch.odeint -----
    method = "dopri8"  # "dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun"
    solver = IvpSolverPytorch()
    t_start = time.time()
    solution = solver.solve(t_eval_gpu, y0_gpu, rtol=rtol, atol=atol, method=method)
    elapsed_time = time.time() - t_start
    print(f"  Elapsed time Pytorch odeint  {method:10s}: {elapsed_time:10.6f} (s)")
    sol_pt_odeint = solution[:, 0].cpu().numpy()

    # --- Plot ---
    plt.plot(t_eval, 2.0 * np.exp(np.cos(t_eval) - 1.0), linestyle="-", color="black", label="Exact")
    plt.plot(t_eval, sol_sp_odeint, marker="o", linestyle="-", color="C0", label="SciPy odeint")
    plt.plot(t_eval, sol_sp_solve_ivp, marker="x", linestyle="-", color="C1", label="SciPy solve_ivp")
    plt.plot(t_eval, sol_pt_odeint, marker="^", linestyle="-", color="C2", label="PyTorch odeint (cuda)")
    plt.legend()
    plt.show()

3. ベンチマーク用コード #

ベンチマークに使用したコードを下記に置きます。

ivp_benchmark_v1.00.zip (3kB)

こちらで計算した結果を別のプログラムで読み込み、グラフ作成をしました。

4. 結果 #

4.1. 実行時間比較 #

IVPを解くための実行時間は、数値手法を決めたとき大きく次のパラメータに依存します。

計算量 小 計算量 大
1ステップ当たりの精度 $\text{tol}$ 大きい 小さい
解の範囲 $[t_a, t_b]$ 狭い 広い
方程式の個数 $N$ 少ない 多い

もう一つ、計算量に関する重要な要素として解きたい方程式の硬さが関係してきます。

方程式の"硬さ"とは、一般的には解中に急激に減衰する成分と緩やかに変化する成分が共存している場合のその差を意味します。 線形の問題に対する硬さの定義は簡単ですが、線形・非線形微分方程式の両方の問題に対する硬さを定義することは困難です4

具体的には次のように表現できます。

  • “硬い方程式” (Stiff) とは、ある積分変数の近傍では滑らかな解の振る舞いをしますが、別の積分変数の近傍では急激に変化するような振る舞いを持つ方程式です。
  • “硬くない方程式” (Non-Stiff) とは、解の変化量がほとんど変化しないような解の振る舞いとなる方程式です。

一般的に硬い方程式は、安定性領域が広い数値手法、例えば陰的解法を用いるのが良いです5

4.1.1. Non-stiffな問題 #

硬くない問題として、

$$ \begin{gather} \frac{d\mathbf{y}}{dt}=-\mathbf{y}\sin(t),\\ \mathbf{y}(t=0)=2,\hspace{2em}t=[0,50] \end{gather} $$

という問題を考えてみます。$\mathbf{y}=(y_0, y_1,\cdots, y_{N-1})$ としてが$N$個の要素から成るとき、計算時間の$N$依存性と数値手法による違いを調べますと、次の図のようになります。

4.1.2. Stiffな問題 #

硬い問題として、Van der Polの方程式を考えてみます6。Van der Pol方程式は次の2階微分方程式です。

$$ \begin{align} \left\{ \begin{aligned} \frac{dx}{dt}&=y, \\ \frac{dy}{dt}&=((1-x^2)y-x)/\nu \end{aligned} \right.\label{v1} \end{align} $$

$$ \nu=10^{-6},\hspace{1em}x(t=0)=2.0,\hspace{1em}y(t=0)=-0.66,\hspace{1em}t=[0,2] $$

という問題を考えてみます7。これを$(x_0, y_0, x_1, y_1\cdots, x_{N/2-1}, y_{N/2-1})$ として、各添え字ごとに式\eqref{v1}の方程式を考えます。
計算時間の$N$依存性と数値手法による違いを調べますと、次の図のようになります。

こちらの問題において陽的解法は100秒近くの計算時間が掛かったため、$N=4$だけを計算し、その他の$N$については計算をしておりません。また、GPUで計算を行うPyTorchのodeintは100秒でも計算が終わらなかったため、計算を取りやめました。

4.1.3. 小まとめ #

次の点が分かります。

  • 硬くない問題

    • $N$が小さいとき: scipy.integrate.odeint (次点でscipy.integrate.solve_ivpの"LSODA"または"RK45")
    • $N$が大きいとき: torchdiffeq.odeintの"dopri8" (次点でtorchdiffeq.odeintの"dopri5")
  • 硬い問題

    • scipy.integrate.odeint (次点でscipy.integrate.solve_ivpの"LSODA")
  • 精度が要求されない場合、“RK45” 等の高次の方法よりも “RK23” 等の低次の方法の方が良い。

  • GPUを用いる計算は、$N$が$10^5$以上で有利になる。

5. 付録 #

5.1. 高速化 #

今回は特別な高速化を実施しないでライブラリ単体の性能を見ました。 その中で特にscipy.integrate.odeintの高速化について考えてみます。

例えば高速化の方法として次の方法が挙げられます。

  • のJacobianの情報を追加で与える
  • Numbaを用いた関数のJIT化
  • 並列処理

並列処理は除いて、Jacobian, JITの効果を確かめてみました。結果は以下のようになりました。 Jacobianは精度に影響を与えるので、単純な比較は難しいです。可能であれば関数のJIT化はしておくと良いでしょう。

対象: Van der Pol方程式

平均実行時間 (ms) ($N=4$) $\mathrm{tol=10^{-4}}$ $\mathrm{tol=10^{-6}}$ $\mathrm{tol=10^{-9}}$ 高速化コードのクラス
工夫無し $10.2 \pm 0.7$ $18.2 \pm 0.6$ $17.7 \pm 1.0$ IvpSolverScipyOdeint
Jacobian追加 $10.9 \pm 1.6$ $18.2 \pm 0.5$ $17.5 \pm 0.9$ IvpSolverScipyOdeintJac
Jit化 $2.8 \pm 0.5$ $4.7 \pm 0.5$ $4.6 \pm 0.5$ IvpSolverScipyOdeintJit
Jacobian追加 & Jit化 $4.9 \pm 0.7$ $8.1 \pm 0.4$ $7.9 \pm 0.6$ IvpSolverScipyOdeintJacJit
Jacobian追加 & Jit化 2 $5.0 \pm 0.6$ $8.2 \pm 0.6$ $7.9 \pm 0.5$ IvpSolverScipyOdeintJacJit2
平均実行時間 (s) ($N=512$) $\mathrm{tol=10^{-4}}$ $\mathrm{tol=10^{-6}}$ $\mathrm{tol=10^{-9}}$ 高速化コードのクラス
工夫無し $2.2 \pm 0.4$ $3.9 \pm 0.6$ $2.8 \pm 0.5$ IvpSolverScipyOdeint
Jacobian追加 $1.9 \pm 0.3$ $2.9 \pm 0.5$ $2.6 \pm 0.4$ IvpSolverScipyOdeintJac
Jit化 $1.8 \pm 0.3$ $3.2 \pm 0.5$ $2.3 \pm 0.4$ IvpSolverScipyOdeintJit
Jacobian追加 & Jit化 $2.0 \pm 0.3$ $3.0 \pm 0.6$ $2.6 \pm 0.4$ IvpSolverScipyOdeintJacJit
Jacobian追加 & Jit化 2 $2.0 \pm 0.4$ $2.9 \pm 0.5$ $2.6 \pm 0.4$ IvpSolverScipyOdeintJacJit2
平均実行時間 (s) ($N=1024$) $\mathrm{tol=10^{-4}}$ $\mathrm{tol=10^{-6}}$ $\mathrm{tol=10^{-9}}$ 高速化コードのクラス
工夫無し $5.2 \pm 0.9$ $9.5 \pm 1.6$ $7.0 \pm 1.3$ IvpSolverScipyOdeint
Jacobian追加 $5.1 \pm 1.0$ $7.8 \pm 1.4$ $6.9 \pm 1.2$ IvpSolverScipyOdeintJac
Jit化 $4.0 \pm 0.8$ $7.4 \pm 1.5$ $5.4 \pm 1.1$ IvpSolverScipyOdeintJit
Jacobian追加 & Jit化 $5.1 \pm 1.0$ $7.8 \pm 1.4$ $6.9 \pm 1.3$ IvpSolverScipyOdeintJacJit
Jacobian追加 & Jit化 2 $5.1 \pm 1.0$ $7.9 \pm 1.4$ $6.9 \pm 1.3$ IvpSolverScipyOdeintJacJit2

5.2. 高速化コード #

上記の計算に用いた5種のソルバーは次のコードで実装しています。

opt odeint

# Ivp Solver with odeint
class IvpSolverScipyOdeint:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(self.func, y0, t, atol=atol, rtol=rtol)
        return res

    def func(self, y: np.array, t: float):
        rhs = np.empty_like(y)
        rhs[0::2] = y[1::2]
        rhs[1::2] = ((1.0 - y[0::2] ** 2) * y[1::2] - y[0::2]) / 1e-6
        return rhs


# Ivp Solver with odeint
class IvpSolverScipyOdeintJac:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(self.func, y0, t, Dfun=self.jac, atol=atol, rtol=rtol)
        return res

    def func(self, y: np.array, t: float):
        rhs = np.empty_like(y)
        rhs[0::2] = y[1::2]
        rhs[1::2] = ((1.0 - y[0::2] ** 2) * y[1::2] - y[0::2]) / 1e-6
        return rhs

    def jac(self, y: np.array, t: float):
        N = y.shape[0]  # always even
        rhs = np.zeros((N, N), dtype=y.dtype)
        idx = np.arange(N // 2)
        rhs[2 * idx, 2 * idx + 1] = 1.0
        rhs[2 * idx + 1, 2 * idx] = (-2.0 * y[0::2] * y[1::2] - 1.0) / 1e-6
        rhs[2 * idx + 1, 2 * idx + 1] = (1.0 - y[0::2] ** 2) / 1e-6
        return rhs


# Ivp Solver with odeint
class IvpSolverScipyOdeintJit:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(IvpSolverScipyOdeintJit.func, y0, t, atol=atol, rtol=rtol)
        return res

    @staticmethod
    @jit(nopython=True)
    def func(y: np.array, t: float):
        rhs = np.empty_like(y)
        rhs[0::2] = y[1::2]
        rhs[1::2] = ((1.0 - y[0::2] ** 2) * y[1::2] - y[0::2]) / 1e-6
        return rhs


# Ivp Solver with odeint
class IvpSolverScipyOdeintJacJit:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(IvpSolverScipyOdeintJacJit.func, y0, t, Dfun=self.jac, atol=atol, rtol=rtol)
        return res

    @staticmethod
    @jit(nopython=True)
    def func(y: np.array, t: float):
        rhs = np.empty_like(y)
        rhs[0::2] = y[1::2]
        rhs[1::2] = ((1.0 - y[0::2] ** 2) * y[1::2] - y[0::2]) / 1e-6
        return rhs

    def jac(self, y: np.array, t: float):
        N = y.shape[0]  # always even
        rhs = np.zeros((N, N), dtype=y.dtype)
        idx = np.arange(N // 2)

        rhs[2 * idx, 2 * idx + 1] = 1.0
        rhs[2 * idx + 1, 2 * idx] = (-2.0 * y[0::2] * y[1::2] - 1.0) / 1e-6
        rhs[2 * idx + 1, 2 * idx + 1] = (1.0 - y[0::2] ** 2) / 1e-6
        return rhs

# Ivp Solver with odeint
class IvpSolverScipyOdeintJacJit2:
    def __init__(self):
        pass

    def solve(
        self,
        t: np.array,
        y0: np.array,
        rtol: float = 1e-9,
        atol: float = 1e-9,
    ):
        res = sp.integrate.odeint(IvpSolverScipyOdeintJacJit2.func, y0, t, Dfun=IvpSolverScipyOdeintJacJit2.jac, atol=atol, rtol=rtol)

        return res

    @staticmethod
    @jit(nopython=True)
    def func(y: np.array, t: float):
        rhs = np.empty_like(y)
        rhs[0::2] = y[1::2]
        rhs[1::2] = ((1.0 - y[0::2] ** 2) * y[1::2] - y[0::2]) / 1e-6
        return rhs

    @staticmethod
    @jit(nopython=True)
    def jac(y: np.array, t: float):
        N = y.shape[0]  # Always even
        rhs = np.zeros((N, N), dtype=y.dtype)
        n2 = N // 2
        for i in range(n2):
            # even-indexed row: 2*i line
            rhs[2 * i, 2 * i] = 0.0
            rhs[2 * i, 2 * i + 1] = 1.0

            # odd-indexed row: 2*i+1 line
            rhs[2 * i + 1, 2 * i] = (-2.0 * y[2 * i] * y[2 * i + 1] - 1.0) / 1e-6
            rhs[2 * i + 1, 2 * i + 1] = (1.0 - y[2 * i] ** 2) / 1e-6
        return rhs

6. 参考文献 #