Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions yolov8/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ else()
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/home/lindsay/TensorRT-8.6.1.6/include)
link_directories(/home/lindsay/TensorRT-8.6.1.6/lib)
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
# include_directories(/home/lindsay/TensorRT-8.6.1.6/include)
# link_directories(/home/lindsay/TensorRT-8.6.1.6/lib)
include_directories(/usr/include/TensorRT10)
link_directories(/usr/lib/TensorRT10)


endif()
Expand Down
126 changes: 126 additions & 0 deletions yolov8/GPU_POSTPROCESSING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# GPU-Accelerated Segmentation Postprocessing for YOLOv8-Seg

This contribution adds **fully GPU-accelerated instance segmentation postprocessing** to tensorrtx/yolov8, achieving a **100x speedup** on the drawing kernel (from ~5ms to ~0.05ms).

## Performance Problem

The original implementation had O(num_pixels × kMaxNumOutputBbox) complexity in the drawing kernel:

```cpp
// BEFORE: Iterates kMaxNumOutputBbox times per pixel
for (int i = 0; i < kMaxNumOutputBbox; ++i) {
if (nms_output[...].keep_flag == 1) {
// Check if pixel is inside bbox and mask
}
}
```

When `kMaxNumOutputBbox` was increased from 1000 to 8500 (required for standard YOLOv8 models), drawing time increased from 1-2ms to 5ms.

## Solution: O(num_dets) Optimization

The fix uses a two-phase approach:

1. **Gather Phase**: Extract only kept detections into a dense buffer
2. **Draw Phase**: Iterate only over actual detections (typically 5-50, not 8500)

```cpp
// AFTER: Iterates only num_dets times per pixel
for (int i = 0; i < num_dets; ++i) {
// dense_bboxes contains only kept detections
float left = dense_bboxes[i*4 + 0];
// ...
}
```

## Files Modified

### `include/config.h`
- `kMaxNumOutputBbox`: 1000 → **8500** (fixes crash with standard YOLOv8 models)
- Added `kProtoTensorName = "proto"` for TensorRT 10 compatibility

### `include/postprocess.h`
Added GPU segmentation function declarations:
```cpp
void cuda_compact_and_gather_masks(...);
void cuda_gather_kept_bboxes(...);
void cuda_process_mask(...); // With strict bbox clipping
void cuda_blur_masks(...);
void cuda_draw_results(...); // Optimized O(num_dets) version
```

### `src/postprocess.cu`
New CUDA kernels:
- `compact_and_gather_masks_kernel`: Extracts kept mask coefficients
- `gather_kept_bboxes_kernel`: Builds dense bbox buffer
- `process_mask_kernel`: Bilinear interpolation + strict bbox clipping
- `box_blur_horizontal/vertical`: Mask smoothing
- `draw_results_on_image_kernel`: **O(num_dets) optimized drawing**

### `yolov8_seg.cpp`
- Updated `prepare_buffer()`: Allocates GPU buffers for mask processing
- Updated `infer()`: Full GPU postprocessing pipeline
- TensorRT 10 API: Uses `setInputTensorAddress`/`setOutputTensorAddress`/`enqueueV3`

## TensorRT Version Compatibility

> **IMPORTANT**: This implementation was developed and tested with **TensorRT 10.x**.

The inference code uses TensorRT 10 APIs:
- `context.setInputTensorAddress()` / `context.setOutputTensorAddress()`
- `context.enqueueV3(stream)`
- `engine->getTensorShape()`

For **TensorRT 8.x** compatibility, use the original binding-based APIs:
```cpp
// TensorRT 8.x
context.enqueue(batchsize, buffers, stream, nullptr);
auto out_dims = engine->getBindingDimensions(1);
```

> **Note**: The upstream `model.cpp` in tensorrtx requires additional modifications for TensorRT 10 compatibility (setStrideNd, setPaddingNd, etc.). This contribution focuses on the postprocessing optimization.

## Performance Results

| Metric | Before | After | Speedup |
|--------|--------|-------|---------|
| Drawing Kernel | ~5.0 ms | ~0.05 ms | **100x** |
| Total GPU Postprocess | ~6.5 ms | ~1.5 ms | **4x** |

Tested on:
- GPU: NVIDIA RTX 3080 Ti (CUDA 12.6)
- TensorRT: 10.x
- Model: YOLOv8m-seg (640x640)



## Usage

```cpp
// After NMS
cuda_compact_and_gather_masks(decode_ptr_device, final_count_device,
compacted_masks_device, mask_mapping_device, kMaxNumOutputBbox, stream);

int num_dets = 0;
cudaMemcpy(&num_dets, final_count_device, sizeof(int), cudaMemcpyDeviceToHost);

if (num_dets > 0) {
cuda_gather_kept_bboxes(decode_ptr_device, mask_mapping_device,
dense_bboxes_device, kMaxNumOutputBbox, stream);

cuda_process_mask(proto_device, compacted_masks_device, dense_bboxes_device,
final_masks_device, num_dets, 160, 160, kInputH, kInputW, stream);

cuda_blur_masks(final_masks_device, num_dets, kInputH, kInputW, stream);

cuda_draw_results(image_buffer, final_masks_device, decode_ptr_device,
mask_mapping_device, num_dets, mask_mode, mask_thresh, stream);
}
```

## Key Design Decisions

1. **Dense bbox buffer**: Avoids sparse iteration over kMaxNumOutputBbox
2. **Strict bbox clipping**: Masks are zeroed outside their bounding box during generation
3. **Bilinear interpolation**: High-quality mask upscaling (mimics `retina_masks=True`)
4. **In-place drawing**: Modifies input buffer directly, avoiding extra copies
5 changes: 3 additions & 2 deletions yolov8/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
//#define USE_INT8

const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static char* kOutputTensorName = "output0";
const static char* kProtoTensorName = "proto";
const static int kNumClass = 80;
const static int kBatchSize = 1;
const static int kGpuId = 0;
Expand All @@ -13,7 +14,7 @@ const static float kNmsThresh = 0.45f;
const static float kConfThresh = 0.5f;
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence
const static int kMaxInputImageSize = 3000 * 3000;
const static int kMaxNumOutputBbox = 1000;
const static int kMaxNumOutputBbox = 8500;
//Quantization input image folder path
const static char* kInputQuantizationFolder = "./coco_calib";

Expand Down
16 changes: 16 additions & 0 deletions yolov8/include/postprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ void draw_bbox_obb(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Dete
void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_mask_bbox(cv::Mat& img, std::vector<Detection>& dets, std::vector<cv::Mat>& masks,
std::unordered_map<int, std::string>& labels_map);

// CUDA Segmentation
void cuda_compact_and_gather_masks(const float* decode_ptr_device, int* final_count_device,
float* compacted_masks_device, int* mask_mapping_device, int max_objects,
cudaStream_t stream);

void cuda_gather_kept_bboxes(const float* nms_output, const int* mask_mapping, float* dense_bboxes, int max_bboxes,
cudaStream_t stream);

void cuda_process_mask(const float* proto, const float* masks_in, const float* bboxes_in, float* masks_out,
int num_dets, int proto_h, int proto_w, int out_h, int out_w, cudaStream_t stream);

void cuda_blur_masks(float* masks_out, int num_dets, int h, int w, cudaStream_t stream);

void cuda_draw_results(float* image_buffer, const float* final_masks, const float* nms_output, const int* mask_mapping,
int num_dets, int mask_mode, float mask_thresh, cudaStream_t stream);
Loading