오늘은 안드로이드에서 커스텀 텐서플로우 라이트 모델을 사용하는 방법을 알아보겠습니다! 구글에 안드로이드에서 텐서플로우 라이트 모델 사용하는 법을 검색하면 거의 안드로이드 프로젝트에 텐서플로우 라이트 모델을 넣고 사용하는 방법만 알려주더라구요! 안드로이드 프로젝트에 모델을 직접 넣는 것이 아닌 FirebaseML을 사용하여 Firebase에서 모델을 가져와 안드로이드에서 모델을 사용하는 방법을 알아보겠습니다 :)
FirebaseML
FirebaseML을 사용해보겠습니다! 이것을 사용하기 위해서는 Firebase에 모델을 업로드해야 합니다 안드로이드에서 텐서플로우 라이트를 사용할 것이기 때문에 모델을 업로드할 때 .tflite 파일로 업로드를 해주어야 합니다 모델을 .tflite로 변환하는 방법은 조금만 구글링하면 엄청나게 잘 나오므로 여기서는 생략하겠습니다 FirebaseML을 사용하기 전에 모델을 .tflite로 변환해주세요!
1. Firebase 사이트에 접속하여 콘솔로 이동한 후 프로젝트를 만들어줍니다 프로젝트를 만드는 방법은 아래의 링크를 참고해주세요!
프로젝트를 만든 후에 프로젝트를 선택하면 다른 페이지로 이동됩니다 그 페이지에서 왼쪽을 보면 다음과 같은 화면을 볼 수 있습니다. 여기서 Machine Learning을 선택합니다.
2. 또 페이지가 이동되면 다음과 같은 화면을 볼 수 있습니다. 여기서 Custom을 선택한 후 .tflite로 변환된 모델을 추가해줍니다. 모델 추가가 끝나면 Firebase쪽은 세팅이 다 끝난 것입니다.
Android
1. build.gradle에 다음 코드를 추가해줍니다
implementation platform('com.google.firebase:firebase-bom:26.7.0')
implementation 'com.google.firebase:firebase-ml-modeldownloader'
implementation 'org.tensorflow:tensorflow-lite:2.4.0'
2. AndroidManifest.xml에 인터넷 권한을 추가해줍니다. 다음 코드를 추가해주면 됩니다.
<uses-permission android:name="android.permission.INTERNET" />
3. 모델의 input과 output을 확인합니다. 모델 코드에서 다음 코드를 넣어보면 input과 ouput을 알 수 있습니다.
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()
# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))
# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
결과는 다음과 같이 나올 것입니다!
4. Firebase에서 모델을 가져와 실행시켜봅니다
FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder().requireWifi().build();
FirebaseModelManager.getInstance().download(remoteModel, conditions).addOnSuccessListener(new OnSuccessListener<Void>() {
@Override
public void onSuccess(Void v) {
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel).addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
interpreter = new Interpreter(modelFile); // 인터프리터 생성
float[][] input = {0,0,0,0,0,0,0,0,0,0}; // 모델의 Input
float[][] output = new float[1][7]; // 모델의 Output
if(interpreter != null) {
interpreter.run(input, output); // 모델 실행
}
// 모델 결과값 출력하기
for (int i = 0; i < 7; i++) {
System.out.println(i + " : " + output[0][i]);
}
interpreter.close(); // 인터프리터 종료
}
}
});
}
});
첫번째 줄을 보면 이 코드를 통해 Firebase에서 모델을 가져옵니다. Builder(모델명) 형태로 써주면 되는데 여기서 모델명은 FirebaseML에 올린 모델명을 말합니다. 모델이 잘 가져와졌다면 onComplete() 가 실행됩니다. 인터프리터를 생성해주고 위에서 확인한 모델의 input과 output에 맞게 모델에 넣어주면 정상적으로 잘 동작할 것입니다 input과 output이 맞지 않는다면 모델이 실행되지 않으니 주의해주세요!! 모델 실행이 끝나면 모델의 결과인 output을 출력하여 확인합니다 다음과 같이 출력된다면 모델이 잘 실행된 것입니다.
'Android' 카테고리의 다른 글
[Android][Kotlin] Fragment의 toolbar 메뉴 클릭 이벤트 (0) | 2021.05.22 |
---|---|
[Android] RxJava란? (0) | 2021.05.14 |
[Android] 12주차 스터디 (Volley) (0) | 2021.02.19 |
[Android] 11주차 스터디 (Firebase Storage with Glide) (0) | 2021.01.27 |
[Android] 10주차 스터디 (Firebase 클라우드 메시징 (FCM)) (0) | 2021.01.14 |