【Scikit-learn】ニューラルネットワーク学習モデルを読み込む(インポート)

この記事では、Pythonと「scikit-learn」を用いて、ニューラルネットワーク(NN)で学習したモデルのファイルを読み込み、使用する方法とソースコードを解説します。

ニューラルネットワークとは

前回までは、Python + scikit-learnでニューラルネットワーク(パーセプトロン方式)を実装し、学習・予測・識別率の計算・ファイル出力を行いました。

前回までの記事
1 【Scikit-learn】ニューラルネットワークで学習・予測
2 【Scikit-learn】ニューラルネットワークの識別率を計算
3 【Scikit-learn】ニューラルネットワーク学習モデルのファイル出力・保存
4 ニューラルネットワークの原理・計算式・特徴

今回は、ファイル出力した学習結果(CSVファイル)を読み込んで再度学習・予測させてみました。

書式

sklearn.externals.joblib.load

ソースコード

サンプルプログラムのソースコードです。

# -*- coding: utf-8 -*-
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.externals import joblib

def main():
    # データを取得
    data = pd.read_csv("data.csv", sep=",")

    # ニューラルネットで学習モデルを読み込み(復元)
    clf = joblib.load('nn.learn')

    # 学習データを元に説明変数x1, x2から目的変数x3を予測
    pred = clf.predict(data[['x1', 'x2']])

    # 識別率を表示
    print (sum(pred == data['x3']) / len(data[['x1', 'x2']]))

if __name__ == "__main__":
    main()
ファイル名 読み込んだファイル
data.csv data.csv
学習ファイル nn.learn

実行結果

サンプルプログラムの実行結果です。

0.4 
関連記事
1 Scikit-learn入門・使い方
2 Scikit-learnをインストールする方法
3 Python入門 基本文法

コメント