Numba¶
简介¶
Numba是一款可以将python函数编译为机器代码的JIT编译器,经过Numba编译的python代码(仅限数组运算),其运行速度可以接近C或FORTRAN语言。
Numba安装以及使用说明¶
思源一号上的Numba¶
先创建一个目录Numbatest并进入该目录:
mkdir Numbatest
cd Numbatest
申请计算资源并通过conda安装Numba
srun -p 64c512g -n 10 --pty /bin/bash
module load miniconda3
conda create -n numbatest
source activate numbatest
conda install -c conda-forge numba
在该目录下创建如下测试文件test.py:
import numba as nb
import numpy as np
from numba.typed import List
import time
@nb.jit('List(f4)(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def fun1(a, b, len):
res = []
for i in range(len):
res.append(a[i]+b[i])
return res
@nb.jit('ListType(f4)(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def fun2(a, b, len):
res = List()
for i in range(len):
res.append(a[i]+b[i])
return res
def fun3(a, b, len):
res = []
for i in range(len):
res.append(a[i]+b[i])
return res
if __name__ == '__main__':
len = 100000000
a = np.random.randn(len).astype(np.float32)
b = np.random.randn(len).astype(np.float32)
t1 = time.time()
c1 = fun1(a, b, len)
t2 = time.time()
c2 = fun2(a, b, len)
t3 = time.time()
c3 = fun3(a, b, len)
t4 = time.time()
print(f'fun1 cost: {t2-t1}s, \nfun2 cost: {t3-t2}s, \nfun3 cost: {t4-t3}s.')
在该目录下创建如下作业提交脚本numbatest.slurm:
#!/bin/bash
#BATCH --job-name=numbatest # 作业名
#SBATCH --partition=64c512g # 64c512g队列
#SBATCH --ntasks-per-node=10 # 每节点核数
#SBATCH -n 10 # 作业核心数
#SBATCH --output=%j.out
#SBATCH --error=%j.err
ulimit -s unlimited
ulimit -l unlimited
module load miniconda3
source activate numbatest
python3 test.py
使用如下命令提交作业:
sbatch numbatest.slurm
作业完成后在.out文件中可看到如下结果:
fun1 cost: 2.0397536754608154s,
fun2 cost: 1.9905965328216553s,
fun3 cost: 17.56288480758667s.
pi2.0上的Numba¶
此步骤和上文完全相同;
申请计算资源并通过conda安装numba
srun -p cpu -N 1 --ntasks-per-node 40 --pty /bin/bash
module load miniconda3
conda create -n numbatest
source activate numbatest
conda install -c conda-forge numba
此步骤和上文完全相同;
在该目录下创建如下作业提交脚本mpi4pytest.slurm:
#!/bin/bash
#BATCH --job-name=numbatest # 作业名
#SBATCH --partition=small # small队列
#SBATCH --ntasks-per-node=10 # 每节点核数
#SBATCH -n 10 # 作业核心数
#SBATCH --output=%j.out
#SBATCH --error=%j.err
ulimit -s unlimited
ulimit -l unlimited
module load miniconda3
source activate numbatest
python3 test.py
此步骤和上文完全相同;
此步骤和上文完全相同;
参考资料¶
最后更新:
2024 年 11 月 19 日