この記事では、Scikit-learnのデータセットの種類と呼び出し(読み込み)方について紹介します。
データセットの種類
Scikit-learnには、動作テストに便利なデータセットがいくつかあります。
データセットの内容は以下の通りです。
| – | データの内容 | 予測対象 |
|---|---|---|
| load_iris | 3種類のアヤメのがく片、花弁の幅および長さ | 分類 |
| load_diabetes | 糖尿病患者の検査数値と1年後の疾患進行状況 | 回帰 |
| load_digits | 0~9の手書き文字画像(8×8) | 分類 |
| load_boston | 米国ボストン市郊外における地域別の住宅価格 | 回帰 |
| load_linnerud | 成人男性の生理学的特徴と運動能力 | 回帰 |
| load_wine | 3種類のワインの科学的特徴 | 分類 |
| load_breast_cancer | 乳がんの診断結果 | 分類 |
データセットの呼び出し
例として、アヤメのデータセットを呼び出してみます。
from matplotlib import pyplot as plt
from sklearn import datasets # データ・セット
def main():
# Iris のデータを呼び出す
iris = datasets.load_iris()
X = iris.data[:, :2] # 最初の二次元のみの特徴量を抽出
Y = iris.target # 目標値(正解データ)
# グラフの軸幅
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
# 可視化のベースを作成
plt.figure(2, figsize=(8, 6))
plt.clf()
# 実際にプロット
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.grid()
plt.show()
if __name__ == "__main__":
main()
関連記事
| – | 関連記事 |
|---|---|
| 1 | Scikit-learnをインストールする方法 |
| 2 | Scikit-learn入門・使い方 |
| 3 | 機械学習のアルゴリズム入門 |

コメント