この記事では、Pythonと「scikit-learn」を用いて、ニューラルネットワーク(NN)で学習したモデルのファイルを読み込み、使用する方法とソースコードを解説します。
ニューラルネットワークとは
前回までは、Python + scikit-learnでニューラルネットワーク(パーセプトロン方式)を実装し、学習・予測・識別率の計算・ファイル出力を行いました。
| – | 前回までの記事 |
|---|---|
| 1 | 【Scikit-learn】ニューラルネットワークで学習・予測 |
| 2 | 【Scikit-learn】ニューラルネットワークの識別率を計算 |
| 3 | 【Scikit-learn】ニューラルネットワーク学習モデルのファイル出力・保存 |
| 4 | ニューラルネットワークの原理・計算式・特徴 |
今回は、ファイル出力した学習結果(CSVファイル)を読み込んで再度学習・予測させてみました。
書式
sklearn.externals.joblib.load
ソースコード
サンプルプログラムのソースコードです。
# -*- coding: utf-8 -*-
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.externals import joblib
def main():
# データを取得
data = pd.read_csv("data.csv", sep=",")
# ニューラルネットで学習モデルを読み込み(復元)
clf = joblib.load('nn.learn')
# 学習データを元に説明変数x1, x2から目的変数x3を予測
pred = clf.predict(data[['x1', 'x2']])
# 識別率を表示
print (sum(pred == data['x3']) / len(data[['x1', 'x2']]))
if __name__ == "__main__":
main()
| ファイル名 | 読み込んだファイル |
|---|---|
| data.csv | data.csv |
| 学習ファイル | nn.learn |
実行結果
サンプルプログラムの実行結果です。
0.4
| – | 関連記事 |
|---|---|
| 1 | Scikit-learn入門・使い方 |
| 2 | Scikit-learnをインストールする方法 |
| 3 | Python入門 基本文法 |

コメント