目次
Python/Kerasで学習させたモデルを速く動かしたい
こんにちは!AIチームの日比野です。
前回の下記記事で書けなかった内容がようやく書けます。
以前セミナーに参加したときに、「Python/Kerasで学習させたモデルは、Javaとかで動かすだけでかなり速くなるよ」という話を耳に挟みました。確かに速度を求める競技プログラミングの情報を見ていると、Pythonってまったく速くないんですよね。
速さだけでいうなら他の選択肢もあるとは思いますが、弊社の主要言語であるJavaで動かすことを目標にします。
そこで今回はJava用のDeep Learningライブラリである、Deep Learning for Java(DL4J)で推論処理を実行し、Pythonで実行したときとの速さを比較しようと思います!
環境
今回は環境をそろえるため、Google Colabではなく私の社用PCで行います。ただし、モデルの学習部分は時間短縮のためGoogle Colabを用います。
学習以外は以下の環境で行います。
OS | Windows 10 Pro 64-bit (10.0, Build 18363) |
---|---|
CPU | Intel(R) Core(TM) i5-8350U CPU @ 1.70GHz (8 CPUs), ~1.9GHz |
メモリ | 8192MB RAM |
GPU | 無し |
IDE:Python | Jupyter lab Version 0.35.4 |
IDE:Java | Eclipse 2019-12 (4.14.0) |
モデルの学習
ある程度負荷のかかるよう、深めのモデルを選択しました。
keras examplesのcifar10_cnnを使います。コードをコピペして、実行してください。
GPUを用いて学習を行います。だいたい70分程かかったため、念のためGoogle Colaboratoryのセッション切れに注意してください。
参考:
学習が完了したら追加で以下のコードを実行することで、モデルをダウンロードしておきます。
1 2 |
from google.colab import files files.download(model_path) |
Python処理
まずは比較用にPython側の推論処理を作成します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import keras import numpy as np model = keras.models.load_model('keras_cifar10_trained_model.h5', compile=False) # データ準備 テストデータのみを使用 _, (x_test, y_test) = keras.datasets.cifar10.load_data() print("データ数:" + str(x_test.shape[0])) opt = keras.optimizers.RMSprop(lr=0.0001, decay=1e-6) model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) y_test = keras.utils.to_categorical(y_test, num_classes) import time sum_time = 0 # 合計経過時間(ns) for img in x_test: img = np.array([img]) start = time.perf_counter_ns() # 推論時間のみを測定するため model.predict(img) sum_time += time.perf_counter_ns() - start print(sum_time) |
結果
データ数:10000
36229353700
テスト用画像の保存
同じデータを確実に使うために、画像ファイルを出力しておきます。
1 2 3 4 5 6 7 |
from keras.datasets import cifar10 from matplotlib import pyplot as plt image_dir = "保存フォルダ" _, (X_test, Y_test) = cifar10.load_data() for i in range(X_test.shape[0]): plt.imsave("image_dir"+f"/{i}_{Y_test[i]}.png",X_test[i]) |
以下のように保存できました。余計な括弧が入ってしまっていますが、まぁ良しとしてこのまま進めます!
Java処理
前回の記事と同じ環境で行います。
まずはプロジェクトを作ります。
新規→Javaプロジェクトから名前だけ編集を行い、完了を押下します。
次に作成したプロジェクトを右クリックし、構成→Mavenプロジェクトへ変換 を押下し、Mavenプロジェクトにします。
生成されたpom.xmlを編集することで、Mavenを使って必要なライブラリを調達していきます。
バージョンが違うとKerasのSequential modelが読み込めなかったりするので気を付けてください。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
<dependencies> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-modelimport</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> <version>1.6.6</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-log4j12</artifactId> <version>1.6.6</version> </dependency> <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> <version>1.2.16</version> </dependency> </dependencies> |
次に実行するクラスを作成します。
新規からクラスを作成します。
名前はとりあえずMainにしておきましょう。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import java.io.File; import java.io.IOException; import org.datavec.image.loader.ImageLoader; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.api.ndarray.INDArray; public class Main { public static void main(String[] args) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { // モデルの読み込み String simpleMlp = new ClassPathResource("keras_cifar10_trained_model.h5") .getFile().getPath(); MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(simpleMlp); System.out.println(model.getLayerNames()); // モデルを読み込めているか確認 //画像読み込み String imageDir = "Pythonでファイルを保存したディレクトリ"; File dir = new File(imageDir); File[] list = dir.listFiles(); ImageLoader imageLoader = new ImageLoader(32, 32, 3); System.out.println("データ数:" + list.length); long sumTime = 0; long start = 0; for (File f : list) { INDArray image = imageLoader.asMatrix(f, false); start = System.nanoTime(); model.output(image); //推論処理 sumTime += System.nanoTime() - start; } System.out.println(sumTime); } } |
画像の読み込み処理がイケてないですが、筆者のJavaスキルが限界のため、これで実行します。読み込み時間は計測に含まれないため、影響はありません。
編集が完了したら、Main.javaを右クリックし、Javaアプリケーションとして実行します。
結果
[conv2d_1, activation_1, conv2d_2, activation_2, max_pooling2d_1, dropout_1, conv2d_3, activation_3, conv2d_4, activation_4, max_pooling2d_2, dropout_2, dense_1, activation_5, dropout_3, dense_2, activation_6, activation_6_loss]
データ数:10000
103058046600
読み込み、ファイル数、時間計測共に問題ありませんね。すでに結果が見えている気もしますが比較を行います。
PythonとJavaの速度結果比較
言語 | 実行時間(NS) |
---|---|
Python | 36,229,353,700 |
Java | 103,058,046,600 |
はい。桁が違いますね。約3倍Javaの方が時間がかかっています。
Pythonで学習させたモデルをJavaで動かすと遅くなった
本当は「Python/Kerasで学習させたモデルをJavaで動かしたら速くなるってマジ?」にする予定でしたが、単純にPythonで学習させたモデルをJavaで動かすだけではむしろ遅くなるという事がわかりました。
考えてみればKerasのbackendで動いているTensorflowはPythonで動いているわけじゃないですし、推論の処理時間だけ見れば速くならないのは当然かもしれません。
ただし、今回はPythonで作ったモデルの推論時間での比較になります。実際のシステムに組み込む際は他の部分で差がつくことも考えられますので、これだけで「Javaが遅い!」という結論にはならないことはご留意ください。
本当はJava速いじゃん!とりあえず速くしたかったらさっとJavaで動かそ!!という流れにしたかったのですが無念です。
Javaの土俵での比較検証も今後行いたいと思っております。
が、検証計画の立案に少々お時間をいただきます……
参考URL
DL4J keras_import overview
github eclipse/deeplearning4j-examples
DL4J API Documents
執筆者プロフィール
-
入社して半年間ロボコン活動に専念。少しのJavaエンジニア期間を経てデータ分析や機械学習、Deep Learningをテーマに勤労しております。
昔取った杵柄を摩耗させつつ新たな支えを求めて試行錯誤中。
この執筆者の最新記事
- Pick UP!2020.10.30Python/Kerasで学習させたモデルをJavaで動かすだけじゃ速くならなかった
- ITコラム2020.08.04DL4Jを使うためにEclipseでMaven使おうとして困ったら
- AI2019.11.29AIの検証・開発を行うときに必要なハードウェア条件とは?
- AI2019.08.05Google Colaboratoryの無料GPU環境を使ってみた