AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火

吴厣 459 0

用 AI 求解偏微分方程,这段时间确实有点火。

但究竟什么样的 AI 求解效果最好,却始终没有一个统一的定论。

现在,终于有人为这个领域制作了一个名叫 PDEBench 的完整基准,论文登上了 NeurIPS 2022

PDEBench 不仅能当成一个大型偏微分方程数据集,也能作为新 AI 求解偏微分方程的基准之一 ——

不少“老前辈”的预训练模型代码都能在这里找到,作为一个比对基础。

例如去年大火了一阵的 FNO,几秒钟求解出传统方法需要计算 18 个小时的偏微分方程,代码就被放进了 PDEBench 中。

这个新基准一出,LeCun 也激情转发:这领域确实很火。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第1张图片-小猪号

所以,AI 求解偏微分方程的优势是什么,这一基准具体提出了哪些评估方法?

为啥用 AI 求解偏微分方程?

偏微分方程(PDE,Partial Differential Equation),是一个生活中常见的方程。

包括预报天气、模拟飞机空气动力、预测疾病传播模型,都会用到这个方程。

目前北大数学系“韦神”韦东奕的研究方向之一,就是流体力学中的数学问题,其中就包括偏微分方程中的 Navier-Stokes 方程。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第2张图片-小猪号

所以,为啥要用 AI 来求解偏微分方程

训练 AI 的本质,是找到一种尽可能逼近真实结果的模型。

用 AI 求解偏微分方程,其实也是找到一种代理模型,来模拟偏微分方程模型。

代理模型指找到一种近似模型,在计算量更小的同时,确保计算结果与原来的偏微分方程尽可能相似。

这与传统的数值方法求解偏微分方程有着异曲同工之妙。

传统方法往往需要通过将连续问题离散化(类似在一个连续函数上切割出很多小点),来对方程进行近似求解。

然而,传统的数值方法非常复杂,计算量也很大;采用 AI 方法训练出来的模型,却模拟得又快又好 ——

继 2017 年华盛顿大学提出 PDE-FIND 后,2018 年谷歌 AI 又提出了数据驱动求解偏微分方程的方法,都比传统方法要快上不少,让更多人开始关注到 AI 求解偏微分方程这一领域。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第3张图片-小猪号

2019 年,布朗大学应用数学团队提出一种名叫 PINN (物理激发的神经网络)的方法,彻底打开了 AI 在物理学领域的广泛应用。

这篇论文在理论上虽然没有 PDE-FIND 和谷歌 AI 的方法突破性强,却给出了非常完整的代码体系,使得开发人员很容易上手,让更多研究者开发出了不同的 PINN,如今它也成为 AI 物理最常见的框架和词汇之一。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第4张图片-小猪号

PINN

去年加州理工大学和普渡大学团队发表的一项研究,更是将偏微分方程计算时间从传统求解的 18 个小时降低为 1 秒钟。

这篇论文提出了一种名为 FNO (傅里叶神经算子)的方法,基于傅里叶变换给神经网络加上“傅里叶层”,进一步节省了近似模拟算子的计算量。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第5张图片-小猪号

除此之外,也有不少研究人员通过训练一些经典 AI 模型,来求解偏微分方程,如 U-Net 等。

不过,无论是 FNO、U-Net 还是 PINN,都还是基于各自给出的基准来评估 AI 计算偏微分方程的效果。

有没有一个更统一、更通用的框架来评估这个领域的新突破?

更全面的 AI 偏微分方程基准

在这样的背景下,研究人员提出了一种名叫 PDEBench 的基准。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第6张图片-小猪号

首先是基准中包含的数据集,目前这些数据集已经全部归纳到 GitHub 中:

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第7张图片-小猪号

这里面包括不少经典偏微分方程问题,如 Navier-Stokes 方程,达西流模型、浅水波模型等等。

随后,PDEBench 提出了几个指标,来从不同角度更全面地对 AI 模型进行评估:

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第8张图片-小猪号

最后,PDEBench 还包含了几种经典模型的预训练模型代码,并将它们作为评估其他模型的基准之一,包括上述提到的 FNO、U-Net、PINN 等。

例如研究团队将这几个模型分别基于各数据集进行了训练,得出的均方根误差(RMSE)如下,也说明它们在不同偏微分方程问题上的表现并不一样:

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第9张图片-小猪号

除此之外,团队还将数据格式进行了统一,同时针对 PDEBench 的可扩展性进行了优化,因此任何人都能参与进来,给这一基准加入更多的数据集、或是更多基准模型。

值得注意的是,团队试了试分别在 PyTorch 和 JAX 两种框架上运行几种预训练模型,发现 JAX 的速度大约是 PyTorch 的 6 倍

看来以后搞相关研究可以试试 JAX 框架了。

作者介绍

作者们来自德国斯图加特大学,欧洲 NEC 研发中心,还有澳大利亚联邦科学与工业研究组织(CSIRO)旗下的 Data61 数字创新中心。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第10张图片-小猪号

Makoto Takamoto,欧洲 NEC 研发中心高级研究员,毕业于京都大学,研究方向是图像处理、图神经网络和科学机器学习。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第11张图片-小猪号

Timothy Praditia,斯图加特大学博士研究生,研究兴趣是开发基于数据驱动和先验物理知识的神经网络模型。

AI 求解偏微分方程新基准登 NeurIPS,发现 JAX 计算速度比 PyTorch 快 6 倍,LeCun 转发:这领域确实很火-第12张图片-小猪号

论文地址:

  • https://arxiv.org/abs/2210.07182

PDEBench 地址:

  • https://github.com/pdebench/PDEBench

参考链接:

  • [1]https://twitter.com/Mniepert/status/1581010273246523393

  • [2]https://mp.weixin.qq.com/s/Rbw2QFavSn8N7pPGS05o6w

本文来自微信公众号:量子位 (ID:QbitAI),作者:萧箫

标签: 人工智能 AI 微分方程

抱歉,评论功能暂时关闭!