이번 글에서는 tf saved model(.pb)을 TensorFlow.js model(.json)으로 변환시키는 것을 목적으로 합니다. 2021년 google에서 나온 Efficientnetv2을 대상으로 TensorFlow.js로 변환하고 웹사이트에서 Efficientetv2으로 classification까지 해보죠! (만약 Efficientnetv2에 대해 알고싶다면 EfficientNetv2 논문 리뷰 참고해주세요~)
0. keras model를 tf saved model로 변환
TensorFlow.js 변환 하기 전에 EfficientNetv2는 keras model로 제공하고 있기 때문에 tf saved model로 변환부터 해보죠.
(변환만 관심 있으시면 넘어 가시면 돼요!!)
해당 github을 clone하여 필요한 lib들을 설치하고 save_to_pb.py파일을 실행시켜 efficientnetv2-b0를 tf saved model로 변환해봅시다. (save_to_pb.py는 제가 efficientnetv2-b0모델을 tf saved model로 변경하기 위해 만든 코드입니다.)
git clone https://github.com/da2so/efficientnetv2.git
cd efficientnetv2
pip install -r requirements.txt
python save_to_pb.py
위의 명령어가 모두 정상적으로 동작했다면 다음과 같이 model(.pb)와 폴더가 생성되어야 합니다.
save_to_pb.py의 (중요 부분) 코드는 다음과 같습니다.
def main(_) -> None:
model = build_tf2_model() #build efficientnetv2-b0 model
input = tf.keras.Input(shape=(224,224,3), batch_size=1) # input shape: (1x3x224x224)
keras_model = tf.keras.Model(inputs=[input], outputs=tf.nn.softmax(model.call(input, training=False))) #keras model
keras_model.save('./efficientnetv2-b0_saved_model', save_format='tf') #save to tf saved model
efficientnetv2-b0을 keras model로 만들 때 softmax부분을 추가해주었고 save함수을 통해 tf saved model 형태로 저장하였습니다.
1. TensorFlow.js model로 변환
TensorFlow.js로 변환하는 방법은 아주 간단합니다. 먼저 tensorflowjs를 설치합니다.
pip install tensorflowjs
이제 tf saved model을 tensorflowjs_converter
명령어를 통해 TensorFlow.js 모델로 변환해봅시다.
tensorflowjs_converter --input_format=tf_saved_model efficientnetv2-b0_saved_model efficientnetv2-b0_web_model
- --input_format: 입력 모델 형식
- tf saved model -> tf_saved_model (제가 사용한 옵션)
- keras model(.h5) -> keras
- frozen model -> tf_frozen_model
옵션에 대한 설정이 끝났으면 다음으로는 source_model인 efficientnetv2-b0_saved_model을 설정하고 마지막은 TensorFlow.js파일들이 저장될 디렉토리(efficientnetv2-b0_web_model)을 설정합니다. 위의 명령어가 정상적으로 작동했다면 다음과 같이 모델의 구조를 담는 model.json과 weight를 담고 있는 bin파일들로 저장됩니다.
2. EfficientNetv2 웹사이트에 deploy
EfficientNetv2를 TensorFlow.js 형태로 만들었으니 실제 웹사이트에서 inference가 잘 작동하는 지 확인해 봐야겠죠?? 저는 javascript기반 react를 기반으로 코딩하였습니다. 직접 자신의 웹사이트에서 작동 확인하시려면 제 github을 clone: git clone https://github.com/da2so/tfjs-efficientnetv2.git
하셔서 사용하시면됩니다.
(npm사용해서 localhost에서 실습하시거나 github page deploy사용하시면 됩니다.)
TensorFlow.js 모델을 react코드에서 어떻게 load하고 inference하는 지 중요한 부분만 골라서 알려드리겠습니다.
const weights = 'https://raw.githubusercontent.com/da22so/tfjs_models/main/efficientnetv2-b0_web_model/model.json';
... //생략
class App extends React.Component {
state = {
model: null,
... //생략
};
componentDidMount() {
tf.loadGraphModel(weights).then(model => { //Efficientnetv2-b0 모델 load
this.setState({
model: model
});
});
}
this.state.model.executeAsync(input).then(res => { // classification execute!
... //생략
const pred = res;
const pred_data = pred.dataSync(); // classification compelete done!
- tf.loadGraphModel: TensorFlow.js모델을 Load하는 함수
- .json 확장자 파일을 입력으로 받으며 해당 json파일은 위에서 만든 model.json이랑 같음
- url을 통해서만 TensorFlow.js model을 load하는 함수임
- load된 모델은 this.state.model에 할당
- this.state.model.executeAsync(input): model의 inference을 execute하는 함수
- input은 입력 이미지(1x3x224x224)를 말함 (자세한거는 코드에서 확인해주세요!)
- classification결과는 res에 할당되지만 javascript 특성상 비동기적이므로 classification이 execute되고 complete되면 pred_data에 할당함
여기서 tfjs-efficientnetv2 실습 EfficientNetv2-b0을 실제로 사용해보실 수 있게 해 놓았으니 이미지 업로드해보세요~~ 저는 저희 집 고양이인 코넛이 사진을 넣어 classification해보았습니다!
'AI Engineering > TensorFlow' 카테고리의 다른 글
TFLite 뽀개기 (4) XNNPACK 이해 및 성능 비교 (0) | 2022.08.10 |
---|---|
TensorFlow.js (4) YOLOv5 Live demo (7) | 2022.04.11 |
TensorFlow.js (2) - WebGL 기반 hand pose detection (2) | 2022.03.23 |
TensorFlow.js (1) - TensorFlow.js 이해 및 detection 예제 (1) | 2022.03.17 |
Mediapipe (2) - custom segmentation model with mediapipe (0) | 2022.03.14 |