尽管这里有tensorflow lite的详细 介绍,在实战中还是不免要踩一些坑。下面是我踩过的坑,记录下来,免得下次踩同样的坑)。

测试环境:

  • 自己训练的模型。使用预训练的模型应该会更简单一些,不需要模型转换这个步骤。
  • Tensorflow 2.0.0-beta1
  • Android Studio 3.4.2
  • 简单的回归预测案例

模型转换

试验了独立程序和训练后即时转换两种方式,最后觉得在训练模型后马上进行模型转换更方便,原因有两个:

  1. 在训练程序中,可以直接使用模型对象model进行模型的转换。
  2. 通常输入数据要进行标准化处理,而每次训练时所采用的数据集是不同的,导致标准化数据的mean和std会随之变化,因此需要将标准化数据的mean和std保存下来,以便传递给android app对输入数据进行同样的数据标准化处理。

上面的两个数据:转换后的模型和数据标准化基础数据都需要复制到Android app的assets目录下,因此在训练的程序中统一写一下更加方便,下面是我这边相应的代码:

# 模型训练完毕后转化模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("mymodel.tflite", "wb").write(tflite_model)
# 保存数据标准化相关数据(主要是mean和std)
train_stats = train_dataset.describe()
train_stats = train_stats.transpose()
# 创建csv文件,以便移动端使用相同的统计数据标准化数据
train_stats.to_csv('train_stats.csv', index=False)

Android App的配置

Android Studio这一端的配置涉及到以下几个方面:

模型和数据标准化基准数据的存放位置

转化后的模型和数据标准化基准数据一般要保存到app/src/main/assets目录下,方便在程序中引用。

依赖的引入

在app/build.gradle文件中增加如下的依赖:

dependencies {
    ......
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
    implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly'
    implementation 'com.opencsv:opencsv:4.5'
    ......
}

同样的,在app/build.gradle文件中增加如下的片段,原因已经在注释中说明了:

android {
    compileSdkVersion 28

    // aapt默认压缩assets下面的文件,直接openFd打不开
    aaptOptions {
        noCompress "tflite"  //表示不让aapt压缩的文件后缀
    }
......
}

Interpreter的创建和使用

有了以上的准备工作,就可以在适当的Action中使用Interpreter来运行模型了:

private static final String MODEL = "mymodel.tflite";
......
try (Interpreter interpreter = new Interpreter(loadModelFile(MODEL))) {
        ......
        interpreter.run(normed_input, output);
}

    /**
     * Memory-map the model file in Assets.
     *
     * @see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
     */
private MappedByteBuffer loadModelFile(String modelPath) throws IOException {
        AssetFileDescriptor fileDescriptor = getAssets().openFd(modelPath);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

输入/输出数据的维度和标准化处理

Interpreter的参数很简单,一个是input,一个output,但是要注意这两个参数的维度和类型。input只能使用Interpreter能够识别的基本数据类型(int,float,long,byte)数组,output的类型和维度因程序而异。input的类型和维度在训练程序中查看更加方便,因此一般要在训练程序中搞清楚input和output的类型和维度再来写android端的程序,此时你会发现,python在AI程序方面的优势太大了。

我这边的情况input和output都是float数组:

    private float[][] input;
    private float[][] output;

在上面说过,喂入模型的数据一般要经过标准化处理,那么在android端的输入数据也要进行同样的标准化处理,这在android端是个有点麻烦的事情。我采取的方法是读取模型训练时的标准化基准数据(主要是mean和std)csv文件,然后对输入数据使用mean和std进行标准化处理:

            private static final String MEAN_STD = "train_stats.csv";
            
            private float[][] norm(float[][] input) {
                float[][] normed_input = new float[input.length][43];
                float[][] mean_std = new float[43][2];
                CSVReader reader = null;
                try {
                    reader = new CSVReader(new BufferedReader(new InputStreamReader(getAssets().open(MEAN_STD))));
                    List<String[]> myEntries = reader.readAll();
                    int index = 0;
                    for(String[] entry:myEntries){
                        float mean = Float.parseFloat(entry[1]);
                        float std = Float.parseFloat(entry[2]);
                        mean_std[index][0] = mean;
                        mean_std[index][1] = std;
                        index++;
                    }

                    for(int i = 0; i < input.length; i++){
                        for(int j = 0; j < 43; j++){
                            normed_input[i][j] = (input[i][j] - mean_std[j][0])/mean_std[j][1];
                        }
                    }


                } catch (FileNotFoundException e) {
                    e.printStackTrace();
                } catch (IOException e) {
                    e.printStackTrace();
                }

                return normed_input;
            }