Skip to content

Commit 7bb1a67

Browse files
Merge pull request #981 from roboflow/feature/detectctions_filter_on_steroids
Add new UQL extension - picking up bounding boxes that are inside specific class
2 parents 8f8dabf + fd8f63f commit 7bb1a67

File tree

13 files changed

+386
-159
lines changed

13 files changed

+386
-159
lines changed

inference/core/workflows/core_steps/common/query_language/entities/operations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,28 @@ class DetectionsRename(OperationDefinition):
536536
)
537537

538538

539+
class PickDetectionsByParentClass(OperationDefinition):
540+
model_config = ConfigDict(
541+
json_schema_extra={
542+
"description": "Picks only those detections which are located inside "
543+
"parent detections of specific class",
544+
"compound": False,
545+
"input_kind": [
546+
OBJECT_DETECTION_PREDICTION_KIND,
547+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
548+
KEYPOINT_DETECTION_PREDICTION_KIND,
549+
],
550+
"output_kind": [
551+
OBJECT_DETECTION_PREDICTION_KIND,
552+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
553+
KEYPOINT_DETECTION_PREDICTION_KIND,
554+
],
555+
},
556+
)
557+
type: Literal["PickDetectionsByParentClass"]
558+
parent_class: str = Field(description="Class of parent detections")
559+
560+
539561
AllOperationsType = Annotated[
540562
Union[
541563
StringToLowerCase,
@@ -569,6 +591,7 @@ class DetectionsRename(OperationDefinition):
569591
ConvertImageToBase64,
570592
DetectionsToDictionary,
571593
ConvertDictionaryToJSON,
594+
PickDetectionsByParentClass,
572595
],
573596
Field(discriminator="type"),
574597
]

inference/core/workflows/core_steps/common/query_language/operations/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
extract_detections_property,
3030
filter_detections,
3131
offset_detections,
32+
pick_detections_by_parent_class,
3233
rename_detections,
3334
select_detections,
3435
shift_detections,
@@ -199,6 +200,7 @@ def build_detections_filter_operation(
199200
"ConvertImageToBase64": encode_image_to_base64,
200201
"DetectionsToDictionary": detections_to_dictionary,
201202
"ConvertDictionaryToJSON": dictionary_to_json,
203+
"PickDetectionsByParentClass": pick_detections_by_parent_class,
202204
}
203205

