diff --git a/app/packages/annotation/src/agents/OperatorAnnotationAgent.test.ts b/app/packages/annotation/src/agents/OperatorAnnotationAgent.test.ts index d6e2706abbc..705d2521fb6 100644 --- a/app/packages/annotation/src/agents/OperatorAnnotationAgent.test.ts +++ b/app/packages/annotation/src/agents/OperatorAnnotationAgent.test.ts @@ -92,7 +92,7 @@ describe("OperatorAnnotationAgent", () => { it("should send the correct request body", async () => { mockFetch.mockResolvedValue({ result: { result: {} } }); - const ctx = makeContext({ textPrompt: "cat" }); + const ctx = makeContext(); await agent.infer(ctx); diff --git a/app/packages/annotation/src/agents/hooks/useAnnotationAgent.ts b/app/packages/annotation/src/agents/hooks/useAnnotationAgent.ts index b9fd980b151..e9293546bf9 100644 --- a/app/packages/annotation/src/agents/hooks/useAnnotationAgent.ts +++ b/app/packages/annotation/src/agents/hooks/useAnnotationAgent.ts @@ -126,7 +126,6 @@ const useAgentContext = (): AnnotationContext | null => { selectedLabel && "bounding_box" in selectedLabel.data ? { taskType: AgentTaskType.SEGMENT, - textPrompt: selectedLabel.data.label, regionsOfInterest: [ bboxToRoi( (selectedLabel as DetectionAnnotationLabel).data.bounding_box diff --git a/app/packages/annotation/src/agents/hooks/useToolsContext.ts b/app/packages/annotation/src/agents/hooks/useToolsContext.ts index 983c2e0ce04..5cef11629e9 100644 --- a/app/packages/annotation/src/agents/hooks/useToolsContext.ts +++ b/app/packages/annotation/src/agents/hooks/useToolsContext.ts @@ -13,8 +13,6 @@ export type ToolsContext = { negativePoints?: Vec2[]; /** The current selection of region-of-interest prompts. */ regionsOfInterest?: ROI[]; - /** The current text prompt. */ - textPrompt?: string; }; /** @@ -40,8 +38,6 @@ export interface ToolsState extends ToolsContext { removeNegativePoint(index: number): void; /** Replaces the full set of ROI prompts. */ setRegionsOfInterest(rois: ROI[]): void; - /** Sets the free-text prompt. */ - setTextPrompt(prompt: string): void; /** Clears all tool inputs back to their initial state. */ reset(): void; } @@ -49,7 +45,6 @@ export interface ToolsState extends ToolsContext { const positivePointsAtom = atom([]); const negativePointsAtom = atom([]); const regionsOfInterestAtom = atom([]); -const textPromptAtom = atom(null); /** * Hook which returns the current {@link ToolsContext} (read-only). @@ -61,7 +56,6 @@ export const useToolsContext = (): ToolsContext => { const positivePoints = useAtomValue(positivePointsAtom); const negativePoints = useAtomValue(negativePointsAtom); const regionsOfInterest = useAtomValue(regionsOfInterestAtom); - const textPrompt = useAtomValue(textPromptAtom); return useMemo( () => ({ @@ -69,9 +63,8 @@ export const useToolsContext = (): ToolsContext => { positivePoints, negativePoints, regionsOfInterest, - textPrompt, }), - [activeTask, positivePoints, negativePoints, regionsOfInterest, textPrompt] + [activeTask, positivePoints, negativePoints, regionsOfInterest] ); }; @@ -87,7 +80,6 @@ export const useToolsState = (): ToolsState => { const [regionsOfInterest, setRegionsOfInterest] = useAtom( regionsOfInterestAtom ); - const [textPrompt, setTextPrompt] = useAtom(textPromptAtom); const addPositivePoint = useCallback( (point: Vec2) => setPositivePoints((prev) => [...prev, point]), @@ -115,13 +107,7 @@ export const useToolsState = (): ToolsState => { setPositivePoints([]); setNegativePoints([]); setRegionsOfInterest([]); - setTextPrompt(null); - }, [ - setPositivePoints, - setNegativePoints, - setRegionsOfInterest, - setTextPrompt, - ]); + }, [setPositivePoints, setNegativePoints, setRegionsOfInterest]); return useMemo( () => ({ @@ -129,13 +115,11 @@ export const useToolsState = (): ToolsState => { positivePoints, negativePoints, regionsOfInterest, - textPrompt, addPositivePoint, removePositivePoint, addNegativePoint, removeNegativePoint, setRegionsOfInterest, - setTextPrompt, reset, }), [ @@ -143,13 +127,11 @@ export const useToolsState = (): ToolsState => { positivePoints, negativePoints, regionsOfInterest, - textPrompt, addPositivePoint, removePositivePoint, addNegativePoint, removeNegativePoint, setRegionsOfInterest, - setTextPrompt, reset, ] ); diff --git a/app/packages/annotation/src/agents/types.ts b/app/packages/annotation/src/agents/types.ts index 3e980e17149..7ea05655d58 100644 --- a/app/packages/annotation/src/agents/types.ts +++ b/app/packages/annotation/src/agents/types.ts @@ -27,13 +27,11 @@ export enum AgentTaskType { * - `"positivePoint"` - a set of positive point prompts indicating regions to include * - `"negativePoint"` - a set of negative point prompts indicating regions to exclude * - `"roi"` - a set of polygonal regions of interest to bound the labels - * - `"textPrompt"` - a free-text description of the target object */ export enum InferenceCapability { POSITIVE_POINT = "positivePoint", NEGATIVE_POINT = "negativePoint", ROI = "roi", - TEXT_PROMPT = "textPrompt", } /** @@ -75,8 +73,6 @@ export type AnnotationContext = { negativePoints?: Vec2[]; /** Polygonal regions of interest drawn by the user. */ regionsOfInterest?: ROI[]; - /** Free-text description of the target object. */ - textPrompt?: string; }; /** diff --git a/app/packages/components/src/components/FloatingToolbar/FloatingToolbar.tsx b/app/packages/components/src/components/FloatingToolbar/FloatingToolbar.tsx new file mode 100644 index 00000000000..eb07d32aee9 --- /dev/null +++ b/app/packages/components/src/components/FloatingToolbar/FloatingToolbar.tsx @@ -0,0 +1,451 @@ +/** + * Copyright 2017-2026, Voxel51, Inc. + * + * A generic floating, draggable toolbar. + * + * Completely abstract — knows nothing about segmentation, 3D, or any specific + * domain. Compose it with `FloatingToolbar.Group` and `FloatingToolbar.Action` + * to build any tool palette. + * + * Designed to be portable to Voodo. + * + * @example + * ```tsx + * + * + * + * + * + * + * + * ``` + */ + +import React, { useCallback, useEffect, useRef, useState } from "react"; +import styled, { css } from "styled-components"; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +type Orientation = "horizontal" | "vertical"; + +/** + * Controls which axes the user can drag along. + * + * - `"both"` (default) — free movement in x and y + * - `"x"` — horizontal only + * - `"y"` — vertical only + * - `"none"` — fixed, no dragging + */ +type DragAxis = "both" | "x" | "y" | "none"; + +interface Position { + /** Horizontal offset as a percentage of the parent width (0–100). */ + x: number; + /** Vertical offset as a percentage of the parent height (0–100). */ + y: number; +} + +export interface FloatingToolbarProps { + children: React.ReactNode; + /** Layout direction for groups and actions. Default `"vertical"`. */ + orientation?: Orientation; + /** Which axes the toolbar can be dragged along. Default `"both"`. */ + dragAxis?: DragAxis; + /** Initial position as `{ x, y }` percentages. Default `{ x: 5, y: 50 }`. */ + defaultPosition?: Partial; + /** Clamp range for each axis as `[min, max]` percentage. Default `[5, 95]`. */ + clamp?: [number, number]; + /** CSS `z-index`. Default 10005. */ + zIndex?: number; + /** Additional class name. */ + className?: string; + /** Additional inline styles on the outer container. */ + style?: React.CSSProperties; + /** Whether the toolbar is visible. Default true. */ + visible?: boolean; +} + +export interface FloatingToolbarGroupProps { + children: React.ReactNode; + /** Optional label rendered above/before the group. */ + label?: string; +} + +export interface FloatingToolbarActionProps { + children: React.ReactNode; + /** Whether the action is currently active/selected. */ + active?: boolean; + /** Whether the action is disabled. */ + disabled?: boolean; + /** Native title attribute for basic tooltip. Callers can wrap with any + * tooltip component they prefer. */ + title?: string; + onClick?: (e: React.MouseEvent) => void; +} + +// --------------------------------------------------------------------------- +// Styled primitives +// --------------------------------------------------------------------------- + +const orientationStyles = { + vertical: css` + flex-direction: column; + `, + horizontal: css` + flex-direction: row; + `, +}; + +const Container = styled.div<{ + $isDragging: boolean; +}>` + position: absolute; + display: flex; + background: ${({ theme }) => theme.background.level2}; + border-radius: 6px; + border: 1px solid ${({ theme }) => theme.primary.plainBorder}; + box-shadow: ${({ $isDragging }) => + $isDragging + ? "0 4px 16px rgba(0, 0, 0, 0.25)" + : "0 2px 8px rgba(0, 0, 0, 0.12)"}; + min-width: 36px; + opacity: ${({ $isDragging }) => ($isDragging ? 0.95 : 0.75)}; + user-select: none; + transition: ${({ $isDragging }) => + $isDragging ? "none" : "opacity 0.2s ease, box-shadow 0.2s ease"}; + + &:hover { + opacity: 0.95; + } +`; + +const DragHandle = styled.div<{ + $orientation: Orientation; + $isDragging: boolean; + $hidden: boolean; +}>` + display: ${({ $hidden }) => ($hidden ? "none" : "flex")}; + align-items: center; + justify-content: center; + cursor: ${({ $isDragging }) => ($isDragging ? "grabbing" : "grab")}; + opacity: 0; + transition: opacity 0.2s ease; + color: ${({ theme }) => theme.text.secondary}; + + ${Container}:hover & { + opacity: 0.8; + } + + ${({ $orientation }) => + $orientation === "vertical" + ? css` + width: 100%; + height: 12px; + border-radius: 6px 6px 0 0; + margin: 2px 2px 0 2px; + svg { + transform: rotate(90deg); + } + ` + : css` + height: 100%; + width: 12px; + border-radius: 6px 0 0 6px; + margin: 2px 0 2px 2px; + `} + + svg { + font-size: 14px; + } +`; + +const Content = styled.div<{ $orientation: Orientation }>` + display: flex; + ${({ $orientation }) => orientationStyles[$orientation]} + gap: 8px; + padding: 8px; +`; + +const GroupContainer = styled.div<{ $orientation: Orientation }>` + display: flex; + ${({ $orientation }) => orientationStyles[$orientation]} + gap: 6px; + align-items: center; +`; + +const GroupLabel = styled.span` + font-size: 9px; + font-weight: 600; + color: ${({ theme }) => theme.text.secondary}; + text-align: center; + margin-bottom: 2px; + text-transform: uppercase; + letter-spacing: 0.3px; +`; + +const ActionContainer = styled.div<{ + $active: boolean; + $disabled: boolean; +}>` + width: 28px; + height: 28px; + border-radius: 4px; + display: flex; + align-items: center; + justify-content: center; + cursor: ${({ $disabled }) => ($disabled ? "not-allowed" : "pointer")}; + transition: all 0.15s ease; + + color: ${({ $active, $disabled, theme }) => + $active + ? theme.primary.plainColor + : $disabled + ? theme.text.tertiary + : theme.text.secondary}; + + background: ${({ $active, theme }) => + $active ? theme.background.level1 : "transparent"}; + + ${({ $active, theme }) => + $active && + css` + svg { + filter: drop-shadow(0 0 4px ${theme.primary.plainColor}); + } + `} + + &:hover { + ${({ $disabled, theme }) => + !$disabled && + css` + background: ${theme.background.level1}; + color: ${theme.text.primary}; + transform: scale(1.1); + `} + } + + &:active { + ${({ $disabled }) => + !$disabled && + css` + transform: scale(0.95); + `} + } + + svg { + font-size: 18px; + } +`; + +// --------------------------------------------------------------------------- +// Drag handle icon (inline SVG — zero external deps) +// --------------------------------------------------------------------------- + +const DragIndicatorSvg = () => ( + + + + + + + + +); + +// --------------------------------------------------------------------------- +// Context (passes orientation to sub-components without prop drilling) +// --------------------------------------------------------------------------- + +const OrientationContext = React.createContext("vertical"); + +// --------------------------------------------------------------------------- +// Sub-components +// --------------------------------------------------------------------------- + +const Group: React.FC = ({ children, label }) => { + const orientation = React.useContext(OrientationContext); + return ( + + {label && {label}} + {children} + + ); +}; +Group.displayName = "FloatingToolbar.Group"; + +const Action: React.FC = ({ + children, + active = false, + disabled = false, + title, + onClick, +}) => { + const handleClick = useCallback( + (e: React.MouseEvent) => { + if (disabled) return; + onClick?.(e); + }, + [disabled, onClick] + ); + + return ( + + {children} + + ); +}; +Action.displayName = "FloatingToolbar.Action"; + +// --------------------------------------------------------------------------- +// Defaults +// --------------------------------------------------------------------------- + +const DEFAULT_POSITION: Position = { x: 5, y: 50 }; +const DEFAULT_CLAMP: [number, number] = [5, 95]; + +// --------------------------------------------------------------------------- +// Main component +// --------------------------------------------------------------------------- + +const FloatingToolbarInner: React.FC = ({ + children, + orientation = "vertical", + dragAxis = "both", + defaultPosition, + clamp = DEFAULT_CLAMP, + zIndex = 10005, + className, + style, + visible = true, +}) => { + const initialPos: Position = { + ...DEFAULT_POSITION, + ...defaultPosition, + }; + + const [position, setPosition] = useState(initialPos); + const [isDragging, setIsDragging] = useState(false); + const dragStartRef = useRef({ clientX: 0, clientY: 0, pos: initialPos }); + const containerRef = useRef(null); + + const canDrag = dragAxis !== "none"; + const canDragX = dragAxis === "both" || dragAxis === "x"; + const canDragY = dragAxis === "both" || dragAxis === "y"; + + // ---- drag handlers ---- + + const handleDragStart = useCallback( + (e: React.MouseEvent) => { + if (!canDrag) return; + e.preventDefault(); + e.stopPropagation(); + setIsDragging(true); + dragStartRef.current = { + clientX: e.clientX, + clientY: e.clientY, + pos: position, + }; + }, + [canDrag, position] + ); + + const handleMouseMove = useCallback( + (e: MouseEvent) => { + const parent = containerRef.current?.parentElement; + const parentW = parent?.clientWidth ?? window.innerWidth; + const parentH = parent?.clientHeight ?? window.innerHeight; + const { clientX, clientY, pos } = dragStartRef.current; + + const nextX = canDragX + ? Math.max( + clamp[0], + Math.min(clamp[1], pos.x + ((e.clientX - clientX) / parentW) * 100) + ) + : pos.x; + + const nextY = canDragY + ? Math.max( + clamp[0], + Math.min(clamp[1], pos.y + ((e.clientY - clientY) / parentH) * 100) + ) + : pos.y; + + setPosition({ x: nextX, y: nextY }); + }, + [canDragX, canDragY, clamp] + ); + + const handleMouseUp = useCallback(() => { + setIsDragging(false); + }, []); + + useEffect(() => { + if (!isDragging) return; + document.addEventListener("mousemove", handleMouseMove); + document.addEventListener("mouseup", handleMouseUp); + return () => { + document.removeEventListener("mousemove", handleMouseMove); + document.removeEventListener("mouseup", handleMouseUp); + }; + }, [isDragging, handleMouseMove, handleMouseUp]); + + // ---- render ---- + + if (!visible) return null; + + return ( + + e.stopPropagation()} + onClick={(e) => e.stopPropagation()} + > + + + + {children} + + + ); +}; + +function orientationToFlex(o: Orientation): React.CSSProperties { + return { flexDirection: o === "vertical" ? "column" : "row" }; +} + +// --------------------------------------------------------------------------- +// Compound component export +// --------------------------------------------------------------------------- + +export const FloatingToolbar = Object.assign(FloatingToolbarInner, { + Group, + Action, +}); diff --git a/app/packages/components/src/components/FloatingToolbar/index.ts b/app/packages/components/src/components/FloatingToolbar/index.ts new file mode 100644 index 00000000000..3b679babfb8 --- /dev/null +++ b/app/packages/components/src/components/FloatingToolbar/index.ts @@ -0,0 +1,6 @@ +export { + FloatingToolbar, + type FloatingToolbarActionProps, + type FloatingToolbarGroupProps, + type FloatingToolbarProps, +} from "./FloatingToolbar"; diff --git a/app/packages/components/src/components/index.ts b/app/packages/components/src/components/index.ts index e98268d0fed..f9cdd27f35d 100644 --- a/app/packages/components/src/components/index.ts +++ b/app/packages/components/src/components/index.ts @@ -13,6 +13,12 @@ export { default as EditableLabel } from "./EditableLabel"; export { default as ErrorBoundary, ErrorDisplayMarkup } from "./ErrorBoundary"; export { default as ExternalLink, useExternalLink } from "./ExternalLink"; export { default as FilterAndSelectionIndicator } from "./FilterAndSelectionIndicator"; +export { FloatingToolbar } from "./FloatingToolbar"; +export type { + FloatingToolbarActionProps, + FloatingToolbarGroupProps, + FloatingToolbarProps, +} from "./FloatingToolbar"; export { default as Header } from "./Header"; export { default as HelpPanel } from "./HelpPanel"; export { default as HelpTooltip } from "./HelpTooltip"; diff --git a/app/packages/core/src/components/Modal/Modal.tsx b/app/packages/core/src/components/Modal/Modal.tsx index 72ffbe77a95..13b22324b47 100644 --- a/app/packages/core/src/components/Modal/Modal.tsx +++ b/app/packages/core/src/components/Modal/Modal.tsx @@ -46,6 +46,7 @@ import { Sidebar } from "./Sidebar"; import SchemaManagementProvider from "./Sidebar/Annotate/SchemaManagementProvider"; import useCanManageSchema from "./Sidebar/Annotate/useCanManageSchema"; import { useAnnotationTracking } from "./Sidebar/Annotate/useAnnotationTracking"; +import { SegmentationToolbar } from "./Sidebar/Annotate/Edit/SegmentationToolbar"; import { TooltipInfo } from "./TooltipInfo"; import { useLookerHelpers, useTooltipEventHandler } from "./hooks"; import { modalContext } from "./modal-context"; @@ -351,6 +352,7 @@ const Modal = () => { data-cy="modal" > + {isAnnotationEnabled && } diff --git a/app/packages/core/src/components/Modal/Sidebar/Annotate/Actions.tsx b/app/packages/core/src/components/Modal/Sidebar/Annotate/Actions.tsx index 8471689ca84..ad2792a5606 100644 --- a/app/packages/core/src/components/Modal/Sidebar/Annotate/Actions.tsx +++ b/app/packages/core/src/components/Modal/Sidebar/Annotate/Actions.tsx @@ -41,6 +41,7 @@ import { editing } from "./Edit"; import { fieldsOfType } from "./Edit/state"; import useCreate from "./Edit/useCreate"; import { useQuickDraw } from "./Edit/useQuickDraw"; +import { useSegmentationMasks } from "./Edit/useSegmentationMasks"; import useCanManageSchema from "./useCanManageSchema"; import useShowModal from "./useShowModal"; @@ -225,6 +226,55 @@ const Detection = () => { ); }; +const Segmentation = () => { + const { active, enter, exit } = useSegmentationMasks(); + const isPatchView = useRecoilValue(isPatchesView); + const fields = useAtomValue(fieldsOfType(DETECTION)); + const disabled = isPatchView || fields.length === 0; + + const tooltip = isPatchView + ? "Segmentation is not supported in this view" + : active + ? "Exit segmentation mode" + : "Segmentation tools"; + + return ( + + { + if (disabled) return; + if (active) { + exit(); + } else { + enter(); + } + }} + > + + Segmentation + {/* Layers / stack icon */} + + + + + ); +}; + export const Undo = () => { const { undo, undoEnabled } = useUndoRedo(); @@ -419,7 +469,10 @@ const Actions = () => { ) : ( - + <> + + + )} diff --git a/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/SegmentationToolbar.tsx b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/SegmentationToolbar.tsx new file mode 100644 index 00000000000..d9354d82e03 --- /dev/null +++ b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/SegmentationToolbar.tsx @@ -0,0 +1,89 @@ +/** + * Copyright 2017-2026, Voxel51, Inc. + * + * Segmentation toolbar built on the generic FloatingToolbar. + * Appears when the segmentation mode is active (layers icon in Actions bar). + * Contains tool buttons for AI Segment, pen, brush, etc. + */ + +import { FloatingToolbar, Tooltip } from "@fiftyone/components"; +import React from "react"; +import { useSegmentationMasks } from "./useSegmentationMasks"; +import { useAISegment } from "./useAISegment"; + +// --------------------------------------------------------------------------- +// Icons +// --------------------------------------------------------------------------- + +const AISegmentIcon = () => ( + + AI Segment + {/* Four-point star / sparkle icon */} + + +); + +const SelectIcon = () => ( + + Select + + +); + +const CloseIcon = () => ( + + Exit AI Segment + + +); + +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- + +export const SegmentationToolbar: React.FC = () => { + const { active: segmentationActive } = useSegmentationMasks(); + const { + active: aiSegmentActive, + enter: enterAI, + exit: exitAI, + } = useAISegment(); + + return ( + + + {aiSegmentActive ? ( + + + + + + ) : ( + + + + + + )} + + { + if (aiSegmentActive) { + exitAI(); + } else { + enterAI(); + } + }} + > + + + + + + ); +}; diff --git a/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useAISegment.ts b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useAISegment.ts new file mode 100644 index 00000000000..fd3223b0093 --- /dev/null +++ b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useAISegment.ts @@ -0,0 +1,421 @@ +/** + * Copyright 2017-2026, Voxel51, Inc. + * + * Hook for AI-assisted segmentation mode. + * + * Flow: + * 1. User enters AI segment mode → point overlay created, interactive mode activated + * 2. User clicks canvas → positive point placed → inference triggered + * 4. First result → pending Detection overlay created with mask + bbox + * 5. More points → inference re-triggered → same Detection overlay updated + * 6. User configures field/class in sidebar → confirms label + */ + +import { useCallback, useEffect, useMemo, useRef } from "react"; +import { atom, getDefaultStore, useAtomValue, useSetAtom } from "jotai"; +import { useAtomCallback } from "jotai/utils"; +import { getEventBus } from "@fiftyone/events"; + +import { + type AISegmentPointOverlay, + type BoundingBoxOverlay, + type BoundingBoxOptions, + type KeypointOptions, + type LighterEventGroup, + InteractiveKeypointHandler, + UNDEFINED_LIGHTER_SCENE_ID, + useLighter, + useLighterEventHandler, +} from "@fiftyone/lighter"; +import type { DetectionLabel } from "@fiftyone/looker"; +import type { AnnotationLabel } from "@fiftyone/state"; +import { useToolsState } from "@fiftyone/annotation/src/agents/hooks/useToolsContext"; +import { useActiveTask } from "@fiftyone/annotation/src/agents/hooks/useActiveTask"; +import { useAgentSelector } from "@fiftyone/annotation/src/agents/hooks/useAgentSelector"; +import { useAgentRegistry } from "@fiftyone/annotation/src/agents/hooks/useAgentRegistry"; +import { useAnnotationAgent } from "@fiftyone/annotation/src/agents/hooks/useAnnotationAgent"; +import { OperatorAnnotationAgent } from "@fiftyone/annotation/src/agents/OperatorAnnotationAgent"; +import { + AgentTaskType, + type InferenceResult, + type SegmentationInferenceResult, +} from "@fiftyone/annotation/src/agents/types"; +import { DETECTION, objectId } from "@fiftyone/utilities"; +import { v4 as generateUUID } from "uuid"; +import { defaultField, editing, savedLabel } from "./state"; +import { useAnnotationContext } from "./state"; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const AI_SEGMENT_OPERATOR_URI = "@voxel51/annotation/segment"; +const AI_SEGMENT_AGENT_ID = "ai-segment-operator"; + +// --------------------------------------------------------------------------- +// Atoms +// --------------------------------------------------------------------------- + +const aiSegmentActiveAtom = atom(false); +const aiSegmentOverlayIdAtom = atom(null); +const pendingDetectionIdAtom = atom(null); +/** Whether the agent has been registered in this session */ +const agentRegisteredAtom = atom(false); + +// --------------------------------------------------------------------------- +// Hook +// --------------------------------------------------------------------------- + +export const useAISegment = () => { + const { scene, addOverlay, removeOverlay, overlayFactory, getOverlay } = + useLighter(); + const toolsState = useToolsState(); + const { setActiveTask } = useActiveTask(); + const { activeAgent, setActiveAgent } = useAgentSelector(); + const registry = useAgentRegistry(); + const resolvedAgent = useAnnotationAgent(activeAgent?.agent); + const { selectedLabel } = useAnnotationContext(); + + const useEventHandler = useLighterEventHandler( + scene?.getEventChannel() ?? UNDEFINED_LIGHTER_SCENE_ID + ); + + const sceneRef = useRef(scene); + sceneRef.current = scene; + const resolvedAgentRef = useRef(resolvedAgent); + resolvedAgentRef.current = resolvedAgent; + + const active = useAtomValue(aiSegmentActiveAtom); + const overlayId = useAtomValue(aiSegmentOverlayIdAtom); + const setActive = useSetAtom(aiSegmentActiveAtom); + const setOverlayId = useSetAtom(aiSegmentOverlayIdAtom); + const setEditing = useSetAtom(editing); + const setPendingDetectionId = useSetAtom(pendingDetectionIdAtom); + + const getAgentRegistered = useAtomCallback( + useCallback((get) => get(agentRegisteredAtom), []) + ); + const setAgentRegistered = useSetAtom(agentRegisteredAtom); + + // ---- lazy agent registration (on first point, not on enter) ---- + + const ensureAgent = useCallback(async () => { + if (getAgentRegistered()) return; + + const agent = new OperatorAnnotationAgent( + AI_SEGMENT_OPERATOR_URI + ); + await registry.register(AI_SEGMENT_AGENT_ID, "AI Segment", agent); + setActiveAgent({ id: AI_SEGMENT_AGENT_ID, label: "AI Segment", agent }); + setAgentRegistered(true); + }, [registry, setActiveAgent, getAgentRegistered, setAgentRegistered]); + + // ---- enter / exit ---- + + const enter = useCallback(() => { + const currentScene = sceneRef.current; + if (!currentScene || !overlayFactory) return; + + setActiveTask(AgentTaskType.SEGMENT); + setActive(true); + + const id = `ai-segment-points-${generateUUID()}`; + const overlay = overlayFactory.create< + Omit, + AISegmentPointOverlay + >("ai-segment-point", { + id, + field: "", + label: { id, label: "", tags: [], points: [] } as any, + }); + + addOverlay(overlay, false); + setOverlayId(id); + }, [overlayFactory, addOverlay, setActive, setActiveTask, setOverlayId]); + + const exit = useCallback(() => { + const currentScene = sceneRef.current; + if (currentScene && !currentScene.isDestroyed) { + currentScene.exitInteractiveMode(); + } + + const currentOverlayId = overlayId; + if (currentOverlayId) { + removeOverlay(currentOverlayId, false); + } + + toolsState.reset(); + setActiveTask(null); + setActive(false); + setOverlayId(null); + setPendingDetectionId(null); + }, [ + overlayId, + removeOverlay, + toolsState, + setActiveTask, + setActive, + setOverlayId, + setPendingDetectionId, + ]); + + // ---- enter interactive mode as soon as the overlay is ready ---- + + const interactiveModeRef = useRef(false); + + useEffect(() => { + const currentScene = sceneRef.current; + if (!currentScene || currentScene.isDestroyed || !active || !overlayId) + return; + + if (!interactiveModeRef.current) { + const overlay = getOverlay?.(overlayId); + if (overlay) { + const eventBus = getEventBus( + currentScene.getEventChannel() + ); + const handler = new InteractiveKeypointHandler( + overlay as AISegmentPointOverlay, + eventBus + ); + currentScene.enterInteractiveMode(handler); + interactiveModeRef.current = true; + } + } + }, [active, overlayId, getOverlay]); + + useEffect(() => { + if (!active) { + interactiveModeRef.current = false; + } + }, [active]); + + // ---- apply inference result ---- + + const getPendingDetectionId = useAtomCallback( + useCallback((get) => get(pendingDetectionIdAtom), []) + ); + + const applyResult = useCallback( + (result: InferenceResult) => { + if (result.type !== "sync") { + console.warn("[AI Segment] Async results not yet supported"); + return; + } + + const detection = result.response?.detections?.[0]; + if (!detection) { + console.warn("[AI Segment] No detection in inference result"); + return; + } + + // (a) Existing label selected → update its mask in-place + if (selectedLabel && "bounding_box" in selectedLabel.data) { + const overlay = getOverlay?.(selectedLabel.data._id) as + | BoundingBoxOverlay + | undefined; + if (overlay) { + overlay.updateLabel({ + ...overlay.label, + bounding_box: detection.bounding_box ?? overlay.label.bounding_box, + mask: detection.mask, + }); + overlay.markDirty(); + return; + } + } + + // (b) Pending detection from a previous inference → update it + const existingId = getPendingDetectionId(); + const existingOverlay = existingId + ? (getOverlay?.(existingId) as BoundingBoxOverlay | undefined) + : undefined; + + if (existingOverlay) { + existingOverlay.updateLabel({ + ...existingOverlay.label, + label: detection.label || existingOverlay.label.label, + bounding_box: detection.bounding_box, + mask: detection.mask, + }); + existingOverlay.markDirty(); + } else { + // (c) First inference → create new pending Detection + if (!overlayFactory || !scene) return; + + const store = getDefaultStore(); + const id = objectId(); + const field = store.get(defaultField(DETECTION)) ?? undefined; + if (!field) { + console.warn("[AI Segment] No detection field available"); + return; + } + + const labelValue = detection.label || "object"; + + const labelData: DetectionLabel = { + _id: id, + _cls: "Detection", + label: labelValue, + bounding_box: detection.bounding_box, + mask: detection.mask, + tags: [], + } as any; + + const overlay = overlayFactory.create< + BoundingBoxOptions, + BoundingBoxOverlay + >("bounding-box", { + id, + field, + label: labelData, + draggable: true, + resizeable: true, + }); + + const bb = detection.bounding_box; + if (bb && bb.length === 4) { + const cs = scene.getCoordinateSystem(); + if (cs) { + const t = cs.getTransform(); + overlay.bounds = { + x: t.offsetX + bb[0] * t.scaleX, + y: t.offsetY + bb[1] * t.scaleY, + width: bb[2] * t.scaleX, + height: bb[3] * t.scaleY, + }; + } + } + + addOverlay(overlay); + setPendingDetectionId(id); + + store.set(savedLabel, labelData); + setEditing( + atom({ + isNew: true, + data: labelData, + overlay, + path: field, + type: DETECTION, + }) + ); + } + }, + [ + overlayFactory, + scene, + addOverlay, + getOverlay, + getPendingDetectionId, + selectedLabel, + setPendingDetectionId, + setEditing, + toolsState, + ] + ); + + const applyResultRef = useRef(applyResult); + applyResultRef.current = applyResult; + + // ---- point events → inference ---- + + const getOverlayId = useAtomCallback( + useCallback((get) => get(aiSegmentOverlayIdAtom), []) + ); + + useEventHandler( + "lighter:keypoint-point-added", + useCallback( + (payload) => { + const currentOverlayId = getOverlayId(); + if (!currentOverlayId || payload.id !== currentOverlayId) return; + + const overlay = getOverlay?.(currentOverlayId); + if (!overlay || !("getRelativePoints" in overlay)) return; + + const keypointOverlay = overlay as AISegmentPointOverlay; + const relativePoints = keypointOverlay.getRelativePoints(); + const newPoint = relativePoints[payload.pointIndex]; + if (!newPoint) return; + + // Update tools state so AnnotationContext stays in sync + toolsState.addPositivePoint(newPoint); + + // Register agent lazily on first point, then infer. + // Use queueMicrotask to let Jotai flush the addPositivePoint + // state update before resolvedAgent.infer() reads the context. + ensureAgent().then(() => { + queueMicrotask(() => { + const agent = resolvedAgentRef.current; + if (agent) { + agent + .infer() + .then((result) => { + keypointOverlay.stopProcessing(); + if (result) applyResultRef.current(result); + }) + .catch((err) => { + console.error("[AI Segment] Inference failed:", err); + keypointOverlay.stopProcessing(); + }); + } else { + keypointOverlay.stopProcessing(); + } + }); + }); + }, + [getOverlay, getOverlayId, toolsState, ensureAgent] + ) + ); + + useEventHandler( + "lighter:keypoint-point-deleted", + useCallback( + (payload) => { + const currentOverlayId = getOverlayId(); + if (!currentOverlayId || payload.id !== currentOverlayId) return; + + const overlay = getOverlay?.(currentOverlayId) as + | AISegmentPointOverlay + | undefined; + if (!overlay) return; + + toolsState.removePositivePoint(payload.pointIndex); + + const points = overlay.getRelativePoints(); + if (points.length > 0) { + overlay.startProcessing(points.length - 1); + + queueMicrotask(() => { + const agent = resolvedAgentRef.current; + if (agent) { + agent + .infer() + .then((result) => { + overlay.stopProcessing(); + if (result) applyResultRef.current(result); + }) + .catch((err) => { + console.error("[AI Segment] Inference failed:", err); + overlay.stopProcessing(); + }); + } + }); + } + }, + [getOverlay, getOverlayId, toolsState] + ) + ); + + // ---- return ---- + + return useMemo( + () => ({ + active, + enter, + exit, + }), + [active, enter, exit] + ); +}; diff --git a/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useSegmentationMasks.ts b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useSegmentationMasks.ts new file mode 100644 index 00000000000..75c7e08bbf0 --- /dev/null +++ b/app/packages/core/src/components/Modal/Sidebar/Annotate/Edit/useSegmentationMasks.ts @@ -0,0 +1,214 @@ +/** + * Copyright 2017-2026, Voxel51, Inc. + */ + +import { useCallback, useMemo, useRef } from "react"; +import { atom, useAtomValue, useSetAtom } from "jotai"; +import { useAtomCallback } from "jotai/utils"; + +import { + BaseOverlay, + UNDEFINED_LIGHTER_SCENE_ID, + useLighter, + useLighterEventHandler, +} from "@fiftyone/lighter"; +import { useAnnotationContext } from "./state"; +import { DETECTION } from "@fiftyone/utilities"; +import useCreate from "./useCreate"; + +export const DEFAULT_TOOL_SIZE = 16; +export const MIN_TOOL_SIZE = 1; +export const MAX_TOOL_SIZE = 32; + +export type SegmentationTool = "select" | "brush" | "eraser"; // | "pen"; +export type SegmentationToolShape = "circle" | "square"; + +export interface SegmentationToolState { + active: boolean; + size: number; + tool: SegmentationTool; + shape: SegmentationToolShape; +} + +// --------------------------------------------------------------------------- +// Atoms (internal) +// --------------------------------------------------------------------------- + +const segmentationActiveAtom = atom(false); +const toolAtom = atom("select"); +const toolSizeAtom = atom(DEFAULT_TOOL_SIZE); +const toolShapeAtom = atom("circle"); + +/** + * Tracks the last processed `lighter:overlay-create` event ID so that only one + * `useSegmentationMasks` instance handles each event, even though the hook is + * called in multiple components. + */ +const lastProcessedCreateIdAtom = atom(null); + +// --------------------------------------------------------------------------- +// Unsafe exports for non-React bridge access only. +// Do not use directly in React components — use useSegmentationMasks() instead. +// --------------------------------------------------------------------------- + +/** @internal */ export { segmentationActiveAtom as _unsafeSegmentationActiveAtom }; +/** @internal */ export { toolAtom as _unsafeToolAtom }; +/** @internal */ export { toolSizeAtom as _unsafeToolSizeAtom }; +/** @internal */ export { toolShapeAtom as _unsafeToolShapeAtom }; + +/** + * Segmentation mask tool state hook. + * + * Selection/editing state is managed by the existing annotation system + * (editing atom in state.ts, SelectionManager in Lighter). + * This hook only owns segmentation-specific tool state. + */ +export const useSegmentationMasks = () => { + const { scene, addOverlay } = useLighter(); + const { selectedLabel } = useAnnotationContext(); + const useEventHandler = useLighterEventHandler( + scene?.getEventChannel() ?? UNDEFINED_LIGHTER_SCENE_ID + ); + + // Using refs to prevent shared closure contexts from retaining old Scene2D instances. + const sceneRef = useRef(scene); + sceneRef.current = scene; + const selectedLabelRef = useRef(selectedLabel); + selectedLabelRef.current = selectedLabel; + + const segmentationActive = useAtomValue(segmentationActiveAtom); + const tool = useAtomValue(toolAtom); + const toolSize = useAtomValue(toolSizeAtom); + const toolShape = useAtomValue(toolShapeAtom); + + const setActive = useSetAtom(segmentationActiveAtom); + const setTool = useSetAtom(toolAtom); + const setToolSizeRaw = useSetAtom(toolSizeAtom); + const setToolShape = useSetAtom(toolShapeAtom); + + const createDetection = useCreate(DETECTION); + + const enter = useCallback(() => { + setActive(true); + }, [setActive]); + + const exit = useCallback(() => { + setActive(false); + setTool("select"); + }, [setActive, setTool]); + + const switchTool = useCallback( + (newTool: SegmentationTool) => { + setTool(newTool); + }, + [setTool] + ); + + const increaseToolSize = useCallback(() => { + setToolSizeRaw((prev) => Math.min(prev + 1, MAX_TOOL_SIZE)); + }, [setToolSizeRaw]); + + const decreaseToolSize = useCallback(() => { + setToolSizeRaw((prev) => Math.max(prev - 1, MIN_TOOL_SIZE)); + }, [setToolSizeRaw]); + + const setToolSize = useCallback( + (size: number) => { + const n = Number(size); + if (Number.isNaN(n)) return; + setToolSizeRaw(Math.max(MIN_TOOL_SIZE, Math.min(n, MAX_TOOL_SIZE))); + }, + [setToolSizeRaw] + ); + + const switchToolShape = useCallback( + (shape: SegmentationToolShape) => { + setToolShape(shape); + }, + [setToolShape] + ); + + const claimCreateEvent = useAtomCallback( + useCallback((get, set, eventId: string) => { + if (get(lastProcessedCreateIdAtom) === eventId) { + return false; + } + + set(lastProcessedCreateIdAtom, eventId); + + return true; + }, []) + ); + + /** + * Handles the `lighter:overlay-create` event fired by `InteractionManager` + * on pointer-down when no interactive handler exists. + * + * 1. Finalize the previous detection (exit interactive mode, persist overlay, + * remember field/label for auto-assignment). + * 2. Resolve field and label for the next detection. + * 3. Create the next detection. + */ + useEventHandler( + "lighter:overlay-create", + useCallback( + (payload) => { + if (!segmentationActive || !claimCreateEvent(payload.eventId)) { + return; + } + + // Finalize the previous detection if one exists + const currentScene = sceneRef.current; + const currentLabel = selectedLabelRef.current; + + if (currentLabel) { + if ( + currentScene && + !currentScene.isDestroyed && + currentScene.renderLoopActive + ) { + currentScene.exitInteractiveMode(); + + if (currentLabel.overlay) { + addOverlay(currentLabel.overlay as BaseOverlay); + } + } + } + + // TODO: assume previous `field` and `labelValue` + // e.g. createDetection({ field, labelValue }); + createDetection(); + }, + [claimCreateEvent, segmentationActive] + ) + ); + + return useMemo( + () => ({ + active: segmentationActive, + tool, + toolSize, + toolShape, + enter, + exit, + switchTool, + switchToolShape, + increaseToolSize, + decreaseToolSize, + setToolSize, + }), + [ + segmentationActive, + tool, + toolSize, + toolShape, + enter, + exit, + switchTool, + switchToolShape, + increaseToolSize, + decreaseToolSize, + setToolSize, + ] + ); +}; diff --git a/app/packages/lighter/src/index.ts b/app/packages/lighter/src/index.ts index 685fba893e0..cdbdb94b20a 100644 --- a/app/packages/lighter/src/index.ts +++ b/app/packages/lighter/src/index.ts @@ -25,6 +25,7 @@ export { ClassificationOverlay } from "./overlay/ClassificationOverlay"; export type { ClassificationOptions } from "./overlay/ClassificationOverlay"; export { ImageOverlay } from "./overlay/ImageOverlay"; export type { ImageOptions } from "./overlay/ImageOverlay"; +export { AISegmentPointOverlay } from "./overlay/AISegmentPointOverlay"; export { KeypointOverlay } from "./overlay/KeypointOverlay"; export type { KeypointLabel, KeypointOptions } from "./overlay/KeypointOverlay"; export { OverlayFactory } from "./overlay/OverlayFactory"; diff --git a/app/packages/lighter/src/interaction/InteractionManager.ts b/app/packages/lighter/src/interaction/InteractionManager.ts index baaff757393..a83cc21710c 100644 --- a/app/packages/lighter/src/interaction/InteractionManager.ts +++ b/app/packages/lighter/src/interaction/InteractionManager.ts @@ -12,6 +12,7 @@ import type { Renderer2D } from "../renderer/Renderer2D"; import type { SelectionManager } from "../selection/SelectionManager"; import type { Point, Rect } from "../types"; import { InteractiveDetectionHandler } from "./InteractiveDetectionHandler"; +import { InteractiveKeypointHandler } from "./InteractiveKeypointHandler"; import { v4 as generateUUID } from "uuid"; /** @@ -280,8 +281,14 @@ export class InteractionManager { const interactiveHandler = this.getInteractiveHandler(); if (interactiveHandler) { - handler = interactiveHandler.getOverlay(); - this.selectionManager.select(handler.id); + if (interactiveHandler instanceof InteractiveKeypointHandler) { + // Keypoint handlers manage their own point placement — route + // events directly to the handler, not to the underlying overlay. + handler = interactiveHandler; + } else { + handler = interactiveHandler.getOverlay(); + } + this.selectionManager.select(interactiveHandler.getOverlay().id); } else { handler = this.findHandlerAtPoint(point); // Prevent pan/zoom when target is selectable @@ -477,7 +484,12 @@ export class InteractionManager { this.maintainAspectRatio ); } else { - handler = interactiveHandler.getOverlay(); + // For keypoint handlers, route to the handler itself (manages preview line); + // for detection handlers, route to the underlying overlay (manages resize). + handler = + interactiveHandler instanceof InteractiveKeypointHandler + ? interactiveHandler + : interactiveHandler.getOverlay(); handler.onMove?.( point, @@ -540,7 +552,10 @@ export class InteractionManager { const interactiveHandler = this.getInteractiveHandler(); if (interactiveHandler) { - handler = interactiveHandler.getOverlay(); + handler = + interactiveHandler instanceof InteractiveKeypointHandler + ? interactiveHandler + : interactiveHandler.getOverlay(); } else { handler = this.findMovingHandler() || this.findHandlerAtPoint(point); } @@ -831,8 +846,15 @@ export class InteractionManager { ); } - private getInteractiveHandler(): InteractiveDetectionHandler | undefined { - return this.handlers.find((h) => h instanceof InteractiveDetectionHandler); + private getInteractiveHandler(): + | InteractiveDetectionHandler + | InteractiveKeypointHandler + | undefined { + return this.handlers.find( + (h) => + h instanceof InteractiveDetectionHandler || + h instanceof InteractiveKeypointHandler + ) as InteractiveDetectionHandler | InteractiveKeypointHandler | undefined; } /** diff --git a/app/packages/lighter/src/interaction/InteractiveKeypointHandler.ts b/app/packages/lighter/src/interaction/InteractiveKeypointHandler.ts index 04a94d17d6b..13b32e5be54 100644 --- a/app/packages/lighter/src/interaction/InteractiveKeypointHandler.ts +++ b/app/packages/lighter/src/interaction/InteractiveKeypointHandler.ts @@ -55,10 +55,38 @@ export class InteractiveKeypointHandler implements InteractionHandler { worldPoint: Point, _event: PointerEvent ): boolean { + // Block placement while inference is processing (ripple active) + if ( + "isProcessing" in this.overlay && + (this.overlay as any).isProcessing() + ) { + return true; // swallow the click + } + + // Only place points within the sample image bounds + if (!this.isWithinSample(worldPoint)) { + return false; + } + this.overlay.addPoint(worldPoint); return true; } + /** + * Checks whether a world-space point falls within the [0,1] normalized + * sample bounds (i.e. inside the image). + */ + private isWithinSample(worldPoint: Point): boolean { + const cs = this.overlay.getCoordinateSystemPublic?.(); + if (!cs) return true; // no coordinate system — allow placement + + const t = cs.getTransform(); + const rx = (worldPoint.x - t.offsetX) / t.scaleX; + const ry = (worldPoint.y - t.offsetY) / t.scaleY; + + return rx >= 0 && rx <= 1 && ry >= 0 && ry <= 1; + } + onMove(_point: Point, worldPoint: Point, _event: PointerEvent): boolean { this.overlay.setPreviewPoint(worldPoint); return true; diff --git a/app/packages/lighter/src/overlay/AISegmentPointOverlay.ts b/app/packages/lighter/src/overlay/AISegmentPointOverlay.ts new file mode 100644 index 00000000000..4238640707e --- /dev/null +++ b/app/packages/lighter/src/overlay/AISegmentPointOverlay.ts @@ -0,0 +1,201 @@ +/** + * Copyright 2017-2026, Voxel51, Inc. + * + * Overlay for AI-assisted segmentation prompt points. + * Renders each point as a green circle with a white "+" cross, + * with a looping ripple animation while inference is in flight. + */ + +import { STROKE_WIDTH } from "../constants"; +import type { Renderer2D } from "../renderer/Renderer2D"; +import type { Point, RenderMeta } from "../types"; +import { KeypointOverlay, type KeypointOptions } from "./KeypointOverlay"; + +const AI_POINT_RADIUS = 10; +const AI_POINT_SELECTED_RADIUS = 13; +const POSITIVE_COLOR = "#22c55e"; // green-500 +const POSITIVE_SELECTED_COLOR = "#4ade80"; // green-400 +const CROSS_COLOR = "#ffffff"; +const CROSS_LINE_WIDTH = 2.5; + +// Ripple animation settings +const RIPPLE_CYCLE_MS = 800; // duration of one ripple cycle +const RIPPLE_MAX_RADIUS = 30; +const RIPPLE_RINGS = 2; + +/** + * AI-segment point overlay that renders positive prompt-points as + * green circles with white "+" crosses, with a looping ripple animation + * while inference is processing. + * + * Call {@link startProcessing} when inference begins and + * {@link stopProcessing} when the result arrives. + */ +export class AISegmentPointOverlay extends KeypointOverlay { + /** Index of the point currently showing a ripple, or null. */ + private processingPointIndex: number | null = null; + private processingStartTime = 0; + private animationFrameId: number | null = null; + + constructor(options: Omit) { + super({ ...options, connections: [], closed: false }); + } + + override getOverlayType(): string { + return "AISegmentPointOverlay"; + } + + /** + * No-op: AI segment points are independent — no preview line from + * the last point to the cursor. + */ + override setPreviewPoint(_worldPoint: { x: number; y: number } | null): void { + // intentionally empty + } + + /** + * Override addPoint to auto-start processing ripple on the new point. + */ + override addPoint(worldPoint: Point): number { + const idx = super.addPoint(worldPoint); + this.startProcessing(idx); + return idx; + } + + /** + * Start the looping ripple animation on a specific point. + * Call this when inference begins. + */ + startProcessing(pointIndex: number): void { + this.processingPointIndex = pointIndex; + this.processingStartTime = performance.now(); + this.scheduleAnimation(); + } + + /** + * Stop the ripple animation. + * Call this when inference result arrives. + */ + stopProcessing(): void { + this.processingPointIndex = null; + if (this.animationFrameId !== null) { + cancelAnimationFrame(this.animationFrameId); + this.animationFrameId = null; + } + this.markDirty(); + } + + /** + * Whether a processing ripple is currently active. + */ + isProcessing(): boolean { + return this.processingPointIndex !== null; + } + + private scheduleAnimation(): void { + if (this.animationFrameId !== null) return; + + this.animationFrameId = requestAnimationFrame(() => { + this.animationFrameId = null; + if (this.processingPointIndex !== null) { + this.markDirty(); + this.scheduleAnimation(); + } + }); + } + + protected override renderImpl( + renderer: Renderer2D, + _renderMeta: RenderMeta + ): void { + renderer.dispose(this.containerId); + + const absPoints = this.getAbsolutePoints(); + if (absPoints.length === 0) { + this.emitLoaded(); + return; + } + + const scale = renderer.getScale() || 1; + + // 1. Draw looping ripple rings on the processing point + if ( + this.processingPointIndex !== null && + this.processingPointIndex < absPoints.length + ) { + const center = absPoints[this.processingPointIndex]; + const elapsed = performance.now() - this.processingStartTime; + // Loop: progress cycles 0→1 repeatedly + const cycleProgress = (elapsed % RIPPLE_CYCLE_MS) / RIPPLE_CYCLE_MS; + + for (let ring = 0; ring < RIPPLE_RINGS; ring++) { + // Stagger each ring + const ringProgress = (cycleProgress - ring * 0.3 + 1) % 1; + // Ease-out + const eased = 1 - Math.pow(1 - ringProgress, 3); + const rippleRadius = + (AI_POINT_RADIUS + eased * RIPPLE_MAX_RADIUS) / scale; + const opacity = (1 - eased) * 0.5; + + if (opacity > 0.01) { + renderer.drawPoint( + center, + rippleRadius * scale, // drawPoint divides by scale internally + { + strokeStyle: POSITIVE_COLOR, + lineWidth: 2, + opacity, + }, + this.containerId + ); + } + } + } + + // 2. Draw each point as a green circle with a white "+" cross + for (let i = 0; i < absPoints.length; i++) { + const isSelected = this.selectedPointIndex === i; + const center = absPoints[i]; + const radius = isSelected ? AI_POINT_SELECTED_RADIUS : AI_POINT_RADIUS; + const fillColor = isSelected ? POSITIVE_SELECTED_COLOR : POSITIVE_COLOR; + + // Green filled circle with white border + renderer.drawPoint( + center, + radius, + { + fillStyle: fillColor, + strokeStyle: CROSS_COLOR, + lineWidth: STROKE_WIDTH, + }, + this.containerId + ); + + // White "+" cross lines (60% of circle diameter) + const crossHalf = (radius * 0.6) / scale; + renderer.drawLine( + { x: center.x, y: center.y - crossHalf }, + { x: center.x, y: center.y + crossHalf }, + { strokeStyle: CROSS_COLOR, lineWidth: CROSS_LINE_WIDTH }, + this.containerId + ); + renderer.drawLine( + { x: center.x - crossHalf, y: center.y }, + { x: center.x + crossHalf, y: center.y }, + { strokeStyle: CROSS_COLOR, lineWidth: CROSS_LINE_WIDTH }, + this.containerId + ); + } + + this.emitLoaded(); + } + + override destroy(): void { + if (this.animationFrameId !== null) { + cancelAnimationFrame(this.animationFrameId); + this.animationFrameId = null; + } + this.processingPointIndex = null; + super.destroy(); + } +} diff --git a/app/packages/lighter/src/overlay/KeypointOverlay.ts b/app/packages/lighter/src/overlay/KeypointOverlay.ts index 205de7013d8..f72ea078c14 100644 --- a/app/packages/lighter/src/overlay/KeypointOverlay.ts +++ b/app/packages/lighter/src/overlay/KeypointOverlay.ts @@ -84,7 +84,7 @@ export class KeypointOverlay #relativePoints: [number, number][]; // Per-point sub-selection - private selectedPointIndex: number | null = null; + protected selectedPointIndex: number | null = null; // Drag state for individual points private dragPointIndex: number | null = null; @@ -92,7 +92,7 @@ export class KeypointOverlay private moveStartRelativePoint?: [number, number]; // Preview point for interactive creation (cursor tracking) - private previewPoint?: Point | null = null; + protected previewPoint?: Point | null = null; // Caches — invalidated in markDirty() private _absPointsCache: Point[] | null = null; @@ -135,7 +135,7 @@ export class KeypointOverlay return [(ap.x - t.offsetX) / t.scaleX, (ap.y - t.offsetY) / t.scaleY]; } - private getAbsolutePoints(): Point[] { + protected getAbsolutePoints(): Point[] { if (this._absPointsCache) return this._absPointsCache; this._absPointsCache = this.#relativePoints.map((p) => this.relativePointToAbsolute(p) @@ -158,6 +158,14 @@ export class KeypointOverlay return this.coordinateSystem; } + /** + * Public accessor for the coordinate system, used by external handlers + * (e.g. InteractiveKeypointHandler) to check sample bounds. + */ + getCoordinateSystemPublic() { + return this.coordinateSystem; + } + // --------------------------------------------------------------------------- // Spatial interface // --------------------------------------------------------------------------- diff --git a/app/packages/lighter/src/overlay/OverlayFactory.ts b/app/packages/lighter/src/overlay/OverlayFactory.ts index 799b78e1375..63c3ece16b0 100644 --- a/app/packages/lighter/src/overlay/OverlayFactory.ts +++ b/app/packages/lighter/src/overlay/OverlayFactory.ts @@ -6,6 +6,7 @@ import type { BaseOverlay } from "./BaseOverlay"; import { BoundingBoxOverlay } from "./BoundingBoxOverlay"; import { ClassificationOverlay } from "./ClassificationOverlay"; import { ImageOverlay } from "./ImageOverlay"; +import { AISegmentPointOverlay } from "./AISegmentPointOverlay"; import { KeypointOverlay } from "./KeypointOverlay"; /** @@ -34,6 +35,10 @@ export class OverlayFactory { ); factory.register("image", (opts) => new ImageOverlay(opts)); factory.register("keypoint", (opts) => new KeypointOverlay(opts)); + factory.register( + "ai-segment-point", + (opts) => new AISegmentPointOverlay(opts) + ); return factory; }