ギークなエンジニアを目指す男

基幹系SIerがWeb系とかネイティブ系の知識を蓄えようとするブログ

MENU

k-最近傍法でアイリスのクラス分類問題を解く(python)

こんばんは。
本日は、機械学習の定番とも言える、アイリスの花のクラス分類問題をk-最近傍法を用いて解いてみようと思います。

実際、アイリスの花を分類したいというモチベーションがビジネス上役に立つかと問われると微妙ですが、学習だと割り切っていきましょう。(お花屋さんの方、すみません)

k-最近傍法とは

k-最近傍法とは、新しいデータポイントに対して、予測する際に新しい点に最も近い点を訓練セットから探し、新しい点に最も近かった点のラベルを予測データとする手法です。
(k-nearest neighbor algorithm という名前から、k-NNとも呼ばれます)
この手法では、近傍数というものをハイパーパラメータとして与えることができます。
近傍数を複数にした場合は、該当する点のラベルから多数決で予測ラベルを決定することになります。
言葉だけでの説明ではイメージしづらいと思いますので、いくつか図を紹介します。

下図は、近傍数 = 1 として予測を行なった結果です。
左上の星は●クラス
真ん中の星は▲クラス
右下の星は●クラス
に分類されていることが分かりますね。
(星は予測したいデータ)

f:id:taxa_program:20180627223442p:plain

では近傍数 = 3とした場合はどうでしょうか。
左上の星は▲クラス(●が1個, ▲が2個)
真ん中の星は▲クラス(●が1個, ▲が2個)
右下の星は●クラス(●が3個, ▲が0個)
に分類されていることが分かりますね。

f:id:taxa_program:20180627223716p:plain

左上の★について、近傍数が1のときと3の時で、予測値が変わっていることに気づきましたでしょうか?
このように、近傍数の数によって、予測値が異なる可能性もあります。

これで少しイメージついたでしょうか?

また、下記サイトで公開されているスライドも分かりやすいです。

www.slideshare.net

アイリスの花データを眺めてみる

では、早速アイリスの花データを眺めてみます。 とりあえず使いそうなライブラリをimportして、データセットを読み込みます。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
from IPython.display import display
from sklearn.datasets import load_iris
%matplotlib inline

# データセットの読み込み
iris_dataset = load_iris()

データセットの中身を見てみます。

# データセットのKeyを表示
print("Keys of iris_dataset: \n{}".format(iris_dataset.keys()))
# -> Keys of iris_dataset: 
# -> dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])

# target_namesには予測しようとしている花の種類が格納されている
print("Taget names: {}".format(iris_dataset['target_names']))
# -> Taget names: ['setosa' 'versicolor' 'virginica']

# データ本体はtargetとdataフィールドに格納されている
# 花のデータとしては、150個のサンプル数にそれぞれ4つづつの特徴量を持つ
print("Shape of data: {}".format(iris_dataset['data'].shape))
# -> Shape of data: (150, 4)

# 先頭10データの特徴量を表示
print("ten coloms of data: \n{}".format(iris_dataset['data'][:10]))
# -> ten coloms of data: 
# -> [[5.1 3.5 1.4 0.2]
# ->  [4.9 3.  1.4 0.2]
# ->  [4.7 3.2 1.3 0.2]
# ->  [4.6 3.1 1.5 0.2]
# ->  [5.  3.6 1.4 0.2]
# ->  [5.4 3.9 1.7 0.4]
# ->  [4.6 3.4 1.4 0.3]
# ->  [5.  3.4 1.5 0.2]
# ->  [4.4 2.9 1.4 0.2]
# ->  [4.9 3.1 1.5 0.1]]

# dataがそれぞれ何を表しているかは、feature_namesに格納されている。
# それぞれ、ガクの長さ、ガクの幅、花弁の長さ、花弁の幅がNumpy配列として格納されている。
print("Feature names: \n{}".format(iris_dataset['feature_names']))
# -> Feature names: 
# -> ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

データを可視化してみる

データの中身がわかったところで、今度は可視化を行なってみます。

# X_trainのデータからpandasのDataframeに変換する。
# iris_dataset.feature_namesの文字列を利用して、カラムに名前をつける
# 先でも確認したが、feature_namesには['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']が格納されている
iris_dataframe = pd.DataFrame(X_train, columns = iris_dataset.feature_names)
# Dataframe変換後の値を表示してみる。print関数より、display関数を用いることで見やすくなる。
# print(iris_dataframe[:10])
display(iris_dataframe[:10])

f:id:taxa_program:20180627234231p:plain

プロットもしてみます。
これを見る限り、データは比較的分離しているように見えるため、モデルの作成には問題ないことが確認できます。

# データフレームからscatter matrixを作成し、y_trainに従って色をつける(青、緑、赤はアイリスの種類を表している)
pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)

f:id:taxa_program:20180627234404p:plain

学習から予測

k-最近傍法のアルゴリズム、データセットの中身がわかったところで、 アイリスの花の分類を行なってみます。
下記がソースコードになります。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
from IPython.display import display
from sklearn.neighbors import KNeighborsClassifier # k-最近傍法関数のimport
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
 
# データのロード
iris_dataset = load_iris()
# データセットのKeyを表示
print("Keys of iris_dataset: \n{}".format(iris_dataset.keys()))
# -> Keys of iris_dataset: dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
 
# データセットを並び替えて学習データと検証データに分割する
# テストデータを25%とし、乱数シードを固定する
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25 ,random_state=0)
 
# インスタンス生成(パラメータは近傍数)
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)
 
# 精度を計算
acc = np.mean(y_pred == y_test)
print("Test set score: {:.2f}".format(acc))
# -> Test set score: 0.97

僅かこれだけのコーディングで97%もの精度がでます。

k-最近傍法の利点と欠点

k-最近傍法には、下記のように利点と欠点があります。

  • 利点
    モデルの理解しやすさ
    調整しなくても、そこそこの性能が出る

  • 欠点
    訓練セットが大きくなると(サンプル数が多くなると)予測が遅くなる
    多数の特徴量を扱うことができない

よって、高度な技術の利用を考える前に、k-最近傍法をベースラインとして試すときなどに有用なアルゴリズムとなる場合が多いです。

本日はここまで。