PyTorchモデルをTorchScriptに変換する方法

  • このエントリーをはてなブックマークに追加

 機械学習モデルの構築はPythonを使うのが一般的です。しかしモバイルアプリやその他組み込み環境に機械学習ができるレベルのPython環境を整えるのはかなりシンドイと思います。ここではPyTorchのモデルをTorchScript形式に変換し、容易にデプロイができる状態を目指します。

こんにちは。wat(@watlablog)です。これまで色々とPyTorchで機械学習モデルを作ってきましたが、TorchScript形式に変換することでよりデプロイを意識します!

TorchScriptとは?

 TorchScriptは、PyTorchのモデルを中間表現Intermediate Representation: IR)に変換し、それをシリアライズ可能(保存可能)にすることで、PyTorchの実行環境(ランタイム)がある場所であればPythonなしでモデルをロード・推論できるようにします。

 例えばモバイルアプリのフレームワークであるFlutterはGoogleが開発したDart言語を使いますが、flutter_pytorchプラグインを使うことでTorchScript形式で書かれたpytorchモデルを読み込めるとのこと(ChatGPTより)。これは面白そうだと思ったので本当にできるものなのか後で検証しようと思います。

 別にモバイルに限らず、Webアプリのバックエンドで動作させるのも良いかも知れません。さらに、TorchScript形式のモデルを使うことで速度UPも期待できる[1]とのことです。TorchScript化する時に直面する課題もあると思うので、まずはやってみるというコンセプトでこの記事を書きました。

題材:2つのサイン波を分類する音声認識モデル

Advertisements

 この記事では誰でも簡単にできる題材として、100Hzと1000Hzの音声を分類する音声認識モデルを作りたいと思います。たった2つの周波数の音声であれば超シンプルなCNNモデルで表現できると思うので、機能調査には打って付けではないでしょうか。WATLABブログは音声系が得意なので、音声認識モデルを題材にしておけば今後Flutterでアプリ化する時にマイク入力等も一緒に学べて一石二鳥です。

サンプルwavファイル

 読者の皆様がすぐにテストできるように、100Hz.wavと1000Hz.wavをダウンロードできるようにしておきました。これもPythonで作っています。

機械学習モデルの作成

TorchScript無し

 まずはTorchScriptを使わず、これまで当ブログで書いてきた書き方で音声認識モデルを書きます。

学習コード

 次のコードが学習(トレーニング)を行うためのPythonコードです。 wavというフォルダの中に 100Hz.wav 1000Hz.wavを入れてコードを実行すると、プログラム実行フォルダの直下に audio_classifier.ptというPyTorchのモデルが保存されます。

テストコード

 そしてこちらが学習モデル audio_classifier.pt をロードして推論を行うテストコードです。コード内でwavファイルをパス指定して実行すると、そのファイルの音声が100Hzなのか1000Hzなのかを分類します。学習に使ったデータをそのまま読み込めば100%分類ができていることが確認できるでしょう。

TorchScript形式で機械学習モデルを出力するPythonコード

 それではTorchScript形式でモデルを保存するコードを書いて検証します。

学習コード

 TorchScriptには torchaudioの関数を含めることが可能です。今回のコードではデータの前処理として、 torchaudio.transformsでスペクトログラム変換をしています。せっかくモデルをTorchScript形式としてIRに落とし込んでも、モバイルアプリやWebアプリのバックエンドで複雑な前処理を書き直すのは面倒です。すべての関数が対応可能というわけではないようですが、できるだけ torchの機能でコード化しておけば、モデルにデータを投入するだけで結果を出すという非常にシンプルな構造にできます。

テストコード

 こちらがテストコードです。推論する関数にもはやネットワーク構造を記載する必要はなく、形を整えたwavファイルを入力するだけで分類ができます。

 特に画像等はないですが、コードを実行すると正確に分類できていることがわかりました。そして気持ち速度も上がっているようです。

 …書いていて思ったのですが、もしかして別に機械学習モデルに限らなくてもPyTorchの機能でPythonコードを書いてTorchScriptでモデルにすれば、色々な環境下でPythonコードを走らせることができる?

Pythonの強みである外部ライブラリは変換できないかもしれないけど、PyTorchには多くの便利な関数があるので、TorchScriptを使ったハック的なものができるかも!
後でやってみよう。

まとめ

 この記事ではPyTorchの機械学習モデルをさまざまなプラットフォーム上で利用できるTorchScript形式に変換する方法を紹介しました。IR(中間表現)にも torchaudioのスペクトログラム変換を含めることができることも確認しています。
 前処理をモデルに含めることができるとわかったので、今後試す予定のモバイルアプリやWebアプリ側での処理がかなり簡便になるはずです。次回はアプリ側で機械学習モデルを利用する方法について調査してみようと思います。

参考文献

[1] Qiita:TorchScriptを使用してPyTorchのモデルを保存する

TorchScriptすごい!!前処理もモデルに含めることができるのは知らなかった!
Xでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!

  • このエントリーをはてなブックマークに追加

SNSでもご購読できます。

コメントを残す

*