PyTorchのネットワークモデルをクラスで書く時のメモ

  • このエントリーをはてなブックマークに追加

PythonのディープラーニングフレームワークであるPyTorchはネットワークモデルをモジュール化して使うとわかりやすいコードになります。ここでは初心者向けにクラスの使い方と、簡単な線形ネットワークを例にモジュール化の方法を紹介します。

こんにちは。wat(@watlablog)です。ここではプログラミング初心者向けにクラスを使ったPyTorchネットワークモデルのモジュール化を説明します

本記事の対象者

他の人の書いたPyTorchコードを読みたい人…

ニューラルネットワーク、機械学習関係の参考コードはGitHub等でよく公開されていますが、プロの方は多くがネットワークをモジュール化(クラス化)した書き方でコードを書いていらっしゃいます。

有名モデル(VGG, ResNetとか)を始めとした他の人のネットワークモデルを参考にする時に、そもそもクラスベースで書いてあると、読み方を知らないともう退場するしかありません。

そのため、本記事は他の人の書いたPyTorchコードを読みたい人向けです。

Pythonのクラス記述を遠ざけて来た人…

ここではそもそもPythonのクラスの書き方が微妙、できれば使いたくない…という方を対象にしています。

かく言う僕もできればブログではクラス表記を避けたいとずっと思っていました。

当WATLABブログの記事はTips的な内容なので、そもそもクラス設計をする程のメリットは無く、クラス表記するとかえって初心者が記事を去ってしまうという懸念がありました。

大抵の内容はdef文があれば綺麗に書けますし、ちょっとしたコードであればそれで可読性も十分でしょう。

クラスのメリットは大規模なプログラムになった時に変数の数を減らす事ができたり、関数をまとめておく事でわかりやすくなったりする所にあります。

Pythonのクラスやオブジェクト指向の考え方、書き方については「Pythonのクラスの使い方とオブジェクト指向の考え方を理解する」をご覧下さい。

動作環境

このページのプログラムは以下のPC&Python環境で動作検証を行なっています(環境変更が面倒でずっとアップデートしていませんね…)。

Windows OS Windows10 64bit
CPU 2.4[GHz]
メモリ 4[GB]
Mac OS macOS Catalina 10.15.7
CPU 1.4[GHz]
メモリ 8[GB]
Python Python 3.7.7
PyCharm (IDE) PyCharm CE 2020.1
PyTorch torch==1.5.1

PyTorchの線形回帰ネットワークをモジュール化するコード

それではPyTorchネットワークのモジュール化(クラスで書く)について、順を追って説明していきます。

classを使わないで書いたコードとして「PyTorchのネットワークモデルを使って線形回帰をする方法」の物を題材とします。

import文

今回の例では以下のimportを行います。torch以外に個別にnnとoptimを書いているのは単純に毎回torch.nn等と書かなくて良くなるというだけです。

nn.Moduleを継承したclassを作成

今回、まずはLinearRegression()というclassを作ります。クラスは単語の区切りを大文字で分け、関数の場合は小文字とアンダーバーで分けるのが一般的な命名規則のようです。以下にクラス部分のコードを示します。

classの引数にnn.Moduleをとる事で、継承を行います。これでPyTorchのネットワークモデル全てを利用可能になります。

次にコンストラクタを設定します。クラスは定義したままではただの設計図であり、インスタンス(実体)を生成して使います。コンストラクタにはインスタンス生成時に最初に実行されるコードを書きます。

継承は色々なやり方がありますが、Python3ではsuper().__init__()を使って継承する事が推奨されています。僕を含めクラス初心者にこの書き方はかなり特殊と感じてしまうと思いますが、今はおまじないのように書いておきます。

最後にメソッドを設定していきます。メソッドにはforward()を設定し、ここにネットワークモデルをシーケンシャル(順番に実行されるように)に登録していきます。

○○.forward()とこのメソッドを実行する事でニューラルネットワークが順伝播するイメージです。

この例では線形回帰ネットワーク1つで説明するのでこれだけですが、このforwardメソッドに活性化関数や中間層を追加していきディープニューラルネットワークを作るのがディープラーニングのモデルとなります。

トレーニング用の関数を作成

次にトレーニング用の関数を定義します。内容は「PyTorchのネットワークモデルを使って線形回帰をする方法」で書いた物と同一です。

グラフ確認まで含めた全コード(コピペ用)

コピペで動作できるようにグラフ確認まで含めた全コードを以下に示します。

ネットワークをクラスでモジュール化していますが、トレーニング(学習)は別途関数で作成、最適化アルゴリズムや損失関数は関数の外で設定して引数として設定値を渡すようにしました。

これらをクラスに含めた方が良いのか、それとも外で設定した方が良いのかはまだわかっていません(書籍等でもこのように作っているようなので真似しました)。

また、前回の記事では回帰係数(model.parameter)から回帰直線を描いていましたが、学習したネットワークモデルは別途テストデータを使って出力を得るという使い方が機械学習の基本と思うため、今回は回帰係数を使わずにネットワークモデルにデータを渡して直線を描画するという内容にしました。

以下がコード実行結果です。モジュール化したネットワークモデルでも損失がしっかりと低下し、ネットワークに通したデータは見事回帰直線を描きました。

回帰結果

まとめ

PyTorchのネットワークモデルを使って線形回帰をする方法」で初歩的なPyTorchによる回帰の方法を学び、本記事でネットワークのモジュール化を行いました。

モジュール化はclassでnn.Moduleの継承を行いましたが、この方法であればネットワークの層構造がforwardメソッドにまとまるので可読性が上がると考えられます。

まだ線形回帰しかやっていませんが、仕組みを学ぶのにはちょうど良いレベルと思っています。

線形ネットワークを例題にPyTorchネットワークモデルのモジュール化を行いました!久々にクラスを使った感覚です!
Twitterでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!

SNSでもご購読できます。

コメント

コメントを残す

*