PyTorchモデルをcloudpickleで保存・読み込みする方法

  • このエントリーをはてなブックマークに追加

scikit-learnではpickleを使って学習済のモデルを保存したり読み込んだりできていましたが、PyTorchのモデルが読み込めない問題に直面したので解決方法をメモします。ここでは最も簡単だと感じたcloudpickleを使った方法を紹介します。

こんにちは。wat(@watlablog)です。PyTorchで学習したモデルを再利用する時にハマったので、解決方法をメモします

PyTorchのモデルをpickleで再利用しようとした時のエラー

エラー発生時の状況とエラーメッセージ

pickleを使った保存(シリアライズ)と読み込み(デシリアライズ)のコードは以下の通りです。

このコードを使ってPyTorchで学習したモデルを保存・読み込みすると、読み込み時に以下のエラーが出ます。今回は回帰モデルを使っていたので「Regression」属性が取得できないと言われているようです。

net = pickle.load(f)
AttributeError: Can't get attribute 'Regression' on <module 'main' from '

エラー内容

どうやら公式ドキュメントによると、PyTorchのテンソルオブジェクトはtorch.savetorch.loadを使うそうです。さらに、state_dict()を使う方法が推奨されているようです。おそらくこの辺を理解して使うのが王道でしょう。

Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict. Python modules even have a function, load_state_dict(), to restore their states from a state dict:

SERIALIZATION SEMANTICS

しかし、自分の理解がまだ曖昧で、うまく行かない場合がありました。そんな中、簡単に成功したのがcloudpickleでした。

pickleとcloudpickleについて

pickle

そもそもpickleとは、Python標準のライブラリの一つで、オブジェクトの保存と読み込みを行う事ができるものです。

pickleについては「Python機械学習済モデルをpickleで保存して復元する方法」という記事で使い方を説明していますので是非ご覧ください。

この記事ではscikit-learnという古典的な機械学習アルゴリズム(決定木やサポートベクターマシン等)がまとめられたライブラリを使ってモデルの保存をしていました。

cloudpickle

cloudpickleはPython標準ライブラリのpickleでサポートしていない構造体をシリアライズする目的で開発されたそうです。

公式としては以下のGithubページが見つかりました。全く同じバージョンのPython間でモデルをやりとりをする事のみに使えるという事と、信頼できるソースからのモデル以外を使うとセキュリティ上危険という事が注意点でしょうか。

cloudpickle makes it possible to serialize Python constructs not supported by the default pickle module from the Python standard library.
cloudpickle is especially useful for cluster computing where Python code is shipped over the network to execute on remote hosts, possibly close to the data.
Among other things, cloudpickle supports pickling for lambda functions along with functions and classes defined interactively in the main module (for instance in a script, a shell or a Jupyter notebook).
Cloudpickle can only be used to send objects between the exact same version of Python.
Using cloudpickle for long-term object storage is not supported and strongly discouraged.
Security notice: one should only load pickle data from trusted sources as otherwise pickle.load can lead to arbitrary code execution resulting in a critical security vulnerability.

Github:cloudpickle

僕のように自分のコンピュータのみでモデルを再利用したいという目的にはちょうど良いかも知れませんが、複数人、複数環境で開発をしている人にはちょっと向かないかもですね。

cloudpickleでPyTorchのネットワークモデルを保存・再利用するコード

トレーニングモデルを保存するコード例

以下にモデルを保存するコード例を示します。コピペで動作するはずです。主要部分は「cloudpickle.save」の部分ですが、コード全体の意味を知りたい方は「PyTorchで色々な非線形関数を回帰してみたらすごかった」をご覧ください。

他にも「PyTorchのネットワークモデルを使って線形回帰をする方法」、「PyTorchのネットワークモデルをクラスで書く時のメモ」を見る事でPyTorchによる回帰について参考になると思います。

上記コードを実行する事で、10000回の学習の後にmodel.ptが作成されます。
※.ptの拡張子は任意です。PyTorchのデータは慣例から.ptがよく使われると聞いた事があるために付けています。

コードを実行すると、model.ptが作成される他に以下のプロットも表示されます。ご参考。

cloudpickle.save

保存されたトレーニングモデルを読み込んで再利用するコード例

そして以下のコードで、作成されたmodel.ptを読み込んで使う事が可能です。

 

cloudpickle.loadで読み込んでいるだけですが、データを入力するだけでしっかりとモデルの応答曲面が計算されました。

cloudpickle.load

まとめ

今回はショートTips的な記事で、pickleで発生したエラーを解決するためにcloudpickleを紹介しました。

Pythonのバージョンを揃えるといった注意点はありますが、自分のPCで学習したモデルを使う場合には問題ないと思います。

Google ColaboratoryやGPU環境で学習したモデルを自分のPCで再利用…とかはもう少し調査をするか、素直にtorch.saveに慣れるのが良いかも知れません。

注意点に気をつければcloudpickleはすごく使いやすいです!
Twitterでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!

  • このエントリーをはてなブックマークに追加

SNSでもご購読できます。

コメントを残す

*