Sound wave neural network based on partial differential equation
-
摘要: 神经网络是一种重要的机器学习算法,在地球物理学等领域的应用得到了迅速发展,这主要得益于其在数据建模、信号处理和图像识别等方面的强大能力. 然而,神经网络的数学基础和物理解释仍然十分不足,模型内部复杂性使得难以解释其决策过程,限制了神经网络的进一步发展. 利用数学和物理方法解释神经网络的行为仍然是一个具有挑战性的任务. 本文目标是从声波偏微分方程和有限差分方法出发设计一个声波神经网络结构,该方法将一阶声波方程转化为基于有限差分的离散化声波方程,声波方程有限差分格式与神经网络传播函数具有近似的数学表达形式,可以构建一种基于声波传播物理模型的神经网络. 声波神经网络的主要特点是:(1)具有压力-速度耦合结构和层间跳跃连接的神经网络;(2)主变量-伴随变量双流网络结构改善了训练中的梯度消失问题. 从声波偏微分方程和有限差分算法出发建立的声波神经网络具有良好数学基础和清晰的物理解释,为在数学和物理方法框架内提高网络性能提供了可行性. 数值计算结果表明,声波神经网络在CIFAR-10和CIFAR-100数据集的图像分类任务中性能有明显提升,优于传统残差神经网络. 偏微分方程神经网络建模方法可以应用于许多其他类型的数学物理方程,并为深度神经网络算法提供数学和物理解释.Abstract: Applications of neural network algorithms in rock physics have developed rapidly developed, mainly due to the neural network's powerful abilities in data modeling, signal processing, and image recognition. However, mathematical and physical explanations of neural networks remain limited, which makes it difficult to understand the behavior and mechanism of neural networks and limits their further development. Using mathematical and physical methods to explain the behavior of neural networks remains a challenging task. The goal of this study was to design a sound wave neural network (SWNN) structure based on sound wave partial differential equations and finite difference methods. The method transforms the first-order sound equations into the frequency domain and discretizes them using a central difference scheme. The differential formula takes the same form as the propagation function of a neural network, enabling the construction of a sound wave neural network. The main features of the SWNN are (1) a neural network with explicitly coupled pressure-velocity streams and inter-layer connections and (2) an adjoint variable method to improve the vanishing gradient problem in network training. The sound wave neural network established from the sound wave partial differential equation and finite difference algorithm has a solid mathematical modeling process and a clear physical explanation. This makes improving network performance within the framework of the mathematical and physical methods feasible. The numerical results showed that SWNN outperforms residual neural networks in image classification on CIFAR-10 and CIFAR-100 datasets. The partial differential equation neural network modeling method proposed in this paper can be applied to many other types of mathematical physics equations, providing a deep mathematical explanation for neural networks.
-
图 1 声波神经网络结构示意图. 压力pj是输入特征,pj+4是输出响应,速度vj是伴随变量. 阴影部分的神经单元形成主变量(压力)网络,非阴影部分的神经单元形成伴随变量(速度)网络,双流网络通过声波方程的差分离散格式耦合在一起. 虚线表示输入信号和神经单元输出之间的跳跃连接
Figure 1. Schematic diagram of a sound wave propagation neural network structure. The pressure pj is the input feature and pj+4 is the output response. The velocity vj is the accompanying variable. The neural units in the shaded region form the main variable (pressure) network. The neural units outside of the shaded region form the accompanying variable (velocity) network. The two-stream networks are coupled together through the finite difference discretization of the sound wave equation. The dashed line indicates a skip connection that directly connects the input signal to the neural unit output
图 2 基于数据集CIFAR-10的不同深度(N=3, 4, 7, 8)SWNN与ResNet的训练精度对比. (a) SWNN精度随着训练迭代的增加而增加;(b) SWNN对比ResNet准确度. 在等比例线以上的数据点表明SWNN的精度较高
Figure 2. SWNN vs. ResNet training accuracy on CIFAR-10 with different network depths (N = 3, 4, 7, 8). (a) Accuracies increase as the training iteration increases; (b) SWNN vs. ResNet accuracy. Data points above the 1:1 line indicate a higher accuracy of SWNN
图 3 基于数据集CIFAR-10的不同深度(N=3, 4, 7, 8) SWNN与ResNet的验证精度对比. (a) SWNN精度随着训练迭代的增加而增加; (b)SWNN 对比 ResNet准确度. 在等比例线以上的数据点表明SWNN的精度较高
Figure 3. SWNN vs. ResNet validation accuracy on CIFAR-10 with different network depths (N = 3, 4, 7, 8). (a) Accuracies increase as training iteration increases; (b) SWNN vs. ResNet accuracy. Data points above the 1:1 line indicate a higher accuracy of SWNN
图 4 基于数据集CIFAR-10的不同深度(N=3, 4, 7, 8)SWNN和ResNet训练损失对比. (a) SWNN损失随着迭代的增加而减小;(b)SWNN 与ResNet训练损失对比. 随着迭代步的增加,数据点逐渐收敛到等比例线下方,说明SWNN的训练损失在逐渐降低,并且低于ResNet的训练损失
Figure 4. SWNN vs. ResNet training loss on CIFAR-10 with different network depths (N = 3, 4, 7, 8). (a) Training loss decreases as iteration increases; (b) SWNN vs. ResNet training loss. Data points below the 1:1 line indicate lower loss of SWNN as iteration increases
图 5 基于数据集CIFAR-100的不同深度(N=3, 4, 7, 8) SWNN与ResNet的训练精度对比. (a) SWNN精度随着训练迭代的增加而增加;(b)SWNN 对比 ResNet准确度. 在等比例线以上的数据点表明SWNN的精度较高
Figure 5. SWNN vs. ResNet training accuracy on CIFAR-100 with different network depths (N = 3, 4, 7, 8). (a) Accuracies increase as training iteration increases; (b) SWNN vs. ResNet accuracy. Data points above the 1:1 line indicate a higher accuracy of SWNN
图 6 基于数据集CIFAR-100的不同深度的 SWNN与ResNet验证精度的比较(N=3, 4, 7, 8). (a) SWNN精度随着迭代的增加而增加;(b)SWNN 对比ResNet准确度. 在等比例线以上的数据点表明SWNN的精度较高
Figure 6. SWNN vs. ResNet validation accuracy on CIFAR-100 with different network depths (N = 3, 4, 7, 8). (a) Accuracies increase as iteration increases; (b) SWNN vs. ResNet accuracy. Data points above the 1:1 line indicate a higher accuracy of SWNN
图 7 基于数据集CIFAR-100的不同深度(N=3, 4, 7, 8)SWNN和ResNet训练损失对比. (a) SWNN损失随着迭代的增加而减小;(b)随着迭代步的增加,数据点逐渐收敛到等比例线下方,说明SWNN的训练损失在逐渐降低,并且低于ResNet的训练损失
Figure 7. SWNN vs. ResNet training loss on CIFAR-100 with different network depths (N = 3, 4, 7, 8). (a) Losses decrease as iteration increases; (b) SWNN vs. ResNet training loss. Data points below the 1:1 line indicate lower loss of SWNN as iteration increases
表 1 在CIFAR-10数据集上训练SWNN和ResNet的评估结果. 使用不同的卷积单元数(N=3, 4, 7, 8)来说明网络深度的影响
Table 1. Training and validation results for SWNN and ResNet on CIFAR-10. Different convolution unit numbers (N = 3, 4, 7, 8) are used to illustrate the effect of network depth
网络类型 训练精度 验证精度 训练损失 验证损失 SWNN3 97.16% 90.38% 0.063 0.311 SWNN4 97.81% 91.32% 0.044 0.290 SWNN7 98.66% 91.73% 0.035 0.275 SWNN8 98.85% 91.82% 0.022 0.281 ResNet3 95.16% 88.60% 0.119 0.350 ResNet4 95.34% 88.71% 0.135 0.344 ResNet7 95.24% 89.00% 0.137 0.342 ResNet8 95.29% 88.96% 0.134 0.343 表 2 在CIFAR-100数据集上训练SWNN和ResNet的评估结果. 使用不同的卷积单元数(N=3, 4, 7, 8)来说明网络深度的影响
Table 2. Training and validation results for SWNN and ResNet on CIFAR-100. Different convolution units numbers (N = 3, 4, 7, 8) are used to illustrate the effect of network depth
网络类型 训练精度 验证精度 训练损失 验证损失 SWNN3 89.92% 74.86% 0.309 0.846 SWNN4 92.19% 76.98% 0.241 0.771 SWNN7 95.31% 77.41% 0.143 0.765 SWNN8 95.56% 78.31% 0.135 0.747 ResNet3 83.59% 72.33% 0.535 0.895 ResNet4 88.28% 72.61% 0.441 0.872 ResNet7 83.59% 73.51% 0.512 0.860 ResNet8 82.41% 72.27% 0.539 0.892 表 3 在CIFAR-10数据集上的SWNN和ResNet效率对比
Table 3. Efficiency comparison of SWNN and ResNet on CIFAR-10
卷积
单元数ResNet最高
训练精度ResNet所需
训练步SWNN所需
训练步SWNN计算
量降低3 95.16% 29976 24249 19.11% 4 95.34% 30640 23957 21.81% 7 95.24% 30332 23661 21.99% 8 95.29% 29462 23623 19.82% 表 4 在CIFAR-100数据集上的SWNN和ResNet效率对比
Table 4. Efficiency comparison of SWNN and ResNet on CIFAR-100
卷积
单元数ResNet最高
训练精度ResNet所需
训练步SWNN所需
训练步SWNN计算
量降低3 83.59% 30687 23697 22.78% 4 88.28% 30619 23583 22.98% 7 83.59% 30682 23421 23.67% 8 82.41% 29898 23367 21.84% -
[1] Chang B, Meng L, Haber E, et al. 2017. Reversible architectures for arbitrarily deep residual neural networks[J]. arXiv: 1709.03698v2 [2] Chang B, Meng L, Haber E, et al. 2018. Multi-level residual networks from dynamical systems view[J]. arXiv: 1710.10348. [3] Cireşan D, Meier U, Masci J, Schmidhuber J. 2012. Multi-column deep neural network for traffic sign classification [J]. Neural Networks, 32: 333-338. doi: 10.1016/j.neunet.2012.02.023 [4] Gers F A, Schmidhuber E. 2001. LSTM recurrent networks learn simple context-free and context-sensitive languages[J]. IEEE Transactions on Neural Networks, 12(6): 1333-1340. [5] Haber E, Ruthotto L. 2017. Stable architectures for deep neural networks[J]. Inverse Problems, 34(1): 014004. [6] Hochreiter S, Bengio Y, Frasconi P, Schmidhuber J. 2001. Gradient Flow in Recurrent Nets: The Difficulty of Learning Long-term Dependencies[M]// Kremer S C, Kolen J F. A Field Guide to Dynamical Recurrent Neural Networks. IEEE Press, 237-243. [7] Hu Y, Zhao T, Xu S, et al. 2020. Neural-PDE: A RNN based neural network for solving time dependent PDEs[J]. arXiv: 2009.03892. [8] Hughes T W, Williamson I A D, Minkov M, Fan S. 2019. Wave physics as an analog recurrent neural network[J]. Science Advances, 5(12): 1-6. [9] Jiang Z, Jiang J, Yao Q, Yang G. 2023. A neural network-based PDE solving algorithm with high precision[J]. Scientific Reports, 13(1): 4479. [10] Karniadakis G E, Kevrekidis I G, Lu L, et al. 2021. Physics-informed machine learning[J]. Nature Reviews Physics, 3(6): 422-440. doi: 10.1038/s42254-021-00314-5 [11] Lee K, Parish E J. 2021. Parameterized neural ordinary differential equations: Applications to computational physics problems[J]. Proceedings of the Royal Society A, 477(2253): 20210162. doi: 10.1098/rspa.2021.0162 [12] Liao Q, Poggio T. 2016. Bridging the gaps between residual learning, recurrent neural networks and visual cortex[J]. arXiv: 1604.03640. [13] Ramabathiran A A, Ramachandran P. 2021. SPINN: Sparse, physics-based, and partially interpretable neural networks for PDEs[J]. Journal of Computational Physics, 445: 110600. doi: 10.1016/j.jcp.2021.110600 [14] Ruthotto L, Haber E. 2020. Deep neural networks motivated by partial differential equations[J]. Journal of Mathematical Imaging and Vision, 62(3): 352-364. doi: 10.1007/s10851-019-00903-1 [15] Smith W G, Leymarie F F. 2017. The machine as artist: An introduction[J]. Arts, 6(2): 28-35. [16] 孙卫涛. 2019. 基于声波传播方程的双路耦合深度学习的目标分类方法[P]. 中国: ZL201910556032.2, 2021-05-14.Sun W T. 2019. A target classification method based on coupled dual-path deep learning and acoustic wave propagation equation[P]. China Patent: ZL201910556032.2, 2021-05-14 (in Chinese). [17] Sutskever I, Vinyals O, Le Q V. 2014. Sequence to sequence learning with neural networks[J]. arXiv: 1409.3215. [18] Weinan E. 2017. A proposal on machine learning via dynamical systems[J]. Communications in Mathematics and Statistics, 5(1): 1-11. doi: 10.1007/s40304-017-0103-z [19] Zhu M, Chang B, Fu C. 2019. Convolutional neural networks combined with Runge-Kutta methods[J]. arXiv: 1802.08831. -