こんにちわ〜。シュです。今回の記事では、データ駆動型AIアプリをすばやく作成する方法を紹介します。 完全な例を使用して、簡単なWebアプリを作成し、学習済みの深層学習モデルを呼び出す方法をご紹介します。
1. データサイエンスをやるのになぜアプリを作る必要があるの?
ほとんどのデータサイエンティスト、または機械学習エンジニアは、アルゴリズムを研究し、ビジネスデータでそれらを応用するなど、アルゴリズムかビジネスシーンに焦点を当てています。 通常、結果報告もしくわ納品時に、ダッシュボードを作成するか、レポートに直接PPTを使用します。 しかし、近年、特に2019年以降、多くの企業が研究レベルのデータサイエンスの仕事に段々満足できなくなり、実際の問題を解決するために実用的なアプリが求められています。 さらに、お客様に報告したりプレゼンテーションを行ったりする場合、シンプルなアプリが自分のソリューションをより適切に説明するのに役立つツールの1つです。
2. では、どういうアプリを作るの?
今豊富なオープンソースフレームワークがありますので、多くの選択肢があります。
- Jupyter Notebookはデータサイエンス分野でよく使われるツールです。私は昔Jupyter Notebookをアプリのように使えるようにパッケージ化してみましたが、Notebookの管理は非常に難しく、特にソフトウェアエンジニアリングの問題が発生する時の対応は厄介です。
- 最先端のデータサイエンスプラットフォームを使うのが一つの選択肢です。Dataiku,Databricksなどのプラットフォームを使うことで、モデリングからアプリケーションまで一括で開発することができて、非常に便利です。但し、基本的にこれらのプラットフォームはサブスクリプションモデルを使っているので、無料なものは存在しません。
- 最後にWebフレークワークを利用し簡単なウェブアプリを作る方法があります。少しソフトウェアエンジニアリングの知識が必要になりますが、無料ですし、深くエンジニアリングのことを理解しなくても、簡単にアプリを作れます。アプリのデプロイも簡単にできますから、スマホアプリとデスクトップアプリに比べると、ウェブアプリの方がPrototypeを作るのにベターな選択肢ではないかと思います。
ここから,PythonのライトウェイのウェブフレームワークFlaskを使って簡単なリアルタイム画像異常検出アプリを作ってみたいと思います。
3. 今回着手するウェブアプリの概要
アプリの概要は以下の通り:

