この記事では、Scikit-learnのデータセットの種類と呼び出し(読み込み)方について紹介します。
データセットの種類
Scikit-learnには、動作テストに便利なデータセットがいくつかあります。
データセットの内容は以下の通りです。
– | データの内容 | 予測対象 |
---|---|---|
load_iris | 3種類のアヤメのがく片、花弁の幅および長さ | 分類 |
load_diabetes | 糖尿病患者の検査数値と1年後の疾患進行状況 | 回帰 |
load_digits | 0~9の手書き文字画像(8×8) | 分類 |
load_boston | 米国ボストン市郊外における地域別の住宅価格 | 回帰 |
load_linnerud | 成人男性の生理学的特徴と運動能力 | 回帰 |
load_wine | 3種類のワインの科学的特徴 | 分類 |
load_breast_cancer | 乳がんの診断結果 | 分類 |
データセットの呼び出し
例として、アヤメのデータセットを呼び出してみます。
from matplotlib import pyplot as plt from sklearn import datasets # データ・セット def main(): # Iris のデータを呼び出す iris = datasets.load_iris() X = iris.data[:, :2] # 最初の二次元のみの特徴量を抽出 Y = iris.target # 目標値(正解データ) # グラフの軸幅 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 # 可視化のベースを作成 plt.figure(2, figsize=(8, 6)) plt.clf() # 実際にプロット plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired) plt.xlabel('Sepal length') plt.ylabel('Sepal width') plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) plt.grid() plt.show() if __name__ == "__main__": main()
関連記事
– | 関連記事 |
---|---|
1 | Scikit-learnをインストールする方法 |
2 | Scikit-learn入門・使い方 |
3 | 機械学習のアルゴリズム入門 |
コメント