204206
REGISTERED_COMPOUND_OPERATIONS_BUILDERS = {

inference/core/workflows/core_steps/common/query_language/operations/detections/base.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,62 @@ def detections_to_dictionary(
351351
context=f"step_execution | roboflow_query_language_evaluation | {execution_context}",
352352
inner_error=error,
353353
)
354+
355+
356+
def pick_detections_by_parent_class(
357+
detections: Any,
358+
parent_class: str,
359+
execution_context: str,
360+
**kwargs,
361+
) -> sv.Detections:
362+
if not isinstance(detections, sv.Detections):
363+
value_as_str = safe_stringify(value=detections)
364+
raise InvalidInputTypeError(
365+
public_message=f"Executing pick_detections_by_parent_class(...) in context {execution_context}, "
366+
f"expected sv.Detections object as value, got {value_as_str} of type {type(detections)}",
367+
context=f"step_execution | roboflow_query_language_evaluation | {execution_context}",
368+
)
369+
try:
370+
return _pick_detections_by_parent_class(
371+
detections=detections, parent_class=parent_class
372+
)
373+
except Exception as error:
374+
raise OperationError(
375+
public_message=f"While Using operation pick_detections_by_parent_class(...) in context {execution_context} "
376+
f"encountered error: {error}",
377+
context=f"step_execution | roboflow_query_language_evaluation | {execution_context}",
378+
inner_error=error,
379+
)
380+
381+
382+
def _pick_detections_by_parent_class(
383+
detections: sv.Detections,
384+
parent_class: str,
385+
) -> sv.Detections:
386+
class_names = detections.data.get("class_name")
387+
if class_names is None or len(class_names) == 0:
388+
return sv.Detections.empty()
389+
if not isinstance(class_names, np.ndarray):
390+
class_names = np.array(class_names)
391+
parent_mask = class_names == parent_class
392+
parent_detections = detections[parent_mask]
393+
if len(parent_detections) == 0:
394+
return sv.Detections.empty()
395+
dependent_detections = detections[~parent_mask]
396+
dependent_detections_anchors = dependent_detections.get_anchors_coordinates(
397+
anchor=Position.CENTER
398+
)
399+
dependent_detections_to_keep = set()
400+
for detection_idx, anchor in enumerate(dependent_detections_anchors):
401+
for parent_detection_box in parent_detections.xyxy:
402+
if _is_point_within_box(point=anchor, box=parent_detection_box):
403+
dependent_detections_to_keep.add(detection_idx)
404+
continue
405+
detections_to_keep_list = sorted(list(dependent_detections_to_keep))
406+
return dependent_detections[detections_to_keep_list]
407+
408+
409+
def _is_point_within_box(point: np.ndarray, box: np.ndarray) -> bool:
410+
px, py = point
411+
x1, y1, x2, y2 = box
412+
return x1 <= px <= x2 and y1 <= py <= y2

inference_cli/lib/container_adapter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def start_inference_container(
138138
docker_run_kwargs = {}
139139
is_gpu = "gpu" in image and "jetson" not in image
140140
is_jetson = "jetson" in image
141-
141+
142142
if is_gpu:
143143
device_requests = [
144144
docker.types.DeviceRequest(device_ids=["all"], capabilities=[["gpu"]])
@@ -170,7 +170,8 @@ def start_inference_container(
170170
labels=labels,
171171
ports=ports,
172172
device_requests=device_requests,
173-
environment=environment + [
173+
environment=environment
174+
+ [
174175
"MODEL_CACHE_DIR=/tmp/model-cache",
175176
"TRANSFORMERS_CACHE=/tmp/huggingface",
176177
"YOLO_CONFIG_DIR=/tmp/yolo",
@@ -182,14 +183,13 @@ def start_inference_container(
182183
cpu_shares=1024,
183184
security_opt=["no-new-privileges"] if not is_jetson else None,
184185
cap_drop=["ALL"] if not is_jetson else None,
185-
cap_add=(["NET_BIND_SERVICE"] + (["SYS_ADMIN"] if is_gpu else [])) if not is_jetson else None,
186+
cap_add=(
187+
(["NET_BIND_SERVICE"] + (["SYS_ADMIN"] if is_gpu else []))
188+
if not is_jetson
189+
else None
190+
),
186191
read_only=not is_jetson,
187-
volumes={
188-
"/tmp": {
189-
"bind": "/tmp",
190-
"mode": "rw"
191-
}
192-
},
192+
volumes={"/tmp": {"bind": "/tmp", "mode": "rw"}},
193193
network_mode="bridge",
194194
ipc_mode="private" if not is_jetson else None,
195195
**docker_run_kwargs,

tests/inference/unit_tests/core/models/utils/test_keypoints.py

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -123,34 +123,12 @@ def test_model_keypoints_to_response() -> None:
123123
# List of keypoints
124124
assert result == (
125125
[
126+
Keypoint(x=100, y=100, confidence=0.5, class_id=0, **{"class": "nose"}),
127+
Keypoint(x=200, y=200, confidence=0.5, class_id=1, **{"class": "left_eye"}),
126128
Keypoint(
127-
x=100,
128-
y=100,
129-
confidence=0.5,
130-
class_id=0,
131-
**{"class": "nose"}
132-
),
133-
Keypoint(
134-
x=200,
135-
y=200,
136-
confidence=0.5,
137-
class_id=1,
138-
**{"class": "left_eye"}
139-
),
140-
Keypoint(
141-
x=300,
142-
y=300,
143-
confidence=0.5,
144-
class_id=2,
145-
**{"class": "right_eye"}
146-
),
147-
Keypoint(
148-
x=400,
149-
y=400,
150-
confidence=0.5,
151-
class_id=3,
152-
**{"class": "left_ear"}
129+
x=300, y=300, confidence=0.5, class_id=2, **{"class": "right_eye"}
153130
),
131+
Keypoint(x=400, y=400, confidence=0.5, class_id=3, **{"class": "left_ear"}),
154132
],
155133
)
156134

@@ -200,33 +178,11 @@ def test_model_keypoints_to_response_padded_points() -> None:
200178
# List of keypoints
201179
assert result == (
202180
[
181+
Keypoint(x=100, y=100, confidence=0.5, class_id=0, **{"class": "nose"}),
182+
Keypoint(x=200, y=200, confidence=0.5, class_id=1, **{"class": "left_eye"}),
203183
Keypoint(
204-
x=100,
205-
y=100,
206-
confidence=0.5,
207-
class_id=0,
208-
**{"class": "nose"}
209-
),
210-
Keypoint(
211-
x=200,
212-
y=200,
213-
confidence=0.5,
214-
class_id=1,
215-
**{"class": "left_eye"}
216-
),
217-
Keypoint(
218-
x=300,
219-
y=300,
220-
confidence=0.5,
221-
class_id=2,
222-
**{"class": "right_eye"}
223-
),
224-
Keypoint(
225-
x=400,
226-
y=400,
227-
confidence=0.5,
228-
class_id=3,
229-
**{"class": "left_ear"}
184+
x=300, y=300, confidence=0.5, class_id=2, **{"class": "right_eye"}
230185
),
186+
Keypoint(x=400, y=400, confidence=0.5, class_id=3, **{"class": "left_ear"}),
231187
],
232188
)

tests/workflows/integration_tests/execution/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import pytest
1010

11-
1211
ASSETS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets"))
1312
ROCK_PAPER_SCISSORS_ASSETS = os.path.join(ASSETS_DIR, "rock_paper_scissors")
1413

tests/workflows/integration_tests/execution/test_workflow_with_dynamic_zone_and_perspective_converter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
44
from inference.core.managers.base import ModelManager
55
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
6+
from inference.core.workflows.core_steps.transformations.dynamic_zones.v1 import (
7+
OUTPUT_KEY as DYNAMIC_ZONES_OUTPUT_KEY,
8+
)
69
from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import (
710
OUTPUT_DETECTIONS_KEY as PERSPECTIVE_CORRECTION_OUTPUT_DETECTIONS_KEY,
8-
OUTPUT_IMAGE_KEY as PERSPECTIVE_CORRECTION_OUTPUT_IMAGE_KEY,
911
)
10-
from inference.core.workflows.core_steps.transformations.dynamic_zones.v1 import (
11-
OUTPUT_KEY as DYNAMIC_ZONES_OUTPUT_KEY,
12+
from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import (
13+
OUTPUT_IMAGE_KEY as PERSPECTIVE_CORRECTION_OUTPUT_IMAGE_KEY,
1214
)
1315
from inference.core.workflows.execution_engine.core import ExecutionEngine
1416
from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import (
Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import MQTTWriterSinkBlockV1
2-
import pytest
31
import threading
42

3+
import pytest
4+
5+
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import (
6+
MQTTWriterSinkBlockV1,
7+
)
8+
9+
510
@pytest.mark.timeout(5)
611
def test_successful_connection_and_publishing(fake_mqtt_broker):
712
# given
813
block = MQTTWriterSinkBlockV1()
9-
published_message = 'Test message'
10-
expected_message = 'Message published successfully'
14+
published_message = "Test message"
15+
expected_message = "Message published successfully"
1116

1217
fake_mqtt_broker.messages_count_to_wait_for = 1
1318
broker_thread = threading.Thread(target=fake_mqtt_broker.start)
@@ -18,7 +23,7 @@ def test_successful_connection_and_publishing(fake_mqtt_broker):
1823
host=fake_mqtt_broker.host,
1924
port=fake_mqtt_broker.port,
2025
topic="RoboflowTopic",
21-
message=published_message
26+
message=published_message,
2227
)
2328

2429
broker_thread.join(timeout=2)
@@ -27,4 +32,4 @@ def test_successful_connection_and_publishing(fake_mqtt_broker):
2732
assert result["error_status"] is False, "No error expected"
2833
assert result["message"] == expected_message
2934

30-
assert published_message.encode() in fake_mqtt_broker.messages[-1]
35+
assert published_message.encode() in fake_mqtt_broker.messages[-1]

tests/workflows/unit_tests/core_steps/analytics/test_velocity.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import supervision as sv
66

77
from inference.core.workflows.core_steps.analytics.velocity.v1 import VelocityBlockV1
8-
from inference.core.workflows.execution_engine.entities.base import VideoMetadata, WorkflowImageData
8+
from inference.core.workflows.execution_engine.entities.base import (
9+
VideoMetadata,
10+
WorkflowImageData,
11+
)
912

1013

1114
def test_velocity_block_basic_calculation() -> None:

0 commit comments

Comments
 (0)