Pythonで実装する画像認識アルゴリズム SLIC 入門
2018.2.13
こんにちは。データサイエンスチーム tmtkです。
この記事では、SLIC (Simple Linear Iterative Clustering) を紹介します。紹介にあたって、私がPython 3で実装したものを使って解説していきます。
(今回処理する画像。choco.jpg
として保存)
SLICとは
SLIC (Simple Linear Iterative Clustering) とは、画像認識にかかわるアルゴリズムのひとつです。Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunkらが2010年に発明したようです(文献[1, 3])。
画像は画素(pixel)が縦横に並んでできています。superpixelは、距離的・色的に近い画素をひとまとまりにとらえたものです。画像認識の前処理としてsuperpixelを計算しておくことで、画像の情報量を上手に減らし、他の画像認識アルゴリズムが適用しやすくなるようです。
(superpixel化された画像)
SLICは画像を入力として、superpixelへの分割を出力とするアルゴリズムです。分割したいsuperpixelの数と画像を指定してSLICで処理すると、入力画像のsuperpixelへの分割を得ることができます。
superpixelの計算アルゴリズムの中で、SLICは領域境界への追従性、時間計算量、空間計算量などの点で他のアルゴリズムに負けず劣らず優れていることが文献[2]で示されています。
SLICのアルゴリズムの概略
SLICのアルゴリズムはk平均法を基にしており、それにいくつかのアイデアを加えて改良されています。SLICのアルゴリズムの特徴を列挙すると、以下のようになります。
- 画像を読み込み、グレースケールやRGBであらわされている値をLab色空間に変換する
- 座標にある色の画素の特徴量として、5次元ユークリッド空間上の点を使う
- 手順2.で得た画素の特徴量の空間に、k平均法の亜種を適用し、クラスタリングを行う
SLICのアルゴリズムの詳細
以下で、SLICのアルゴリズムの詳細を解説していきます。
実装に、以下のライブラリを使います。Pythonの画像処理ライブラリscikit-imageには既にSLICが実装されていますが、ここではscikit-imageのSLICは使わず、画像の読み込みとLab色空間への変換にのみscikit-imageを使うことにします。
import sys, math import numpy as np from skimage import io, color
SLICクラスを定義します。以下でメソッドを定義していきます。
class SLIC:
コンストラクタで各種パラメータを定義します。は計算するsuperpixel(クラスタ)の数、は画素の距離を計算するとき色成分に比べて近さ成分をどれだけ重視するかを決めるパラメータです。文献[2]ではで決めるよう書かれているので、デフォルトでとすることにします。
def __init__(self, k, m = 20): """ Constructor. k: the number of superpixels. m: a parameter to weigh the relative importance of spatial proximity. """ self.k = k self.m = m self.iter_max = 10 # c.f. the paper.
SLICのアルゴリズムは、大きく初期化と繰り返し処理の二つに分かれています。初期化処理をfit_init()
、繰り返し処理をfit_iter()
として、それぞれ実装していくことにします。
def fit(self, img_path): """ Calculate superpixels. Returns the mask array. """ self.fit_init(img_path) self.fit_iter() return self.l
初期化処理では、
- 画像をLab色空間に変換する
- 位置にある色の画素を座標の点とみなす
(コンピュータで処理する都合上、座標はではなくという逆転した順番になっています) - 個のクラスタの中心を等間隔に初期化する
(文献[2]では、物体のふちの部分にクラスタの中心点を置くことをさけるために、等間隔でおいたクラスタの中心の周囲3×3画素も見て、その中で勾配(文献[3]を参照)が一番小さい画素にクラスタの中心をおきなおすと説明されていますが、ここではその処理は省略しています) - 番目の点と最寄のクラスタの中心の距離をに初期化する
- クラスタの直径の近似値(は画素数)を計算しておく
などの処理を行います。
def fit_init(self, img_path): """ Read the image from img_path, convert to Lab color space, and initialize cluster centers. """ img_rgb = io.imread(img_path) if img_rgb.ndim != 3 or img_rgb.shape[2] != 3: raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape)) img_lab = color.rgb2lab(img_rgb) self.height = img_lab.shape[0] self.width = img_lab.shape[1] self.pixels = [] for h in range(self.height): for w in range(self.width): self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w])) self.size = len(self.pixels) # Initialize cluster centers to be regularly spaced. self.cluster_center = [] k_w = int(math.sqrt(self.k * self.width / self.height)) + 1 k_h = int(math.sqrt(self.k * self.height / self.width)) + 1 for h_cnt in range(k_h): h = (2 * h_cnt + 1) * self.height // (2 * k_h) for w_cnt in range(k_w): w = (2 * w_cnt + 1) * self.width // (2 * k_w) self.cluster_center.append(self.pixels[h*self.width + w]) self.k = k_w*k_h self.l = [None] * self.size # The cluster labels self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers self.metric = np.diagflat([1/(self.m**2)]*3 + [1/(self.S**2)]*2)
繰り返し処理では、k平均法の類似を行います。
各クラスタの中心ごとに、番目の点との距離を計算します。その際、中心からの差が以下(つまり)の点のみについて距離を計算します。そして、その距離が記録されている最小距離より小さかった場合、番目の点は番目のクラスタに所属する、と更新します。おおむね通常のk平均法と同じですが、クラスタの中心との距離を計算する点が、クラスタの中心の位置と画素の位置が近い点のみに絞られていることが特徴です。
所属するクラスタを更新したら、クラスタの中心も更新します(calc_new_center()
)。
また、距離の定義も通常のユークリッド距離とは別の距離を用います(distance()
)。点と点の距離は、と定義します。これは、色の尺度と位置の尺度のスケールが異なるためです。
文献[2]によると、このk平均法の類似の繰り返し回数は、10回程度で十分であることが経験的に知られているそうです。
def fit_iter(self): """ Iteration step. """ for iter_cnt in range(self.iter_max): for center_idx, center in enumerate(self.cluster_center): for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)): for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)): d = self.distance(self.pixels[h*self.width + w], center) if d < self.d[h*self.width + w]: self.d[h*self.width + w] = d self.l[h*self.width + w] = center_idx self.calc_new_center() def distance(self, x, y): """ Squared distance between x and y. """ return (x-y).dot(self.metric).dot(x-y) def calc_new_center(self): """ Caluclate new cluster centers. """ cnt = [0] * self.k new_cluster_center = [np.array([0., 0., 0., 0. ,0.]) for _ in range(self.k)] for i in range(self.size): new_cluster_center[self.l[i]] += self.pixels[i] cnt[self.l[i]] += 1 for i in range(self.k): new_cluster_center[i] /= cnt[i] self.cluster_center = new_cluster_center
ここまでの実装で、 SLIC(k=100).fit("choco.jpg")
などとすると座標ごとに何番目のsuperpixelに所属するのかのラベルの配列が計算できるようになりました。
superpixelごとにsuperpixelに属する画素の色の平均を計算し、RGBに変換してから返すメソッドも実装します。
def transform(self): """ Returns new image RGB ndarray """ cnt = [0] * self.k cluster_color = [np.array([0., 0., 0.]) for _ in range(self.k)] for i in range(self.size): cluster_color[self.l[i]] += self.pixels[i][:3] cnt[self.l[i]] += 1 for i in range(self.k): cluster_color[i] /= cnt[i] new_img_lab = np.zeros((self.height, self.width, 3)) for h in range(self.height): for w in range(self.width): new_img_lab[h][w] = cluster_color[self.l[h*self.width + w]] return color.lab2rgb(new_img_lab)
ここまでの実装をchoco.jpg
にかけ、変換してみます。私の環境で60秒くらいかかります。
slic = SLIC(k = 100) slic.fit("choco.jpg") res = slic.transform() io.imshow(res)
写っている物体の境界にうまく追従していることがわかると思います。
superpixelを連結にするため、文献[1]の実装では、孤立したsuperpixelを近くの大きなsuperpixelに併合する後処理が追加されていますが、この記事では省略します。scikit-imageで実装されているSLICによる結果は以下のとおりです。こちらはCythonで実装されているため、1秒程度で終わります。
from skimage import io, segmentation, color img = io.imread("choco.jpg") label = segmentation.slic(img, compactness=20) out = color.label2rgb(label, img, kind = 'avg') io.imsave("lena_skimage.png", out)
scikit-imageでは孤立したsuperpixelを併合する後処理が追加されているため、イチゴの部分などに孤立したsuperpixelがありません。
さらなる話題
パラメータを自動的に決めるSLICOという手法もあります(文献[1, 2])。
まとめ
画像認識の前処理に使われるsuperpixelを計算するアルゴリズムのひとつであるSLICの紹介・解説をしました。SLICはk平均法を応用したアルゴリズムです。
SLIC
クラスの全コード
""" SLIC implementation in Python 3 """ import sys, math import numpy as np from skimage import io, color class SLIC: def __init__(self, k, m = 20): """ Constructor. k: the number of superpixels. m: a parameter to weigh the relative importance of spatial proximity. """ self.k = k self.m = m self.iter_max = 10 # c.f. the paper. def fit(self, img_path): """ Calculate superpixels. Returns the mask array. """ self.fit_init(img_path) self.fit_iter() return self.l def fit_init(self, img_path): """ Read the image from img_path, convert to Lab color space, and initialize cluster centers. """ img_rgb = io.imread(img_path) if img_rgb.ndim != 3 or img_rgb.shape[2] != 3: raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape)) img_lab = color.rgb2lab(img_rgb) self.height = img_lab.shape[0] self.width = img_lab.shape[1] self.pixels = [] for h in range(self.height): for w in range(self.width): self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w])) self.size = len(self.pixels) # Initialize cluster centers to be regularly spaced. self.cluster_center = [] k_w = int(math.sqrt(self.k * self.width / self.height)) + 1 k_h = int(math.sqrt(self.k * self.height / self.width)) + 1 for h_cnt in range(k_h): h = (2 * h_cnt + 1) * self.height // (2 * k_h) for w_cnt in range(k_w): w = (2 * w_cnt + 1) * self.width // (2 * k_w) self.cluster_center.append(self.pixels[h*self.width + w]) self.k = k_w*k_h self.l = [None] * self.size # The cluster labels self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers self.metric = np.diagflat([1/(self.m**2)]*3 + [1/(self.S**2)]*2) def fit_iter(self): """ Iteration step. """ for iter_cnt in range(self.iter_max): for center_idx, center in enumerate(self.cluster_center): for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)): for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)): d = self.distance(self.pixels[h*self.width + w], center) if d < self.d[h*self.width + w]: self.d[h*self.width + w] = d self.l[h*self.width + w] = center_idx self.calc_new_center() def distance(self, x, y): return (x-y).dot(self.metric).dot(x-y) self.iter_max = 10 # c.f. the paper. def fit(self, img_path): """ Calculate superpixels. Returns the mask array. """ self.fit_init(img_path) self.fit_iter() return self.l def fit_init(self, img_path): """ Read the image from img_path, convert to Lab color space, and initialize cluster centers. """ img_rgb = io.imread(img_path) if img_rgb.ndim != 3 or img_rgb.shape[2] != 3: raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape)) img_lab = color.rgb2lab(img_rgb) self.height = img_lab.shape[0] self.width = img_lab.shape[1] self.pixels = [] for h in range(self.height): for w in range(self.width): self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w])) self.size = len(self.pixels) # Initialize cluster centers to be regularly spaced. self.cluster_center = [] k_w = int(math.sqrt(self.k * self.width / self.height)) + 1 k_h = int(math.sqrt(self.k * self.height / self.width)) + 1 for h_cnt in range(k_h): h = (2 * h_cnt + 1) * self.height // (2 * k_h) for w_cnt in range(k_w): w = (2 * w_cnt + 1) * self.width // (2 * k_w) self.cluster_center.append(self.pixels[h*self.width + w]) self.k = k_w*k_h self.l = [None] * self.size # The cluster labels self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers self.metric = np.diagflat([1/(self.m**2)]*3 + [1/(self.S**2)]*2) def fit_iter(self): """ Iteration step. """ for iter_cnt in range(self.iter_max): for center_idx, center in enumerate(self.cluster_center): for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)): for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)): d = self.distance(self.pixels[h*self.width + w], center) if d < self.d[h*self.width + w]: self.d[h*self.width + w] = d self.l[h*self.width + w] = center_idx self.calc_new_center() def distance(self, x, y): """ Squared distance between x and y. """ return (x-y).dot(self.metric).dot(x-y) def calc_new_center(self): """ Caluclate new cluster centers. """ cnt = [0] * self.k new_cluster_center = [np.array([0., 0., 0., 0. ,0.]) for _ in range(self.k)] for i in range(self.size): new_cluster_center[self.l[i]] += self.pixels[i] cnt[self.l[i]] += 1 for i in range(self.k): new_cluster_center[i] /= cnt[i] self.cluster_center = new_cluster_center def transform(self): """ Returns new image RGB ndarray """ cnt = [0] * self.k cluster_color = [np.array([0., 0., 0.]) for _ in range(self.k)] for i in range(self.size): cluster_color[self.l[i]] += self.pixels[i][:3] cnt[self.l[i]] += 1 for i in range(self.k): cluster_color[i] /= cnt[i] new_img_lab = np.zeros((self.height, self.width, 3)) for h in range(self.height): for w in range(self.width): new_img_lab[h][w] = cluster_color[self.l[h*self.width + w]] return color.lab2rgb(new_img_lab)
文献
- Superpixel segmentation | IVRL
- Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to State-of-the-art Superpixel Methods, IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 34, num. 11, p. 2274 – 2282, May 2012.
- Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels, EPFL Technical Report no. 149300, June 2010.
- k平均法 – Wikipedia
- Lab色空間 – Wikipedia
- バレンタインのチョコレートケーキを焼く女性|ぱくたそフリー素材
- scikit-image: Image processing in Python — scikit-image
- Normalized Cut — skimage v0.14dev docs
テックブログ新着情報のほか、AWSやGoogle Cloudに関するお役立ち情報を配信中!
Follow @twitterデータ分析と機械学習とソフトウェア開発をしています。 アルゴリズムとデータ構造が好きです。
Recommends
こちらもおすすめ
-
PythonやR言語で相関係数を計算する方法
2018.2.20
-
相関係数は外れ値の影響をうけやすい?Pythonで確認してみた。
2018.2.28
-
基礎からはじめる時系列解析入門
2019.2.22
Special Topics
注目記事はこちら
データ分析入門
これから始めるBigQuery基礎知識
2024.02.28
AWSの料金が 10 %割引になる!
『AWSの請求代行リセールサービス』
2024.07.16