これは、ディープラーニング技術を使用して異常検出を実現するWebアプリです。工場の生産ラインでは、製造された部品が正常に機能するかどうかを監視する必要があります。 このアプリは、生産ライン上にあるサーキュレーターの作動を監視し、異常動作(ファンが異常停止など)を検出し、結果をブラウザ画面にリアルタイムで表示します。
技術的な観点から説明すると、Apache Kafkaを使用して、カメラから受信したリアルタイム画像データを処理します。 Producerのプログラムはリアルタイムの画像データ(ストリーミングデータ)をKafkaに保存し、ConsumerのプログラムはKafkaに保存されたデータを読み取り画像処理を行います。 ブラウザーでアプリ画面を表示すると、FlaskフレームワークはConsumerプログラムを呼び出して、異常検出結果をリアルタイムで表示します。
4. 実装
ハードウェア環境:
- Mac x 1
- カメラ x 1
ソフトウェア環境:
- 開発言語:Python3
- Pythonライブラリ: Flask, Kafka, OpenCV, Keras
環境の設定についてはcondaかpipを使うと簡単に設定できますので、ここで紹介しません。
4.1. Apache Kafkaについて
Kafkaは今回のアプリに置いて重要な部分です。但し、そんなにKafkaを深く理解しなくても大丈夫です。ProducerとConsumerのコンポーネントを理解するだけで十分です。
Apache Kafkaは、主にデータハブとストリーミング処理の2つの用途があります。 そして、ストリーミング処理はまさに今回私たちが必要とするものです。 Apache Kafkaは、パブリッシュ/サブスクライブ型のシステムです。 もともとはLinkedInによって提案され、後にオープンソースプロジェクトになりました。
Kafkaの重要な概念:
- Broker:Kafkaに送信されたデータはBrokerに保存されます。 複数のBrokerでKafkaクラスターを構築することができます。 各Brokerは、Apache ZooKeeperを介して通信します。
- Producer:ProducerはBrokerにデータを送信するために使用されるプログラムです。Kafkaの独自のProducerはJavaで記述されていますが、Pythonなどの他の言語で利用可能な多くのサードパーティのKafkaライブラリがあります。
- Consumer:Consumerは、Producerによって保存されたデータを要求するプログラムです。
- Topic:データが実際にBrokerに保管される場所です。 ConsumerがTopicをサブスクライブして情報を取得し、ProducerがTopicを公開します。
Kafkaの概念のみだとまだ抽象的ですが、アプリでは、カメラのリアルタイム画像を、ProducerプログラムによりBrokerのTopicに送信し保存します。そして、ConsumerはTopic内の未処理データをリアルタイムで読み取り、異常検出のために学習済みモデルを呼び出します。
実際に実装するのは、ProducerとConsumerのみです。その他はKafkaフレームワークが全て用意してくれています。
4.2 FlaskのRESTfulインタフェースについて
今回作るのはウェブアプリなので、RESTfulインターフェースを少し理解する必要があります。RESTfulインターフェースは、Webサービスの通信技術です。RESTfulはHTTPプロトコルを使用してデータを転送します。通常、一つのRESTfulインターフェースは一つのURLに対応していますので、アプリ起動後、URLをブラウザーに入力することにより、インターフェイスに実装したビジネスロジックを呼び出すことができます。Flaskでは、Python関数の前に@app.routeアノテーションを追加すると、RESTfulインターフェースとして定義されます。
4.3. ソースコード
ではProducerとConsumerの実装をみてみましょう。
Producerのコード:
import cv2
from kafka import KafkaProducer
topic = "distributed-video"
def publish_camera():
# Start up producer
producer = KafkaProducer(bootstrap_servers='localhost:9092')
camera = cv2.VideoCapture(0)
try:
while(True):
success, frame = camera.read()
ret, buffer = cv2.imencode('.jpg', frame)
producer.send(topic, buffer.tobytes())
# Choppier stream, reduced load on processor
time.sleep(0.2)
except:
print("\nExiting.")
sys.exit(1)
camera.release()
if __name__ == '__main__':
print("publishing feed!")
publish_camera()
コードは非常に簡潔です。localhost:9092でProducerのインスタンスを初期化します。次に、OpenCVを使用してカメラのインスタンスを初期化し、画像データをリアルタイムで読み取ります。次に、Producerのインスタンスを使用して、KafkaのBrokeに送信します。BrokeのTopic名は、distributed-videoと定義しています。
では、リアルタイムでKafkaのBrokerにデータを保存するProducerが完成したところで、次に、Brokerからデータを読み取り、異常検出処理を行うConsumerを実装します。
Consumerのコードは以下の通り:
from flask import Flask, Response
from kafka import KafkaConsumer
import numpy as np
import cv2
# Fire up the Kafka Consumer
topic = "distributed-video"
consumer = KafkaConsumer(
topic,
bootstrap_servers=['localhost:9092'])
# Set the consumer in a Flask App
app = Flask(__name__)
@app.route('/video', methods=['GET'])
def video():
return Response(
get_video_stream(),
mimetype='multipart/x-mixed-replace; boundary=frame')
get_video_stream()関数の実装は以下の通り:
from keras.models import model_from_json
from skimage.measure import compare_ssim
def get_video_stream():
ssim_score_list = []
# load keras model
model = model_from_json(open('./model/vae_anomaly_detection_model.json', 'r').read())
model.load_weights('./model/vae_anomaly_detection_weight.h5')
model._make_predict_function()
msg_buff = []
for msg in consumer:
input_img = np.frombuffer(msg.value, dtype=np.uint8)
# decode msg
input_img = cv2.imdecode(input_img, -1)
# prereprocessing
input_img = prereprocess(input_img)
# predict(generate) output
decoded_img_arr = model.predict(input_img, batch_size=1)
decoded_img = decoded_img_arr[0]
# compare decode_img with input_img
ssim_score = compare_ssim(input_img, decoded_img, multichannel=True)
ssim_score_list.append(ssim_score)
# concat ssim and the original image
im_v = make_score_graph(ssim_score_list)
ret, buffer = cv2.imencode('.jpg', im_v)
yield (b'--frame\r\n'
b'Content-Type: image/jpg\r\n\r\n' + buffer.tobytes() + b'\r\n\r\n')
Consumerは、リアルタイムで受信した画像データに対して、オブジェクト検出、フィールドカット、色調整などの基本的な画像前処理を実施する必要がありますが、 画像の前処理は今回の話題ではありません。前処理及び学習済みのVAE(Variational Auto Encoder)モデルの神経ネットワークに関しては別の記事で詳しく説明します。ここでは、Kerasで学習済みのモデルをエクスポートしたもの読み込んで利用します。もちろん、Pytorchなどの他のフレームワークを使用する場合は、ConsumerのモデルをロードするコードをPytorchバージョンに置き換える必要があります。
最後に、モデルの推論(Inference)結果をブラウザで表示する処理が必要です。 私たちのモデルはVAE生成モデルですので、モデルは受け取った画像に基づいて新しい画像を生成し、受け取った画像と生成された画像の類似性を比較することで、モデルの評価指標を作れます。 私たちのモデルは、稼働中のサーキュレーターの画像からのみ学習するため、稼働していない画像に過学習(Overfit)します。そうなると、稼働していない画像に対して、非常に異なる画像が生成されます。したがって、類似性評価指標で、サーキュレーターの異常動作(停止状態)を検出することができる。
以下は、matplotlibを使用して異常検出結果をリアルタイムで表示する、つまり計算された画像類似度を表示する方法です。
def make_score_graph(ssim_score_list):
if len(ssim_score_list) > 50:
ssim_score_list.pop(0)
else:
pass
# make ssim score graph
# make an agg figure
fig, ax = plt.subplots()
ax.plot(ssim_score_index[:len(ssim_score_list)], ssim_score_list)
#ax.set_title('a simple figure')
ax.set_ylim(0.1, 0.9)
ax.set_xticks([])
ax.set_ylabel("SSIM Score")
fig.canvas.draw()
# grab the pixel buffer and dump it into a numpy array
ssim_graph = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8')
ssim_graph = ssim_graph.reshape(fig.canvas.get_width_height()[::-1] + (3,))
print(ssim_graph.shape)
# resize original image and ssim_graph to the same height
ssim_graph = cv2.resize(ssim_graph, (400, 400))
src_img = cv2.resize(src_img, (400, 400))
# concat ssim and the original image
im_v = np.concatenate((src_img, ssim_graph), axis=1)
return im_v
実際の運用では、サーキュレーターを急にオフにすると、画面に表示される類似度の値がそれに応じて低下するため、異常動作が発生していること検出できる訳です。
5. アプリを起動
先にZooKeeperとKafkaを起動する必要があります。
$ brew install kafka $ brew services start zookeeper $ brew services start kafka

Flaskアプリを起動(アプリを起動するだけではConsumerのロジクが呼ばれません)
$ python consumer.py
正常起動の結果は以下の通りです。ブラウザでhttp://0.0.0.0:5000/video にアクセスするとウェブアプリの画面が見えるようになります。
* Serving Flask app "videoConsumer" (lazy loading) * Environment: production WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead. * Debug mode: on * Running on http://0.0.0.0:5000/ (Press CTRL+C to quit) * Restarting with stat
最後にProducerを起動し、カメラからリアルタイムデータ収集を開始します。起動語、先ほどのブラウザの画面でリアルタイムの異常検出結果が表示されます。
$ python producer.py

6. 最後に
ここまで、完全でシンプルな画像ストリーム処理アプリを実装しました。アプリ内でリアルタイム画像に対して異常検出を実行し、結果をブラウザーに表示していました。次回は今回で利用した学習済みのVAEの解説篇にしたいと考えています。