일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- recommendation system
- kubernetes
- eks
- 개발자
- 추천시스템
- 빅데이터플랫폼
- 데이터엔지니어링
- AWS SageMaker
- mlops
- apache spark
- hadoop
- redis bloom filter
- pyspark
- Terraform
- BigData
- dataengineer
- 개발자혜성
- 클라우데라
- spark
- 데이터엔지니어
- kafka
- 블로그
- Data engineering
- 하둡
- Python
- 하둡에코시스템
- Spark structured streaming
- cloudera
- 빅데이터
- DataEngineering
- Today
- Total
Hyesung Oh
추론 최적화 시리즈 [1] Bert4rec Pytorch module을 Torch-Tensorrt로 compile 하여 Tritonserver로 실시간 추론하기 본문
추론 최적화 시리즈 [1] Bert4rec Pytorch module을 Torch-Tensorrt로 compile 하여 Tritonserver로 실시간 추론하기
혜성 Hyesung 2024. 7. 15. 17:37TL;DR
지난 추천 시스템 고도화 시리즈의 실시간 추론 편 마지막 단락에서 계획 중인 사이드 프로젝트에 대해 말씀드렸었는데요, 운이 좋게도 사내 추천 시스템에 실시간 추론을 도입하여 사용자에게 조금 더 다이내믹한 탐색 경험을 제공하자는 방향성이 논의되어 사내 PoC Task로 진행해 보게 되었습니다 :).
https://surgach.tistory.com/139
우선 이번 실시간 추론 시리즈를 통해서 달성하고 싶은 목표를 아래와 같이 정했습니다.
- 학습된 torch nn.Module을 Torch-TensorRT 라이브러리를 사용하여 TensorRT engine이 내장된 TorchScript로 컴파일하여 추론 속도를 개선해 봅니다.
- 컴파일된 모델 파일을 Triton server(pytorch backend를 사용)에 호스팅 하여 실시간 추론 퍼포먼스를 측정해 봅니다.
- Torch-TensorRT 라이브러리 사용법 및 컴파일 동작을 코드 레벨에서 이해합니다.
- Triton server 모델 구성 설정법을 습득합니다.
PoC
PoC를 진행하면서 실행한 코드, 트러블 슈팅 과정을 순서대로 첨부하였습니다.
Version Compatibility
PoC 기본 환경 설정입니다. Sagemaker jupyterlab on EC2 g5.xlarge volume 16GB, 진행 일자 기준 가장 최신 NGC Container Tag 24.05를 사용하였습니다.
Container Tag: 24.05
*cuda driver: 12.1.1
*nvidia container toolkit: 1.13.5
Torch-Tensorrt Triton
Torch Conainer | Tritonserver Container | |
Release note | https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-05.html | https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel-23-11.html#rel-23-11 |
NGC Catalog Container Tag | https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags | https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags |
lightning, pytorch_lightning, torch version 간 compatibility는 아래 문서를 참고해 주세요.
https://lightning.ai/docs/pytorch/stable/versioning.html
Model Compiler
학습된 torch module을 불러와 Tensorrt로 변환화는 과정입니다. torch, tensorrt, torch-tensorrt, cuda driver 등 버전 간 호환이 되지 않아 제대로 동작하지 않고, 하위 버전의 컨테이너에선 torch_tensorrt github main branch 기준 제공하는 torch_tensnsorrt.save api가 지원되지 않는 등 문서가 친절하지 않아 PoC 과정 중에서 가장 오랜 시간 삽질을 하였습니다.
여기선 torch_tensorrt.compile api 사용법에 대한 소개만 하고 자세한 디버깅 과정에 대해선 아래에서 다루도록 하겠습니다.
모델로는 팀에서 사용 중인 Bert4rec을 사용하였고, 컴파일하기에 앞서 forward method를 추론 환경에 적합하게 수정하였습니다.
class CustomBert4rec(bert4rec.Bert4Rec):
def forward(
self,
inputs: torch.Tensor,
target_idx: torch.Tensor | None = None,
targets: torch.Tensor | None = None,
):
batch_size = inputs.size(0)
seq_len = inputs.size(1)
pos_ids = torch.arange(seq_len, dtype=torch.long, device=self.device)
pos_ids = pos_ids.unsqueeze(0).expand(batch_size, -1)
output = self._bert(
inputs,
labels=targets, # only used for training and validation steps
position_ids=pos_ids,
attention_mask=inputs != self._pad_token_id,
)
return output.logits, output.loss
원래는 위와 같이 HuggingFace의 BertForMaskedLM 모델을 사용하여 계산된 output인 logit과 loss를 곧바로 return 하는 형식이며 이와 같은 형태가 일반적입니다. 하지만 output.loss의 경우 evalutaion 단계에선 None이며 이 경우 torch_tensorrt AOT Compiler에서 에러가 발생합니다.
또한 output.logits의 차원은 학습한 item diemension size와도 같은데요, 이 경우 매 학습마다 dimension 이 달라질 수 있으므로 compiler에게 static 한 정보를 줄 수가 없는 문제가 있습니다(dynamic input shape 사용법은 아래에서 다루겠습니다).
이뿐만 아니라 실제 유즈 케이스에선 모든 item을 추천하는 것이 아닌, score 기준 top_k item을 줄 것이기에 과도한 메모리 사용을 줄이고자 하였습니다. top_k 로직은 postprocessing으로 분리할 수 있지만 그렇게 되면 Triton Python backend를 사용하게 되어 성능저하 우려가 되기도 하여 결국 forward 내부로 넣었습니다. 이로 인한 성능 결과와 고민에 대해선 아래에서 다루겠습니다.
*테스트용도로 top_k는 10으로 설정.
output = self._bert(
inputs,
labels=targets, # only used for training and validation steps
position_ids=pos_ids,
attention_mask=inputs != self._pad_token_id,
)
# always assume target_idx is not None which means inference mode.
gather_index = target_idx.view(-1, 1, 1).expand(-1, -1, output.logits.shape[-1])
# replace top_k logic with fixed slicing size 10.
logits = output.logits.gather(dim=1, index=gather_index).squeeze(1)[:, :-2]
top_k = 10
topk_score, topk_idx = torch.topk(logits, top_k)
return topk_idx, topk_score
model load -> compile -> save 하는 최종코드는 아래와 같게 됩니다.
# 1. load model
model = model_utils.load_model(
_MDOEL_CKPT_PATH, model_klass=CustomBert4rec, generator_klass=item2user.ExclusionGenerator, device=_DEVICE
)
model.eval()
# 2. Compile with Torch TensorRT;
inputs = [
torch_tensorrt.Input(shape=[1, _INPUT_MAX_SEQ_LEN], dtype=torch.int64),
torch_tensorrt.Input(shape=[1, 1], dtype=torch.int64),
]
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
enabled_precisions={torch.half, torch.float32},
workspace_size=2000000000,
truncate_long_and_double=True,
)
inputs = [torch.rand(1, _INPUT_MAX_SEQ_LEN).long().cuda(), torch.rand(1, 1).long().cuda()]
# 3. Save the model
torch_tensorrt.save(
trt_model,
f"{_MODEL_REPO_PATH}/{_MODEL_NAME}/{_MODEL_VERSION}/model.pt",
output_format="torchscript",
inputs=inputs,
)
Input
inputs = [
torch_tensorrt.Input(shape=[1, _INPUT_MAX_SEQ_LEN], dtype=torch.int64),
torch_tensorrt.Input(shape=[1, 1], dtype=torch.int64),
]
forward 메서드는 evaluation mode에선 inputs(item sequence입니다), target_idx 두 개의 Input을 받기 때문에 위와 같이 정의하였습니다. shape의 첫 dim은 batch_size로 triton dynamic batching을 사용하지 않는 단일 request 추론(with max_batch_size=1)을 가정하였습니다.
dynamic shape input을 사용하려면 아래와 같이 정의하면 됩니다.
dynamic_inputs = [
torch_tensorrt.Input(
name="inputs",
min_shape=[1, 40],
opt_shape=[32,40],
max_shape=[64,40],
dtype=torch.int64,
),
torch_tensorrt.Input(
name="target_idx",
min_shape=[1, 1],
opt_shape=[32, 1],
max_shape=[64, 1],
dtype=torch.int64,
),
]
min(최소), opt(최적), max(최대) 세 가지 정보를 미리 알려주면 되고, AOT는 opt_shape에 가장 최적화되게 컴파일하게 됩니다. 그 외는 동일합니다.
IR (Intermediate Representation)
ir=dynamo를 지정하였습니다. IR은 중간표현으로 하드웨어 최적화 단계에 input이라 하드웨어 독립적입니다. ir=dynamo를 사용할 시 output은 GraphModule입니다.
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
enabled_precisions={torch.half, torch.float32},
workspace_size=2000000000,
truncate_long_and_double=True,
)
compile은 내부적으로 model tracing을 하는데 이때 모델의 forward 함수에서 사용되는 tensor 연산을 Graph 데이터 구조로 파싱하고 저장합니다. 이를 통해 연산 최적화 진입점을 제공하며, 이는 중간표현식(IR)이므로 다른 형식(예: onnx, tensorrt)으로 변환될 수 있습니다.
그 다음으로는 trace 결과로 파싱 한 오퍼레이션들을 하드웨어 최적화된 연산으로 변환하는 compile을 진행합니다. 이 단계는 대표적으로는 TensorRT 엔진에서 지원하는 오퍼레이션으로 변환하는 단계가 포함됩니다. 모든 연산이 진행되는 것은 아니며 직접 구현하여 converter mapper에 등록하는 방법도 있습니다. 자세한 가이드는 아래 문서를 참고해주세요.
https://github.com/pytorch/TensorRT/tree/main/py#registering-custom-converters
오퍼레이터 변환 시 실제로는 tensorrt engine과 통신할 수 있는 tensorrt executor 정보를 저장하며 실제 연산은 tensorrt engine에서 실행되어 jit interpreter로 반한되는 구조입니다.
graph(%self.1 : __torch__.___torch_mangle_10.LeNet_trt,
%2 : Tensor):
%1 : int = prim::Constant[value=94106001690080]()
%3 : Tensor = trt::execute_engine(%1, %2)
return (%3)
(AddEngineToGraph)
Tensorrt 연산 최적화 단계에서는 이 graph 정보를 사용하게 됩니다. Tensorrt에서 지원되는 연산은 JIT interpreter → Tensorrt engine으로 실행되며, 전환 불가능한 연산은 TorchScript로 fallback 하는 구조입니다.
*각 IR 마다 사용하는 trace, compile 구현이 다르며 각 구현별 동작에 대해 다음 포스팅에서 더 깊게 알아보겠습니다.
Output
torch_tensorrt.save(
trt_model,
f"{_MODEL_REPO_PATH}/{_MODEL_NAME}/{_MODEL_VERSION}/model.pt",
output_format="torchscript",
inputs=inputs,
)
compile output을 serialize 하여 파일로 저장하는 단계입니다.
The output type of ir=dynamo compilation of Torch-TensorRT is torch.fx.GraphModule object by default. We can save this object in either TorchScript (torch.jit.ScriptModule) or ExportedProgram (torch.export.ExportedProgram) formats by specifying the output_format flag. Here are the options output_format will accept - exported_program : This is the default. We perform transformations on the graphmodule first and use torch.export.save to save the module. - torchscript : We trace the graphmodule via torch.jit.trace and save it via torch.jit.save.
- 출처: https://pytorch.org/TensorRT/user_guide/saving_models.html
ScriptModule과 ExportedProgram은 serializable, optimizable 한 In-memory 표현입니다.
GraphModule은 ExportedProgram의 메타 클래스로 존재합니다. GraphModule은 특히 forward 함수의 연산에 대한 정보를 가지고 있고, ExportedProgram은 실행을 위한 모델의 메타 정보를 모두 가지고 있습니다.
ScriptModule도 자체적으로 graph 속성을 가지나, GraphModule과는 별개의 타입과 표현입니다.
Tritonserver
컴파일된 모델 준비가 완료되었으니 Tritonserver를 구동하여 모델을 서빙할 단계입니다.
아래 명령어와 도커 컨테이너로 간단하게 실행하며 앞선 compile 단계의 model_repository/ 디렉터리의 알맞은 경로에 model.pt 파일이 있으면 정상적으로 로드되게 됩니다.
$ cd ${WORKSPACE}
$ docker run --gpus=all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v $PWD:/workspace nvcr.io/nvidia/tritonserver:24.05-py3
$ cd /workspace
$ pip install fsspec==2023.5.0 # 필요한 dependency 추가 설치, 생략가능
$ tritonserver --model-repository=/model_repository
config.pbtxt
max_batch_size를 1로 이상으로 설정하게 되면 tritonserver 내부적으로 input을 넘겨줄 때 batch_dim이 없을 시 expand 해줍니다. 반대로 batch를 사용하지 않으면 batch_dim이 있으면 squeeze 해줍니다. 따라서 아래에서 다시 다루겠지만 이를 고려하여 preprocess, postprocess 단계를 구현해야 합니다.
name: "bert4rec"
platform: "pytorch_libtorch"
max_batch_size : 1
input [
{
name: "inputs"
data_type: TYPE_INT64
dims: [ 40 ]
},
{
name: "target_idx"
data_type: TYPE_INT64
dims: [ 1 ]
}
]
output [
{
name: "output__0"
data_type: TYPE_INT64
dims: [ 10 ]
},
{
name: "output__1"
data_type: TYPE_FP32
dims: [ 10 ]
}
]
instance_group [
{
count: 2
kind: KIND_GPU
}
]
version_policy: { specific: { versions: [ 1 ]}}
Model Ensemble
위 단일 모델은 순수 추론만 수행하지만, 실제로 모델을 서빙하게 되면 Input, output을 그대로 사용할 수 없는 경우가 대부분입니다. 예를 들어 item 학습 시 고유식별자를 각 index로 매핑하는 인덱싱 단계가 필요하며 약간의 전처리가 필요할 수 있습니다. 현재 팀 추천 파이프라인에서는 배치 학습 시 Dataprocess layer가 해당 역할을 하는데요, 마찬가지로 preprocess, postprocess 단계를 하나의 모델로서 정의할 수 있습니다
그리고 preprocess -> inference -> postprocess 과정을 하나의 ensemble 모델로서 추상화하여 client에 제공할 수 있습니다. 아래는 bert4rec_ensemble config.pbtxt 파일 내용입니다. client는 input spec에 맞게 요청하면 output을 받게 되며 그 사이 step 들에 대해선 알지 못합니다.
내부적으로는 각 backend framework 마다 scheduler 및 queue가 있고 C API core backend에서 받은 요청을 queue에 넣고 scheduler가 이를 backend에 요청하여 응답을 반환하는 구조입니다.
아래는 ensemble 모델을 위한 설정 파일입니다. input_map, output_map을 잘 정의해야 하며 preprocess, postprocess에서는 이에 맞게 input을 파싱 하여 output을 return 하는 model.py를 구현하면 됩니다.
name: "bert4rec_ensemble"
platform: "ensemble"
input [
{
name: "inputs"
data_type: TYPE_INT64
dims: [ 39 ]
}
]
output [
{
name: "output__0"
data_type: TYPE_INT64
dims: [ 10 ]
},
{
name: "output__1"
data_type: TYPE_FP32
dims: [ 10 ]
}
]
ensemble_scheduling {
step [
{
model_name: "bert4rec_preprocessing"
model_version: -1
input_map {
key: "bert4rec_preprocessing_input__0"
value: "inputs"
}
output_map {
key: "bert4rec_preprocessing_output__0"
value: "processed_inputs"
}
output_map {
key: "bert4rec_preprocessing_output__1"
value: "processed_target_idx"
}
},
{
model_name: "bert4rec"
model_version: -1
input_map {
key: "inputs"
value: "processed_inputs"
}
input_map {
key: "target_idx"
value: "processed_target_idx"
}
output_map {
key: "output__0"
value: "item_idxes"
}
output_map {
key: "output__1"
value: "item_scores"
}
},
{
model_name: "bert4rec_postprocessing"
model_version: -1
input_map {
key: "bert4rec_postprocessing_input__0"
value: "item_idxes"
}
input_map {
key: "bert4rec_postprocessing_input__1"
value: "item_scores"
}
output_map {
key: "bert4rec_postprocessing_output__0"
value: "output__0"
}
output_map {
key: "bert4rec_postprocessing_output__1"
value: "output__1"
}
}
]
}
preprocess
name: "bert4rec_preprocessing"
backend: "python"
input [
{
name: "bert4rec_preprocessing_input__0"
data_type: TYPE_INT64
dims: [ 39 ]
}
]
output [
{
name: "bert4rec_preprocessing_output__0"
data_type: TYPE_INT64
dims: [ 40 ]
},
{
name: "bert4rec_preprocessing_output__1"
data_type: TYPE_INT64
dims: [ 1 ]
}
]
instance_group [{ kind: KIND_CPU }]
preprocess, postprocess 모델은 python backend에서 실행됩니다.
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to initialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
model_config = json.loads(args["model_config"])
# Get OUTPUT0 configuration
output0_config = pb_utils.get_output_config_by_name(model_config, "bert4rec_preprocessing_output__0")
# Convert Triton types to numpy types
self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
# pylint:disable=line-too-long
# TODO(hyesung): load from s3 in production env. we just hard code meta for test.
self.meta = meta_utils.load_meta(_META_PATH)
self.id_to_idx = self.meta.get_id_to_idx("item_id")
self.item_size = len(self.id_to_idx)
self.masked_token = self.item_size + 1
model.py에서는 `TritonPythonModel` 이름의 class를 정의해야 합니다.
import triton_python_backend_utils as pb_utils
는 triton python runtime에서 사용할 수 있는 유틸리티입니다. 사용법에 대한 가이드를 찾지 못하여 github repo를 참고했습니다. https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py
크게 initialize, execute, finalize 3 개의 메서드를 구현하면 되며, initialize와 finalize는 optional입니다.
def execute(self, requests):
"""`execute` MUST be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference request is made
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
responses = []
# Every Python backend must iterate over everyone of the requests
# and create a pb_utils.InferenceResponse for each of them.
for request in requests:
item_ids = pb_utils.get_input_tensor_by_name(request, "bert4rec_preprocessing_input__0").as_numpy()
item_idxes = [self.id_to_idx[str(id_)] for id_ in item_ids]
# add mask token at last which would be target index.
input_sequence = np.array(item_idxes + [self.masked_token]).astype(self.output0_dtype)
input_sequence_tensor = pb_utils.Tensor(
"bert4rec_preprocessing_output__0",
input_sequence,
)
# target index is last index of input sequence.
target_idx = np.array([len(input_sequence) - 1])
target_idx_tensor = pb_utils.Tensor(
"bert4rec_preprocessing_output__1",
target_idx,
)
# Create InferenceResponse. You can set an error here in case
# there was a problem with handling this inference request.
# Below is an example of how you can set errors in inference
# response:
#
# pb_utils.InferenceResponse(
# output_tensors=..., TritonError("An error occurred"))
inference_response = pb_utils.InferenceResponse(output_tensors=[input_sequence_tensor, target_idx_tensor])
responses.append(inference_response)
# You should return a list of pb_utils.InferenceResponse. Length
# of this list must match the length of `requests` list.
return responses
bert4rec_ensemble/1/config.pbtxt 의 ensemble_scheduling.step 에 정의한 Input_map, output_map에 맞게 input을 파싱 하여 output을 return 해주는 것 외에 특별할 것은 없습니다.
postprocess
name: "bert4rec_postprocessing"
backend: "python"
input [
{
name: "bert4rec_postprocessing_input__0"
data_type: TYPE_INT64
dims: [ 10 ]
},
{
name: "bert4rec_postprocessing_input__1"
data_type: TYPE_FP32
dims: [ 10 ]
}
]
output [
{
name: "bert4rec_postprocessing_output__0"
data_type: TYPE_INT64
dims: [ 10 ]
},
{
name: "bert4rec_postprocessing_output__1"
data_type: TYPE_FP32
dims: [ 10 ]
}
]
instance_group [{ kind: KIND_CPU }]
특별할 것은 없지만 한 가지 짚고 넘어갈 부분이 있는데요, bert4rec/1/config.pbtxt에서 max_batch_size는 1이고 실제 forward의 output은 [batch_dim, ..] 형태입니다.
def execute(self, requests):
responses = []
for request in requests:
item_idxes = pb_utils.get_input_tensor_by_name(request, "bert4rec_postprocessing_input__0").as_numpy()
item_ids = np.array([self.idx_to_id[idx] for idx in item_idxes]).astype(self.output0_dtype)
item_ids_tensor = pb_utils.Tensor(
"bert4rec_postprocessing_output__0",
item_ids,
)
item_scores = pb_utils.get_input_tensor_by_name(
request,
"bert4rec_postprocessing_input__1",
)
item_scores_tensor = pb_utils.Tensor(
"bert4rec_postprocessing_output__1",
item_scores.as_numpy(),
)
inference_response = pb_utils.InferenceResponse(output_tensors=[item_ids_tensor, item_scores_tensor])
responses.append(inference_response)
return responses
하지만 bert4rec_postproecssing/1/config.pbtxt 에선 max_batch_size를 지정하지 않았고 shape 은 [ 10 ] 으로 지정하였습니다. 따라서 bert4rec_postprocessing_input__0 의 shape은 [batch_dim, topk] 가 아니라 [ topk ] (여기선 10으로 하드코딩 하였습니다) 가 되게 됩니다.
Triton Client
$ cd ${WORKSPACE}
$ docker run --gpus=all --rm -it -v $PWD:/workspace --net=host --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:24.05-py3
$ cd /workspace
$ pip install tritonclient[all]
$ python client.py
최종적으로 우리가 호출할 것은 model_name = "bert4rec_ensemble"이 되게 되며 item의 고유식별자 기반 sequence를 input으로 요청하면
import numpy as np
import tritonclient.http as httpclient
def main():
model_name = "bert4rec_ensemble"
client = httpclient.InferenceServerClient(url="localhost:8000", verbose=True)
input_inputs = httpclient.InferInput("inputs", [39], "INT64")
input_inputs.set_data_from_numpy(
input_tensor=np.array(
[
1,
2,
1,
3,
4,
5,
10,
6,
11,
7,
8,
9,
12,
1,
2,
5,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
],
dtype=np.int64,
),
binary_data=True,
)
inputs = [
input_inputs,
]
outputs = [httpclient.InferRequestedOutput("output__0"), httpclient.InferRequestedOutput("output__1")]
response = client.infer(model_name, inputs, outputs=outputs)
print(response.as_numpy("output__0"))
print(response.as_numpy("output__1"))
아래와 같이 순서대로 score, item_id 가 출력되는 것을 확인할 수 있습니다.
Performance Analyzer
Triton에서는 배포된 모델의 최적의 성능을 찾기 위한 도구를 제공하는데요,
이를 이용하여 RPS와 Latency를 확인해 보았습니다. 모델 구성의 변주로는 instance_group 정도를 사용하였지만, single gpu 환경이라 성능 지표가 instance 수에 비례하여 개선되진 않았습니다.
$ cd ${WORKSPACE}
$ docker run --gpus all --rm -it --net host -v $PWD:/workspace nvcr.io/nvidia/tritonserver:24.05-py3
$ cd /workspace
$ pip install tritonclient[perf_analyzer]
$ perf_analyzer -m bert4rec -u 127.0.0.1:8001 -i grpc --input-data ./test_data/data.json --request-rate-range 1500:3000:500 -f perf_analyzer_result.csv
자세한 사용법 가이드는 위 문서를 첨부하였으니 참고하시면 될 거 같습니다.
single gpu and max_batch_size=1 (without dynamic batch)
- compute instance: g5.xlarge instance 기준
- model instance: 1
- max_batch_size: 1
대략 throughput 1000/sec, latency 4000 usec 정도의 성능을 보여주는 것을 확인할 수 있습니다.
model ensemble의 각 단계별 latency는 아래와 같이 출력됩니다.
instance_group size를 1 -> 2로 변경하였을 시에는 동일 gpu에 대한 경합으로 인한 overhead 때문인지 throughput peak 지점이 내려간 것을 확인할 수 있었습니다.
$ watch -n0.1 nvidia-smi
아쉬웠던 것은 gpu utilization이 다소 낮았는데요, memory 또한 많이 남아서 전체적으로 인스턴스 자원을 제대로 활용하지 못하는 모습이었습니다. 따라서 실제 운영환경에서는 gpu utilization 대비 메모리 사용율을 최대한 맞추어 운영할 수 있도록 최적화 작업이 필요할 거 같습니다.
이상으로 PoC의 처음부터 끝까지 완주하였는데요, 아래 PoC 과정 중에 겪은 trouble shooting 과정을 첨부하며 이만 마치도록 하겠습니다.
감사합니다 : ).
TroubleShooting
pip install timeout
NGC container로 PoC를 하는 과정에서 필요한 의존성을 추가로 설치해 줄 일이 조금 있었는데, 설치 추가를 반복하다 겪은 이슈입니다. pip install 시에 extra index url로 설정되어 있는 https://pypi.ngc.nvidia.com 주소에서 timeout이 나서 이를 제거해 주면 해결됩니다.
아래 모든 파일에서 extra-index-url을 제거해야 합니다.
# [Priority 1] Site level configuration files
# 1. `/usr/pip.conf`
#
# [Priority 2] User level configuration files
# 1. `/root/.config/pip/pip.conf`
# 2. `/root/.pip/pip.conf`
#
# [Priority 3] Global level configuration files
# 1. `/etc/pip.conf`
# 2. `/etc/xdg/pip/pip.conf`
$ pip config list
:env:.default-timeout='100'
global.index-url='https://pypi.org/simple'
global.no-cache-dir='true'
pip install no space left on device
cache dir을 다른 곳으로 잡아주면 됩니다.
$ pip install --cache-dir=/var/tmp torch-tensorrt==2.2.0
torch.jit.save error
# Save the model
torch.jit.save(trt_model, f"{_TEST_MODEL_REPO}/bert4rec/1/model.pt") # <- AttributeError: 'GraphModule' object has no attribute 'save with ir='dynamo'
torch_tensorrt.save(trt_model, "trt.ts", output_format="torchscript", inputs=inputs) <- AttributeError: module 'torch_tensorrt' has no attribute 'save' in current torch_tensorrt version.
torch_tensorrt 버전을 2.2.x 에서 발생하는 문제인데, torch_tensorrt, tensorrt, torch, torchvision, cuda driver 간에 버전 및 설정이 하나라도 틀어지면 실행이 되지 않는 문제를 겪을 수 있습니다.
ImportError: /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so: cannot read file data
이는 가장 최신 NGC container image(tag 24.05)를 사용하면 해결됩니다.
2.2.x 아래 버전대에서는 torch_tensorrt.compile 함수를 아래와 같이 호출할 시에
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
enabled_precisions={torch.half, torch.float32},
workspace_size=2000000000,
truncate_long_and_double=True,
output_format="torchscript",
)
output_format을 torchscript로 설정하면 output type이 ScriptModule (torch.jit.trace 결과)이므로 torch.jit.save를 사용할 수 있습니다.
하지만 2.3 이상에서는 ir="dynamo"일 경우, output_format은 항상 GraphModule이므로 torch.jit.trace(output) 단계가 추가적으로 필요합니다.
24.05 container image를 사용할 경우에는 torch_tensorrt.save api를 사용할 수 있으므로 아래와 같이 사용할 수 있습니다.
torch_tensorrt.save(
trt_model,
f"{_MODEL_REPO_PATH}/{_MODEL_NAME}/{_MODEL_VERSION}/model.pt",
output_format="torchscript",
inputs=inputs,
)
docker pull image no space left on device
Sagemaker jupyerlab instance에 docker image를 버전별로 설치해 보며 디버깅하다 겪은 이슈입니다.
안 쓰는 이미지 삭제해 줍니다.
경로: /var/lib/docker
$ docker images
$ docker rmi {{ image_id }}
tritonserver model load시 file extension 제약
.pth 도 되는지 모르겠으나 ScriptModule이라 해서 .ts 확장자로 저장하면 인식이 되지 않고 .pt 확장자로 저장해야 모델 로딩이 정상 동작하였습니다.
ExportedProgram은 save시 input에 int64를 지원하지 않는 문제
torch_tensorrt.compile의 출력인 GraphModule은 재추적(retrace) 또는 건너뛰기(skip) 단계를 거쳐 ExportedProgram 또는 ScriptModule로 변환 가능합니다. 하지만 이때 재추적이 필요하여 입력 텐서를 제공해야 하는데, ExportedProgram의 경우 입력 텐서로 int64를 지원하지 않는 문제가 있습니다.
Dynamic Shape Input
torch.expand operation이 제대로 컴파일 안 되는 문제가 있습니다.
이는 torch_tensorrt 구현에서 dynamic shape을 expand operator 매핑 시 호환이 안 되는 문제인 것으로 추측되는데요, 해당 구현부 코드를 탐독하고 수정 PR을 제출해 볼 계획입니다.
'Data Engineering' 카테고리의 다른 글
Hive, RDMBS, Hbase, HDFS 개념잡기 (0) | 2020.10.19 |
---|---|
빅데이터 플랫폼 Pilot 프로젝트 04 feat. Cloudera Data Platform (0) | 2020.08.31 |
빅데이터 플랫폼 Pilot 프로젝트 03 feat. Cloudera Data Platform (0) | 2020.08.31 |
빅데이터 플랫폼 Pilot 프로젝트 02 feat. Cloudera Data Platform (0) | 2020.08.31 |
빅데이터 플랫폼 Pilot 프로젝트 01 feat. Cloudera Data Platform (0) | 2020.08.31 |