From d6f7b3f4348cc897e038a8a01d9c90e73020c064 Mon Sep 17 00:00:00 2001 From: Carlos Guzman <42706936+guzman109@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:42:38 -0500 Subject: [PATCH 1/4] Adding files to build new container --- .../image_scoring_plugin/Dockerfile-3.8 | 4 + .../image_scoring_plugin/entrypoint.sh | 2 + .../image_scoring_plugin/pytorch_detector.py | 233 ++++++++++++++++++ 3 files changed, 239 insertions(+) create mode 100644 external_plugins/image_scoring_plugin/pytorch_detector.py diff --git a/external_plugins/image_scoring_plugin/Dockerfile-3.8 b/external_plugins/image_scoring_plugin/Dockerfile-3.8 index 0012ed5..d82b73c 100644 --- a/external_plugins/image_scoring_plugin/Dockerfile-3.8 +++ b/external_plugins/image_scoring_plugin/Dockerfile-3.8 @@ -47,4 +47,8 @@ ADD run_detector_multi.py /run_detector_multi.py RUN chmod +x /entrypoint.sh +# Adding for Triton E2E Example +RUN pip install "tritonclient[all]" +ADD pytorch_detector.py /camera_traps_MD/pytorch_detector.py + ENTRYPOINT ["./entrypoint.sh"] diff --git a/external_plugins/image_scoring_plugin/entrypoint.sh b/external_plugins/image_scoring_plugin/entrypoint.sh index 04c2180..275c1b1 100644 --- a/external_plugins/image_scoring_plugin/entrypoint.sh +++ b/external_plugins/image_scoring_plugin/entrypoint.sh @@ -1,3 +1,5 @@ #!/bin/bash +# Sleep while triton loads. TODO: Will replace with a better method. +sleep 20 python -u image_scoring_plugin.py diff --git a/external_plugins/image_scoring_plugin/pytorch_detector.py b/external_plugins/image_scoring_plugin/pytorch_detector.py new file mode 100644 index 0000000..10f5598 --- /dev/null +++ b/external_plugins/image_scoring_plugin/pytorch_detector.py @@ -0,0 +1,233 @@ +""" +Module to run MegaDetector v5, a PyTorch YOLOv5 (Ultralytics) animal detection model, +on images. +""" + +#%% Imports + +import torch +import numpy as np +from run_detector import CONF_DIGITS, COORD_DIGITS +import ct_utils +try: + # import pre- and post-processing functions from the YOLOv5 repo https://github.com/ultralytics/yolov5 + from utils.general import non_max_suppression, xyxy2xywh + from utils.augmentations import letterbox + + # scale_coords() became scale_boxes() in later YOLOv5 versions + try: + from utils.general import scale_coords + except ImportError: + from utils.general import scale_boxes as scale_coords +except ModuleNotFoundError: + raise ModuleNotFoundError('Could not import YOLOv5 functions.') + +print(f'Using PyTorch version {torch.__version__}') + +# THESE WERE ADD TO WORK WITH TRITON +import tritonclient.http as httpclient +import torchvision + +#%% Classes + +class PTDetector: + + IMAGE_SIZE = 1280 + STRIDE = 64 + + def __init__(self, model_path: str, force_cpu: bool = False): + self.device = 'cpu' + if not force_cpu: + if torch.cuda.is_available(): + self.device = torch.device('cuda:0') + try: + if torch.backends.mps.is_built and torch.backends.mps.is_available(): + self.device = 'mps' + except AttributeError: + pass + + # OMITTING IN ORDER FOR TRITON EXAMPLE TO WORK + # self.model = PTDetector._load_model(model_path, self.device) + # if (self.device != 'cpu'): + # print('Sending model to GPU') + # self.model.to(self.device) + + self.printed_image_size_warning = False + + @staticmethod + def _load_model(model_pt_path, device): + checkpoint = torch.load(model_pt_path, map_location=device) + for m in checkpoint['model'].modules(): + if type(m) is torch.nn.Upsample: + m.recompute_scale_factor = None + torch.save(checkpoint, model_pt_path) + model = checkpoint['model'].float().fuse().eval() # FP32 model + return model + + def generate_detections_one_image(self, img_original, image_id, detection_threshold, image_size=None): + """Apply the detector to an image. + + Args: + img_original: the PIL Image object with EXIF rotation taken into account + image_id: a path to identify the image; will be in the "file" field of the output object + detection_threshold: confidence above which to include the detection proposal + + Returns: + A dict with the following fields, see the 'images' key in https://github.com/microsoft/CameraTraps/tree/master/api/batch_processing#batch-processing-api-output-format + - 'file' (always present) + - 'max_detection_conf' + - 'detections', which is a list of detection objects containing keys 'category', 'conf' and 'bbox' + - 'failure' + """ + + result = { + 'file': image_id + } + detections = [] + max_conf = 0.0 + + try: + + img_original = np.asarray(img_original) + + # padded resize + target_size = PTDetector.IMAGE_SIZE + + # Image size can be an int (which translates to a square target size) or (h,w) + if image_size is not None: + + assert isinstance(image_size,int) or (len(image_size)==2) + + if not self.printed_image_size_warning: + print('Warning: using user-supplied image size {}'.format(image_size)) + self.printed_image_size_warning = True + + target_size = image_size + + else: + + self.printed_image_size_warning = False + + # ...if the caller has specified an image size + + img = letterbox(img_original, new_shape=target_size, stride=PTDetector.STRIDE, auto=True)[0] # JIT requires auto=False + + img = img.transpose((2, 0, 1)) # HWC to CHW; PIL Image is RGB already + img = np.ascontiguousarray(img) + img = torch.from_numpy(img) + img = img.to(self.device) + img = img.float() + img /= 255 + + if len(img.shape) == 3: # always true for now, TODO add inference using larger batch size + img = torch.unsqueeze(img, 0) + + # OMITTING THIS LINE WHICH WILL BE REPLACED BY TRITON CODE. + # pred: list = self.model(img)[0] + + + #-------------------------------------------------------------Triton Client---------------------------------------------------------------------# + # Establish connection to Triton + client = httpclient.InferenceServerClient(url="triton:8000") + + # Get input ready, here we are going to resize the image to dimensions (640x640). + # TensorRT version of yolov5 requires minimum batch size of 4, + # which is why the image is repeated 4 times in the batch. + # Update: Turns out yolov5 does that by default in the background + # when running without triton. + img = torchvision.transforms.Resize((640,640))(img).repeat(4,1,1,1) + + # Infer input types from Triton and add the image data to the request. + input_tensor = [httpclient.InferInput("images", + img.cpu().numpy().shape, + datatype="FP32" + )] + input_tensor[0].set_data_from_numpy(img.cpu().numpy()) + + # Set Outputs data type. + output_tensor = [httpclient.InferRequestedOutput("output0", binary_data=False)] + + # Send image to Triton + resp = client.infer("yolov5", + model_version="1", + inputs=input_tensor, + outputs=output_tensor + ) + + # Retrieve the detection results from Triton's response + pred = torch.from_numpy(resp.as_numpy("output0")).to(self.device) + #-------------------------------------------------------------Triton Client---------------------------------------------------------------------# + + # CONTINUE WITH REGUALR EXECUTION OF MEGADETECTOR + + # NMS + if self.device == 'mps': + # Current v1.13.0.dev20220824 torchvision::nms is not current implemented for the MPS device + # Send pred back to cpu to fix + pred = non_max_suppression(prediction=pred.cpu(), conf_thres=detection_threshold) + else: + pred = non_max_suppression(prediction=pred, conf_thres=detection_threshold) + # format detections/bounding boxes + gn = torch.tensor(img_original.shape)[[1, 0, 1, 0]] # normalization gain whwh + # This is a loop over detection batches, which will always be length 1 in our case, + # since we're not doing batch inference. + for det in pred: + + if len(det): + + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round() + + for *xyxy, conf, cls in reversed(det): + + # normalized center-x, center-y, width and height + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + + api_box = ct_utils.convert_yolo_to_xywh(xywh) + + conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS) + + # MegaDetector output format's categories start at 1, but this model's start at 0 + cls = int(cls.tolist()) + 1 + if cls not in range(1, 24): + raise KeyError(f'{cls} is not a valid class.') + + detections.append({ + 'category': str(cls), + 'conf': conf, + 'bbox': ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS) + }) + max_conf = max(max_conf, conf) + + # ...for each detection in this batch + + # ...if this is a non-empty batch + + # ...for each detection batch + + # ...try + + except Exception as e: + + result['failure'] = FAILURE_INFER + print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e))) + traceback.print_exc(e) + + result['max_detection_conf'] = max_conf + result['detections'] = detections + + return result + + +if __name__ == '__main__': + # for testing + + import visualization_utils as viz_utils + + model_file = "" + im_file = "test_images/test_images/island_conservation_camera_traps_palau_cam10a_cam10a12122018_palau_cam10a12122018_20181108_174532_rcnx1035.jpg" + + detector = PTDetector(model_file) + image = viz_utils.load_image(im_file) + + res = detector.generate_detections_one_image(image, im_file, detection_threshold=0.00001) From 9536c0dd8d5f65c43687169895954339604a3955 Mon Sep 17 00:00:00 2001 From: Carlos Guzman <42706936+guzman109@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:43:49 -0500 Subject: [PATCH 2/4] Added triton to compose file --- releases/0.3.3/docker-compose.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/releases/0.3.3/docker-compose.yml b/releases/0.3.3/docker-compose.yml index 8d087eb..4d26a9d 100644 --- a/releases/0.3.3/docker-compose.yml +++ b/releases/0.3.3/docker-compose.yml @@ -5,6 +5,29 @@ networks: driver: bridge services: + triton: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: + - gpu + stdin_open: true + tty: true + shm_size: 256m + networks: + - cameratraps + ports: + - 8000:8000 + - 8001:8001 + - 8002:8002 + volumes: + - /home/guzman.109/Documents/model-commons/model-repository:/models + container_name: triton + image: nvcr.io/nvidia/tritonserver:23.01-py3 + command: tritonserver --model-repository=/models # the name `engine` is important here; in general, the sevice name is addressable by other conatiners on the same # docker network. The default "hostname" used by the python plugin library (pyevents) is "engine" engine: From 45598560ef11807f85111dd4fcaf3b210496bce8 Mon Sep 17 00:00:00 2001 From: Carlos Guzman <42706936+guzman109@users.noreply.github.com> Date: Wed, 13 Dec 2023 09:25:58 -0500 Subject: [PATCH 3/4] Update pytorch_detector.py --- external_plugins/image_scoring_plugin/pytorch_detector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/external_plugins/image_scoring_plugin/pytorch_detector.py b/external_plugins/image_scoring_plugin/pytorch_detector.py index 10f5598..1518d83 100644 --- a/external_plugins/image_scoring_plugin/pytorch_detector.py +++ b/external_plugins/image_scoring_plugin/pytorch_detector.py @@ -133,8 +133,6 @@ def generate_detections_one_image(self, img_original, image_id, detection_thresh # Get input ready, here we are going to resize the image to dimensions (640x640). # TensorRT version of yolov5 requires minimum batch size of 4, # which is why the image is repeated 4 times in the batch. - # Update: Turns out yolov5 does that by default in the background - # when running without triton. img = torchvision.transforms.Resize((640,640))(img).repeat(4,1,1,1) # Infer input types from Triton and add the image data to the request. From ea461d6e53be0c1145e8c4f9bdc010e2d5fd95e9 Mon Sep 17 00:00:00 2001 From: guzman109 Date: Thu, 14 Dec 2023 15:16:32 -0500 Subject: [PATCH 4/4] Example works with Triton running outside of compose camera-traps network --- .gitignore | 5 ++- .../image_scoring_plugin/Dockerfile-3.8 | 4 +- .../image_scoring_plugin/entrypoint.sh | 2 - .../image_scoring_plugin/pytorch_detector.py | 41 ++++++++++--------- releases/0.3.3/docker-compose.yml | 23 ----------- 5 files changed, 29 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 8224638..44fcbcb 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,7 @@ src/python/dist # Nix result external_plugins/image_generating_plugin/result -external_plugins/image_scoring_plugin/result \ No newline at end of file +external_plugins/image_scoring_plugin/result + +*/.mypy_cache + diff --git a/external_plugins/image_scoring_plugin/Dockerfile-3.8 b/external_plugins/image_scoring_plugin/Dockerfile-3.8 index d82b73c..63e4d59 100644 --- a/external_plugins/image_scoring_plugin/Dockerfile-3.8 +++ b/external_plugins/image_scoring_plugin/Dockerfile-3.8 @@ -48,7 +48,9 @@ ADD run_detector_multi.py /run_detector_multi.py RUN chmod +x /entrypoint.sh # Adding for Triton E2E Example -RUN pip install "tritonclient[all]" +# Install http only version of Triton Client +RUN pip install "tritonclient[http]" +# Replace modified MegaDetector file (Just has code to call Yolov5 on Triton). ADD pytorch_detector.py /camera_traps_MD/pytorch_detector.py ENTRYPOINT ["./entrypoint.sh"] diff --git a/external_plugins/image_scoring_plugin/entrypoint.sh b/external_plugins/image_scoring_plugin/entrypoint.sh index 275c1b1..04c2180 100644 --- a/external_plugins/image_scoring_plugin/entrypoint.sh +++ b/external_plugins/image_scoring_plugin/entrypoint.sh @@ -1,5 +1,3 @@ #!/bin/bash -# Sleep while triton loads. TODO: Will replace with a better method. -sleep 20 python -u image_scoring_plugin.py diff --git a/external_plugins/image_scoring_plugin/pytorch_detector.py b/external_plugins/image_scoring_plugin/pytorch_detector.py index 1518d83..a7ea7e3 100644 --- a/external_plugins/image_scoring_plugin/pytorch_detector.py +++ b/external_plugins/image_scoring_plugin/pytorch_detector.py @@ -87,31 +87,31 @@ def generate_detections_one_image(self, img_original, image_id, detection_thresh max_conf = 0.0 try: - + img_original = np.asarray(img_original) # padded resize target_size = PTDetector.IMAGE_SIZE - + # Image size can be an int (which translates to a square target size) or (h,w) if image_size is not None: - + assert isinstance(image_size,int) or (len(image_size)==2) - + if not self.printed_image_size_warning: print('Warning: using user-supplied image size {}'.format(image_size)) self.printed_image_size_warning = True - + target_size = image_size - + else: - + self.printed_image_size_warning = False - + # ...if the caller has specified an image size - + img = letterbox(img_original, new_shape=target_size, stride=PTDetector.STRIDE, auto=True)[0] # JIT requires auto=False - + img = img.transpose((2, 0, 1)) # HWC to CHW; PIL Image is RGB already img = np.ascontiguousarray(img) img = torch.from_numpy(img) @@ -124,14 +124,14 @@ def generate_detections_one_image(self, img_original, image_id, detection_thresh # OMITTING THIS LINE WHICH WILL BE REPLACED BY TRITON CODE. # pred: list = self.model(img)[0] - + #-------------------------------------------------------------Triton Client---------------------------------------------------------------------# # Establish connection to Triton - client = httpclient.InferenceServerClient(url="triton:8000") + client = httpclient.InferenceServerClient(url="172.17.0.1:8000") # Get input ready, here we are going to resize the image to dimensions (640x640). - # TensorRT version of yolov5 requires minimum batch size of 4, + # TensorRT version of yolov5 requires a minimum batch size of 4, # which is why the image is repeated 4 times in the batch. img = torchvision.transforms.Resize((640,640))(img).repeat(4,1,1,1) @@ -152,11 +152,14 @@ def generate_detections_one_image(self, img_original, image_id, detection_thresh outputs=output_tensor ) - # Retrieve the detection results from Triton's response - pred = torch.from_numpy(resp.as_numpy("output0")).to(self.device) + # Retrieve the detection results from Triton's response. + # Use the first tensor in the batch + # and reshape it to the proper dimensions (25500,8) -> (1,25500,8). + # PyTorch still expects the tensor to be the same device render the bounding box. + pred = torch.from_numpy(resp.as_numpy("output0")[0]).unsqueeze(0).to(self.device) #-------------------------------------------------------------Triton Client---------------------------------------------------------------------# - - # CONTINUE WITH REGUALR EXECUTION OF MEGADETECTOR + + # CONTINUE WITH REGUALR EXECUTION OF MegaDetector # NMS if self.device == 'mps': @@ -203,8 +206,8 @@ def generate_detections_one_image(self, img_original, image_id, detection_thresh # ...for each detection batch - # ...try - + # ...try + except Exception as e: result['failure'] = FAILURE_INFER diff --git a/releases/0.3.3/docker-compose.yml b/releases/0.3.3/docker-compose.yml index 4d26a9d..8d087eb 100644 --- a/releases/0.3.3/docker-compose.yml +++ b/releases/0.3.3/docker-compose.yml @@ -5,29 +5,6 @@ networks: driver: bridge services: - triton: - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: all - capabilities: - - gpu - stdin_open: true - tty: true - shm_size: 256m - networks: - - cameratraps - ports: - - 8000:8000 - - 8001:8001 - - 8002:8002 - volumes: - - /home/guzman.109/Documents/model-commons/model-repository:/models - container_name: triton - image: nvcr.io/nvidia/tritonserver:23.01-py3 - command: tritonserver --model-repository=/models # the name `engine` is important here; in general, the sevice name is addressable by other conatiners on the same # docker network. The default "hostname" used by the python plugin library (pyevents) is "engine" engine: