
機械学習を活用したアプリ開発を学ぶ第一歩として、PyTorchを使った手書き文字認識(MNIST)に挑戦します。本記事では、モデルの学習から検証までのコードを習得しましょう。MNISTの学習ができるようになれば、自分のデータで学習モデルを作成することもできるようになります。
こんにちは。wat(@watlablog)です。今回はN番煎じ的なMNISTデータセットを使ったCNNによる機械学習をやってみます!
はじめに
なぜいまさら手書き文字の学習?
生成AI群雄割拠の2025年にもなって手書き文字の学習を行うのは、正直イマサラ感があります。筆者watも別途音声認識の深層学習モデルを構築するといったことは経験していました。ただ、深層学習を活用したコードは書けるようになっても、アプリとして世に公開することは未経験です。
機械学習モデルの構築からフロントエンドの作成、デプロイまでの一連の流れを掴むために、最も基礎的な手書き文字認識をまずは記事化しようという意図でこの記事を書いています。
データセットが既に用意されている問題を例題にすれば、ブログにも書きやすいですしね!
MNISTとは?
MNISTとは、Modified National Institute of Standards and Technologyの略で、米国国勢調査局の職員や高校生が手で書いた数字をまとめたデータセットです。60000枚の訓練用画像と10000枚の評価用画像に分かれています。
データの例を示します。下図のように、正解ラベルと画像が対になった構成です。

この記事の範囲
本記事ではMNISTデータセットをダウンロードし、学習モデルを構築するコードをまず紹介します。そして、機械学習モデルを使って推論(分類)するコードも紹介します。アプリ化やデプロイはまた別の記事です。
動作環境
この記事では深層学習を扱いますが学習はCPUで行います。こちらで動作を確認しているPC環境とPython環境を以下に示します。
Mac | OS | macOS Sonoma 14.3 |
---|---|---|
チップ | Apple M3 | |
CPU | 1.4[GHz] | |
メモリ | 16[GB] |
Python | Python 3.12.3 |
---|---|
torch | 2.6.0 |
matplotlib | 3.10.1 |
torchvision | 0.21.0 |
Pythonコード:学習
学習はPyTorchで行います。次のコードを実行するとデータセットのダウンロードが始まり、学習まで完了します。CPUによる学習をしていますが、この程度のネットワークであれば上記PCスペックで1分もかからないくらいの学習時間です。
コードの中身解説
正規化
transforms.Normalize((0.1307,), (0.3081,)) の0.1307と0.3081はMNIST手書き数字データセットの平均値と標準偏差です。以下のページが参考になります。
https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
CNNモデルの概要:畳み込み層とプーリング層
このコードの機械学習は深層学習の一種であるCNN(Convolutional Neural Network:畳み込みニューラルネットワーク)を使っています。ここでは各層の内容を簡単に紹介します。
ディープニューラルネットワークは class simpleCNNで定義しています。 nn.Sequentialで層を1層ずつ記述します。ここで nn.Conv2d(1, 16, kernel_size=3, padding=1)は最初のMNISTデータセットの画像(1チャネル)を16個のカーネルを使った畳み込み演算から特徴マップを抽出する部分です。
次の画像は実際に学習過程の特徴マップを拾ってきたものですが、1チャネルの画像データから16個の特徴マップを抽出している様子がわかります。 padding=1は画像の周りに1ピクセル分のデータを追加するという意味で、この段階のアウトプットは28×28ピクセルの画像になります。カーネルは学習パラメータであり、学習が進んでいくとカーネルの数値が変わり、より特徴を捉えるようなものになっていきます。

nn.ReLU()は活性化関数です。この活性化関数は負の値を0にして非線形性を導入することができ、ネットワークの複雑な学習を可能にします。ReLUを通した後の特徴マップがこちらです。

nn.MaxPool2d(2, 2)は最大値プーリング層です。2×2のウィンドウでプーリングを行い、ウィンドウ領域内の最大値を抽出して空間次元を削減します。この層を通ると特徴マップのサイズが14×14になります(以下図)。

