1. TensorRT Plugin이란?
TensorRT는 C++ library이고 nvidia GPUs와 deep learning accelerator를 제공함으로써 뛰어난 performance를 제공합니다. 그래서 nvidia GPU가 장착된 서버를 쓰신다면 TensorRT(.trt)모델로 변환하여 inference하는 것이 효과적입니다.
추가로 TensorRT에서는 plugin기능을 제공하는데요. Plugin을 사용하여 model의 추가적인 연산(preprocess, postprocess)를 C++, cuda programming 으로 대체할 수 있어서 (1) 코드의 간결화 (2) 연산속도의 효율의 장점이 있습니다. 대표적인 예시로 대부분의 AI개발자분들은 detection model의 NMS(Non Maximum Suppression) 코드를 python으로 개발하실텐데 TensorRT에서는 detection model의 뒤에 NMS plugin을 붙일 수 있어서 model의 output이 NMS를 통과한 output으로 간결화되고 python이 아닌 C++이기때문에 연산속도의 효율도 가지게 됩니다.
TensorRT에서 제공하는 plugin은 아래의 그림을 통해 확인 가능하며 사용 가능한 모든 plugin은 여기서 보시면 됩니다.
2. TensorRT plugin 예제 및 실습
최근에 새로나온 yolov7 model에 batchedNMSPlugin을 추가해보는 실습을 해보겠습니다. batchedNMSPlugin은 NMS step을 C++언어와 GPU로 inference가능하다는 장점이 있습니다. 개발 환경은 다음과 같습니다.
- onnx: 1.8.0
- torch: 1.9.0a0+df837d0
- onnx-graphsurgeon: 0.2.8
- tensorrt: 7.2.2.3
- CUDA: 11.2
- Driver Version: 460.73.01
2.1 batchedNMSPlugin 이란?
일단 먼저 batchedNMSPlugin의 input, output 형태가 어떤 지 알아봅시다. 해당 plugin의 input 형태는 yolov7모델의 output과 동일해야 해당 plugin을 사용가능하다는 것을 말합니다. 그리고 output형태는 해당 plugin의 결과 의미합니다.
- Input
- Boxes input: [batch_size, number_boxes, 1, number_box_parameters]
- number_box_parameters는 bbox의 정보를 담고 있는데 [x1, y1, x2, y2]으로 (x1, y1), (x2,y2)는 각각 왼쪽 위, 오른쪽 아래 bbox좌표를 나타냄
- Scores input: [batch_size, number_boxes, class_with_confidence]
- class_with_condience = number_classes(각 클래스의 확률) * confidence(objectness)를 의미함
- Boxes input: [batch_size, number_boxes, 1, number_box_parameters]
- Output
- num_detections: [bacth_size]
- batch마다 detection된 object수를 나타냄
- nmsed_boxes: [batch_size, keepTopK, 4]
- NMS를 통과한 bounding box 좌표 [x1, y1, x2, y2]
- nmsed_scores: [batch_size, keepTopK]
- NMS를 통과한 bounding box score
- nmsed_classes: [batch_size, keepTopK]
- NMS를 통과한 bounding box class
- num_detections: [bacth_size]
그리고 batchedNMSPlugin에 필요한 parameter(중요한 것은 highlight)은 다음과 같습니다.
2.2 Torch모델 ONNX모델로 변환
TensorRT Plugin을 사용하기 위해서는 TRT모델로 만들어야 합니다. TRT모델은 ONNX모델으로부터 생성 가능하므로 Torch모델을 ONNX모델로 변환해보죠. 하지만 YOLOv7 모델의 output shape은 [1, 25200, 85]이기 때문에 batchedNMSPlugin의 input shape으로는 맞지 않습니다. 그래서 아래의 코드를 통해 YOLOv7의 output shape을 바꿔보죠.
class ProcModel(nn.Module):
def __init__(self, model, class_num):
super(ProcModel, self).__init__()
self.model = model
self.class_num = class_num
def forward(self, x):
out = self.model(x)[0] # out shape = [batch, num_object, 85], 85 = class_num(80)+bbox(4)+confidence(1)
bbox_out = torch.unsqueeze(out[:,:,:4], 2) # bbox_out shape = [batch, num_object, 1, bbox], bbox = [cx,cy,w,h]
x1 = bbox_out[:,:,:,0] - bbox_out[:,:,:,2] / 2
y1 = bbox_out[:,:,:,1] - bbox_out[:,:,:,3] / 2
x2 = bbox_out[:,:,:,0] + bbox_out[:,:,:,2] / 2
y2 = bbox_out[:,:,:,1] + bbox_out[:,:,:,3] / 2
bbox_out = torch.stack((x1,y1,x2,y2), dim=3) # bbox_out shape = [batch, num_object, 1, bbox], bbox = [x1,y1,x2,y2]
conf_out = out[:,:,4] # [batch, num_object, 1]
conf_out = torch.reshape(conf_out, (conf_out.shape[1],)) # [batch, num_object]
class_out = torch.mul(out[:,:,5:].transpose(1,2) , conf_out).transpose(1,2) # [batch, num_object, num_classes]
return [bbox_out, class_out]
procmodel = ProcModel(model, 80)
위와 같이 YOLOv7의 output shape을 batchedNMSPlugin의 input shape에 맞춰 변환하였습니다. YOLOv7 모델은 coco dataset(class num: 80)으로 학습되었으므로 bbox_out의 shape은 [1, 25200, 1, 4]이고 class_out의 shape은 [1,25200, 80]입니다. (bbox_out은 batchedNMSPIugin의 input중 하나인 Boxes input에 대응되고 class out은 Scores out에 대응됨)
이제 torch.onnx.export
함수를 통해 ONNX 모델로 변환해봅시다.
f = str(weights).replace('.pt', '.onnx') # yolov7.pt -> yolov7.onnnx
input_names = ['images']
output_names = ['bbox_out','class_out']
train = False
opset_version = 12
torch.onnx.export(procmodel, img, f, verbose=False, opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=input_names,
output_names=output_names,
dynamic_axes=None)
opset은 12를 기준으로 하였고 training이 아닌 eval버전으로 export하였습니다. 그리고 output_names=[bbox_out, class_out]인것을 기억해야 합니다. 다음 스텝에서 중요한 요소이거든요.
위의 코드를 실행 시키면 ONNX 모델의 output이 잘 바뀐 것을 확인 가능합니다. (전체 변환 코드는 export.py에서 확인합니다.)
2.3 onnx_graphsurgeon사용하여 batchedNMSPlugin추가
ONNX모델로 변환하면 ONNX 모델의 output shape이 위에서 설명드린 batchedNMSPlugin의 input shape과 같은 형태로 변환되었을 것입니다. 이제 onnx모델을 TRT모델로 변환하는 CLI인 trtexec
를 사용하기 전에 onnx_graphsurgeon으로 ONNX 모델에 batchedNMSPlugin을 추가해주어야 합니다.
onnx_graphsurgeon는 TensorRT/tools/onnx-graphsurgeon에서 제공하는 TensoRT tool로 ONNX model에 ONNX graph를 추가하거나 수정 가능하게 해 줍니다. 해당 폴더에서 1. make install
2. make build
명령어로 쉽게 설치 가능합니다. (저는 release/7.2 version을 설치하였습니다)
이제 batchedNMSPlugin을 추가하는 코드를 설명드립니다.
import onnx_graphsurgeon as gs
def create_attrs(input_h, input_w, topK, keepTopK):
attrs = {}
attrs["shareLocation"] = 1
attrs["backgroundLabelId"] = -1
attrs["numClasses"] = 80
attrs["topK"] = topK
attrs["keepTopK"] = keepTopK
attrs["scoreThreshold"] = 0.25
attrs["iouThreshold"] = 0.6
attrs["isNormalized"] = False
attrs["clipBoxes"] = False
# 001 is the default plugin version the parser will search for, and therefore can be omitted,
# but we include it here for illustrative purposes.
attrs["plugin_version"] = "1"
return attrs
graph = gs.import_onnx(onnx.load('yolov7.onxx')) # load onnx model
batch_size = graph.inputs[0].shape[0]
input_h = graph.inputs[0].shape[2]
input_w = graph.inputs[0].shape[3]
tensors = graph.tensors()
boxes_tensor = tensors["bbox_out"] # match with onnx model output name
confs_tensor = tensors["class_out"] # match with onnx model output name
topK = 100
keepTopK = 50
num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[batch_size, 1]) # do not change
nmsed_boxes = gs.Variable(name="nmsed_boxes").to_variable(dtype=np.float32, shape=[batch_size, keepTopK, 4]) # do not change
nmsed_scores = gs.Variable(name="nmsed_scores").to_variable(dtype=np.float32, shape=[batch_size, keepTopK]) # do not change
nmsed_classes = gs.Variable(name="nmsed_classes").to_variable(dtype=np.float32, shape=[batch_size, keepTopK]) # do not change
new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes] # do not change
nms_node = gs.Node( # define nms plugin
op="BatchedNMSDynamic_TRT", # match with batchedNMSPlugn
attrs=create_attrs(input_h, input_w, topK, keepTopK), # set attributes for nms plugin
inputs=[boxes_tensor, confs_tensor],
outputs=new_outputs)
graph.nodes.append(nms_node) # nms plugin added
graph.outputs = new_outputs
graph = graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), 'yolov7_gs.onnx') # save model
create_attrs
함수를 통해 위에서 표로 설명드린 batchedNMSPlugin의 parameter를 설정함numClasses
는 coco dataset를 사용하므로 80keepTopK
를 50으로 설정하여 한 이미지 당 detect 가능 object개수 제한isNormalized
를 False를 설정하여 yolov7의 bbox 좌표 output이 normalized안되어 있음을 명시
gs.import_onnx
함수를 통하여 onnx model를 변경할 수 있도록 graph형태로 생성tensors['bbox_out'] , tensors['class_out']
은 ONNX model로 변환 시에 output_names로 명시한 이름gs.Variable
을 통해 batchedNMSPlugin의 output형태를 만듦gs.Node
를 통해 추가하는 batchedNMSPlugin의 노드를 만듦- op에
BatchedNMSDynamic_TRT
은 노드 이름이며 이후 trtexec cli실행 시에 해당 이름을 보고 plugin 구현체 만듦 - 여기서 attrs인자의 값으로 위의
create_attrs
함수의 리턴 값을 넘김
- op에
graph.nodes.append(nms_node)
으로 ONNXM model의 뒤에 batchedNMSPlugin삽입- 여기서 중요한 것은 batchedNMSPlugin의 노드의 이름, 입출력 형태, 해당 노드의 속성만 정의한 것임
- 그래서 실제 batchedNMSPlugin의 구현체는
trtexec
를 통해 TRT모델로 변환 시에 생성됨
그래서 위 코드를 기반으로 실행하면 yolov7_gs.onnx파일이 새로 생성되었으며 netron으로 까 보면 위에서 정한 노드 이름(BatchedNMSDynamic_TRT)으로 추가된 output 노드를 볼 수 있다. (전체 코드는 add_nmsplugin.py 확인!)
2.4 ONNX모델 TRT모델로 변환
드디어 TRT모델을 만들어 볼 시간입니다. trtexec
명령어로 ONNX 모델을 TRT모델로 변경해보죠.
trtexec --onnx=yolov7_gs.onnx --fp16 --workspace=1024 --saveEngine=yolov7_gs.trt
- --onnx: input model이며 BatchedNMSDynamic_TRT node가 추가된 ONNX model path
- --fp16: weight의 data type을 fp16으로 함
- --workspace: workspace의 최대 크기를 정함 (클수록 성능이 올라갈 수 있다 함)
- --saveEngine: TRT모델이 저장될 path
위 CLI를 실행하고 출력된 부분을 보았을 때 아래와 같이 plugin이 정상적으로 생성됨을 알 수 있다.
해당 CLI의 동작은 기기마다 다르겠지만 10~30분 걸릴 것이니 직접 하신 다면 편안하게 기다리시죠. 기다리고 나면 yolov7_gs.trt모델이 생성될 것입니다.
2.5Plugin 추가된 TRT모델 실행
이제 batchedNMSPlugin이 추가된 TRT모델이 정상적으로 detection 되는지 확인해보죠. TRT모델을 Load하고 inference하는 코드는 이번 글의 목적은 아니니 자세히 설명드리지는 않습니다. (다음 글에서 설명드릴게요!) batchedNMSPlugin이 잘 작동하여 위에서 말씀드린 output대로 잘 나오는지 detection이 잘 되는 지 확인하는 것을 목적으로 합니다.
TRT모델로 detection 하는 코드는 detect_trt_plugin.py 확인 가능합니다.
2.5.1 batchedNMSPlugin의 output 확인
yolov7 모델의 input image는 아래와 같이 horse.jpg입니다.
TensorRT모델을 통해 inference 하는 (부분) 코드입니다.
image = input_image.transpose(0, 3, 1, 2) # NHWC to NWHC
image = np.ascontiguousarray(image)
cuda.memcpy_htod(self.inputs[0]['allocation'], image) # input image array(host)를 GPU(device)로 보내주는 작업
self.context.execute_v2(self.allocations) #inference 실행!
for o in range(len(self.outputs)):
cuda.memcpy_dtoh(self.outputs[o]['host_allocation'], self.outputs[o]['allocation']) # GPU에서 작업한 값을 host로 보냄
num_detections = self.outputs[0]['host_allocation'] # detection된 object개수
nmsed_boxes = self.outputs[1]['host_allocation'] # detection된 object coordinate
nmsed_scores = self.outputs[2]['host_allocation'] # detection된 object confidence
nmsed_classes = self.outputs[3]['host_allocation'] # detection된 object class number
result = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes]
batchedNMSPlugin의 output은 위와 같이 총 4개이며 각각이 제대로 출력되는지 확인해봅니다.
위를 통해 총 6개의 object가 detection 되었고 각 object의 좌표는 nms_boxes에서 확인 가능하며 nms_scores와 nmsed_classes 또한 순차적으로 object의 confidence와 class를 나타냅니다.
2.5.2 detection결과 확인
위의 output기반으로 detection결과를 확인하면 아래와 같습니다.
batchedNMSPlugin이 추가된 상태로 잘 detection 되는 것을 확인 가능합니다! 아래는 실제 yolov7 repo(pytorch 모델 + python NMS code)에서 제공하는 horse.jpg에 대한 결과입니다.
'AI Engineering > NVIDIA' 카테고리의 다른 글
[NVIDIA] DALI multi-GPU 사용법 with PyTorch (2) | 2023.06.22 |
---|---|
[NVIDIA] DALI 사용법 with PyTorch (4) | 2023.06.18 |
[NVIDIA] TensorRT inference 코드 및 예제 (feat. yolov7) (3) | 2022.07.29 |
[NVIDIA] DeepStream 이해 및 설명 (0) | 2022.07.13 |