From 1179dab9a680b7578b87d53e58258e65a0f69b4a Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 09:25:26 +0900 Subject: [PATCH 01/16] try commit --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 11d0608..455f85a 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# rccs-pytorch \ No newline at end of file +# rccs-pytorch + +test From 6cabd1eeb2c7f9fd463e5f8af5671e3c3c5f4d91 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:02:56 +0900 Subject: [PATCH 02/16] try commit --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 455f85a..8c5bd20 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ # rccs-pytorch -test +## はじめに + + From 562a385f44301236f6682b817caacb1263730479 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:19:47 +0900 Subject: [PATCH 03/16] test --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index 8c5bd20..8748c22 100644 --- a/README.md +++ b/README.md @@ -2,4 +2,20 @@ ## はじめに +本書では、作業報告として「富岳」におけるAIフレームワークPyTorch v2のビルド手順および標準的なテストデータ(mnist)を用いた動作確認の手順について述べる。 +## AIプレームワークPyTorchのバージョンアップ + +### PyTorchおよび主要モジュールの版数 + +PyTorchのバージョンアップについて、ビルド対象であるPyTorchおよび主要モジュールの版数を示す。本作業では、Python v3.9.18、PyTorch v2.1、Numpy v1.22.4、Scipy v1.7.3、OneDNN v3.1.1、Horovod v0.26.1を採用することとした。 + +| モジュール名 | 版数 | +| Python | v3.9.18 | +| PyTorch | v2.1 | +| Numpy | v1.22.4 | +| Scipy | v1.7.3 | +| oneDNN | v3.1.1 | +|Horovod | v0.26.1 | + +### ビルド環境の整備 From fb6a67126823f7fb502887716d5036e30f12f606 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:21:21 +0900 Subject: [PATCH 04/16] test --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8748c22..c529c0b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ PyTorchのバージョンアップについて、ビルド対象であるPyTorchおよび主要モジュールの版数を示す。本作業では、Python v3.9.18、PyTorch v2.1、Numpy v1.22.4、Scipy v1.7.3、OneDNN v3.1.1、Horovod v0.26.1を採用することとした。 | モジュール名 | 版数 | +| --- | --- | | Python | v3.9.18 | | PyTorch | v2.1 | | Numpy | v1.22.4 | From 8334593a1e8665893e2453a876484ff94b6e32c3 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:30:12 +0900 Subject: [PATCH 05/16] test --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c529c0b..477962d 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ ## はじめに -本書では、作業報告として「富岳」におけるAIフレームワークPyTorch v2のビルド手順および標準的なテストデータ(mnist)を用いた動作確認の手順について述べる。 +本書では、「富岳」におけるAIフレームワークPyTorch v2のビルド手順および標準的なテストデータ(mnist)を用いた動作確認の手順について述べる。 ## AIプレームワークPyTorchのバージョンアップ ### PyTorchおよび主要モジュールの版数 -PyTorchのバージョンアップについて、ビルド対象であるPyTorchおよび主要モジュールの版数を示す。本作業では、Python v3.9.18、PyTorch v2.1、Numpy v1.22.4、Scipy v1.7.3、OneDNN v3.1.1、Horovod v0.26.1を採用することとした。 +ビルド対象であるPyTorchおよび主要モジュールの版数を示す。本作業では、Python v3.9.18、PyTorch v2.1、Numpy v1.22.4、Scipy v1.7.3、OneDNN v3.1.1、Horovod v0.26.1を採用することとした。 | モジュール名 | 版数 | | --- | --- | @@ -20,3 +20,12 @@ PyTorchのバージョンアップについて、ビルド対象であるPyTorch |Horovod | v0.26.1 | ### ビルド環境の整備 + +[200~Pytorch v2.1の「富岳」向けビルドでは、富士通Githubで公開されている” 富士通 Supercomputer PRIMEHPC FX1000/FX700 上の PyTorch 構築手順”から入手可能なPytorch v1.13.1向けのビルド用スクリプトを利用する。また、PyTorch v1.13.1における富士通言語環境向けの修正を取り込む。 +本作業においては、言語環境としてtcsds-1.2.38を用いた。 + +#### (1) 富士通GithubからPyTorchをクローンする。 + +``` +$ git clone https://github.com/fujitsu/pytorch.git +``` From f0f66d5b69acf9c9080fb7bb40adad4a572c3993 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:35:40 +0900 Subject: [PATCH 06/16] test --- README.md | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/README.md b/README.md index 477962d..48ad169 100644 --- a/README.md +++ b/README.md @@ -29,3 +29,53 @@ ``` $ git clone https://github.com/fujitsu/pytorch.git ``` + +#### (2) pytorch/ディレクトリへ移動し、公式PyTorchのリポジトリを認識する。 + +``` +$ cd pytorch +$ git remote add upstream https://github.com/pytorch/pytorch.git +$ git fetch upstream v2.1.0 +``` + +#### (3) 公式v2.1.0をベースに新しいブランチを作成する。 + +``` +$ git checkout -b r2.1.0_for_a64fx FETCH_HEAD +``` + +#### (4) 富士通PyTorch v1.13.1から、ビルド用スクリプト一式を取り込む。 + +``` +$ git cherry-pick 17afed104f0a2ac47bab78aebf584fb3c578e707 +$ git reset --mixed HEAD^ +$ git add scripts/fujitsu --all +$ git commit -m "add fujitsu/script" +``` + +#### (5) 2つの富士通コンパイラ向けのブランチをcherry-pickする。 +1つ目では、8x8c1x4-dq-packedA-aarch64-neon.Sへの修正を取り込む。 +``` +$ git cherry-pick e81f6c00acef2cebaaca9e5085fa6a2b0181ecd4 +$ git checkout --theirs aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +$ git add aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +$ git cherry-pick --continue + # ファイル編集画面が開くが、編集せず終了する。 +``` + +2つ目では、cmakeへの修正を取り込む。 +``` +$ git cherry-pick 2f85b96ce569a8e60eb1627746fba3ee8ba12a57 +$ git status + # マージされていない2つのファイルを修正する。 +``` + +git statusで出力されるマージされていない2つのファイルを以下の通り修正する。 +1. cmake/Dependencies.cmakeの修正 +- 238、243、258行目を削除 +- 281行目の”OR FlexiBLAS_FOUND”の後に” OR SSL2_FOUND”を追加する +- 280、282、283、284行目を削除 + +2. cmake/Modules/FindOpenMP.cmake +- 262、266、278行目を削除 + From 9e926ae9f6ba146aab81a5fe16eb4b61dbe2a621 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:41:23 +0900 Subject: [PATCH 07/16] test --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 48ad169..b53e139 100644 --- a/README.md +++ b/README.md @@ -71,11 +71,11 @@ $ git status ``` git statusで出力されるマージされていない2つのファイルを以下の通り修正する。 -1. cmake/Dependencies.cmakeの修正 +**1. cmake/Dependencies.cmakeの修正** - 238、243、258行目を削除 - 281行目の”OR FlexiBLAS_FOUND”の後に” OR SSL2_FOUND”を追加する - 280、282、283、284行目を削除 -2. cmake/Modules/FindOpenMP.cmake +**2. cmake/Modules/FindOpenMP.cmakeの修正** - 262、266、278行目を削除 From c0387b1b19d214fb698d34c772de2bed7baaf18c Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 10:42:13 +0900 Subject: [PATCH 08/16] test --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b53e139..19e646c 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ $ git status ``` git statusで出力されるマージされていない2つのファイルを以下の通り修正する。 + **1. cmake/Dependencies.cmakeの修正** - 238、243、258行目を削除 - 281行目の”OR FlexiBLAS_FOUND”の後に” OR SSL2_FOUND”を追加する From c367ade85ad704cbd5d63b7cae0ae5d3c8d6ab7c Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 14:21:05 +0900 Subject: [PATCH 09/16] test --- README.md | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/README.md b/README.md index 19e646c..d3ee2a9 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,128 @@ git statusで出力されるマージされていない2つのファイルを以 - 238、243、258行目を削除 - 281行目の”OR FlexiBLAS_FOUND”の後に” OR SSL2_FOUND”を追加する - 280、282、283、284行目を削除 +- 修正後のcmake/Dependencies.cmake +``` +237 set(BLAS_LIBRARIES ${vecLib_LINKER_LIBS}) +238 elseif(BLAS STREQUAL "FlexiBLAS") +239 find_package(FlexiBLAS REQUIRED) +240 include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR}) +241 list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB}) +242 elseif(BLAS STREQUAL "SSL2") +243 if(CMAKE_CXX_COMPILER MATCHES ".*/FCC$" +244 AND CMAKE_C_COMPILER MATCHES ".*/fcc$") +245 message(STATUS "SSL2 Selected BLAS library") +246 list(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS "fjlapackexsve.so") +247 set(SSL2_FOUND ON) +248 message(STATUS "set CMAKE_SHARED_LINKER_FLAGS: -SSL2 --linkfortran") +249 set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -SSL2 --linkfortran") +250 set(WITH_BLAS "ssl2") +251 else() +252 message(STATUS "Not built using fcc and FCC.") +253 message(STATUS "CMAKE_C_COMPILER: ${CMAKE_C_COMPILER}") +254 message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") +255 endif() + : +273 if(NOT INTERN_BUILD_MOBILE) +274 set(AT_MKL_ENABLED 0) +275 set(AT_MKL_SEQUENTIAL 0) +276 set(USE_BLAS 1) +277 if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR SSL2_FOUND)) +278 message(WARNING "Preferred BLAS (" ${BLAS} ") cannot be found, now searching for a general BLAS library") +279 find_package(BLAS) +280 if(NOT BLAS_FOUND) +281 set(USE_BLAS 0) +282 endif() +283 endif() +``` **2. cmake/Modules/FindOpenMP.cmakeの修正** - 262、266、278行目を削除 +- 修正後のcmake/Modules/FindOpenMP.cmake +``` +261 set(OpenMP_libomp_LIBRARY "${MKL_OPENMP_LIBRARY}" CACHE STRING "libomp location for OpenMP") +262 if("-fopenmp=libiomp5" IN_LIST OpenMP_${LANG}_FLAG_CANDIDATES) +263 set(OPENMP_FLAG "-fopenmp=libiomp5") +264 endif() +265 elseif(CMAKE_${LANG}_COMPILER MATCHES ".*/fcc$" OR +266 CMAKE_${LANG}_COMPILER MATCHES ".*/FCC$") + : +275 endif() +276 else() +``` + +#### (6) 修正した2つのファイルをgit addする。 +``` +$ git add cmake/Dependencies.cmake cmake/Modules/FindOpenMP.cmake +$ git cherry-pick --continue + # ファイル編集画面が開くが、編集せず終了する。 +``` + +### (7) scripts/fujitsu/env.srcを修正する。 +- 46行目を有効化し、コンパイラの版数としてtcsds-1.2.38を指定する。 +- 47行目をコメントアウトする。 +- 48、49行目のvenvとprefixのパスを適宜修正する。 + +### (8) scripts/fujitsu/3_venv.shを修正する。 +``` +$ sed -i -e "s/pip future six wheel/pip/g" 3_venv.sh +``` + +### (9) scripts/fujitsu/4_numpy_scipy.shを修正する。 +``` +$ sed -i -e "s/Cython>=0.29.30/Cython>=0.29.30,<3.0/g" 4_numpy_scipy.sh +``` + +### (10) scripts/fujitsu/5_pytorch.shを修正する。 +``` +$ sed -i -e "s/ONEDNN_VER=v2.7/ONEDNN_VER=v3.1.1/g" 5_pytorch.sh +$ sed -i -e "s%/third_party/oneDNN%%g" 5_pytorch.sh +$ sed -i -e "s/CFLAGS=-O3/WITH_BLAS=ssl2 CFLAGS='-O3 -Kopenmp'/g" 5_pytorch.sh +``` + +### (11) 富士通コンパイラでのエラー回避のためpatch(pytorch21_q8gemm_sparse.ptach)を適用する。 +``` +$ pwd +(somewhere)/ pytorch +$ cd aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse +$ patch -p1 -i (somewhere)/pytorch21_q8gemm_sparse.ptach +``` + +### (12) scripts/fujitsu/6_vision.shを修正する。PyTorchの版数に合わせてVisionをv1.16.0に変更する。 +``` +$ sed -i -e "s/TORCHVISION_VER=v0.14.1/TORCHVISION_VER=v0.16.0/g" 6_vision.sh +``` + +### (13) Visionの版数の変更に伴い、scripts/fujitsu/vision.patchを以下の通り修正する。 +- vision.patchの修正前 +``` +144 - // we want to precalculate indices and weights shared by all chanels, +``` +- vision.patchの修正後 +``` +144 - // we want to precalculate indices and weights shared by all channels, +``` + +### (14) scripts/fujitsu/horovod.patchを以下の通り、C++17向けのパッチをmpi_ops.pyの行の間に挿入する。 +- horovod.patchの修正後 +``` +62 +63 diff --git a/horovod/torch/CMakeLists.txt b/horovod/torch/CMakeLists.txt +64 index eecd198..b1bdee1 100644 +65 --- a/horovod/torch/CMakeLists.txt +66 +++ b/horovod/torch/CMakeLists.txt +67 @@ -63,7 +63,9 @@ endif() +68 parse_version(${Pytorch_VERSION} VERSION_DEC) +69 add_definitions(-DPYTORCH_VERSION=${VERSION_DEC} -DTORCH_API_INCLUDE_EXTENSION_H=1) +70 set(Pytorch_CXX11 ${Pytorch_CXX11} PARENT_SCOPE) +71 -if(NOT Pytorch_VERSION VERSION_LESS "1.5.0") +72 +if(NOT Pytorch_VERSION VERSION_LESS "2.1.0") +73 + set(CMAKE_CXX_STANDARD 17) +74 +elseif(NOT Pytorch_VERSION VERSION_LESS "1.5.0") +75 set(CMAKE_CXX_STANDARD 14) +76 endif() +77 +78 diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py +``` +## ビルド手順 From 2635bf85f496ffb4bdcaa05c4137a338f78f2169 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 14:27:10 +0900 Subject: [PATCH 10/16] test --- README.md | 82 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d3ee2a9..f88c3b4 100644 --- a/README.md +++ b/README.md @@ -133,29 +133,29 @@ $ git cherry-pick --continue # ファイル編集画面が開くが、編集せず終了する。 ``` -### (7) scripts/fujitsu/env.srcを修正する。 +#### (7) scripts/fujitsu/env.srcを修正する。 - 46行目を有効化し、コンパイラの版数としてtcsds-1.2.38を指定する。 - 47行目をコメントアウトする。 - 48、49行目のvenvとprefixのパスを適宜修正する。 -### (8) scripts/fujitsu/3_venv.shを修正する。 +#### (8) scripts/fujitsu/3_venv.shを修正する。 ``` $ sed -i -e "s/pip future six wheel/pip/g" 3_venv.sh ``` -### (9) scripts/fujitsu/4_numpy_scipy.shを修正する。 +#### (9) scripts/fujitsu/4_numpy_scipy.shを修正する。 ``` $ sed -i -e "s/Cython>=0.29.30/Cython>=0.29.30,<3.0/g" 4_numpy_scipy.sh ``` -### (10) scripts/fujitsu/5_pytorch.shを修正する。 +#### (10) scripts/fujitsu/5_pytorch.shを修正する。 ``` $ sed -i -e "s/ONEDNN_VER=v2.7/ONEDNN_VER=v3.1.1/g" 5_pytorch.sh $ sed -i -e "s%/third_party/oneDNN%%g" 5_pytorch.sh $ sed -i -e "s/CFLAGS=-O3/WITH_BLAS=ssl2 CFLAGS='-O3 -Kopenmp'/g" 5_pytorch.sh ``` -### (11) 富士通コンパイラでのエラー回避のためpatch(pytorch21_q8gemm_sparse.ptach)を適用する。 +#### (11) 富士通コンパイラでのエラー回避のためpatch(pytorch21_q8gemm_sparse.ptach)を適用する。 ``` $ pwd (somewhere)/ pytorch @@ -163,12 +163,12 @@ $ cd aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse $ patch -p1 -i (somewhere)/pytorch21_q8gemm_sparse.ptach ``` -### (12) scripts/fujitsu/6_vision.shを修正する。PyTorchの版数に合わせてVisionをv1.16.0に変更する。 +#### (12) scripts/fujitsu/6_vision.shを修正する。PyTorchの版数に合わせてVisionをv1.16.0に変更する。 ``` $ sed -i -e "s/TORCHVISION_VER=v0.14.1/TORCHVISION_VER=v0.16.0/g" 6_vision.sh ``` -### (13) Visionの版数の変更に伴い、scripts/fujitsu/vision.patchを以下の通り修正する。 +#### (13) Visionの版数の変更に伴い、scripts/fujitsu/vision.patchを以下の通り修正する。 - vision.patchの修正前 ``` 144 - // we want to precalculate indices and weights shared by all chanels, @@ -178,7 +178,7 @@ $ sed -i -e "s/TORCHVISION_VER=v0.14.1/TORCHVISION_VER=v0.16.0/g" 6_vision.sh 144 - // we want to precalculate indices and weights shared by all channels, ``` -### (14) scripts/fujitsu/horovod.patchを以下の通り、C++17向けのパッチをmpi_ops.pyの行の間に挿入する。 +#### (14) scripts/fujitsu/horovod.patchを以下の通り、C++17向けのパッチをmpi_ops.pyの行の間に挿入する。 - horovod.patchの修正後 ``` 62 @@ -200,4 +200,68 @@ $ sed -i -e "s/TORCHVISION_VER=v0.14.1/TORCHVISION_VER=v0.16.0/g" 6_vision.sh 78 diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py ``` -## ビルド手順 +### ビルド手順 +2.2 ビルド環境の整備が完了後、会話型ジョブにより以下の手順でビルドする。 +``` +$ cd (somewhere)/pytorch/scripts/fujitsu + +$ bash 1_python.sh +$ bash 3_venv.sh +$ bash 4_numpy_scipy.sh +$ bash 5_pytorch.sh +$ bash 6_vision.sh +$ bash 7_horovod.sh +$ bash 8_libtcmalloc.sh +``` + +ビルド用のスクリプトの実行後に出力されるpip3 list(pip3_list.txt)の内容を示す。 +``` +Package Version +------------------ ------------------ +astunparse 1.6.3 +attrs 23.2.0 +beniget 0.4.1 +certifi 2024.2.2 +cffi 1.16.0 +charset-normalizer 3.3.2 +cloudpickle 3.0.0 +Cython 0.29.37 +exceptiongroup 1.2.0 +expecttest 0.2.1 +filelock 3.13.1 +fsspec 2024.2.0 +gast 0.5.4 +horovod 0.26.1 +hypothesis 6.99.6 +idna 3.6 +iniconfig 2.0.0 +Jinja2 3.1.3 +MarkupSafe 2.1.5 +mpmath 1.3.0 +networkx 3.2.1 +numpy 1.22.4 +packaging 24.0 +Pillow 7.2.0 +pip 23.0.1 +pluggy 1.4.0 +ply 3.11 +psutil 5.9.8 +pybind11 2.11.1 +pycparser 2.21 +pytest 8.1.1 +pythran 0.15.0 +PyYAML 6.0.1 +requests 2.31.0 +SciPy 1.7.3 +setuptools 69.2.0 +six 1.16.0 +sortedcontainers 2.4.0 +sympy 1.12 +tomli 2.0.1 +torch 2.1.0a0+gitd886a2e +torchvision 0.16.0+fbb4cc5 +types-dataclasses 0.6.6 +typing_extensions 4.10.0 +urllib3 2.2.1 +wheel 0.43.0 +``` From 5eb3f53b1096c84d5fc725379ca843f89dc8f78b Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 14:42:49 +0900 Subject: [PATCH 11/16] test --- README.md | 241 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/README.md b/README.md index f88c3b4..dde258c 100644 --- a/README.md +++ b/README.md @@ -265,3 +265,244 @@ typing_extensions 4.10.0 urllib3 2.2.1 wheel 0.43.0 ``` + +### 標準的なテストデータ(mnist)を用いた動作確認 + +ビルドしたPyTorch v2.1の動作確認では、機械学習の画像認識の学習においてサンプルデータ +としてよく利用される「mnist」を用いた。 +mnistを実行するコードは公式PyTorchのgithubのexamplesから入手した。 +(https://github.com/pytorch/examples/blob/main/mnist/main.py) +また、mnistのコードを実行するスクリプトにはscripts/fujitsu/run1proc.shを流用した。 + +#### mnistの実行環境の構築 + +scripts/fujitsu/配下に以下のmnistコード(mnist.py)と実行用スクリプト(run1proc_mnist.sh)を作る。 + +- mnist.py +``` +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main() +``` + +- run1proc_mnist.sh +``` +#! /bin/bash + +set -euo pipefail + +script_basedir=$(cd $(dirname $0); pwd) +source $script_basedir/env.src +[ -v VENV_PATH ] && source $VENV_PATH/bin/activate + +set -x + +#export OMP_PROC_BIND=false +export OMP_NUM_THREADS=48 + +# For oneDNN debug +# Output debug message (CSV) to stdout. +# The message begin with 'dnnl_verbose,' which is the first entry in CSV. +#export DNNL_VERBOSE=1 # 0: (no output), 1: (exec), 2: (1 + cache hit/miss) +#export DNNL_VERBOSE_TIMESTAMP=1 + +ulimit -s 8192 + +if [ ${PMIX_RANK:-0} -eq 0 ]; then + env + pip3 list + KMP_SETTINGS=1 python3 -c "import torch; print(torch.__version__); print(torch.__config__.show()); print(torch.__config__.parallel_info())" +fi + +LD_PRELOAD=$PREFIX/lib/libtcmalloc.so python3 -u mnist.py --epoch 2 --no-cuda --no-mps +``` + +#### mnistの実行 +対話型ジョブにより計算ノードから以下のコマンドでmnistを実行する。 +``` +$ cd (somewhere)/pytorch/scripts/Fujitsu +$ bash run1proc_mnist.sh +``` + +以下の出力によりmnistがPyTorch v2.1で正常に動作していることを確認した。 + +``` +Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz +Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz +100.0% +Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw +              : +Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw + +Train Epoch: 1 [0/60000 (0%)] Loss: 2.329474 +Train Epoch: 1 [640/60000 (1%)] Loss: 1.425025 +Train Epoch: 1 [1280/60000 (2%)] Loss: 0.797880 +Train Epoch: 1 [1920/60000 (3%)] Loss: 0.536058 +Train Epoch: 1 [2560/60000 (4%)] Loss: 0.438659 +Train Epoch: 1 [3200/60000 (5%)] Loss: 0.272091 + : +Train Epoch: 1 [56960/60000 (95%)] Loss: 0.028683 +Train Epoch: 1 [57600/60000 (96%)] Loss: 0.158729 +Train Epoch: 1 [58240/60000 (97%)] Loss: 0.003202 +Train Epoch: 1 [58880/60000 (98%)] Loss: 0.009425 +Train Epoch: 1 [59520/60000 (99%)] Loss: 0.003038 + +Test set: Average loss: 0.0458, Accuracy: 9840/10000 (98%) + +Train Epoch: 2 [0/60000 (0%)] Loss: 0.024910 +Train Epoch: 2 [640/60000 (1%)] Loss: 0.025748 +Train Epoch: 2 [1280/60000 (2%)] Loss: 0.074290 +Train Epoch: 2 [1920/60000 (3%)] Loss: 0.184948 +Train Epoch: 2 [2560/60000 (4%)] Loss: 0.053342 +Train Epoch: 2 [3200/60000 (5%)] Loss: 0.025564 + : +Train Epoch: 2 [56960/60000 (95%)] Loss: 0.032589 +Train Epoch: 2 [57600/60000 (96%)] Loss: 0.136949 +Train Epoch: 2 [58240/60000 (97%)] Loss: 0.031606 +Train Epoch: 2 [58880/60000 (98%)] Loss: 0.005720 +Train Epoch: 2 [59520/60000 (99%)] Loss: 0.002099 + +Test set: Average loss: 0.0370, Accuracy: 9870/10000 (99%) +``` From 4b4da77ece3efb66f1c3d8f2044b19d4145223f2 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 15:53:38 +0900 Subject: [PATCH 12/16] test --- run/mnist.py | 145 ++++++++++++++++++++++++++++++++++++++++++ run/run1proc_mnist.sh | 28 ++++++++ 2 files changed, 173 insertions(+) create mode 100644 run/mnist.py create mode 100644 run/run1proc_mnist.sh diff --git a/run/mnist.py b/run/mnist.py new file mode 100644 index 0000000..29d81d6 --- /dev/null +++ b/run/mnist.py @@ -0,0 +1,145 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main() diff --git a/run/run1proc_mnist.sh b/run/run1proc_mnist.sh new file mode 100644 index 0000000..2ef6dc1 --- /dev/null +++ b/run/run1proc_mnist.sh @@ -0,0 +1,28 @@ +#! /bin/bash + +set -euo pipefail + +script_basedir=$(cd $(dirname $0); pwd) +source $script_basedir/env.src +[ -v VENV_PATH ] && source $VENV_PATH/bin/activate + +set -x + +#export OMP_PROC_BIND=false +export OMP_NUM_THREADS=48 + +# For oneDNN debug +# Output debug message (CSV) to stdout. +# The message begin with 'dnnl_verbose,' which is the first entry in CSV. +#export DNNL_VERBOSE=1 # 0: (no output), 1: (exec), 2: (1 + cache hit/miss) +#export DNNL_VERBOSE_TIMESTAMP=1 + +ulimit -s 8192 + +if [ ${PMIX_RANK:-0} -eq 0 ]; then + env + pip3 list + KMP_SETTINGS=1 python3 -c "import torch; print(torch.__version__); print(torch.__config__.show()); print(torch.__config__.parallel_info())" +fi + +LD_PRELOAD=$PREFIX/lib/libtcmalloc.so python3 -u mnist.py --epoch 2 --no-cuda --no-mps From 77f8cb2d88b526895b238517c0aac141d6d7f693 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 15:57:01 +0900 Subject: [PATCH 13/16] add run instructions --- README.md | 181 +----------------------------------------------------- 1 file changed, 1 insertion(+), 180 deletions(-) diff --git a/README.md b/README.md index dde258c..ddbf918 100644 --- a/README.md +++ b/README.md @@ -276,188 +276,9 @@ mnistを実行するコードは公式PyTorchのgithubのexamplesから入手し #### mnistの実行環境の構築 -scripts/fujitsu/配下に以下のmnistコード(mnist.py)と実行用スクリプト(run1proc_mnist.sh)を作る。 - +run/ディレクトリに格納されている以下の2つのファイルをscripts/fujitsu/配下にコピーする。 - mnist.py -``` -from __future__ import print_function -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) - if args.dry_run: - break - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() - - torch.manual_seed(args.seed) - - if use_cuda: - device = torch.device("cuda") - elif use_mps: - device = torch.device("mps") - else: - device = torch.device("cpu") - - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} - if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - - model = Net().to(device) - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) - - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - test(model, device, test_loader) - scheduler.step() - - if args.save_model: - torch.save(model.state_dict(), "mnist_cnn.pt") - - -if __name__ == '__main__': - main() -``` - - run1proc_mnist.sh -``` -#! /bin/bash - -set -euo pipefail - -script_basedir=$(cd $(dirname $0); pwd) -source $script_basedir/env.src -[ -v VENV_PATH ] && source $VENV_PATH/bin/activate - -set -x - -#export OMP_PROC_BIND=false -export OMP_NUM_THREADS=48 - -# For oneDNN debug -# Output debug message (CSV) to stdout. -# The message begin with 'dnnl_verbose,' which is the first entry in CSV. -#export DNNL_VERBOSE=1 # 0: (no output), 1: (exec), 2: (1 + cache hit/miss) -#export DNNL_VERBOSE_TIMESTAMP=1 - -ulimit -s 8192 - -if [ ${PMIX_RANK:-0} -eq 0 ]; then - env - pip3 list - KMP_SETTINGS=1 python3 -c "import torch; print(torch.__version__); print(torch.__config__.show()); print(torch.__config__.parallel_info())" -fi - -LD_PRELOAD=$PREFIX/lib/libtcmalloc.so python3 -u mnist.py --epoch 2 --no-cuda --no-mps -``` #### mnistの実行 対話型ジョブにより計算ノードから以下のコマンドでmnistを実行する。 From d1c0ec1adac7853315f1bf69a89c229417391e7e Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 16:01:36 +0900 Subject: [PATCH 14/16] add patch file --- PATCH/pytorch21_q8gemm_sparse.patch | 3359 +++++++++++++++++++++++++++ 1 file changed, 3359 insertions(+) create mode 100644 PATCH/pytorch21_q8gemm_sparse.patch diff --git a/PATCH/pytorch21_q8gemm_sparse.patch b/PATCH/pytorch21_q8gemm_sparse.patch new file mode 100644 index 0000000..3144837 --- /dev/null +++ b/PATCH/pytorch21_q8gemm_sparse.patch @@ -0,0 +1,3359 @@ +diff -Naur q8gemm_sparse.orig/8x8c1x4-dq-packedA-aarch64-neon.S q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +--- q8gemm_sparse.orig/8x8c1x4-dq-packedA-aarch64-neon.S 2024-03-19 21:58:45.522119847 +0900 ++++ q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S 2024-01-24 10:42:10.000000000 +0900 +@@ -8,6 +8,24 @@ + + #include + ++#ifndef IGNORE_CODE_ALIGN_DIRECTIVES ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 .p2align 5 ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 .p2align 4 ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 .p2align 3 ++#else ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++#define NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++#endif ++ ++# Macro for separating instructions. For most builds, ; can be used, but for ++# ARM64 + Mach, ; begins a comment, and %% is used to separate instructions ++#if defined(__MACH__) ++#define XX %% ++#else ++#define XX ; ++#endif ++ + .macro TRANSPOSE_4X4_S32 vin0, vin1, vin2, vin3, temp0, temp1, temp2, temp3 + TRN1 \temp0\().4s, \vin0\().4s, \vin1\().4s + TRN2 \temp1\().4s, \vin0\().4s, \vin1\().4s +@@ -30,7 +48,7 @@ + # |params | 16 + # |-----------| + +-# void pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon( ++# void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon( + # size_t mr, + # size_t nr, + # const uint8_t* a_packed, +@@ -42,455 +60,1355 @@ + # size_t c_stride, + # size_t output_channel_index, + # const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) +-BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA__aarch64_neon ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ CMP x1, 1 ++ B.LO _7_w32 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++_0_w32: ++ /* v8 := zero */ ++ EOR v8.16b, v8.16b, v8.16b ++ /* v9 := zero */ ++ EOR v9.16b, v9.16b, v9.16b ++ ++ DUP v29.8b, v25.b[0] ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + W_INDEX_DTYPE_NUM_BYTES_ARG to point to next n */ ++ LDR w12, [x4], #4 ++ LDR w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 4 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #2 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #2 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w32 ++ ++k_loop_w32: ++ /* b0-7 (channel 0) */ ++ LD1 {v10.8b}, [x10], 8 ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDR w12, [x9], #4 ++ LDR w13, [x9], #4 ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ADD x17, x2, x13, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ LD1 {v4.8b}, [x17], 8 ++ LD1 {v5.8b}, [x17], 8 ++ LD1 {v6.8b}, [x17], 8 ++ LD1 {v7.8b}, [x17] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ USUBL v4.8h, v4.8b, v24.8b ++ USUBL v5.8h, v5.8b, v24.8b ++ USUBL v6.8h, v6.8b, v24.8b ++ USUBL v7.8h, v7.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ SMLAL v8.4s, v4.4h, v10.h[4] ++ SMLAL2 v9.4s, v4.8h, v10.h[4] ++ SMLAL v8.4s, v5.4h, v10.h[5] ++ SMLAL2 v9.4s, v5.8h, v10.h[5] ++ SMLAL v8.4s, v6.4h, v10.h[6] ++ SMLAL2 v9.4s, v6.8h, v10.h[6] ++ SUBS x8, x8, 2 ++ SMLAL v8.4s, v7.4h, v10.h[7] ++ SMLAL2 v9.4s, v7.8h, v10.h[7] ++ ++ ++ B.HS k_loop_w32 ++ ++_1_w32: ++ CMP x8, -2 ++ B.EQ _2_w32 ++ ++ /* b0-7 (channel 0) */ ++ LD1R {v10.4s}, [x10] ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDR w12, [x9] ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_2_w32: ++ /* Store result on stack */ ++ ++ /* -64 because all d8-d15 are on stack */ ++ /* + 256 bytes of buffer when nr = 1 */ ++ /* 256 because we are doing 8x8 block with each value being 4 bytes */ ++ /* Thus 64 * 4 = 256 */ ++ /* 256 + 64 = 320 */ ++ /* This is needed because after processing all nrs we will */ ++ /* load 256 bytes from stack. */ ++ /* Thus we will load accumulators back in v8, v9, v10, v11, v12, v13, v14, v15 */ ++ /* v16, v17, v18, v19, v20, v21, v22, v23 */ ++ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ ++ /* with other parts of stack storing local variables. To avoid that we just */ ++ /* create a buffer of 256 bytes inbetween to make sure pointer increment */ ++ /* never produces address that is beyond the stack frame of this function. */ ++ SUB x9, sp, 320 ++ /* Each iteration produce 8 values each of 4 bytes */ ++ /* Thus 8 x 4 = 32 bytes 2^5 */ ++ /* In this implementation, first value will be stored at */ ++ /* 1st value: sp - 64 - r1 * 32 */ ++ /* 2nd value: sp - 12 - (r1 - 1) * 32 */ ++ /* and so on. */ ++ SUB x9, x9, x1, LSL #5 ++ ST1 {v8.4s}, [x9], 16 ++ ST1 {v9.4s}, [x9] ++ ++ /* Shift zero point vector by 8 to load */ ++ /* zero point of the next channel */ ++ SRI v25.2d, v25.2d, #8 ++ /* Check if nr >=1 */ ++ SUBS x1, x1, 1 ++ BHI _0_w32 ++_3_w32: ++ /* First load all the accumulators from stack */ ++ /* Load nr */ ++ SUB x9, sp, 320 ++ SUB x9, x9, x11, LSL #5 ++ /* Now load v8-v15 */ ++ /* This is 8x4 block (nrxmr) */ ++ /* We will transpose this to 4x8 (mrxnr) */ ++ /* v8, v9 : x00, x10, x20, x30; x40, x50, x60, x70 */ ++ /* v10, v11 : x01, x11, x21, x31; x41, x51, x61, x71 */ ++ /* v12, v13 : x02, x12, x22, x32; x42, x52, x62, x72 */ ++ /* v14, v15 : x03, x13, x23, x33; x43, x53, x63, x73 */ ++ /* */ ++ /* v16, v17 : x04, x14, x24, x34; x44, x54, x64, x74 */ ++ /* v18, v19 : x05, x15, x25, x35; x45, x55, x65, x75 */ ++ /* v20, v21 : x06, x16, x26, x36; x46, x56, x66, x76 */ ++ /* v22, v23 : x07, x17, x27, x37; x47, x57, x67, x77 */ ++ LD1 {v8.4s}, [x9], 16 ++ LD1 {v9.4s}, [x9], 16 ++ LD1 {v10.4s}, [x9], 16 ++ LD1 {v11.4s}, [x9], 16 ++ LD1 {v12.4s}, [x9], 16 ++ LD1 {v13.4s}, [x9], 16 ++ LD1 {v14.4s}, [x9], 16 ++ LD1 {v15.4s}, [x9], 16 ++ LD1 {v16.4s}, [x9], 16 ++ LD1 {v17.4s}, [x9], 16 ++ LD1 {v18.4s}, [x9], 16 ++ LD1 {v19.4s}, [x9], 16 ++ LD1 {v20.4s}, [x9], 16 ++ LD1 {v21.4s}, [x9], 16 ++ LD1 {v22.4s}, [x9], 16 ++ LD1 {v23.4s}, [x9] ++ ++ /* We can tranpose one 4x4 block using macro */ ++ /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ ++ /* After this we have */ ++ /* v8 : x00, x01, x02, x03 */ ++ /* v10 : x10, x11, x12, x13 */ ++ /* v12 : x20, x21, x22, x23 */ ++ /* v14 : x30, x31, x32, x33 */ ++ /* Then using */ ++ /* TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 */ ++ /* We get */ ++ /* v16 : x04, x05, x06, x07 */ ++ /* v18 : x14, x15, x16, x17 */ ++ /* v20 : x24, x25, x26, x27 */ ++ /* v22 : x34, x35, x36, x37 */ ++ /* Similarly we can transpose other two 4x4 blocks and we get */ ++ /* tranposed 8x8 */ ++ ++ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 ++ TRANSPOSE_4X4_S32 v9, v11, v13, v15, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v17, v19, v21, v23, v4, v5, v6, v7 ++ ++ /* row 0: v8, v16 */ ++ /* row 1: v10, v18 */ ++ /* row 2: v12, v20 */ ++ /* row 3: v14, v22 */ ++ /* row 4: v9, v17 */ ++ /* row 5: v11, v19 */ ++ /* row 6: v13, v21 */ ++ /* row 7: v15, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v16.4s, v16.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v18.4s, v18.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v20.4s, v20.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v22.4s, v22.4s, v30.4s ++ FMUL v9.4s, v9.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v11.4s, v11.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v13.4s, v13.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v15.4s, v15.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v16.4s, v16.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v18.4s, v18.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v20.4s, v20.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v22.4s, v22.4s, v25.4s ++ FADD v9.4s, v9.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v11.4s, v11.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v13.4s, v13.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v15.4s, v15.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w32 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v16.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v18.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v20.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v22.4s}, [x8] ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v15.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w32: ++ CMP x11, 4 ++ B.LO _5_w32 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v15.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v16.16b ++ MOV v10.16b, v18.16b ++ MOV v12.16b, v20.16b ++ MOV v14.16b, v22.16b ++ MOV v9.16b, v17.16b ++ MOV v11.16b, v19.16b ++ MOV v13.16b, v21.16b ++ MOV v15.16b, v23.16b ++ ++_5_w32: ++ CMP x11, 2 ++ B.LO _6_w32 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v9.2s}, [x12], 8 ++ ST1 {v11.2s}, [x13], 8 ++ ST1 {v13.2s}, [x14], 8 ++ ST1 {v15.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v9.16b, v9.16b, v9.16b, 8 ++ EXT v11.16b, v11.16b, v11.16b, 8 ++ EXT v13.16b, v13.16b, v13.16b, 8 ++ EXT v15.16b, v15.16b, v15.16b, 8 ++ ++_6_w32: ++ CMP x11, 1 ++ B.LO _7_w32 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v9.s}[0], [x12] ++ ST1 {v11.s}[0], [x13] ++ ST1 {v13.s}[0], [x14] ++ ST1 {v15.s}[0], [x15] ++ ++_7_w32: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon + +- STP d15, d14, [sp, -16] +- STP d13, d12, [sp, -32] +- STP d11, d10, [sp, -48] +- STP d9, d8, [sp, -64] +- +- MOV x11, x1 +- # Load output channel index +- LDR x10, [sp, 8] +- # Load params +- LDR x8, [sp, 16] +- +- # Load a_zero_point +- LD1R {v24.8b}, [x8] +- ADD x8, x8, 8 +- +- # Load pointer to per channel zero points array +- LDR x17, [x8], 8 +- +- # Load pointer to per channel multiplier +- LDR x13, [x8] +- +- # Add offset to the base pointer +- ADD x17, x17, x10 +- # Mul by 4 to get byte offset for multiplier +- LSL x10, x10, 2 +- # Add offset to the base pointer for multiplier +- ADD x13, x13, x10 +- +- # Load b_zero_point +- LD1 {v25.8b}, [x17] +- # Load multiplier c0123 +- LD1 {v26.4s}, [x13], 16 +- # Load multiplier c4567 +- LD1 {v30.4s}, [x13] +- +- EOR x12, x12, x12 +- EOR x13, x13, x13 +- +- CMP x1, 1 +- B.LO 7f +- +-#ifndef IGNORE_CODE_ALIGN_DIRECTIVES +- .p2align 5 +-#endif +-0: +- # v8 := zero +- EOR v8.16b, v8.16b, v8.16b +- # v9 := zero +- EOR v9.16b, v9.16b, v9.16b +- +- DUP v29.8b, v25.b[0] +- # w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] +- # x4 = x4 + 4 to point to next n +- LDR w12, [x4], #4 +- LDR w13, [x4] +- # x10 = temp_packed_w = packed_w + w_row_ptr[n] * 4 +- # This points to the first block of nonzero value +- # for the nth row. +- ADD x10, x3, x12, LSL #2 +- # x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] +- # LSL2 because each element is 4 bytes +- # This points to the block id of the first block +- # It should contain x13 - x12 number of block ids +- ADD x9, x5, x12, LSL #2 +- # x8 = num_blocks that needs to be processed +- SUB x8, x13, x12 +- SUBS x8, x8, 2 +- B.LO 1f +- +-k_loop: +- // b0-7 (channel 0) +- LD1 {v10.8b}, [x10], 8 +- USUBL v10.8h, v10.8b, v29.8b +- +- #x12 = block_id_ptr[0] +- #x13 = block_id_ptr[1] +- LDR w12, [x9], #4 +- LDR w13, [x9], #4 +- # Add offset to x2 +- # Shift by 5 because each packed block is a block of 8x4 +- # which 32 bytes +- ADD x16, x2, x12, LSL #5 +- ADD x17, x2, x13, LSL #5 +- +- LD1 {v0.8b}, [x16], 8 +- LD1 {v1.8b}, [x16], 8 +- LD1 {v2.8b}, [x16], 8 +- LD1 {v3.8b}, [x16] +- LD1 {v4.8b}, [x17], 8 +- LD1 {v5.8b}, [x17], 8 +- LD1 {v6.8b}, [x17], 8 +- LD1 {v7.8b}, [x17] +- +- USUBL v0.8h, v0.8b, v24.8b +- USUBL v1.8h, v1.8b, v24.8b +- USUBL v2.8h, v2.8b, v24.8b +- USUBL v3.8h, v3.8b, v24.8b +- USUBL v4.8h, v4.8b, v24.8b +- USUBL v5.8h, v5.8b, v24.8b +- USUBL v6.8h, v6.8b, v24.8b +- USUBL v7.8h, v7.8b, v24.8b +- +- SMLAL v8.4s, v0.4h, v10.h[0] +- SMLAL2 v9.4s, v0.8h, v10.h[0] +- SMLAL v8.4s, v1.4h, v10.h[1] +- SMLAL2 v9.4s, v1.8h, v10.h[1] +- SMLAL v8.4s, v2.4h, v10.h[2] +- SMLAL2 v9.4s, v2.8h, v10.h[2] +- SMLAL v8.4s, v3.4h, v10.h[3] +- SMLAL2 v9.4s, v3.8h, v10.h[3] +- SMLAL v8.4s, v4.4h, v10.h[4] +- SMLAL2 v9.4s, v4.8h, v10.h[4] +- SMLAL v8.4s, v5.4h, v10.h[5] +- SMLAL2 v9.4s, v5.8h, v10.h[5] +- SMLAL v8.4s, v6.4h, v10.h[6] +- SMLAL2 v9.4s, v6.8h, v10.h[6] +- SUBS x8, x8, 2 +- SMLAL v8.4s, v7.4h, v10.h[7] +- SMLAL2 v9.4s, v7.8h, v10.h[7] +- +- +- B.HS k_loop +- +-1: +- CMP x8, -2 +- B.EQ 2f +- +- // b0-7 (channel 0) +- LD1R {v10.4s}, [x10] +- USUBL v10.8h, v10.8b, v29.8b +- +- #x12 = block_id_ptr[0] +- LDR w12, [x9] +- # Add offset to x2 +- # Shift by 5 because each packed block is a block of 8x4 +- # which 32 bytes +- ADD x16, x2, x12, LSL #5 +- +- LD1 {v0.8b}, [x16], 8 +- LD1 {v1.8b}, [x16], 8 +- LD1 {v2.8b}, [x16], 8 +- LD1 {v3.8b}, [x16] +- +- USUBL v0.8h, v0.8b, v24.8b +- USUBL v1.8h, v1.8b, v24.8b +- USUBL v2.8h, v2.8b, v24.8b +- USUBL v3.8h, v3.8b, v24.8b +- +- SMLAL v8.4s, v0.4h, v10.h[0] +- SMLAL2 v9.4s, v0.8h, v10.h[0] +- SMLAL v8.4s, v1.4h, v10.h[1] +- SMLAL2 v9.4s, v1.8h, v10.h[1] +- SMLAL v8.4s, v2.4h, v10.h[2] +- SMLAL2 v9.4s, v2.8h, v10.h[2] +- SMLAL v8.4s, v3.4h, v10.h[3] +- SMLAL2 v9.4s, v3.8h, v10.h[3] +- +-#ifndef IGNORE_CODE_ALIGN_DIRECTIVES +- .p2align 4 +-#endif +-2: +- /* +- # Store result on stack +- +- # -64 because all d8-d15 are on stack +- # + 256 bytes of buffer when nr = 1 +- # 256 because we are doing 8x8 block with each value being 4 bytes +- # Thus 64 * 4 = 256 +- # 256 + 64 = 320 +- # This is needed because after processing all nrs we will +- # load 256 bytes from stack. +- # Thus we will load accumulators back in v8, v9, v10, v11, v12, v13, v14, v15 +- # v16, v17, v18, v19, v20, v21, v22, v23 +- # When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap +- # with other parts of stack storing local variables. To avoid that we just +- # create a buffer of 256 bytes inbetween to make sure pointer increment +- # never produces address that is beyond the stack frame of this function. +- */ +- SUB x9, sp, 320 +- /* +- # Each iteration produce 8 values each of 4 bytes +- # Thus 8 x 4 = 32 bytes 2^5 +- # In this implementation, first value will be stored at +- # 1st value: sp - 64 - r1 * 32 +- # 2nd value: sp - 12 - (r1 - 1) * 32 +- # and so on. +- */ +- SUB x9, x9, x1, LSL #5 +- ST1 {v8.4s}, [x9], 16 +- ST1 {v9.4s}, [x9] +- +- # Shift zero point vector by 8 to load +- # zero point of the next channel +- SRI v25.2d, v25.2d, #8 +- # Check if nr >=1 +- SUBS x1, x1, 1 +- BHI 0b +-3: +- # First load all the accumulators from stack +- # Load nr +- SUB x9, sp, 320 +- SUB x9, x9, x11, LSL #5 +- # Now load v8-v15 +- # This is 8x4 block (nrxmr) +- # We will transpose this to 4x8 (mrxnr) +- # v8, v9 : x00, x10, x20, x30; x40, x50, x60, x70 +- # v10, v11 : x01, x11, x21, x31; x41, x51, x61, x71 +- # v12, v13 : x02, x12, x22, x32; x42, x52, x62, x72 +- # v14, v15 : x03, x13, x23, x33; x43, x53, x63, x73 +- # +- # v16, v17 : x04, x14, x24, x34; x44, x54, x64, x74 +- # v18, v19 : x05, x15, x25, x35; x45, x55, x65, x75 +- # v20, v21 : x06, x16, x26, x36; x46, x56, x66, x76 +- # v22, v23 : x07, x17, x27, x37; x47, x57, x67, x77 +- LD1 {v8.4s}, [x9], 16 +- LD1 {v9.4s}, [x9], 16 +- LD1 {v10.4s}, [x9], 16 +- LD1 {v11.4s}, [x9], 16 +- LD1 {v12.4s}, [x9], 16 +- LD1 {v13.4s}, [x9], 16 +- LD1 {v14.4s}, [x9], 16 +- LD1 {v15.4s}, [x9], 16 +- LD1 {v16.4s}, [x9], 16 +- LD1 {v17.4s}, [x9], 16 +- LD1 {v18.4s}, [x9], 16 +- LD1 {v19.4s}, [x9], 16 +- LD1 {v20.4s}, [x9], 16 +- LD1 {v21.4s}, [x9], 16 +- LD1 {v22.4s}, [x9], 16 +- LD1 {v23.4s}, [x9] +- +- # We can tranpose one 4x4 block using macro +- # TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 +- # After this we have +- # v8 : x00, x01, x02, x03 +- # v10 : x10, x11, x12, x13 +- # v12 : x20, x21, x22, x23 +- # v14 : x30, x31, x32, x33 +- # Then using +- # TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 +- # We get +- # v16 : x04, x05, x06, x07 +- # v18 : x14, x15, x16, x17 +- # v20 : x24, x25, x26, x27 +- # v22 : x34, x35, x36, x37 +- # Similarly we can transpose other two 4x4 blocks and we get +- # tranposed 8x8 +- +- TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 +- TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 +- TRANSPOSE_4X4_S32 v9, v11, v13, v15, v0, v1, v2, v3 +- TRANSPOSE_4X4_S32 v17, v19, v21, v23, v4, v5, v6, v7 +- +- # row 0: v8, v16 +- # row 1: v10, v18 +- # row 2: v12, v20 +- # row 3: v14, v22 +- # row 4: v9, v17 +- # row 5: v11, v19 +- # row 6: v13, v21 +- # row 7: v15, v23 +- +- # Load c_stride & params +- LDR x16, [sp] +- LSL x16, x16, 2 +- LD1 {v24.4s}, [x6], 16 +- LD1 {v25.4s}, [x6] +- +- SCVTF v8.4s, v8.4s +- SCVTF v9.4s, v9.4s +- SCVTF v10.4s, v10.4s +- SCVTF v11.4s, v11.4s +- SCVTF v12.4s, v12.4s +- SCVTF v13.4s, v13.4s +- SCVTF v14.4s, v14.4s +- SCVTF v15.4s, v15.4s +- SCVTF v16.4s, v16.4s +- SCVTF v17.4s, v17.4s +- SCVTF v18.4s, v18.4s +- SCVTF v19.4s, v19.4s +- SCVTF v20.4s, v20.4s +- SCVTF v21.4s, v21.4s +- SCVTF v22.4s, v22.4s +- SCVTF v23.4s, v23.4s +- +- FMUL v8.4s, v8.4s, v26.4s +- FMUL v16.4s, v16.4s, v30.4s +- FMUL v10.4s, v10.4s, v26.4s +- FMUL v18.4s, v18.4s, v30.4s +- FMUL v12.4s, v12.4s, v26.4s +- FMUL v20.4s, v20.4s, v30.4s +- FMUL v14.4s, v14.4s, v26.4s +- FMUL v22.4s, v22.4s, v30.4s +- FMUL v9.4s, v9.4s, v26.4s +- FMUL v17.4s, v17.4s, v30.4s +- FMUL v11.4s, v11.4s, v26.4s +- FMUL v19.4s, v19.4s, v30.4s +- FMUL v13.4s, v13.4s, v26.4s +- FMUL v21.4s, v21.4s, v30.4s +- FMUL v15.4s, v15.4s, v26.4s +- FMUL v23.4s, v23.4s, v30.4s +- +- FADD v8.4s, v8.4s, v24.4s +- FADD v16.4s, v16.4s, v25.4s +- FADD v10.4s, v10.4s, v24.4s +- FADD v18.4s, v18.4s, v25.4s +- FADD v12.4s, v12.4s, v24.4s +- FADD v20.4s, v20.4s, v25.4s +- FADD v14.4s, v14.4s, v24.4s +- FADD v22.4s, v22.4s, v25.4s +- FADD v9.4s, v9.4s, v24.4s +- FADD v17.4s, v17.4s, v25.4s +- FADD v11.4s, v11.4s, v24.4s +- FADD v19.4s, v19.4s, v25.4s +- FADD v13.4s, v13.4s, v24.4s +- FADD v21.4s, v21.4s, v25.4s +- FADD v15.4s, v15.4s, v24.4s +- FADD v23.4s, v23.4s, v25.4s +- +- // Compute c0-c7 +- +- ADD x9, x7, x16 +- CMP x0, 2 +- CSEL x9, x7, x9, LO +- +- ADD x10, x9, x16 +- CSEL x10, x9, x10, LS +- +- ADD x8, x10, x16 +- CMP x0, 4 +- CSEL x8, x10, x8, LO +- +- ADD x12, x8, x16 +- CSEL x12, x8, x12, LS +- +- ADD x13, x12, x16 +- CMP x0, 6 +- CSEL x13, x12, x13, LO +- +- ADD x14, x13, x16 +- CSEL x14, x13, x14, LS +- +- ADD x15, x14, x16 +- CMP x0, 8 +- CSEL x15, x14, x15, NE +- +- CMP x11, 8 +- B.NE 4f +- +- ST1 {v8.4s}, [x7], 16 +- ST1 {v16.4s}, [x7] +- ST1 {v10.4s}, [x9], 16 +- ST1 {v18.4s}, [x9] +- ST1 {v12.4s}, [x10], 16 +- ST1 {v20.4s}, [x10] +- ST1 {v14.4s}, [x8], 16 +- ST1 {v22.4s}, [x8] +- ST1 {v9.4s}, [x12], 16 +- ST1 {v17.4s}, [x12] +- ST1 {v11.4s}, [x13], 16 +- ST1 {v19.4s}, [x13] +- ST1 {v13.4s}, [x14], 16 +- ST1 {v21.4s}, [x14] +- ST1 {v15.4s}, [x15], 16 +- ST1 {v23.4s}, [x15] +- +- LDP d9, d8, [sp, -64] +- LDP d11, d10, [sp, -48] +- LDP d13, d12, [sp, -32] +- LDP d15, d14, [sp, -16] +- +- RET +- +-#ifndef IGNORE_CODE_ALIGN_DIRECTIVES +- .p2align 3 +-#endif +-4: +- CMP x11, 4 +- B.LO 5f +- +- ST1 {v8.4s}, [x7], 16 +- ST1 {v10.4s}, [x9], 16 +- ST1 {v12.4s}, [x10], 16 +- ST1 {v14.4s}, [x8], 16 +- ST1 {v9.4s}, [x12], 16 +- ST1 {v11.4s}, [x13], 16 +- ST1 {v13.4s}, [x14], 16 +- ST1 {v15.4s}, [x15], 16 +- +- SUB x11, x11, 4 +- +- MOV v8.16b, v16.16b +- MOV v10.16b, v18.16b +- MOV v12.16b, v20.16b +- MOV v14.16b, v22.16b +- MOV v9.16b, v17.16b +- MOV v11.16b, v19.16b +- MOV v13.16b, v21.16b +- MOV v15.16b, v23.16b +- +-5: +- CMP x11, 2 +- B.LO 6f +- +- ST1 {v8.2s}, [x7], 8 +- ST1 {v10.2s}, [x9], 8 +- ST1 {v12.2s}, [x10], 8 +- ST1 {v14.2s}, [x8], 8 +- ST1 {v9.2s}, [x12], 8 +- ST1 {v11.2s}, [x13], 8 +- ST1 {v13.2s}, [x14], 8 +- ST1 {v15.2s}, [x15], 8 +- +- SUB x11, x11, 2 +- +- EXT v8.16b, v8.16b, v8.16b, 8 +- EXT v10.16b, v10.16b, v10.16b, 8 +- EXT v12.16b, v12.16b, v12.16b, 8 +- EXT v14.16b, v14.16b, v14.16b, 8 +- EXT v9.16b, v9.16b, v9.16b, 8 +- EXT v11.16b, v11.16b, v11.16b, 8 +- EXT v13.16b, v13.16b, v13.16b, 8 +- EXT v15.16b, v15.16b, v15.16b, 8 +- +-6: +- CMP x11, 1 +- B.LO 7f +- +- ST1 {v8.s}[0], [x7] +- ST1 {v10.s}[0], [x9] +- ST1 {v12.s}[0], [x10] +- ST1 {v14.s}[0], [x8] +- ST1 {v9.s}[0], [x12] +- ST1 {v11.s}[0], [x13] +- ST1 {v13.s}[0], [x14] +- ST1 {v15.s}[0], [x15] +- +-7: +- LDP d9, d8, [sp, -64] +- LDP d11, d10, [sp, -48] +- LDP d13, d12, [sp, -32] +- LDP d15, d14, [sp, -16] +- +- RET ++# void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w16__aarch64_neon( ++# size_t mr, ++# size_t nr, ++# const uint8_t* a_packed, ++# const uint8_t* packed_w, ++# const uint16_t* w_row_ptr, ++# const uint16_t* w_block_ids_ptr, ++# const float* b, ++# uint8_t* restrict c, ++# size_t c_stride, ++# size_t output_channel_index, ++# const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w16__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ CMP x1, 1 ++ B.LO _7_w16 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++_0_w16: ++ /* v8 := zero */ ++ EOR v8.16b, v8.16b, v8.16b ++ /* v9 := zero */ ++ EOR v9.16b, v9.16b, v9.16b ++ ++ DUP v29.8b, v25.b[0] ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + #2 to point to next n */ ++ LDRH w12, [x4], #2 ++ LDRH w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 4 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #2 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #1 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w16 ++ ++k_loop_w16: ++ /* b0-7 (channel 0) */ ++ LD1 {v10.8b}, [x10], 8 ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDRH w12, [x9], #2 ++ LDRH w13, [x9], #2 ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ADD x17, x2, x13, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ LD1 {v4.8b}, [x17], 8 ++ LD1 {v5.8b}, [x17], 8 ++ LD1 {v6.8b}, [x17], 8 ++ LD1 {v7.8b}, [x17] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ USUBL v4.8h, v4.8b, v24.8b ++ USUBL v5.8h, v5.8b, v24.8b ++ USUBL v6.8h, v6.8b, v24.8b ++ USUBL v7.8h, v7.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ SMLAL v8.4s, v4.4h, v10.h[4] ++ SMLAL2 v9.4s, v4.8h, v10.h[4] ++ SMLAL v8.4s, v5.4h, v10.h[5] ++ SMLAL2 v9.4s, v5.8h, v10.h[5] ++ SMLAL v8.4s, v6.4h, v10.h[6] ++ SMLAL2 v9.4s, v6.8h, v10.h[6] ++ SUBS x8, x8, 2 ++ SMLAL v8.4s, v7.4h, v10.h[7] ++ SMLAL2 v9.4s, v7.8h, v10.h[7] ++ ++ ++ B.HS k_loop_w16 ++ ++_1_w16: ++ CMP x8, -2 ++ B.EQ _2_w16 ++ ++ /* b0-7 (channel 0) */ ++ LD1R {v10.4s}, [x10] ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDRH w12, [x9] ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_2_w16: ++ /* Store result on stack */ ++ ++ /* -64 because all d8-d15 are on stack */ ++ /* + 256 bytes of buffer when nr = 1 */ ++ /* 256 because we are doing 8x8 block with each value being 4 bytes */ ++ /* Thus 64 * 4 = 256 */ ++ /* 256 + 64 = 320 */ ++ /* This is needed because after processing all nrs we will */ ++ /* load 256 bytes from stack. */ ++ /* Thus we will load accumulators back in v8, v9, v10, v11, v12, v13, v14, v15 */ ++ /* v16, v17, v18, v19, v20, v21, v22, v23 */ ++ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ ++ /* with other parts of stack storing local variables. To avoid that we just */ ++ /* create a buffer of 256 bytes inbetween to make sure pointer increment */ ++ /* never produces address that is beyond the stack frame of this function. */ ++ SUB x9, sp, 320 ++ /* Each iteration produce 8 values each of 4 bytes */ ++ /* Thus 8 x 4 = 32 bytes 2^5 */ ++ /* In this implementation, first value will be stored at */ ++ /* 1st value: sp - 64 - r1 * 32 */ ++ /* 2nd value: sp - 12 - (r1 - 1) * 32 */ ++ /* and so on. */ ++ SUB x9, x9, x1, LSL #5 ++ ST1 {v8.4s}, [x9], 16 ++ ST1 {v9.4s}, [x9] ++ ++ /* Shift zero point vector by 8 to load */ ++ /* zero point of the next channel */ ++ SRI v25.2d, v25.2d, #8 ++ /* Check if nr >=1 */ ++ SUBS x1, x1, 1 ++ BHI _0_w16 ++_3_w16: ++ /* First load all the accumulators from stack */ ++ /* Load nr */ ++ SUB x9, sp, 320 ++ SUB x9, x9, x11, LSL #5 ++ /* Now load v8-v15 */ ++ /* This is 8x4 block (nrxmr) */ ++ /* We will transpose this to 4x8 (mrxnr) */ ++ /* v8, v9 : x00, x10, x20, x30; x40, x50, x60, x70 */ ++ /* v10, v11 : x01, x11, x21, x31; x41, x51, x61, x71 */ ++ /* v12, v13 : x02, x12, x22, x32; x42, x52, x62, x72 */ ++ /* v14, v15 : x03, x13, x23, x33; x43, x53, x63, x73 */ ++ /* */ ++ /* v16, v17 : x04, x14, x24, x34; x44, x54, x64, x74 */ ++ /* v18, v19 : x05, x15, x25, x35; x45, x55, x65, x75 */ ++ /* v20, v21 : x06, x16, x26, x36; x46, x56, x66, x76 */ ++ /* v22, v23 : x07, x17, x27, x37; x47, x57, x67, x77 */ ++ LD1 {v8.4s}, [x9], 16 ++ LD1 {v9.4s}, [x9], 16 ++ LD1 {v10.4s}, [x9], 16 ++ LD1 {v11.4s}, [x9], 16 ++ LD1 {v12.4s}, [x9], 16 ++ LD1 {v13.4s}, [x9], 16 ++ LD1 {v14.4s}, [x9], 16 ++ LD1 {v15.4s}, [x9], 16 ++ LD1 {v16.4s}, [x9], 16 ++ LD1 {v17.4s}, [x9], 16 ++ LD1 {v18.4s}, [x9], 16 ++ LD1 {v19.4s}, [x9], 16 ++ LD1 {v20.4s}, [x9], 16 ++ LD1 {v21.4s}, [x9], 16 ++ LD1 {v22.4s}, [x9], 16 ++ LD1 {v23.4s}, [x9] ++ ++ /* We can tranpose one 4x4 block using macro */ ++ /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ ++ /* After this we have */ ++ /* v8 : x00, x01, x02, x03 */ ++ /* v10 : x10, x11, x12, x13 */ ++ /* v12 : x20, x21, x22, x23 */ ++ /* v14 : x30, x31, x32, x33 */ ++ /* Then using */ ++ /* TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 */ ++ /* We get */ ++ /* v16 : x04, x05, x06, x07 */ ++ /* v18 : x14, x15, x16, x17 */ ++ /* v20 : x24, x25, x26, x27 */ ++ /* v22 : x34, x35, x36, x37 */ ++ /* Similarly we can transpose other two 4x4 blocks and we get */ ++ /* tranposed 8x8 */ ++ ++ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 ++ TRANSPOSE_4X4_S32 v9, v11, v13, v15, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v17, v19, v21, v23, v4, v5, v6, v7 ++ ++ /* row 0: v8, v16 */ ++ /* row 1: v10, v18 */ ++ /* row 2: v12, v20 */ ++ /* row 3: v14, v22 */ ++ /* row 4: v9, v17 */ ++ /* row 5: v11, v19 */ ++ /* row 6: v13, v21 */ ++ /* row 7: v15, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v16.4s, v16.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v18.4s, v18.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v20.4s, v20.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v22.4s, v22.4s, v30.4s ++ FMUL v9.4s, v9.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v11.4s, v11.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v13.4s, v13.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v15.4s, v15.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v16.4s, v16.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v18.4s, v18.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v20.4s, v20.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v22.4s, v22.4s, v25.4s ++ FADD v9.4s, v9.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v11.4s, v11.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v13.4s, v13.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v15.4s, v15.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w16 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v16.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v18.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v20.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v22.4s}, [x8] ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v15.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w16: ++ CMP x11, 4 ++ B.LO _5_w16 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v15.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v16.16b ++ MOV v10.16b, v18.16b ++ MOV v12.16b, v20.16b ++ MOV v14.16b, v22.16b ++ MOV v9.16b, v17.16b ++ MOV v11.16b, v19.16b ++ MOV v13.16b, v21.16b ++ MOV v15.16b, v23.16b ++ ++_5_w16: ++ CMP x11, 2 ++ B.LO _6_w16 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v9.2s}, [x12], 8 ++ ST1 {v11.2s}, [x13], 8 ++ ST1 {v13.2s}, [x14], 8 ++ ST1 {v15.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v9.16b, v9.16b, v9.16b, 8 ++ EXT v11.16b, v11.16b, v11.16b, 8 ++ EXT v13.16b, v13.16b, v13.16b, 8 ++ EXT v15.16b, v15.16b, v15.16b, 8 ++ ++_6_w16: ++ CMP x11, 1 ++ B.LO _7_w16 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v9.s}[0], [x12] ++ ST1 {v11.s}[0], [x13] ++ ST1 {v13.s}[0], [x14] ++ ST1 {v15.s}[0], [x15] ++ ++_7_w16: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w16__aarch64_neon + +-END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA__aarch64_neon ++# void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w8__aarch64_neon( ++# size_t mr, ++# size_t nr, ++# const uint8_t* a_packed, ++# const uint8_t* packed_w, ++# const uint8_t* w_row_ptr, ++# const uint8_t* w_block_ids_ptr, ++# const float* b, ++# uint8_t* restrict c, ++# size_t c_stride, ++# size_t output_channel_index, ++# const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w8__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ CMP x1, 1 ++ B.LO _7_w8 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++_0_w8: ++ /* v8 := zero */ ++ EOR v8.16b, v8.16b, v8.16b ++ /* v9 := zero */ ++ EOR v9.16b, v9.16b, v9.16b ++ ++ DUP v29.8b, v25.b[0] ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + #1 to point to next n */ ++ LDRB w12, [x4], #1 ++ LDRB w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 4 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #2 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #0 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w8 ++ ++k_loop_w8: ++ /* b0-7 (channel 0) */ ++ LD1 {v10.8b}, [x10], 8 ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDRB w12, [x9], #1 ++ LDRB w13, [x9], #1 ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ADD x17, x2, x13, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ LD1 {v4.8b}, [x17], 8 ++ LD1 {v5.8b}, [x17], 8 ++ LD1 {v6.8b}, [x17], 8 ++ LD1 {v7.8b}, [x17] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ USUBL v4.8h, v4.8b, v24.8b ++ USUBL v5.8h, v5.8b, v24.8b ++ USUBL v6.8h, v6.8b, v24.8b ++ USUBL v7.8h, v7.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ SMLAL v8.4s, v4.4h, v10.h[4] ++ SMLAL2 v9.4s, v4.8h, v10.h[4] ++ SMLAL v8.4s, v5.4h, v10.h[5] ++ SMLAL2 v9.4s, v5.8h, v10.h[5] ++ SMLAL v8.4s, v6.4h, v10.h[6] ++ SMLAL2 v9.4s, v6.8h, v10.h[6] ++ SUBS x8, x8, 2 ++ SMLAL v8.4s, v7.4h, v10.h[7] ++ SMLAL2 v9.4s, v7.8h, v10.h[7] ++ ++ ++ B.HS k_loop_w8 ++ ++_1_w8: ++ CMP x8, -2 ++ B.EQ _2_w8 ++ ++ /* b0-7 (channel 0) */ ++ LD1R {v10.4s}, [x10] ++ USUBL v10.8h, v10.8b, v29.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDRB w12, [x9] ++ /* Add offset to x2 */ ++ /* Shift by 5 because each packed block is a block of 8x4 */ ++ /* which 32 bytes */ ++ ADD x16, x2, x12, LSL #5 ++ ++ LD1 {v0.8b}, [x16], 8 ++ LD1 {v1.8b}, [x16], 8 ++ LD1 {v2.8b}, [x16], 8 ++ LD1 {v3.8b}, [x16] ++ ++ USUBL v0.8h, v0.8b, v24.8b ++ USUBL v1.8h, v1.8b, v24.8b ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v10.h[0] ++ SMLAL2 v9.4s, v0.8h, v10.h[0] ++ SMLAL v8.4s, v1.4h, v10.h[1] ++ SMLAL2 v9.4s, v1.8h, v10.h[1] ++ SMLAL v8.4s, v2.4h, v10.h[2] ++ SMLAL2 v9.4s, v2.8h, v10.h[2] ++ SMLAL v8.4s, v3.4h, v10.h[3] ++ SMLAL2 v9.4s, v3.8h, v10.h[3] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_2_w8: ++ /* Store result on stack */ ++ ++ /* -64 because all d8-d15 are on stack */ ++ /* + 256 bytes of buffer when nr = 1 */ ++ /* 256 because we are doing 8x8 block with each value being 4 bytes */ ++ /* Thus 64 * 4 = 256 */ ++ /* 256 + 64 = 320 */ ++ /* This is needed because after processing all nrs we will */ ++ /* load 256 bytes from stack. */ ++ /* Thus we will load accumulators back in v8, v9, v10, v11, v12, v13, v14, v15 */ ++ /* v16, v17, v18, v19, v20, v21, v22, v23 */ ++ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ ++ /* with other parts of stack storing local variables. To avoid that we just */ ++ /* create a buffer of 256 bytes inbetween to make sure pointer increment */ ++ /* never produces address that is beyond the stack frame of this function. */ ++ SUB x9, sp, 320 ++ /* Each iteration produce 8 values each of 4 bytes */ ++ /* Thus 8 x 4 = 32 bytes 2^5 */ ++ /* In this implementation, first value will be stored at */ ++ /* 1st value: sp - 64 - r1 * 32 */ ++ /* 2nd value: sp - 12 - (r1 - 1) * 32 */ ++ /* and so on. */ ++ SUB x9, x9, x1, LSL #5 ++ ST1 {v8.4s}, [x9], 16 ++ ST1 {v9.4s}, [x9] ++ ++ /* Shift zero point vector by 8 to load */ ++ /* zero point of the next channel */ ++ SRI v25.2d, v25.2d, #8 ++ /* Check if nr >=1 */ ++ SUBS x1, x1, 1 ++ BHI _0_w8 ++_3_w8: ++ /* First load all the accumulators from stack */ ++ /* Load nr */ ++ SUB x9, sp, 320 ++ SUB x9, x9, x11, LSL #5 ++ /* Now load v8-v15 */ ++ /* This is 8x4 block (nrxmr) */ ++ /* We will transpose this to 4x8 (mrxnr) */ ++ /* v8, v9 : x00, x10, x20, x30; x40, x50, x60, x70 */ ++ /* v10, v11 : x01, x11, x21, x31; x41, x51, x61, x71 */ ++ /* v12, v13 : x02, x12, x22, x32; x42, x52, x62, x72 */ ++ /* v14, v15 : x03, x13, x23, x33; x43, x53, x63, x73 */ ++ /* */ ++ /* v16, v17 : x04, x14, x24, x34; x44, x54, x64, x74 */ ++ /* v18, v19 : x05, x15, x25, x35; x45, x55, x65, x75 */ ++ /* v20, v21 : x06, x16, x26, x36; x46, x56, x66, x76 */ ++ /* v22, v23 : x07, x17, x27, x37; x47, x57, x67, x77 */ ++ LD1 {v8.4s}, [x9], 16 ++ LD1 {v9.4s}, [x9], 16 ++ LD1 {v10.4s}, [x9], 16 ++ LD1 {v11.4s}, [x9], 16 ++ LD1 {v12.4s}, [x9], 16 ++ LD1 {v13.4s}, [x9], 16 ++ LD1 {v14.4s}, [x9], 16 ++ LD1 {v15.4s}, [x9], 16 ++ LD1 {v16.4s}, [x9], 16 ++ LD1 {v17.4s}, [x9], 16 ++ LD1 {v18.4s}, [x9], 16 ++ LD1 {v19.4s}, [x9], 16 ++ LD1 {v20.4s}, [x9], 16 ++ LD1 {v21.4s}, [x9], 16 ++ LD1 {v22.4s}, [x9], 16 ++ LD1 {v23.4s}, [x9] ++ ++ /* We can tranpose one 4x4 block using macro */ ++ /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ ++ /* After this we have */ ++ /* v8 : x00, x01, x02, x03 */ ++ /* v10 : x10, x11, x12, x13 */ ++ /* v12 : x20, x21, x22, x23 */ ++ /* v14 : x30, x31, x32, x33 */ ++ /* Then using */ ++ /* TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 */ ++ /* We get */ ++ /* v16 : x04, x05, x06, x07 */ ++ /* v18 : x14, x15, x16, x17 */ ++ /* v20 : x24, x25, x26, x27 */ ++ /* v22 : x34, x35, x36, x37 */ ++ /* Similarly we can transpose other two 4x4 blocks and we get */ ++ /* tranposed 8x8 */ ++ ++ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 ++ TRANSPOSE_4X4_S32 v9, v11, v13, v15, v0, v1, v2, v3 ++ TRANSPOSE_4X4_S32 v17, v19, v21, v23, v4, v5, v6, v7 ++ ++ /* row 0: v8, v16 */ ++ /* row 1: v10, v18 */ ++ /* row 2: v12, v20 */ ++ /* row 3: v14, v22 */ ++ /* row 4: v9, v17 */ ++ /* row 5: v11, v19 */ ++ /* row 6: v13, v21 */ ++ /* row 7: v15, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v16.4s, v16.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v18.4s, v18.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v20.4s, v20.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v22.4s, v22.4s, v30.4s ++ FMUL v9.4s, v9.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v11.4s, v11.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v13.4s, v13.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v15.4s, v15.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v16.4s, v16.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v18.4s, v18.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v20.4s, v20.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v22.4s, v22.4s, v25.4s ++ FADD v9.4s, v9.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v11.4s, v11.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v13.4s, v13.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v15.4s, v15.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w8 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v16.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v18.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v20.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v22.4s}, [x8] ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v15.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w8: ++ CMP x11, 4 ++ B.LO _5_w8 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v9.4s}, [x12], 16 ++ ST1 {v11.4s}, [x13], 16 ++ ST1 {v13.4s}, [x14], 16 ++ ST1 {v15.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v16.16b ++ MOV v10.16b, v18.16b ++ MOV v12.16b, v20.16b ++ MOV v14.16b, v22.16b ++ MOV v9.16b, v17.16b ++ MOV v11.16b, v19.16b ++ MOV v13.16b, v21.16b ++ MOV v15.16b, v23.16b ++ ++_5_w8: ++ CMP x11, 2 ++ B.LO _6_w8 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v9.2s}, [x12], 8 ++ ST1 {v11.2s}, [x13], 8 ++ ST1 {v13.2s}, [x14], 8 ++ ST1 {v15.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v9.16b, v9.16b, v9.16b, 8 ++ EXT v11.16b, v11.16b, v11.16b, 8 ++ EXT v13.16b, v13.16b, v13.16b, 8 ++ EXT v15.16b, v15.16b, v15.16b, 8 ++ ++_6_w8: ++ CMP x11, 1 ++ B.LO _7_w8 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v9.s}[0], [x12] ++ ST1 {v11.s}[0], [x13] ++ ST1 {v13.s}[0], [x14] ++ ST1 {v15.s}[0], [x15] ++ ++_7_w8: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w8__aarch64_neon + + #ifdef __ELF__ + .section ".note.GNU-stack","",%progbits + #endif ++ ++#undef NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++#undef NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++#undef NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++#undef XX +diff -Naur q8gemm_sparse.orig/8x8c8x1-dq-packedA-aarch64-neon.S q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S +--- q8gemm_sparse.orig/8x8c8x1-dq-packedA-aarch64-neon.S 2024-03-19 21:57:24.089557428 +0900 ++++ q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S 2024-01-24 10:42:22.000000000 +0900 +@@ -37,387 +37,6 @@ + # |params | 16 + # |-----------| + +-# void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( +-# size_t mr, +-# size_t nr, +-# const uint8_t* a_packed, +-# const uint8_t* packed_w, +-# const uint##W_INDEX_DTYPE_NUM_BITS##_t* w_row_ptr, +-# const uint##W_INDEX_DTYPE_NUM_BITS##_t* w_block_ids_ptr, +-# const float* b, +-# uint8_t* restrict c, +-# size_t c_stride, +-# size_t output_channel_index, +-# const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) +-#define MAKE_PYTORCH_Q8GEMM_DQ_SPARSE_8X1_UKERNEL_8X8_PACKEDA__AARCH64_NEON(W_INDEX_DTYPE_NUM_BITS, W_INDEX_DTYPE_NUM_BYTES_ARG, W_INDEX_DTYPE_LOG_NUM_BYTES_ARG, LOAD_INDEX_INSTRUCTION) XX\ +- BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon XX\ +- XX\ +- STP d15, d14, [sp, -16] XX\ +- STP d13, d12, [sp, -32] XX\ +- STP d11, d10, [sp, -48] XX\ +- STP d9, d8, [sp, -64] XX\ +- XX\ +- MOV x11, x1 XX\ +- /* Load output channel index */ XX\ +- LDR x10, [sp, 8] XX\ +- /* Load params */ XX\ +- LDR x8, [sp, 16] XX\ +- XX\ +- /* Load a_zero_point */ XX\ +- LD1R {v24.8b}, [x8] XX\ +- ADD x8, x8, 8 XX\ +- XX\ +- /* Load pointer to per channel zero points array */ XX\ +- LDR x17, [x8], 8 XX\ +- XX\ +- /* Load pointer to per channel multiplier */ XX\ +- LDR x13, [x8] XX\ +- XX\ +- /* Add offset to the base pointer */ XX\ +- ADD x17, x17, x10 XX\ +- /* Mul by 4 to get byte offset for multiplier */ XX\ +- LSL x10, x10, 2 XX\ +- /* Add offset to the base pointer for multiplier */ XX\ +- ADD x13, x13, x10 XX\ +- XX\ +- /* Load b_zero_point */ XX\ +- LD1 {v25.8b}, [x17] XX\ +- /* Load multiplier c0123 */ XX\ +- LD1 {v26.4s}, [x13], 16 XX\ +- /* Load multiplier c4567 */ XX\ +- LD1 {v30.4s}, [x13] XX\ +- XX\ +- EOR x12, x12, x12 XX\ +- EOR x13, x13, x13 XX\ +- XX\ +- EOR v8.16b, v8.16b, v8.16b XX\ +- EOR v9.16b, v9.16b, v9.16b XX\ +- EOR v10.16b, v10.16b, v10.16b XX\ +- EOR v11.16b, v11.16b, v11.16b XX\ +- EOR v12.16b, v12.16b, v12.16b XX\ +- EOR v13.16b, v13.16b, v13.16b XX\ +- EOR v14.16b, v14.16b, v14.16b XX\ +- EOR v15.16b, v15.16b, v15.16b XX\ +- EOR v16.16b, v16.16b, v16.16b XX\ +- EOR v17.16b, v17.16b, v17.16b XX\ +- EOR v18.16b, v18.16b, v18.16b XX\ +- EOR v19.16b, v19.16b, v19.16b XX\ +- EOR v20.16b, v20.16b, v20.16b XX\ +- EOR v21.16b, v21.16b, v21.16b XX\ +- EOR v22.16b, v22.16b, v22.16b XX\ +- EOR v23.16b, v23.16b, v23.16b XX\ +- XX\ +- /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ XX\ +- /* x4 = x4 + W_INDEX_DTYPE_NUM_BYTES_ARG to point to next n */ XX\ +- LOAD_INDEX_INSTRUCTION w12, [x4], W_INDEX_DTYPE_NUM_BYTES_ARG XX\ +- LOAD_INDEX_INSTRUCTION w13, [x4] XX\ +- /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 8 */ XX\ +- /* This points to the first block of nonzero value */ XX\ +- /* for the nth row. */ XX\ +- ADD x10, x3, x12, LSL #3 XX\ +- /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ XX\ +- /* LSL for when elements are >1 byte */ XX\ +- /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ XX\ +- /* This points to the block id of the first block */ XX\ +- /* It should contain x13 - x12 number of block ids */ XX\ +- ADD x9, x5, x12, LSL W_INDEX_DTYPE_LOG_NUM_BYTES_ARG XX\ +- /* x8 = num_blocks that needs to be processed */ XX\ +- SUB x8, x13, x12 XX\ +- SUBS x8, x8, 2 XX\ +- B.LO _1_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 XX\ +- k_loop_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- /* k_loop processes two k values */ XX\ +- /* Load two 8x1 blocks */ XX\ +- LD1 {v0.8b}, [x10], 8 XX\ +- LD1 {v1.8b}, [x10], 8 XX\ +- USUBL v0.8h, v0.8b, v25.8b XX\ +- USUBL v1.8h, v1.8b, v25.8b XX\ +- XX\ +- /* x12 = block_id_ptr[0] */ XX\ +- /* x13 = block_id_ptr[1] */ XX\ +- LOAD_INDEX_INSTRUCTION w12, [x9], W_INDEX_DTYPE_NUM_BYTES_ARG XX\ +- LOAD_INDEX_INSTRUCTION w13, [x9], W_INDEX_DTYPE_NUM_BYTES_ARG XX\ +- /* Add offset to x2 */ XX\ +- /* Shift by 3 because each packed block is a block of 8x1 */ XX\ +- /* which 8 bytes */ XX\ +- ADD x16, x2, x12, LSL #3 XX\ +- ADD x17, x2, x13, LSL #3 XX\ +- XX\ +- /* Load two 8x1 blocks of activation */ XX\ +- /* First 8x1 for first channel */ XX\ +- /* second 8x1 for next channel */ XX\ +- LD1 {v2.8b}, [x16] XX\ +- LD1 {v3.8b}, [x17] XX\ +- XX\ +- USUBL v2.8h, v2.8b, v24.8b XX\ +- USUBL v3.8h, v3.8b, v24.8b XX\ +- XX\ +- /* First channel */ XX\ +- SMLAL v8.4s, v0.4h, v2.h[0] XX\ +- SMLAL2 v9.4s, v0.8h, v2.h[0] XX\ +- SMLAL v10.4s, v0.4h, v2.h[1] XX\ +- SMLAL2 v11.4s, v0.8h, v2.h[1] XX\ +- SMLAL v12.4s, v0.4h, v2.h[2] XX\ +- SMLAL2 v13.4s, v0.8h, v2.h[2] XX\ +- SMLAL v14.4s, v0.4h, v2.h[3] XX\ +- SMLAL2 v15.4s, v0.8h, v2.h[3] XX\ +- SMLAL v16.4s, v0.4h, v2.h[4] XX\ +- SMLAL2 v17.4s, v0.8h, v2.h[4] XX\ +- SMLAL v18.4s, v0.4h, v2.h[5] XX\ +- SMLAL2 v19.4s, v0.8h, v2.h[5] XX\ +- SMLAL v20.4s, v0.4h, v2.h[6] XX\ +- SMLAL2 v21.4s, v0.8h, v2.h[6] XX\ +- SMLAL v22.4s, v0.4h, v2.h[7] XX\ +- SMLAL2 v23.4s, v0.8h, v2.h[7] XX\ +- XX\ +- SUBS x8, x8, 2 XX\ +- /* Second channel */ XX\ +- SMLAL v8.4s, v1.4h, v3.h[0] XX\ +- SMLAL2 v9.4s, v1.8h, v3.h[0] XX\ +- SMLAL v10.4s, v1.4h, v3.h[1] XX\ +- SMLAL2 v11.4s, v1.8h, v3.h[1] XX\ +- SMLAL v12.4s, v1.4h, v3.h[2] XX\ +- SMLAL2 v13.4s, v1.8h, v3.h[2] XX\ +- SMLAL v14.4s, v1.4h, v3.h[3] XX\ +- SMLAL2 v15.4s, v1.8h, v3.h[3] XX\ +- SMLAL v16.4s, v1.4h, v3.h[4] XX\ +- SMLAL2 v17.4s, v1.8h, v3.h[4] XX\ +- SMLAL v18.4s, v1.4h, v3.h[5] XX\ +- SMLAL2 v19.4s, v1.8h, v3.h[5] XX\ +- SMLAL v20.4s, v1.4h, v3.h[6] XX\ +- SMLAL2 v21.4s, v1.8h, v3.h[6] XX\ +- SMLAL v22.4s, v1.4h, v3.h[7] XX\ +- SMLAL2 v23.4s, v1.8h, v3.h[7] XX\ +- XX\ +- B.HS k_loop_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- _1_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- CMP x8, -2 XX\ +- B.EQ _3_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- LD1 {v0.8b}, [x10] XX\ +- USUBL v0.8h, v0.8b, v25.8b XX\ +- XX\ +- /* x12 = block_id_ptr[0] */ XX\ +- LOAD_INDEX_INSTRUCTION w12, [x9] XX\ +- /* Add offset to x2 */ XX\ +- ADD x16, x2, x12, LSL #3 XX\ +- XX\ +- LD1 {v2.8b}, [x16] XX\ +- USUBL v2.8h, v2.8b, v24.8b XX\ +- XX\ +- SMLAL v8.4s, v0.4h, v2.h[0] XX\ +- SMLAL2 v9.4s, v0.8h, v2.h[0] XX\ +- SMLAL v10.4s, v0.4h, v2.h[1] XX\ +- SMLAL2 v11.4s, v0.8h, v2.h[1] XX\ +- SMLAL v12.4s, v0.4h, v2.h[2] XX\ +- SMLAL2 v13.4s, v0.8h, v2.h[2] XX\ +- SMLAL v14.4s, v0.4h, v2.h[3] XX\ +- SMLAL2 v15.4s, v0.8h, v2.h[3] XX\ +- SMLAL v16.4s, v0.4h, v2.h[4] XX\ +- SMLAL2 v17.4s, v0.8h, v2.h[4] XX\ +- SMLAL v18.4s, v0.4h, v2.h[5] XX\ +- SMLAL2 v19.4s, v0.8h, v2.h[5] XX\ +- SMLAL v20.4s, v0.4h, v2.h[6] XX\ +- SMLAL2 v21.4s, v0.8h, v2.h[6] XX\ +- SMLAL v22.4s, v0.4h, v2.h[7] XX\ +- SMLAL2 v23.4s, v0.8h, v2.h[7] XX\ +- XX\ +- NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 XX\ +- _3_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- /* row 0: v8, v9 */ XX\ +- /* row 1: v10, v11 */ XX\ +- /* row 2: v12, v13 */ XX\ +- /* row 3: v14, v15 */ XX\ +- /* row 4: v16, v17 */ XX\ +- /* row 5: v18, v19 */ XX\ +- /* row 6: v20, v21 */ XX\ +- /* row 7: v22, v23 */ XX\ +- XX\ +- /* Load c_stride & params */ XX\ +- LDR x16, [sp] XX\ +- LSL x16, x16, 2 XX\ +- LD1 {v24.4s}, [x6], 16 XX\ +- LD1 {v25.4s}, [x6] XX\ +- XX\ +- SCVTF v8.4s, v8.4s XX\ +- SCVTF v9.4s, v9.4s XX\ +- SCVTF v10.4s, v10.4s XX\ +- SCVTF v11.4s, v11.4s XX\ +- SCVTF v12.4s, v12.4s XX\ +- SCVTF v13.4s, v13.4s XX\ +- SCVTF v14.4s, v14.4s XX\ +- SCVTF v15.4s, v15.4s XX\ +- SCVTF v16.4s, v16.4s XX\ +- SCVTF v17.4s, v17.4s XX\ +- SCVTF v18.4s, v18.4s XX\ +- SCVTF v19.4s, v19.4s XX\ +- SCVTF v20.4s, v20.4s XX\ +- SCVTF v21.4s, v21.4s XX\ +- SCVTF v22.4s, v22.4s XX\ +- SCVTF v23.4s, v23.4s XX\ +- XX\ +- FMUL v8.4s, v8.4s, v26.4s XX\ +- FMUL v9.4s, v9.4s, v30.4s XX\ +- FMUL v10.4s, v10.4s, v26.4s XX\ +- FMUL v11.4s, v11.4s, v30.4s XX\ +- FMUL v12.4s, v12.4s, v26.4s XX\ +- FMUL v13.4s, v13.4s, v30.4s XX\ +- FMUL v14.4s, v14.4s, v26.4s XX\ +- FMUL v15.4s, v15.4s, v30.4s XX\ +- FMUL v16.4s, v16.4s, v26.4s XX\ +- FMUL v17.4s, v17.4s, v30.4s XX\ +- FMUL v18.4s, v18.4s, v26.4s XX\ +- FMUL v19.4s, v19.4s, v30.4s XX\ +- FMUL v20.4s, v20.4s, v26.4s XX\ +- FMUL v21.4s, v21.4s, v30.4s XX\ +- FMUL v22.4s, v22.4s, v26.4s XX\ +- FMUL v23.4s, v23.4s, v30.4s XX\ +- XX\ +- FADD v8.4s, v8.4s, v24.4s XX\ +- FADD v9.4s, v9.4s, v25.4s XX\ +- FADD v10.4s, v10.4s, v24.4s XX\ +- FADD v11.4s, v11.4s, v25.4s XX\ +- FADD v12.4s, v12.4s, v24.4s XX\ +- FADD v13.4s, v13.4s, v25.4s XX\ +- FADD v14.4s, v14.4s, v24.4s XX\ +- FADD v15.4s, v15.4s, v25.4s XX\ +- FADD v16.4s, v16.4s, v24.4s XX\ +- FADD v17.4s, v17.4s, v25.4s XX\ +- FADD v18.4s, v18.4s, v24.4s XX\ +- FADD v19.4s, v19.4s, v25.4s XX\ +- FADD v20.4s, v20.4s, v24.4s XX\ +- FADD v21.4s, v21.4s, v25.4s XX\ +- FADD v22.4s, v22.4s, v24.4s XX\ +- FADD v23.4s, v23.4s, v25.4s XX\ +- XX\ +- /* Compute c0-c7 */ XX\ +- XX\ +- ADD x9, x7, x16 XX\ +- CMP x0, 2 XX\ +- CSEL x9, x7, x9, LO XX\ +- XX\ +- ADD x10, x9, x16 XX\ +- CSEL x10, x9, x10, LS XX\ +- XX\ +- ADD x8, x10, x16 XX\ +- CMP x0, 4 XX\ +- CSEL x8, x10, x8, LO XX\ +- XX\ +- ADD x12, x8, x16 XX\ +- CSEL x12, x8, x12, LS XX\ +- XX\ +- ADD x13, x12, x16 XX\ +- CMP x0, 6 XX\ +- CSEL x13, x12, x13, LO XX\ +- XX\ +- ADD x14, x13, x16 XX\ +- CSEL x14, x13, x14, LS XX\ +- XX\ +- ADD x15, x14, x16 XX\ +- CMP x0, 8 XX\ +- CSEL x15, x14, x15, NE XX\ +- XX\ +- CMP x11, 8 XX\ +- B.NE _4_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- ST1 {v8.4s}, [x7], 16 XX\ +- ST1 {v9.4s}, [x7] XX\ +- ST1 {v10.4s}, [x9], 16 XX\ +- ST1 {v11.4s}, [x9] XX\ +- ST1 {v12.4s}, [x10], 16 XX\ +- ST1 {v13.4s}, [x10] XX\ +- ST1 {v14.4s}, [x8], 16 XX\ +- ST1 {v15.4s}, [x8] XX\ +- ST1 {v16.4s}, [x12], 16 XX\ +- ST1 {v17.4s}, [x12] XX\ +- ST1 {v18.4s}, [x13], 16 XX\ +- ST1 {v19.4s}, [x13] XX\ +- ST1 {v20.4s}, [x14], 16 XX\ +- ST1 {v21.4s}, [x14] XX\ +- ST1 {v22.4s}, [x15], 16 XX\ +- ST1 {v23.4s}, [x15] XX\ +- XX\ +- LDP d9, d8, [sp, -64] XX\ +- LDP d11, d10, [sp, -48] XX\ +- LDP d13, d12, [sp, -32] XX\ +- LDP d15, d14, [sp, -16] XX\ +- XX\ +- RET XX\ +- XX\ +- NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 XX\ +- _4_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- CMP x11, 4 XX\ +- B.LO _5_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- ST1 {v8.4s}, [x7], 16 XX\ +- ST1 {v10.4s}, [x9], 16 XX\ +- ST1 {v12.4s}, [x10], 16 XX\ +- ST1 {v14.4s}, [x8], 16 XX\ +- ST1 {v16.4s}, [x12], 16 XX\ +- ST1 {v18.4s}, [x13], 16 XX\ +- ST1 {v20.4s}, [x14], 16 XX\ +- ST1 {v22.4s}, [x15], 16 XX\ +- XX\ +- SUB x11, x11, 4 XX\ +- XX\ +- MOV v8.16b, v9.16b XX\ +- MOV v10.16b, v11.16b XX\ +- MOV v12.16b, v13.16b XX\ +- MOV v14.16b, v15.16b XX\ +- MOV v16.16b, v17.16b XX\ +- MOV v18.16b, v19.16b XX\ +- MOV v20.16b, v21.16b XX\ +- MOV v22.16b, v23.16b XX\ +- XX\ +- _5_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- CMP x11, 2 XX\ +- B.LO _6_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- ST1 {v8.2s}, [x7], 8 XX\ +- ST1 {v10.2s}, [x9], 8 XX\ +- ST1 {v12.2s}, [x10], 8 XX\ +- ST1 {v14.2s}, [x8], 8 XX\ +- ST1 {v16.2s}, [x12], 8 XX\ +- ST1 {v18.2s}, [x13], 8 XX\ +- ST1 {v20.2s}, [x14], 8 XX\ +- ST1 {v22.2s}, [x15], 8 XX\ +- XX\ +- SUB x11, x11, 2 XX\ +- XX\ +- EXT v8.16b, v8.16b, v8.16b, 8 XX\ +- EXT v10.16b, v10.16b, v10.16b, 8 XX\ +- EXT v12.16b, v12.16b, v12.16b, 8 XX\ +- EXT v14.16b, v14.16b, v14.16b, 8 XX\ +- EXT v16.16b, v16.16b, v16.16b, 8 XX\ +- EXT v18.16b, v18.16b, v18.16b, 8 XX\ +- EXT v20.16b, v20.16b, v20.16b, 8 XX\ +- EXT v22.16b, v22.16b, v22.16b, 8 XX\ +- XX\ +- _6_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- CMP x11, 1 XX\ +- B.LO _7_w##W_INDEX_DTYPE_NUM_BITS XX\ +- XX\ +- ST1 {v8.s}[0], [x7] XX\ +- ST1 {v10.s}[0], [x9] XX\ +- ST1 {v12.s}[0], [x10] XX\ +- ST1 {v14.s}[0], [x8] XX\ +- ST1 {v16.s}[0], [x12] XX\ +- ST1 {v18.s}[0], [x13] XX\ +- ST1 {v20.s}[0], [x14] XX\ +- ST1 {v22.s}[0], [x15] XX\ +- XX\ +- _7_w##W_INDEX_DTYPE_NUM_BITS##: XX\ +- LDP d9, d8, [sp, -64] XX\ +- LDP d11, d10, [sp, -48] XX\ +- LDP d13, d12, [sp, -32] XX\ +- LDP d15, d14, [sp, -16] XX\ +- XX\ +- RET XX\ +- XX\ +- END_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon + + # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon( + # size_t mr, +@@ -431,7 +50,375 @@ + # size_t c_stride, + # size_t output_channel_index, + # const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) +-MAKE_PYTORCH_Q8GEMM_DQ_SPARSE_8X1_UKERNEL_8X8_PACKEDA__AARCH64_NEON(32, #4, #2, LDR) ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ EOR v8.16b, v8.16b, v8.16b ++ EOR v9.16b, v9.16b, v9.16b ++ EOR v10.16b, v10.16b, v10.16b ++ EOR v11.16b, v11.16b, v11.16b ++ EOR v12.16b, v12.16b, v12.16b ++ EOR v13.16b, v13.16b, v13.16b ++ EOR v14.16b, v14.16b, v14.16b ++ EOR v15.16b, v15.16b, v15.16b ++ EOR v16.16b, v16.16b, v16.16b ++ EOR v17.16b, v17.16b, v17.16b ++ EOR v18.16b, v18.16b, v18.16b ++ EOR v19.16b, v19.16b, v19.16b ++ EOR v20.16b, v20.16b, v20.16b ++ EOR v21.16b, v21.16b, v21.16b ++ EOR v22.16b, v22.16b, v22.16b ++ EOR v23.16b, v23.16b, v23.16b ++ ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + #4 to point to next n */ ++ LDR w12, [x4], #4 ++ LDR w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 8 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #3 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #2 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w32 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++k_loop_w32: ++ /* k_loop processes two k values */ ++ /* Load two 8x1 blocks */ ++ LD1 {v0.8b}, [x10], 8 ++ LD1 {v1.8b}, [x10], 8 ++ USUBL v0.8h, v0.8b, v25.8b ++ USUBL v1.8h, v1.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDR w12, [x9], #4 ++ LDR w13, [x9], #4 ++ /* Add offset to x2 */ ++ /* Shift by 3 because each packed block is a block of 8x1 */ ++ /* which 8 bytes */ ++ ADD x16, x2, x12, LSL #3 ++ ADD x17, x2, x13, LSL #3 ++ ++ /* Load two 8x1 blocks of activation */ ++ /* First 8x1 for first channel */ ++ /* second 8x1 for next channel */ ++ LD1 {v2.8b}, [x16] ++ LD1 {v3.8b}, [x17] ++ ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ /* First channel */ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ SUBS x8, x8, 2 ++ /* Second channel */ ++ SMLAL v8.4s, v1.4h, v3.h[0] ++ SMLAL2 v9.4s, v1.8h, v3.h[0] ++ SMLAL v10.4s, v1.4h, v3.h[1] ++ SMLAL2 v11.4s, v1.8h, v3.h[1] ++ SMLAL v12.4s, v1.4h, v3.h[2] ++ SMLAL2 v13.4s, v1.8h, v3.h[2] ++ SMLAL v14.4s, v1.4h, v3.h[3] ++ SMLAL2 v15.4s, v1.8h, v3.h[3] ++ SMLAL v16.4s, v1.4h, v3.h[4] ++ SMLAL2 v17.4s, v1.8h, v3.h[4] ++ SMLAL v18.4s, v1.4h, v3.h[5] ++ SMLAL2 v19.4s, v1.8h, v3.h[5] ++ SMLAL v20.4s, v1.4h, v3.h[6] ++ SMLAL2 v21.4s, v1.8h, v3.h[6] ++ SMLAL v22.4s, v1.4h, v3.h[7] ++ SMLAL2 v23.4s, v1.8h, v3.h[7] ++ ++ B.HS k_loop_w32 ++ ++_1_w32: ++ CMP x8, -2 ++ B.EQ _3_w32 ++ ++ LD1 {v0.8b}, [x10] ++ USUBL v0.8h, v0.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDR w12, [x9] ++ /* Add offset to x2 */ ++ ADD x16, x2, x12, LSL #3 ++ ++ LD1 {v2.8b}, [x16] ++ USUBL v2.8h, v2.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_3_w32: ++ /* row 0: v8, v9 */ ++ /* row 1: v10, v11 */ ++ /* row 2: v12, v13 */ ++ /* row 3: v14, v15 */ ++ /* row 4: v16, v17 */ ++ /* row 5: v18, v19 */ ++ /* row 6: v20, v21 */ ++ /* row 7: v22, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v9.4s, v9.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v11.4s, v11.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v13.4s, v13.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v15.4s, v15.4s, v30.4s ++ FMUL v16.4s, v16.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v18.4s, v18.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v20.4s, v20.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v22.4s, v22.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v9.4s, v9.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v11.4s, v11.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v13.4s, v13.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v15.4s, v15.4s, v25.4s ++ FADD v16.4s, v16.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v18.4s, v18.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v20.4s, v20.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v22.4s, v22.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w32 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v9.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v11.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v13.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v15.4s}, [x8] ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v22.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w32: ++ CMP x11, 4 ++ B.LO _5_w32 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v22.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v9.16b ++ MOV v10.16b, v11.16b ++ MOV v12.16b, v13.16b ++ MOV v14.16b, v15.16b ++ MOV v16.16b, v17.16b ++ MOV v18.16b, v19.16b ++ MOV v20.16b, v21.16b ++ MOV v22.16b, v23.16b ++ ++_5_w32: ++ CMP x11, 2 ++ B.LO _6_w32 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v16.2s}, [x12], 8 ++ ST1 {v18.2s}, [x13], 8 ++ ST1 {v20.2s}, [x14], 8 ++ ST1 {v22.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v16.16b, v16.16b, v16.16b, 8 ++ EXT v18.16b, v18.16b, v18.16b, 8 ++ EXT v20.16b, v20.16b, v20.16b, 8 ++ EXT v22.16b, v22.16b, v22.16b, 8 ++ ++_6_w32: ++ CMP x11, 1 ++ B.LO _7_w32 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v16.s}[0], [x12] ++ ST1 {v18.s}[0], [x13] ++ ST1 {v20.s}[0], [x14] ++ ST1 {v22.s}[0], [x15] ++ ++_7_w32: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon ++ + + # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w16__aarch64_neon( + # size_t mr, +@@ -445,7 +432,374 @@ + # size_t c_stride, + # size_t output_channel_index, + # const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) +-MAKE_PYTORCH_Q8GEMM_DQ_SPARSE_8X1_UKERNEL_8X8_PACKEDA__AARCH64_NEON(16, #2, #1, LDRH) ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w16__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ EOR v8.16b, v8.16b, v8.16b ++ EOR v9.16b, v9.16b, v9.16b ++ EOR v10.16b, v10.16b, v10.16b ++ EOR v11.16b, v11.16b, v11.16b ++ EOR v12.16b, v12.16b, v12.16b ++ EOR v13.16b, v13.16b, v13.16b ++ EOR v14.16b, v14.16b, v14.16b ++ EOR v15.16b, v15.16b, v15.16b ++ EOR v16.16b, v16.16b, v16.16b ++ EOR v17.16b, v17.16b, v17.16b ++ EOR v18.16b, v18.16b, v18.16b ++ EOR v19.16b, v19.16b, v19.16b ++ EOR v20.16b, v20.16b, v20.16b ++ EOR v21.16b, v21.16b, v21.16b ++ EOR v22.16b, v22.16b, v22.16b ++ EOR v23.16b, v23.16b, v23.16b ++ ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + #2 to point to next n */ ++ LDRH w12, [x4], #2 ++ LDRH w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 8 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #3 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #1 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w16 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++k_loop_w16: ++ /* k_loop processes two k values */ ++ /* Load two 8x1 blocks */ ++ LD1 {v0.8b}, [x10], 8 ++ LD1 {v1.8b}, [x10], 8 ++ USUBL v0.8h, v0.8b, v25.8b ++ USUBL v1.8h, v1.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDRH w12, [x9], #2 ++ LDRH w13, [x9], #2 ++ /* Add offset to x2 */ ++ /* Shift by 3 because each packed block is a block of 8x1 */ ++ /* which 8 bytes */ ++ ADD x16, x2, x12, LSL #3 ++ ADD x17, x2, x13, LSL #3 ++ ++ /* Load two 8x1 blocks of activation */ ++ /* First 8x1 for first channel */ ++ /* second 8x1 for next channel */ ++ LD1 {v2.8b}, [x16] ++ LD1 {v3.8b}, [x17] ++ ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ /* First channel */ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ SUBS x8, x8, 2 ++ /* Second channel */ ++ SMLAL v8.4s, v1.4h, v3.h[0] ++ SMLAL2 v9.4s, v1.8h, v3.h[0] ++ SMLAL v10.4s, v1.4h, v3.h[1] ++ SMLAL2 v11.4s, v1.8h, v3.h[1] ++ SMLAL v12.4s, v1.4h, v3.h[2] ++ SMLAL2 v13.4s, v1.8h, v3.h[2] ++ SMLAL v14.4s, v1.4h, v3.h[3] ++ SMLAL2 v15.4s, v1.8h, v3.h[3] ++ SMLAL v16.4s, v1.4h, v3.h[4] ++ SMLAL2 v17.4s, v1.8h, v3.h[4] ++ SMLAL v18.4s, v1.4h, v3.h[5] ++ SMLAL2 v19.4s, v1.8h, v3.h[5] ++ SMLAL v20.4s, v1.4h, v3.h[6] ++ SMLAL2 v21.4s, v1.8h, v3.h[6] ++ SMLAL v22.4s, v1.4h, v3.h[7] ++ SMLAL2 v23.4s, v1.8h, v3.h[7] ++ ++ B.HS k_loop_w16 ++ ++_1_w16: ++ CMP x8, -2 ++ B.EQ _3_w16 ++ ++ LD1 {v0.8b}, [x10] ++ USUBL v0.8h, v0.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDRH w12, [x9] ++ /* Add offset to x2 */ ++ ADD x16, x2, x12, LSL #3 ++ ++ LD1 {v2.8b}, [x16] ++ USUBL v2.8h, v2.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_3_w16: ++ /* row 0: v8, v9 */ ++ /* row 1: v10, v11 */ ++ /* row 2: v12, v13 */ ++ /* row 3: v14, v15 */ ++ /* row 4: v16, v17 */ ++ /* row 5: v18, v19 */ ++ /* row 6: v20, v21 */ ++ /* row 7: v22, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v9.4s, v9.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v11.4s, v11.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v13.4s, v13.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v15.4s, v15.4s, v30.4s ++ FMUL v16.4s, v16.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v18.4s, v18.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v20.4s, v20.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v22.4s, v22.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v9.4s, v9.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v11.4s, v11.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v13.4s, v13.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v15.4s, v15.4s, v25.4s ++ FADD v16.4s, v16.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v18.4s, v18.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v20.4s, v20.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v22.4s, v22.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w16 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v9.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v11.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v13.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v15.4s}, [x8] ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v22.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w16: ++ CMP x11, 4 ++ B.LO _5_w16 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v22.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v9.16b ++ MOV v10.16b, v11.16b ++ MOV v12.16b, v13.16b ++ MOV v14.16b, v15.16b ++ MOV v16.16b, v17.16b ++ MOV v18.16b, v19.16b ++ MOV v20.16b, v21.16b ++ MOV v22.16b, v23.16b ++ ++_5_w16: ++ CMP x11, 2 ++ B.LO _6_w16 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v16.2s}, [x12], 8 ++ ST1 {v18.2s}, [x13], 8 ++ ST1 {v20.2s}, [x14], 8 ++ ST1 {v22.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v16.16b, v16.16b, v16.16b, 8 ++ EXT v18.16b, v18.16b, v18.16b, 8 ++ EXT v20.16b, v20.16b, v20.16b, 8 ++ EXT v22.16b, v22.16b, v22.16b, 8 ++ ++_6_w16: ++ CMP x11, 1 ++ B.LO _7_w16 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v16.s}[0], [x12] ++ ST1 {v18.s}[0], [x13] ++ ST1 {v20.s}[0], [x14] ++ ST1 {v22.s}[0], [x15] ++ ++_7_w16: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w16__aarch64_neon + + # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w8__aarch64_neon( + # size_t mr, +@@ -459,7 +813,374 @@ + # size_t c_stride, + # size_t output_channel_index, + # const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1]) +-MAKE_PYTORCH_Q8GEMM_DQ_SPARSE_8X1_UKERNEL_8X8_PACKEDA__AARCH64_NEON(8, #1, #0, LDRB) ++BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w8__aarch64_neon ++ ++ STP d15, d14, [sp, -16] ++ STP d13, d12, [sp, -32] ++ STP d11, d10, [sp, -48] ++ STP d9, d8, [sp, -64] ++ ++ MOV x11, x1 ++ /* Load output channel index */ ++ LDR x10, [sp, 8] ++ /* Load params */ ++ LDR x8, [sp, 16] ++ ++ /* Load a_zero_point */ ++ LD1R {v24.8b}, [x8] ++ ADD x8, x8, 8 ++ ++ /* Load pointer to per channel zero points array */ ++ LDR x17, [x8], 8 ++ ++ /* Load pointer to per channel multiplier */ ++ LDR x13, [x8] ++ ++ /* Add offset to the base pointer */ ++ ADD x17, x17, x10 ++ /* Mul by 4 to get byte offset for multiplier */ ++ LSL x10, x10, 2 ++ /* Add offset to the base pointer for multiplier */ ++ ADD x13, x13, x10 ++ ++ /* Load b_zero_point */ ++ LD1 {v25.8b}, [x17] ++ /* Load multiplier c0123 */ ++ LD1 {v26.4s}, [x13], 16 ++ /* Load multiplier c4567 */ ++ LD1 {v30.4s}, [x13] ++ ++ EOR x12, x12, x12 ++ EOR x13, x13, x13 ++ ++ EOR v8.16b, v8.16b, v8.16b ++ EOR v9.16b, v9.16b, v9.16b ++ EOR v10.16b, v10.16b, v10.16b ++ EOR v11.16b, v11.16b, v11.16b ++ EOR v12.16b, v12.16b, v12.16b ++ EOR v13.16b, v13.16b, v13.16b ++ EOR v14.16b, v14.16b, v14.16b ++ EOR v15.16b, v15.16b, v15.16b ++ EOR v16.16b, v16.16b, v16.16b ++ EOR v17.16b, v17.16b, v17.16b ++ EOR v18.16b, v18.16b, v18.16b ++ EOR v19.16b, v19.16b, v19.16b ++ EOR v20.16b, v20.16b, v20.16b ++ EOR v21.16b, v21.16b, v21.16b ++ EOR v22.16b, v22.16b, v22.16b ++ EOR v23.16b, v23.16b, v23.16b ++ ++ /* w12 = w_row_ptr[n], x13 = w_row_ptr[n+1] */ ++ /* x4 = x4 + #1 to point to next n */ ++ LDRB w12, [x4], #1 ++ LDRB w13, [x4] ++ /* x10 = temp_packed_w = packed_w + w_row_ptr[n] * 8 */ ++ /* This points to the first block of nonzero value */ ++ /* for the nth row. */ ++ ADD x10, x3, x12, LSL #3 ++ /* x9 = temp_w_block_ids_ptr = w_block_ids_ptr (x5) + w_row_ptr[n] */ ++ /* LSL for when elements are >1 byte */ ++ /* (4 bytes: LSL #2, 2 bytes: LSL #1, 1 byte: LSL #0) */ ++ /* This points to the block id of the first block */ ++ /* It should contain x13 - x12 number of block ids */ ++ ADD x9, x5, x12, LSL #0 ++ /* x8 = num_blocks that needs to be processed */ ++ SUB x8, x13, x12 ++ SUBS x8, x8, 2 ++ B.LO _1_w8 ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_5 ++k_loop_w8: ++ /* k_loop processes two k values */ ++ /* Load two 8x1 blocks */ ++ LD1 {v0.8b}, [x10], 8 ++ LD1 {v1.8b}, [x10], 8 ++ USUBL v0.8h, v0.8b, v25.8b ++ USUBL v1.8h, v1.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ /* x13 = block_id_ptr[1] */ ++ LDRB w12, [x9], #1 ++ LDRB w13, [x9], #1 ++ /* Add offset to x2 */ ++ /* Shift by 3 because each packed block is a block of 8x1 */ ++ /* which 8 bytes */ ++ ADD x16, x2, x12, LSL #3 ++ ADD x17, x2, x13, LSL #3 ++ ++ /* Load two 8x1 blocks of activation */ ++ /* First 8x1 for first channel */ ++ /* second 8x1 for next channel */ ++ LD1 {v2.8b}, [x16] ++ LD1 {v3.8b}, [x17] ++ ++ USUBL v2.8h, v2.8b, v24.8b ++ USUBL v3.8h, v3.8b, v24.8b ++ ++ /* First channel */ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ SUBS x8, x8, 2 ++ /* Second channel */ ++ SMLAL v8.4s, v1.4h, v3.h[0] ++ SMLAL2 v9.4s, v1.8h, v3.h[0] ++ SMLAL v10.4s, v1.4h, v3.h[1] ++ SMLAL2 v11.4s, v1.8h, v3.h[1] ++ SMLAL v12.4s, v1.4h, v3.h[2] ++ SMLAL2 v13.4s, v1.8h, v3.h[2] ++ SMLAL v14.4s, v1.4h, v3.h[3] ++ SMLAL2 v15.4s, v1.8h, v3.h[3] ++ SMLAL v16.4s, v1.4h, v3.h[4] ++ SMLAL2 v17.4s, v1.8h, v3.h[4] ++ SMLAL v18.4s, v1.4h, v3.h[5] ++ SMLAL2 v19.4s, v1.8h, v3.h[5] ++ SMLAL v20.4s, v1.4h, v3.h[6] ++ SMLAL2 v21.4s, v1.8h, v3.h[6] ++ SMLAL v22.4s, v1.4h, v3.h[7] ++ SMLAL2 v23.4s, v1.8h, v3.h[7] ++ ++ B.HS k_loop_w8 ++ ++_1_w8: ++ CMP x8, -2 ++ B.EQ _3_w8 ++ ++ LD1 {v0.8b}, [x10] ++ USUBL v0.8h, v0.8b, v25.8b ++ ++ /* x12 = block_id_ptr[0] */ ++ LDRB w12, [x9] ++ /* Add offset to x2 */ ++ ADD x16, x2, x12, LSL #3 ++ ++ LD1 {v2.8b}, [x16] ++ USUBL v2.8h, v2.8b, v24.8b ++ ++ SMLAL v8.4s, v0.4h, v2.h[0] ++ SMLAL2 v9.4s, v0.8h, v2.h[0] ++ SMLAL v10.4s, v0.4h, v2.h[1] ++ SMLAL2 v11.4s, v0.8h, v2.h[1] ++ SMLAL v12.4s, v0.4h, v2.h[2] ++ SMLAL2 v13.4s, v0.8h, v2.h[2] ++ SMLAL v14.4s, v0.4h, v2.h[3] ++ SMLAL2 v15.4s, v0.8h, v2.h[3] ++ SMLAL v16.4s, v0.4h, v2.h[4] ++ SMLAL2 v17.4s, v0.8h, v2.h[4] ++ SMLAL v18.4s, v0.4h, v2.h[5] ++ SMLAL2 v19.4s, v0.8h, v2.h[5] ++ SMLAL v20.4s, v0.4h, v2.h[6] ++ SMLAL2 v21.4s, v0.8h, v2.h[6] ++ SMLAL v22.4s, v0.4h, v2.h[7] ++ SMLAL2 v23.4s, v0.8h, v2.h[7] ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_4 ++_3_w8: ++ /* row 0: v8, v9 */ ++ /* row 1: v10, v11 */ ++ /* row 2: v12, v13 */ ++ /* row 3: v14, v15 */ ++ /* row 4: v16, v17 */ ++ /* row 5: v18, v19 */ ++ /* row 6: v20, v21 */ ++ /* row 7: v22, v23 */ ++ ++ /* Load c_stride & params */ ++ LDR x16, [sp] ++ LSL x16, x16, 2 ++ LD1 {v24.4s}, [x6], 16 ++ LD1 {v25.4s}, [x6] ++ ++ SCVTF v8.4s, v8.4s ++ SCVTF v9.4s, v9.4s ++ SCVTF v10.4s, v10.4s ++ SCVTF v11.4s, v11.4s ++ SCVTF v12.4s, v12.4s ++ SCVTF v13.4s, v13.4s ++ SCVTF v14.4s, v14.4s ++ SCVTF v15.4s, v15.4s ++ SCVTF v16.4s, v16.4s ++ SCVTF v17.4s, v17.4s ++ SCVTF v18.4s, v18.4s ++ SCVTF v19.4s, v19.4s ++ SCVTF v20.4s, v20.4s ++ SCVTF v21.4s, v21.4s ++ SCVTF v22.4s, v22.4s ++ SCVTF v23.4s, v23.4s ++ ++ FMUL v8.4s, v8.4s, v26.4s ++ FMUL v9.4s, v9.4s, v30.4s ++ FMUL v10.4s, v10.4s, v26.4s ++ FMUL v11.4s, v11.4s, v30.4s ++ FMUL v12.4s, v12.4s, v26.4s ++ FMUL v13.4s, v13.4s, v30.4s ++ FMUL v14.4s, v14.4s, v26.4s ++ FMUL v15.4s, v15.4s, v30.4s ++ FMUL v16.4s, v16.4s, v26.4s ++ FMUL v17.4s, v17.4s, v30.4s ++ FMUL v18.4s, v18.4s, v26.4s ++ FMUL v19.4s, v19.4s, v30.4s ++ FMUL v20.4s, v20.4s, v26.4s ++ FMUL v21.4s, v21.4s, v30.4s ++ FMUL v22.4s, v22.4s, v26.4s ++ FMUL v23.4s, v23.4s, v30.4s ++ ++ FADD v8.4s, v8.4s, v24.4s ++ FADD v9.4s, v9.4s, v25.4s ++ FADD v10.4s, v10.4s, v24.4s ++ FADD v11.4s, v11.4s, v25.4s ++ FADD v12.4s, v12.4s, v24.4s ++ FADD v13.4s, v13.4s, v25.4s ++ FADD v14.4s, v14.4s, v24.4s ++ FADD v15.4s, v15.4s, v25.4s ++ FADD v16.4s, v16.4s, v24.4s ++ FADD v17.4s, v17.4s, v25.4s ++ FADD v18.4s, v18.4s, v24.4s ++ FADD v19.4s, v19.4s, v25.4s ++ FADD v20.4s, v20.4s, v24.4s ++ FADD v21.4s, v21.4s, v25.4s ++ FADD v22.4s, v22.4s, v24.4s ++ FADD v23.4s, v23.4s, v25.4s ++ ++ /* Compute c0-c7 */ ++ ++ ADD x9, x7, x16 ++ CMP x0, 2 ++ CSEL x9, x7, x9, LO ++ ++ ADD x10, x9, x16 ++ CSEL x10, x9, x10, LS ++ ++ ADD x8, x10, x16 ++ CMP x0, 4 ++ CSEL x8, x10, x8, LO ++ ++ ADD x12, x8, x16 ++ CSEL x12, x8, x12, LS ++ ++ ADD x13, x12, x16 ++ CMP x0, 6 ++ CSEL x13, x12, x13, LO ++ ++ ADD x14, x13, x16 ++ CSEL x14, x13, x14, LS ++ ++ ADD x15, x14, x16 ++ CMP x0, 8 ++ CSEL x15, x14, x15, NE ++ ++ CMP x11, 8 ++ B.NE _4_w8 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v9.4s}, [x7] ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v11.4s}, [x9] ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v13.4s}, [x10] ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v15.4s}, [x8] ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v17.4s}, [x12] ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v19.4s}, [x13] ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v21.4s}, [x14] ++ ST1 {v22.4s}, [x15], 16 ++ ST1 {v23.4s}, [x15] ++ ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++ NDEF_IGNORE_CODE_ALIGN_DIRECTIVES_P2ALIGN_3 ++_4_w8: ++ CMP x11, 4 ++ B.LO _5_w8 ++ ++ ST1 {v8.4s}, [x7], 16 ++ ST1 {v10.4s}, [x9], 16 ++ ST1 {v12.4s}, [x10], 16 ++ ST1 {v14.4s}, [x8], 16 ++ ST1 {v16.4s}, [x12], 16 ++ ST1 {v18.4s}, [x13], 16 ++ ST1 {v20.4s}, [x14], 16 ++ ST1 {v22.4s}, [x15], 16 ++ ++ SUB x11, x11, 4 ++ ++ MOV v8.16b, v9.16b ++ MOV v10.16b, v11.16b ++ MOV v12.16b, v13.16b ++ MOV v14.16b, v15.16b ++ MOV v16.16b, v17.16b ++ MOV v18.16b, v19.16b ++ MOV v20.16b, v21.16b ++ MOV v22.16b, v23.16b ++ ++_5_w8: ++ CMP x11, 2 ++ B.LO _6_w8 ++ ++ ST1 {v8.2s}, [x7], 8 ++ ST1 {v10.2s}, [x9], 8 ++ ST1 {v12.2s}, [x10], 8 ++ ST1 {v14.2s}, [x8], 8 ++ ST1 {v16.2s}, [x12], 8 ++ ST1 {v18.2s}, [x13], 8 ++ ST1 {v20.2s}, [x14], 8 ++ ST1 {v22.2s}, [x15], 8 ++ ++ SUB x11, x11, 2 ++ ++ EXT v8.16b, v8.16b, v8.16b, 8 ++ EXT v10.16b, v10.16b, v10.16b, 8 ++ EXT v12.16b, v12.16b, v12.16b, 8 ++ EXT v14.16b, v14.16b, v14.16b, 8 ++ EXT v16.16b, v16.16b, v16.16b, 8 ++ EXT v18.16b, v18.16b, v18.16b, 8 ++ EXT v20.16b, v20.16b, v20.16b, 8 ++ EXT v22.16b, v22.16b, v22.16b, 8 ++ ++_6_w8: ++ CMP x11, 1 ++ B.LO _7_w8 ++ ++ ST1 {v8.s}[0], [x7] ++ ST1 {v10.s}[0], [x9] ++ ST1 {v12.s}[0], [x10] ++ ST1 {v14.s}[0], [x8] ++ ST1 {v16.s}[0], [x12] ++ ST1 {v18.s}[0], [x13] ++ ST1 {v20.s}[0], [x14] ++ ST1 {v22.s}[0], [x15] ++ ++_7_w8: ++ LDP d9, d8, [sp, -64] ++ LDP d11, d10, [sp, -48] ++ LDP d13, d12, [sp, -32] ++ LDP d15, d14, [sp, -16] ++ ++ RET ++ ++END_FUNCTION pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w8__aarch64_neon + + #ifdef __ELF__ + .section ".note.GNU-stack","",%progbits From 81408976f5712b8269e101ccad5520325c4687f9 Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 16:05:02 +0900 Subject: [PATCH 15/16] modified build instructions --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ddbf918..4b4ef97 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,7 @@ $ sed -i -e "s/CFLAGS=-O3/WITH_BLAS=ssl2 CFLAGS='-O3 -Kopenmp'/g" 5_pytorch.sh ``` #### (11) 富士通コンパイラでのエラー回避のためpatch(pytorch21_q8gemm_sparse.ptach)を適用する。 +PATCH/ディレクトリに格納されているパッチファイル(pytorch21_q8gemm_sparse.ptach)をビルド環境にコピーする。 ``` $ pwd (somewhere)/ pytorch From 7070808708d11e45fefa8ca25571e37840c87a8a Mon Sep 17 00:00:00 2001 From: watanabek Date: Wed, 20 Mar 2024 21:02:59 +0900 Subject: [PATCH 16/16] fix a typo --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4b4ef97..15622d8 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## はじめに -本書では、「富岳」におけるAIフレームワークPyTorch v2のビルド手順および標準的なテストデータ(mnist)を用いた動作確認の手順について述べる。 +本書では、「富岳」におけるAIフレームワークPyTorch v2系のビルド手順および標準的なテストデータ(mnist)を用いた動作確認の手順について述べる。 ## AIプレームワークPyTorchのバージョンアップ @@ -21,7 +21,7 @@ ### ビルド環境の整備 -[200~Pytorch v2.1の「富岳」向けビルドでは、富士通Githubで公開されている” 富士通 Supercomputer PRIMEHPC FX1000/FX700 上の PyTorch 構築手順”から入手可能なPytorch v1.13.1向けのビルド用スクリプトを利用する。また、PyTorch v1.13.1における富士通言語環境向けの修正を取り込む。 +Pytorch v2.1の「富岳」向けビルドでは、富士通Githubで公開されている” 富士通 Supercomputer PRIMEHPC FX1000/FX700 上の PyTorch 構築手順”から入手可能なPytorch v1.13.1向けのビルド用スクリプトを利用する。また、PyTorch v1.13.1における富士通言語環境向けの修正を取り込む。 本作業においては、言語環境としてtcsds-1.2.38を用いた。 #### (1) 富士通GithubからPyTorchをクローンする。 @@ -159,7 +159,7 @@ $ sed -i -e "s/CFLAGS=-O3/WITH_BLAS=ssl2 CFLAGS='-O3 -Kopenmp'/g" 5_pytorch.sh PATCH/ディレクトリに格納されているパッチファイル(pytorch21_q8gemm_sparse.ptach)をビルド環境にコピーする。 ``` $ pwd -(somewhere)/ pytorch +(somewhere)/pytorch $ cd aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse $ patch -p1 -i (somewhere)/pytorch21_q8gemm_sparse.ptach ```