nn.Conv2d(16, 32, kernel_size=3, padding=1)で再び畳み込み演算を行います。この段階で特徴マップは16個ありますが、それを32個にします。最初の畳み込み演算と特徴マップのサイズが異なっている(小さくなっている)ことも変化点です。

その後再びReLUを通し、非線形性を追加します。

nn.MaxPool2d(2, 2)で再度最大値プーリングを行い、7×7まで空間次元を削減します。

これらCNNの概要、用語は過去記事(【G検定の学習】ディープラーニングの概要と具体的な手法)に記載していますので参照ください。
CNNモデルの概要:全結合層
self.fc_layersは全結合層です。
nn.Linear(32 * 7 * 7, 128)でこれまでの32チャネル分の7×7サイズのデータを平坦化(32×7×7=1568次元のベクトル化)し、さらに128次元に削減します。
nn.ReLU を通し非線形性を追加し、
nn.Linear(128, 10)で128次元を10次元に削減します。この最後の10次元データがそれぞれ0〜9の10個のクラスに相当します。
損失関数:交差エントロピー損失
学習にはデータと正解との誤差(損失)を計算する損失関数が必要です。ここでは入力のロジットとターゲットの交差エントロピー損失を
nn.CrossEntropyLossで計算しています。式は次の形式で表現され、
・PyTorch公式ドキュメント:CrossEntropyLoss
・Qiita:PyTorchのCrossEntropyLossクラスについて
ロジットについては当WATLABブログで扱ったロジスティック回帰の記事が参考になるかもしれません。
最適化手法:Adam
学習の最適化にはAdamオプティマイザを使用しました。この最適化は optim.Adam(model.parameters(), lr=0.001)で与えられます。Adamオプティマイザは確率的勾配降下法(SGD)に加え、各パラメータごとに学習率を自動調整する手法です。
エポック
epochs = 5と書いてあるように、このループではエポックは5回です。エポックとは学習データ全体を1回すべてネットワークに通すことを意味します。各エポックごとに学習( train関数)と評価( test関数)が実行されます。
実行結果
実行結果の例を次に示します。シンプルなCNNモデルで10000個のテストデータに対して99.13%の精度となりました。
Pythonコード:推論
カメラを使って画像処理するコードは次回にするとして、ここでは学習済モデルを使って読み込んだ手書き文字がちゃんと分類されるか推論するコードも書いてみましょう。次のコードはMNISTデータセットから1枚だけ画像を抽出し、その画像をCNNのネットワークに通します。画像を matplotlibでプロットし、画像の右上に推論結果を重ね書きするというものです。
コードの中身解説
モデルの読み込み
model.load_state_dict(torch.load("mnist_cnn.pth", map_location=device))がモデルを読み込んでいる部分です。この推論コードにもCNNのモデル定義をしていますが、これは学習時に torch.save(model.state_dict(), "mnist_cnn.pth")でパラメータだけを保存( .state_dict())しているからです。こうすることでPC環境の違いによって問題になる再現性を確保しやすくなり、さらにモデルも軽量になります。
実行結果の例
こちらが分類結果です。コードを実行する度にランダムにデータが選ばれて、右上に分類結果が描画されます。

シンプルなCNNでもちゃんと分類できていますね!
まとめ
この記事ではイマサラ感のあるMNIST手書き文字認識をやってみました。イマサラやった理由としては、この後機械学習モデルを使ったアプリの構築方法を学ぶ良い例題だと思ったからです。
まだこのブログではCNNの事例を書いたことが無かったので、コンテンツを充実させることもできました。
CNNは非常に汎用性のあるネットワークです。画像やカメラの映像だけでなく、音声も画像情報に変換することで音声認識も可能になります。この記事ではネットワークの定義、損失関数やオプティマイザの適用、データのローディング方法の例を紹介しました。是非みなさんもモデルをカスタマイズして自分の用途にご活用ください。
PyTorchによるCNNの学習ができました!
Xでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!