Python/Kerasで学習させたモデルをJavaで動かすだけじゃ速くならなかった

Pocket

Python/Kerasで学習させたモデルを速く動かしたい

こんにちは!AIチームの日比野です。

前回の下記記事で書けなかった内容がようやく書けます。

以前セミナーに参加したときに、「Python/Kerasで学習させたモデルは、Javaとかで動かすだけでかなり速くなるよ」という話を耳に挟みました。確かに速度を求める競技プログラミングの情報を見ていると、Pythonってまったく速くないんですよね。

速さだけでいうなら他の選択肢もあるとは思いますが、弊社の主要言語であるJavaで動かすことを目標にします。

そこで今回はJava用のDeep Learningライブラリである、Deep Learning for Java(DL4J)で推論処理を実行し、Pythonで実行したときとの速さを比較しようと思います!

環境

今回は環境をそろえるため、Google Colabではなく私の社用PCで行います。ただし、モデルの学習部分は時間短縮のためGoogle Colabを用います。

学習以外は以下の環境で行います。

OS Windows 10 Pro 64-bit (10.0, Build 18363)
CPU Intel(R) Core(TM) i5-8350U CPU @ 1.70GHz (8 CPUs), ~1.9GHz
メモリ 8192MB RAM
GPU 無し
IDE:Python Jupyter lab Version 0.35.4
IDE:Java Eclipse 2019-12 (4.14.0)

モデルの学習

ある程度負荷のかかるよう、深めのモデルを選択しました。
keras examplesのcifar10_cnnを使います。コードをコピペして、実行してください。

GPUを用いて学習を行います。だいたい70分程かかったため、念のためGoogle Colaboratoryのセッション切れに注意してください。

参考:

学習が完了したら追加で以下のコードを実行することで、モデルをダウンロードしておきます。

Python処理

まずは比較用にPython側の推論処理を作成します。

結果

データ数:10000
36229353700

テスト用画像の保存

同じデータを確実に使うために、画像ファイルを出力しておきます。

以下のように保存できました。余計な括弧が入ってしまっていますが、まぁ良しとしてこのまま進めます!

Java処理

前回の記事と同じ環境で行います。

まずはプロジェクトを作ります。
新規→Javaプロジェクトから名前だけ編集を行い、完了を押下します。

次に作成したプロジェクトを右クリックし、構成→Mavenプロジェクトへ変換 を押下し、Mavenプロジェクトにします。

生成されたpom.xmlを編集することで、Mavenを使って必要なライブラリを調達していきます。

バージョンが違うとKerasのSequential modelが読み込めなかったりするので気を付けてください。

次に実行するクラスを作成します。

新規からクラスを作成します。
名前はとりあえずMainにしておきましょう。

 

画像の読み込み処理がイケてないですが、筆者のJavaスキルが限界のため、これで実行します。読み込み時間は計測に含まれないため、影響はありません。

編集が完了したら、Main.javaを右クリックし、Javaアプリケーションとして実行します。

結果

[conv2d_1, activation_1, conv2d_2, activation_2, max_pooling2d_1, dropout_1, conv2d_3, activation_3, conv2d_4, activation_4, max_pooling2d_2, dropout_2, dense_1, activation_5, dropout_3, dense_2, activation_6, activation_6_loss]
データ数:10000
103058046600

読み込み、ファイル数、時間計測共に問題ありませんね。すでに結果が見えている気もしますが比較を行います。

PythonとJavaの速度結果比較

言語 実行時間(NS)
Python 36,229,353,700
Java 103,058,046,600

はい。桁が違いますね。約3倍Javaの方が時間がかかっています。

Pythonで学習させたモデルをJavaで動かすと遅くなった

本当は「Python/Kerasで学習させたモデルをJavaで動かしたら速くなるってマジ?」にする予定でしたが、単純にPythonで学習させたモデルをJavaで動かすだけではむしろ遅くなるという事がわかりました。

考えてみればKerasのbackendで動いているTensorflowはPythonで動いているわけじゃないですし、推論の処理時間だけ見れば速くならないのは当然かもしれません。

ただし、今回はPythonで作ったモデルの推論時間での比較になります。実際のシステムに組み込む際は他の部分で差がつくことも考えられますので、これだけで「Javaが遅い!」という結論にはならないことはご留意ください。

本当はJava速いじゃん!とりあえず速くしたかったらさっとJavaで動かそ!!という流れにしたかったのですが無念です。

Javaの土俵での比較検証も今後行いたいと思っております。
が、検証計画の立案に少々お時間をいただきます……

参考URL

DL4J keras_import overview
github eclipse/deeplearning4j-examples
DL4J API Documents

お問い合わせ先

執筆者プロフィール

Hibino Ichirou
Hibino Ichiroutdi デジタルイノベーション技術部
入社して半年間ロボコン活動に専念。少しのJavaエンジニア期間を経てデータ分析や機械学習、Deep Learningをテーマに勤労しております。
昔取った杵柄を摩耗させつつ新たな支えを求めて試行錯誤中。
Pocket

関連記事