この記事では、Pythonと機械学習ライブラリ「scikit-learn」を用いて、ニューラルネットワーク(NN)で学習したモデルをファイルに出力し、保存する方法とソースコードを解説します。
ニューラルネットワークとは
前回までは、Python + scikit-learnでニューラルネットワーク(パーセプトロン方式)を実装し、学習・予測・識別率の計算を行いました。
| – | 前回までの記事 |
|---|---|
| 1 | 【Scikit-learn】ニューラルネットワークで学習・予測 |
| 2 | 【Scikit-learn】ニューラルネットワークの識別率を計算 |
| 3 | ニューラルネットワークの原理・計算式・特徴 |
今回は学習データをファイル出力(保存)してみます。
書式
sklearn.externals.joblib.dump(clf, filepath)
| パラメータ | 説明 |
|---|---|
| clf | 学習データ |
| filepath | 出力先のファイルパス |
ソースコード
サンプルプログラムのソースコードは下記の通りです。
# -*- coding: utf-8 -*-
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.externals import joblib
def main():
# データを取得
data = pd.read_csv("data.csv", sep=",")
# ニューラルネットで学習
clf = MLPClassifier(solver="sgd",random_state=0,max_iter=10000)
# 学習(説明変数x1, x2、目的変数x3)
clf.fit(data[['x1', 'x2']], data['x3'])
# 学習データを元に説明変数x1, x2から目的変数x3を予測
pred = clf.predict(data[['x1', 'x2']])
# 結果表示
print (pred)
joblib.dump(clf, 'nn.learn')
if __name__ == "__main__":
main()
data.csv
x1,x2,x3 45,17.5,30 38,17.0,25 41,18.5,20 34,18.5,30 59,16.0,45 47,19.0,35 35,19.5,25 43,16.0,35 54,18.0,35 52,19.0,40
実行結果
サンプルプログラムの実行結果は下記の通りです。
【学習ファイル】
・nn.learn
| – | 関連記事 |
|---|---|
| 1 | Scikit-learn入門・使い方 |
| 2 | Scikit-learnをインストールする方法 |
| 3 | Python入門 基本文法 |

コメント