import torch
from torch import nn, optim
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import os
import glob
import cloudpickle
# GIFアニメーションを作成
def create_gif(in_dir, out_filename):
# ファイルパスをソートしてリストする
path_list = sorted(glob.glob(os.path.join(*[in_dir, '*'])))
imgs = []
# ファイルのフルパスからファイル名と拡張子を抽出
for i in range(len(path_list)):
img = Image.open(path_list[i])
imgs.append(img)
# appendした画像配列をGIFにする。durationで持続時間、loopでループ数を指定可能。
imgs[0].save(out_filename,
save_all=True, append_images=imgs[1:], optimize=False, duration=100, loop=0)
# 線形回帰ネットワークのclassをnn.Moduleの継承で定義
class Regression(nn.Module):
# コンストラクタ(インスタンス生成時の初期化)
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 32)
self.linear2 = nn.Linear(32, 32)
self.linear3 = nn.Linear(32, 32)
self.linear4 = nn.Linear(32, 32)
self.linear5 = nn.Linear(32, 16)
self.linear6 = nn.Linear(16, 1)
# メソッド(ネットワークをシーケンシャルに定義)
def forward(self, x):
x = nn.functional.relu(self.linear1(x))
x = nn.functional.relu(self.linear2(x))
x = nn.functional.relu(self.linear3(x))
x = nn.functional.relu(self.linear4(x))
x = nn.functional.relu(self.linear5(x))
x = self.linear6(x)
return x
# トレーニング関数
def train(model, optimizer, E, iteration, x, y):
# 学習ループ
losses = []
for i in range(iteration):
optimizer.zero_grad() # 勾配情報を0に初期化
y_pred = model(x) # 予測
loss = E(y_pred.reshape(y.shape), y) # 損失を計算(shapeを揃える)
loss.backward() # 勾配の計算
optimizer.step() # 勾配の更新
losses.append(loss.item()) # 損失値の蓄積
print('epoch=', i+1, 'loss=', loss)
#グラフ描画用にXY軸とグリッドを作成
X1 = np.arange(-512, 512, 16)
X2 = np.arange(-512, 512, 16)
X, Y = np.meshgrid(X1, X2)
# データをテンソルに変換(切片用の定数1も結合)
X2 = torch.from_numpy(X.ravel().astype(np.float32)).float()
Y2 = torch.from_numpy(Y.ravel().astype(np.float32)).float()
Input = torch.stack([torch.ones(len(X.ravel())), X2, Y2], 1)
# 100計算毎にプロットを保存
if (i + 1) % 100 == 0:
Z = test(model, Input).reshape(X.shape)
plot_3d(x.T[1], x.T[2], y, X, Y, Z, losses, 'out', i+1)
return model, losses
# テスト
def test(model, x):
y_pred = model(x).data.numpy()
return y_pred
# グラフ描画関数
def plot_3d(x1, x2, z, X, Y, Z, losses, dir, index):
# ここからグラフ描画-------------------------------------------------
# フォントの種類とサイズを設定する。
plt.rcParams['font.size'] = 14
plt.rcParams['font.family'] = 'Times New Roman'
# 目盛を内側にする。
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
# グラフの上下左右に目盛線を付ける。
fig = plt.figure(figsize=(9, 4))
ax1 = fig.add_subplot(121, projection='3d')
ax1.yaxis.set_ticks_position('both')
ax1.xaxis.set_ticks_position('both')
ax2 = fig.add_subplot(122)
ax2.yaxis.set_ticks_position('both')
ax2.xaxis.set_ticks_position('both')
# 軸のラベルを設定する。
ax1.set_xlabel('x1')
ax1.set_ylabel('x2')
ax1.set_zlabel('y')
ax2.set_xlabel('Iteration')
ax2.set_ylabel('E')
# スケール設定
ax1.set_xlim(-512, 512)
ax1.set_ylim(-512, 512)
ax1.set_zlim(-1000, 1000)
ax2.set_xlim(0, 10000)
ax2.set_ylim(1000, 100000)
ax2.set_yscale('log')
# データプロット
ax1.scatter3D(x1, x2, z, label='dataset')
ax1.plot_surface(X, Y, Z, cmap='jet')
ax2.plot(np.arange(0, len(losses), 1), losses)
ax2.scatter(len(losses), losses[len(losses) - 1], color='red')
ax2.text(600, 3000, 'Loss=' + str(round(losses[len(losses)-1], 2)), fontsize=16)
ax2.text(600, 8000, 'Iteration=' + str(round(len(losses), 1)), fontsize=16)
# グラフを表示する。
ax1.legend(bbox_to_anchor=(0, 1), loc='upper left')
fig.tight_layout()
# dirフォルダが無い時に新規作成
if os.path.exists(dir):
pass
else:
os.mkdir(dir)
# 画像保存パスを準備
path = os.path.join(*[dir, str("{:05}".format(index)) + '.png'])
# 画像を保存する
plt.savefig(path)
# plt.show()
if index == 2000:
plt.show()
plt.close()
plt.close()
# -------------------------------------------------------------------
# トレーニングデータ
x1 = np.random.uniform(-512, 512, 50)
x2 = np.random.uniform(-512, 512, 50)
grid_x, grid_y = np.meshgrid(x1, x2)
# ノイズを含んだ平面点列データを作成
# Eggholder function (https://qiita.com/tomitomi3/items/d4318bf7afbc1c835dda)
z = -(grid_y + 47) * np.sin(np.sqrt(np.abs(grid_y + (grid_x/2) + 47))) - grid_x * np.sin(np.abs(grid_x - (grid_y + 47)))
# テンソル変換とデータ整形
grid_x = torch.from_numpy(grid_x.ravel().astype(np.float32)).float()
grid_y = torch.from_numpy(grid_y.ravel().astype(np.float32)).float()
z = torch.from_numpy(z.astype(np.float32)).float()
X = torch.stack([torch.ones(len(grid_x)), grid_x, grid_y], 1)
# ネットワークのインスタンスを生成
net = Regression()
# 最適化アルゴリズム(RMSProp)と損失関数(MSE)を設定
optimizer = optim.RMSprop(net.parameters(), lr=0.005)
E = nn.MSELoss()
# トレーニング
net, losses = train(model=net, optimizer=optimizer, E=E, iteration=10000, x=X, y=z)
# ネットワークモデルをシリアライズしてファイルに保存
with open('model-B.pt', mode='wb') as f:
cloudpickle.dump(net, f)
# GIFアニメーションを作成する関数を実行する
create_gif(in_dir='out', out_filename='movie-B.gif')