diff --git a/packages/examples/cvat/exchange-oracle/README.md b/packages/examples/cvat/exchange-oracle/README.md index 639423301d..7d70eeac82 100644 --- a/packages/examples/cvat/exchange-oracle/README.md +++ b/packages/examples/cvat/exchange-oracle/README.md @@ -73,5 +73,6 @@ Available at `/docs` route To run tests ``` -docker compose -f docker-compose.test.yml up --build test --attach test --exit-code-from test +docker compose -f docker-compose.test.yml up --build test --attach test --exit-code-from test && \ + docker compose -f docker-compose.test.yml down ``` \ No newline at end of file diff --git a/packages/examples/cvat/exchange-oracle/poetry.lock b/packages/examples/cvat/exchange-oracle/poetry.lock index f1970962d5..8af6c9bc2d 100644 --- a/packages/examples/cvat/exchange-oracle/poetry.lock +++ b/packages/examples/cvat/exchange-oracle/poetry.lock @@ -945,13 +945,13 @@ test-randomorder = ["pytest-randomly"] [[package]] name = "cvat-sdk" -version = "2.31.0" +version = "2.37.0" description = "CVAT REST API" optional = false python-versions = ">=3.9" files = [ - {file = "cvat_sdk-2.31.0-py3-none-any.whl", hash = "sha256:b33e8526dad8c481f82e445badfced5d69747eaf7e5660b0d176cf86d394a02e"}, - {file = "cvat_sdk-2.31.0.tar.gz", hash = "sha256:aaeff833c32bfe711f418c62bdab135e0746eff0e89757e8b61cfad14a42ef23"}, + {file = "cvat_sdk-2.37.0-py3-none-any.whl", hash = "sha256:faa94cfd6678089814179a8da828761dfa3daf08eb752490ee85551a1045dac5"}, + {file = "cvat_sdk-2.37.0.tar.gz", hash = "sha256:e990908a473c499eb6d7b84f7f2e640ea729ef027d4c4cc32a5a925752532689"}, ] [package.dependencies] @@ -5047,4 +5047,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "c643f28ae7113ae0b8051952c0adfcb74a0ae182ba5039133f57f52b78608007" +content-hash = "8bf7f09b99af5cd8b02a36fc0a1b5ad4af28d5d17d7c0275afa38edb281c3cfc" diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index 187d9c81e3..5dd04c1813 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -16,7 +16,7 @@ sqlalchemy-utils = "^0.41.1" alembic = "^1.11.1" httpx = "^0.24.1" pytest = "^7.2.2" -cvat-sdk = "2.31.0" +cvat-sdk = "2.37.0" sqlalchemy = "^2.0.16" apscheduler = "^3.10.1" xmltodict = "^0.13.0" diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index b07c93515c..88a546cff2 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -77,6 +77,8 @@ CVAT_IOU_THRESHOLD= CVAT_OKS_SIGMA= CVAT_EXPORT_TIMEOUT= CVAT_IMPORT_TIMEOUT= +CVAT_PROJECTS_PAGE_SIZE= +CVAT_JOBS_PAGE_SIZE= # Storage Config (S3/GCS) diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 9e40bb1265..245380cc58 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -146,7 +146,7 @@ class CronConfig: "Maximum number of downloading attempts per job or project during results downloading" track_completed_escrows_jobs_downloading_batch_size = int( - getenv("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 500) + getenv("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 10) ) "Maximum number of parallel downloading requests during results downloading" @@ -183,6 +183,9 @@ class CvatConfig: incoming_webhooks_url = getenv("CVAT_INCOMING_WEBHOOKS_URL") webhook_secret = getenv("CVAT_WEBHOOK_SECRET", "thisisasamplesecret") + projects_page_size = int(getenv("CVAT_PROJECTS_PAGE_SIZE", 100)) + jobs_page_size = int(getenv("CVAT_JOBS_PAGE_SIZE", 100)) + class StorageConfig: provider: ClassVar[str] = os.environ["STORAGE_PROVIDER"].lower() diff --git a/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py b/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py index dc70d33a91..d48397e5d1 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py +++ b/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py @@ -23,6 +23,9 @@ class RoiInfo: bbox_y: int bbox_label: int + point_x: int + point_y: int + # RoI is centered on the bbox center # Coordinates can be out of image boundaries. # In this case RoI includes extra margins to be centered on bbox center @@ -117,7 +120,10 @@ def parse_skeleton_bbox_mapping(self, skeleton_bbox_mapping_data: bytes) -> Skel return {int(k): int(v) for k, v in parse_json(skeleton_bbox_mapping_data).items()} def parse_roi_info(self, rois_info_data: bytes) -> RoiInfos: - return [RoiInfo(**roi_info) for roi_info in parse_json(rois_info_data)] + return [ + RoiInfo(**{"point_x": 0, "point_y": 0, **roi_info}) + for roi_info in parse_json(rois_info_data) + ] def parse_roi_filenames(self, roi_filenames_data: bytes) -> RoiFilenames: return {int(k): v for k, v in parse_json(roi_filenames_data).items()} diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index 27c869c863..06934225b4 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -47,15 +47,23 @@ def _request_annotations(endpoint: Endpoint, cvat_id: int, format_name: str) -> _get_annotations(request_id, ...) """ - (_, response) = endpoint.call_with_http_info( - id=cvat_id, - format=format_name, - save_images=False, - _parse_response=False, - ) + try: + (_, response) = endpoint.call_with_http_info( + id=cvat_id, + format=format_name, + save_images=False, + _parse_response=False, + ) + + assert response.status in [HTTPStatus.ACCEPTED, HTTPStatus.CREATED] + rq_id = response.json()["rq_id"] + except exceptions.ApiException as e: + if e.status == HTTPStatus.CONFLICT: + rq_id = json.loads(e.body)["rq_id"] + else: + raise - assert response.status in [HTTPStatus.ACCEPTED, HTTPStatus.CREATED] - return response.json()["rq_id"] + return rq_id def _get_annotations( @@ -462,6 +470,7 @@ def fetch_task_jobs(task_id: int) -> list[models.JobRead]: api_client.jobs_api.list_endpoint, task_id=task_id, type="annotation", + page_size=Config.cvat_config.jobs_page_size, ) except exceptions.ApiException as e: logger.exception(f"Exception when calling JobsApi.list: {e}\n") @@ -535,6 +544,7 @@ def fetch_projects(assignee: str = "") -> list[models.ProjectRead]: return get_paginated_collection( api_client.projects_api.list_endpoint, **({"assignee": assignee} if assignee else {}), + page_size=Config.cvat_config.projects_page_size, ) except exceptions.ApiException as e: logger.exception(f"Exception when calling ProjectsApi.list(): {e}\n") @@ -711,6 +721,7 @@ def update_quality_control_settings( logger = logging.getLogger("app") params = { + "inherit": False, "max_validations_per_job": max_validations_per_job, "target_metric": target_metric, "target_metric_threshold": target_metric_threshold, diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index c9a474fe67..96949a1e23 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -20,7 +20,7 @@ import datumaro as dm import numpy as np from datumaro.util import filter_dict, take_by -from datumaro.util.annotation_util import BboxCoords, bbox_iou +from datumaro.util.annotation_util import BboxCoords, bbox_iou, find_instances from datumaro.util.image import IMAGE_EXTENSIONS, decode_image, encode_image import src.core.tasks.boxes_from_points as boxes_from_points_task @@ -1709,13 +1709,18 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) - ) "Minimum absolute ROI size, (w, h)" - self.boxes_format = "coco_instances" + self.boxes_format = "coco_person_keypoints" self.embed_bbox_in_roi_image = True "Put a bbox into the extracted skeleton RoI images" self.embed_tile_border = True + self.embedded_point_radius = 15 + self.min_embedded_point_radius_percent = 0.005 + self.max_embedded_point_radius_percent = 0.01 + self.embedded_point_color = (0, 255, 255) + self.roi_embedded_bbox_color = (0, 255, 255) # BGR self.roi_background_color = (245, 240, 242) # BGR - CVAT background color @@ -1729,6 +1734,9 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) - GT annotations or samples for successful job launch """ + self.gt_id_attribute = "object_id" + "An additional way to match GT skeletons with input boxes" + # TODO: probably, need to also add an absolute number of minimum GT RoIs per class def _download_input_data(self): @@ -1948,7 +1956,7 @@ def _validate_boxes_filenames(self): ) ) - def _validate_boxes_annotations(self): + def _validate_boxes_annotations(self): # noqa: PLR0912 # Convert possible polygons and masks into boxes self._boxes_dataset.transform(InstanceSegmentsToBbox) self._boxes_dataset.init_cache() @@ -1962,15 +1970,70 @@ def _validate_boxes_annotations(self): # Could fail on this as well image_h, image_w = sample.media_as(dm.Image).size - sample_boxes = [a for a in sample.annotations if isinstance(a, dm.Bbox)] - valid_boxes = [] - for bbox in sample_boxes: - if not ( - (0 <= int(bbox.x) < int(bbox.x + bbox.w) <= image_w) - and (0 <= int(bbox.y) < int(bbox.y + bbox.h) <= image_h) - ): + valid_instances: list[tuple[dm.Bbox, dm.Points]] = [] + instances = find_instances( + [a for a in sample.annotations if isinstance(a, dm.Bbox | dm.Skeleton)] + ) + for instance_anns in instances: + if len(instance_anns) != 2: + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - unexpected group size ({})".format( + sample.id, + instance_anns[0].id, + label_cat[instance_anns[0].label].name, + len(instance_anns), + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + + bbox = next((a for a in instance_anns if isinstance(a, dm.Bbox)), None) + if not bbox: + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - no matching bbox".format( + sample.id, instance_anns[0].id, label_cat[instance_anns[0].label].name + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + + skeleton = next((a for a in instance_anns if isinstance(a, dm.Skeleton)), None) + if not skeleton: + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - no matching skeleton".format( + sample.id, instance_anns[0].id, label_cat[instance_anns[0].label].name + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + + if len(skeleton.elements) != 1 or len(skeleton.elements[0].points) != 2: + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - invalid skeleton points".format( + sample.id, skeleton.id, label_cat[skeleton.label].name + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + + point = skeleton.elements[0] + if not is_point_in_bbox(point.points[0], point.points[1], (0, 0, image_w, image_h)): excluded_boxes_info.add_message( - "Sample '{}': bbox #{} ({}) skipped - invalid coordinates".format( + "Sample '{}': object #{} ({}) skipped - invalid point coordinates".format( + sample.id, skeleton.id, label_cat[skeleton.label].name + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + + if not is_point_in_bbox(int(bbox.x), int(bbox.y), (0, 0, image_w, image_h)): + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - invalid bbox coordinates".format( sample.id, bbox.id, label_cat[bbox.label].name ), sample_id=sample.id, @@ -1978,6 +2041,16 @@ def _validate_boxes_annotations(self): ) continue + if not is_point_in_bbox(point.points[0], point.points[1], bbox): + excluded_boxes_info.add_message( + "Sample '{}': object #{} ({}) skipped - point is outside the bbox".format( + sample.id, skeleton.id, label_cat[skeleton.label].name + ), + sample_id=sample.id, + sample_subset=sample.subset, + ) + continue + if bbox.id in visited_ids: excluded_boxes_info.add_message( "Sample '{}': bbox #{} ({}) skipped - repeated annotation id {}".format( @@ -1988,14 +2061,18 @@ def _validate_boxes_annotations(self): ) continue - valid_boxes.append(bbox) + valid_instances.append( + (bbox, point.wrap(group=bbox.group, id=bbox.id, attributes=bbox.attributes)) + ) visited_ids.add(bbox.id) - excluded_boxes_info.excluded_count += len(sample_boxes) - len(valid_boxes) - excluded_boxes_info.total_count += len(sample_boxes) + excluded_boxes_info.excluded_count += len(instances) - len(valid_instances) + excluded_boxes_info.total_count += len(instances) - if len(valid_boxes) != len(sample.annotations): - self._boxes_dataset.put(sample.wrap(annotations=valid_boxes)) + if len(valid_instances) != len(sample.annotations): + self._boxes_dataset.put( + sample.wrap(annotations=list(chain.from_iterable(valid_instances))) + ) if excluded_boxes_info.excluded_count > ceil( excluded_boxes_info.total_count * self.max_discarded_threshold @@ -2066,8 +2143,14 @@ def _find_unambiguous_matches( input_boxes: list[dm.Bbox], gt_skeletons: list[dm.Skeleton], *, + input_points: list[dm.Points], gt_annotations: list[dm.Annotation], ) -> list[tuple[dm.Bbox, dm.Skeleton]]: + bbox_point_mapping: dict[int, dm.Points] = { + bbox.id: next(p for p in input_points if p.group == bbox.group) + for bbox in input_boxes + } + matches = [ [ (input_bbox.label == gt_skeleton.label) @@ -2077,6 +2160,18 @@ def _find_unambiguous_matches( self._get_skeleton_bbox(gt_skeleton, gt_annotations), ) ) + and (input_point := bbox_point_mapping[input_bbox.id]) + and is_point_in_bbox( + input_point.points[0], + input_point.points[1], + self._get_skeleton_bbox(gt_skeleton, gt_annotations), + ) + and ( + # a way to customize matching if the default method is too rough + not (bbox_id := input_bbox.attributes.get(self.gt_id_attribute)) + or not (skeleton_id := gt_skeleton.attributes.get(self.gt_id_attribute)) + or bbox_id == skeleton_id + ) for gt_skeleton in gt_skeletons ] for input_bbox in input_boxes @@ -2167,10 +2262,11 @@ def _find_good_gt_skeletons( input_boxes: list[dm.Bbox], gt_skeletons: list[dm.Skeleton], *, + input_points: list[dm.Points], gt_annotations: list[dm.Annotation], ) -> list[dm.Skeleton]: matches = _find_unambiguous_matches( - input_boxes, gt_skeletons, gt_annotations=gt_annotations + input_boxes, gt_skeletons, input_points=input_points, gt_annotations=gt_annotations ) matched_skeletons = [] @@ -2221,13 +2317,18 @@ def _find_good_gt_skeletons( gt_skeletons = [a for a in gt_sample.annotations if isinstance(a, dm.Skeleton)] input_boxes = [a for a in boxes_sample.annotations if isinstance(a, dm.Bbox)] + input_points = [a for a in boxes_sample.annotations if isinstance(a, dm.Points)] + assert len(input_boxes) == len(input_points) # Samples without boxes are allowed, so we just skip them without an error if not gt_skeletons: continue matched_skeletons = _find_good_gt_skeletons( - input_boxes, gt_skeletons, gt_annotations=gt_sample.annotations + input_boxes, + gt_skeletons, + input_points=input_points, + gt_annotations=gt_sample.annotations, ) if not matched_skeletons: continue @@ -2294,9 +2395,10 @@ def _prepare_roi_infos(self): rois: list[skeletons_from_boxes_task.RoiInfo] = [] for sample in self._boxes_dataset: - for bbox in sample.annotations: - if not isinstance(bbox, dm.Bbox): - continue + instances = find_instances(sample.annotations) + for instance_anns in instances: + bbox = next(a for a in instance_anns if isinstance(a, dm.Bbox)) + point = next(a for a in instance_anns if isinstance(a, dm.Points)) # RoI is centered on bbox center original_bbox_cx = int(bbox.x + bbox.w / 2) @@ -2320,6 +2422,8 @@ def _prepare_roi_infos(self): bbox_label=bbox.label, bbox_x=new_bbox_x, bbox_y=new_bbox_y, + point_x=point.points[0] - roi_x, + point_y=point.points[1] - roi_y, roi_x=roi_x, roi_y=roi_y, roi_w=roi_w, @@ -2511,6 +2615,32 @@ def _draw_roi_bbox(self, roi_image: np.ndarray, bbox: dm.Bbox) -> np.ndarray: cv2.LINE_4, ) + def _draw_roi_point(self, roi_image: np.ndarray, point: tuple[float, float]) -> np.ndarray: + roi_r = (roi_image.shape[0] ** 2 + roi_image.shape[1] ** 2) ** 0.5 / 2 + radius = int( + min( + self.max_embedded_point_radius_percent * roi_r, + max(self.embedded_point_radius, self.min_embedded_point_radius_percent * roi_r), + ) + ) + + roi_image = cv2.circle( + roi_image, + tuple(map(int, (point[0], point[1]))), + radius + 1, + (255, 255, 255), + -1, + cv2.LINE_4, + ) + return cv2.circle( + roi_image, + tuple(map(int, (point[0], point[1]))), + radius, + self.embedded_point_color, + -1, + cv2.LINE_4, + ) + def _extract_and_upload_rois(self): assert self._roi_filenames is not _unset assert self._roi_infos is not _unset @@ -2564,6 +2694,9 @@ def process_file(filename: str, image_pixels: np.ndarray): if self.embed_bbox_in_roi_image: roi_pixels = self._draw_roi_bbox(roi_pixels, bbox_by_id[roi_info.bbox_id]) + roi_pixels = self._draw_roi_point( + roi_pixels, (roi_info.point_x, roi_info.point_y) + ) filename = self._roi_filenames[roi_info.bbox_id] roi_bytes = encode_image(roi_pixels, os.path.splitext(filename)[-1]) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index ca0e980c1e..5b2bbc96e0 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -794,7 +794,8 @@ def count_jobs_by_escrow_address( def get_free_job( session: Session, - cvat_projects: list[int], + escrow_address: str, + chain_id: int, *, user_wallet_address: str, for_update: bool | ForUpdateParams = False, @@ -805,7 +806,11 @@ def get_free_job( return ( _maybe_for_update(session.query(Job), enable=for_update) .where( - Job.cvat_project_id.in_(cvat_projects), + Job.project.has( + (Project.escrow_address == escrow_address) + & (Project.chain_id == chain_id) + & (Project.status == ProjectStatuses.annotation) + ), Job.status == JobStatuses.new, ~Job.assignments.any( ( @@ -984,22 +989,27 @@ def get_user_assignments_in_cvat_projects( ) -def count_active_user_assignments( +def has_active_user_assignments( session: Session, wallet_address: int, - cvat_projects: list[int], -) -> int: - return ( + escrow_address: str, + chain_id: int, +) -> bool: + return session.query( session.query(Assignment) .where( - Assignment.job.has(Job.cvat_project_id.in_(cvat_projects)), + Assignment.job.has( + Job.project.has( + (Project.escrow_address == escrow_address) & (Project.chain_id == chain_id) + ) + ), Assignment.user_wallet_address == wallet_address, Assignment.status == AssignmentStatuses.created.value, Assignment.completed_at == None, utcnow() < Assignment.expires_at, ) - .count() - ) + .exists() + ).scalar() # Image diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index c1fec576f0..23a7b8e1cc 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -18,7 +18,7 @@ def __str__(self) -> str: ) -def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: str) -> str | None: # noqa: ARG001 (don't we want to use chain_id for filter?) +def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: str) -> str | None: with SessionLocal.begin() as session: user = get_or_404( cvat_service.get_user_by_id(session, wallet_address, for_update=True), @@ -26,44 +26,37 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s "user", ) - # There can be several projects under one escrow, we need any - project = cvat_service.get_project_by_escrow_address( + if cvat_service.has_active_user_assignments( session, - escrow_address, - status_in=[ - ProjectStatuses.annotation - ], # avoid unnecessary locking on completed projects - for_update=True, - ) - - if not project: - # Retry without a lock to check if the project doesn't exist - get_or_404( - cvat_service.get_project_by_escrow_address( - session, escrow_address, status_in=[ProjectStatuses.annotation] - ), - escrow_address, - "job", - ) - return None - - has_active_assignments = ( - cvat_service.count_active_user_assignments( - session, wallet_address=wallet_address, cvat_projects=[project.cvat_id] - ) - > 0 - ) - if has_active_assignments: + wallet_address=wallet_address, + escrow_address=escrow_address, + chain_id=chain_id.value, + ): raise UserHasUnfinishedAssignmentError( "The user already has an unfinished assignment in this project" ) + # TODO: Try to put into 1 request. SQLAlchemy generates 2 queries with simple + # .options(selectinload(Job.project)) + project = get_or_404( + cvat_service.get_project_by_escrow_address( + session, escrow_address, status_in=[ProjectStatuses.annotation] + ), + escrow_address, + "job", + ) + unassigned_job = cvat_service.get_free_job( session, - cvat_projects=[project.cvat_id], + escrow_address=escrow_address, + chain_id=chain_id.value, user_wallet_address=wallet_address, for_update=True, + # lock the job to be able to make a rollback if CVAT requests fail + # can potentially be optimized to make less DB requests + # and rely only on assignment expiration ) + if not unassigned_job: return None @@ -72,7 +65,12 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s wallet_address=user.wallet_address, cvat_job_id=unassigned_job.cvat_id, expires_at=utcnow() - + timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))), + + timedelta( + seconds=get_default_assignment_timeout( + TaskTypes(project.job_type) + # TODO: need to update this if we have multiple job types per escrow + ) + ), ) cvat_service.touch(session, Job, [unassigned_job.id]) diff --git a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py index 075ce7d035..110ec2b069 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py @@ -8,7 +8,7 @@ import datumaro as dm import numpy as np from datumaro.util import filter_dict, mask_tools -from datumaro.util.annotation_util import find_group_leader, find_instances, max_bbox +from datumaro.util.annotation_util import BboxCoords, find_group_leader, find_instances, max_bbox from defusedxml import ElementTree @@ -343,8 +343,12 @@ def transform_item(self, item): return item.wrap(annotations=annotations) -def is_point_in_bbox(px: float, py: float, bbox: dm.Bbox) -> bool: - return (bbox.x <= px <= bbox.x + bbox.w) and (bbox.y <= py <= bbox.y + bbox.h) +def is_point_in_bbox(px: float, py: float, bbox: dm.Bbox | BboxCoords) -> bool: + if isinstance(bbox, dm.Bbox): + bbox = bbox.get_bbox() + + x, y, w, h = bbox + return (x <= px <= x + w) and (y <= py <= y + h) class InstanceSegmentsToBbox(dm.ItemTransform): diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index 25d1e6f81b..7d074d12e9 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -8,13 +8,14 @@ from fastapi import HTTPException from pydantic import ValidationError -from src.core.types import AssignmentStatuses, JobStatuses, Networks +from src.core.types import AssignmentStatuses, JobStatuses, Networks, TaskTypes from src.db import SessionLocal from src.endpoints.serializers import serialize_job from src.models.cvat import Assignment, User from src.schemas import exchange as service_api from src.services.exchange import create_assignment +from tests.utils.constants import ESCROW_ADDRESS, WALLET_ADDRESS1, WALLET_ADDRESS2 from tests.utils.db_helper import ( create_job, create_project, @@ -32,7 +33,7 @@ def tearDown(self): def test_serialize_job(self): cvat_id = 1 - escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + escrow_address = ESCROW_ADDRESS cvat_project = create_project(self.session, escrow_address, cvat_id) self.session.commit() @@ -64,7 +65,7 @@ def test_serialize_task_invalid_project(self): def test_serialize_task_invalid_manifest(self): cvat_id = 1 - escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + escrow_address = ESCROW_ADDRESS cvat_project = create_project(self.session, escrow_address, cvat_id) self.session.commit() @@ -75,10 +76,8 @@ def test_serialize_task_invalid_manifest(self): serialize_job(cvat_project) def test_create_assignment(self): - cvat_project_1, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) - user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + cvat_project_1, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, cvat_email="test@hmt.ai", @@ -87,15 +86,9 @@ def test_create_assignment(self): self.session.add(user) self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project_1.escrow_address, cvat_project_1.chain_id, user_address + cvat_project_1.escrow_address, Networks(cvat_project_1.chain_id), user_address ) assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() @@ -105,15 +98,13 @@ def test_create_assignment(self): assert assignment.status == AssignmentStatuses.created def test_create_assignment_many_jobs_1_completed(self): - cvat_project, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) + cvat_project, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) cvat_job_1.status = JobStatuses.completed.value cvat_task_2 = create_task(self.session, 2, cvat_project.cvat_id) cvat_job_2 = create_job(self.session, 2, cvat_task_2.cvat_id, cvat_project.cvat_id) - user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, cvat_email="test@hmt.ai", @@ -135,15 +126,9 @@ def test_create_assignment_many_jobs_1_completed(self): self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project.escrow_address, cvat_project.chain_id, user_address + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address ) assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() @@ -153,20 +138,18 @@ def test_create_assignment_many_jobs_1_completed(self): assert assignment.status == AssignmentStatuses.created def test_create_assignment_invalid_user_address(self): - cvat_project_1, _, _ = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) + cvat_project_1, _, _ = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) self.session.commit() - with pytest.raises(HTTPException): + with pytest.raises(HTTPException, match="Can't find user"): create_assignment( cvat_project_1.escrow_address, - cvat_project_1.chain_id, + Networks(cvat_project_1.chain_id), "invalid_address", ) def test_create_assignment_invalid_project(self): - user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, cvat_email="test@hmt.ai", @@ -175,14 +158,12 @@ def test_create_assignment_invalid_project(self): self.session.add(user) self.session.commit() - with pytest.raises(HTTPException): - create_assignment("1", Networks.localhost.value, user_address) + with pytest.raises(HTTPException, match="Can't find job"): + create_assignment("1", Networks.localhost, user_address) def test_create_assignment_unfinished_assignment(self): - _, _, cvat_job = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) - user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + _, _, cvat_job = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) + user_address = WALLET_ADDRESS1 user = User( wallet_address=user_address, cvat_email="test@hmt.ai", @@ -200,23 +181,52 @@ def test_create_assignment_unfinished_assignment(self): self.session.commit() with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, patch("src.services.exchange.cvat_api"), + pytest.raises(Exception, match="unfinished assignment"), ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + create_assignment(ESCROW_ADDRESS, Networks.localhost, user_address) - with pytest.raises(HTTPException): - create_assignment("1", Networks.localhost.value, user_address) + def test_create_assignment_has_expired_assignment_and_available_jobs(self): + escrow_address = ESCROW_ADDRESS + project1, _, cvat_job1 = create_project_task_and_job(self.session, escrow_address, 1) + project2, _, cvat_job2 = create_project_task_and_job(self.session, escrow_address, 2) + project1.job_type = TaskTypes.image_skeletons_from_boxes + project2.job_type = TaskTypes.image_skeletons_from_boxes + self.session.add_all([project1, project2]) - def test_create_assignment_no_available_jobs_completed_assignment(self): - cvat_project, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + user_address = WALLET_ADDRESS1 + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, ) + self.session.add(user) + + old_assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address, + cvat_job_id=cvat_job1.cvat_id, + created_at=datetime.now() - timedelta(hours=1), + expires_at=datetime.now() - timedelta(minutes=1), + status=AssignmentStatuses.expired.value, + ) + self.session.add(old_assignment) + + self.session.commit() + + with patch("src.services.exchange.cvat_api"): + new_assignment_id = create_assignment(escrow_address, Networks.localhost, user_address) + + new_assignment = self.session.query(Assignment).filter_by(id=new_assignment_id).first() + assert new_assignment.cvat_job_id == cvat_job2.cvat_id # job1 was attempted already + assert new_assignment.user_wallet_address == user_address + assert new_assignment.status == AssignmentStatuses.created + + def test_create_assignment_no_available_jobs_completed_assignment(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) cvat_job_1.status = JobStatuses.completed.value - user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user_address1 = WALLET_ADDRESS1 user1 = User( wallet_address=user_address1, cvat_email="test1@hmt.ai", @@ -224,7 +234,7 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): ) self.session.add(user1) - user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user_address2 = WALLET_ADDRESS2 user2 = User( wallet_address=user_address2, cvat_email="test2@hmt.ai", @@ -246,25 +256,17 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project.escrow_address, cvat_project.chain_id, user_address2 + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2 ) assert assignment_id == None def test_create_assignment_no_available_jobs_active_foreign_assignment(self): - cvat_project, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) + cvat_project, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) - user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user_address1 = WALLET_ADDRESS1 user1 = User( wallet_address=user_address1, cvat_email="test1@hmt.ai", @@ -272,7 +274,7 @@ def test_create_assignment_no_available_jobs_active_foreign_assignment(self): ) self.session.add(user1) - user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user_address2 = WALLET_ADDRESS2 user2 = User( wallet_address=user_address2, cvat_email="test2@hmt.ai", @@ -290,27 +292,19 @@ def test_create_assignment_no_available_jobs_active_foreign_assignment(self): self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project.escrow_address, cvat_project.chain_id, user_address2 + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2 ) assert assignment_id == None def test_create_assignment_wont_reassign_job_to_previous_user(self): - cvat_project_1, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) + cvat_project_1, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) cvat_job_1.status = JobStatuses.new.value # validated and rejected return to 'new' user = User( - wallet_address="0x86e83d346041E8806e352681f3F14549C0d2BC69", + wallet_address=WALLET_ADDRESS1, cvat_email="test@hmt.ai", cvat_id=1, ) @@ -330,32 +324,26 @@ def test_create_assignment_wont_reassign_job_to_previous_user(self): self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project_1.escrow_address, cvat_project_1.chain_id, user.wallet_address + cvat_project_1.escrow_address, + Networks(cvat_project_1.chain_id), + user.wallet_address, ) assert assignment_id is None def test_create_assignment_can_assign_job_to_new_user(self): - cvat_project_1, _, cvat_job_1 = create_project_task_and_job( - self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - ) + cvat_project_1, _, cvat_job_1 = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) cvat_job_1.status = JobStatuses.new.value # validated and rejected return to 'new' previous_user = User( - wallet_address="0x86e83d346041E8806e352681f3F14549C0d2BC69", + wallet_address=WALLET_ADDRESS1, cvat_email="previous@hmt.ai", cvat_id=1, ) new_user = User( - wallet_address="0x69e83d346041E8806e352681f3F14549C0d2BC42", + wallet_address=WALLET_ADDRESS2, cvat_email="new@hmt.ai", cvat_id=2, ) @@ -376,15 +364,11 @@ def test_create_assignment_can_assign_job_to_new_user(self): self.session.commit() - with ( - open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, - patch("src.services.exchange.cvat_api"), - ): - manifest = json.load(data) - mock_get_manifest.return_value = manifest + with patch("src.services.exchange.cvat_api"): assignment_id = create_assignment( - cvat_project_1.escrow_address, cvat_project_1.chain_id, new_user.wallet_address + cvat_project_1.escrow_address, + Networks(cvat_project_1.chain_id), + new_user.wallet_address, ) assignment = self.session.get(Assignment, assignment_id) diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/constants.py b/packages/examples/cvat/exchange-oracle/tests/utils/constants.py index 2a5082c255..a1d535a814 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/constants.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/constants.py @@ -20,6 +20,9 @@ TOKEN_ADDRESS = "0x976EA74026E726554dB657fA54763abd0C3a0aa9" FACTORY_ADDRESS = "0x14dC79964da2C08b23698B3D3cc7Ca32193d9955" +WALLET_ADDRESS1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" +WALLET_ADDRESS2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + DEFAULT_MANIFEST_URL = "http://host.docker.internal:9000/manifests/manifest.json" DEFAULT_HASH = "test" diff --git a/packages/examples/cvat/recording-oracle/poetry.lock b/packages/examples/cvat/recording-oracle/poetry.lock index 4b99cada57..52dd95468a 100644 --- a/packages/examples/cvat/recording-oracle/poetry.lock +++ b/packages/examples/cvat/recording-oracle/poetry.lock @@ -914,13 +914,13 @@ test-randomorder = ["pytest-randomly"] [[package]] name = "cvat-sdk" -version = "2.31.0" +version = "2.37.0" description = "CVAT REST API" optional = false python-versions = ">=3.9" files = [ - {file = "cvat_sdk-2.31.0-py3-none-any.whl", hash = "sha256:b33e8526dad8c481f82e445badfced5d69747eaf7e5660b0d176cf86d394a02e"}, - {file = "cvat_sdk-2.31.0.tar.gz", hash = "sha256:aaeff833c32bfe711f418c62bdab135e0746eff0e89757e8b61cfad14a42ef23"}, + {file = "cvat_sdk-2.37.0-py3-none-any.whl", hash = "sha256:faa94cfd6678089814179a8da828761dfa3daf08eb752490ee85551a1045dac5"}, + {file = "cvat_sdk-2.37.0.tar.gz", hash = "sha256:e990908a473c499eb6d7b84f7f2e640ea729ef027d4c4cc32a5a925752532689"}, ] [package.dependencies] @@ -4732,4 +4732,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "5f830a339a6f870a60e94be16dc742280e0ec9002fb4a51404fca9e18a6f399f" +content-hash = "3f4ce0cc7668a0c9ffaa02c1306404603d41215e39199c212dabef36ab112a7c" diff --git a/packages/examples/cvat/recording-oracle/pyproject.toml b/packages/examples/cvat/recording-oracle/pyproject.toml index fa03eb2769..194543a267 100644 --- a/packages/examples/cvat/recording-oracle/pyproject.toml +++ b/packages/examples/cvat/recording-oracle/pyproject.toml @@ -24,7 +24,7 @@ google-cloud-storage = "^2.14.0" datumaro = {git = "https://github.com/cvat-ai/datumaro.git", rev = "ff83c00c2c1bc4b8fdfcc55067fcab0a9b5b6b11"} hexbytes = ">=1.2.0" # required for to_0x_hex() function starlette = ">=0.40.0" # avoid the vulnerability with multipart/form-data -cvat-sdk = "2.31.0" +cvat-sdk = "2.37.0" cryptography = "<44.0.0" # human-protocol-sdk -> pgpy dep requires cryptography < 45 human-protocol-sdk = "^4.0.3" diff --git a/packages/examples/cvat/recording-oracle/src/.env.template b/packages/examples/cvat/recording-oracle/src/.env.template index 6abb5d245f..a5851b0fb6 100644 --- a/packages/examples/cvat/recording-oracle/src/.env.template +++ b/packages/examples/cvat/recording-oracle/src/.env.template @@ -61,6 +61,7 @@ CVAT_ADMIN_PASS= CVAT_ORG_SLUG= CVAT_QUALITY_RETRIEVAL_TIMEOUT= CVAT_QUALITY_CHECK_INTERVAL= +CVAT_QUALITY_REPORTS_PAGE_SIZE= # Localhost diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index fa262b8b0b..248f51cc6e 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -195,6 +195,7 @@ class ValidationConfig: warmup_iterations = int(getenv("WARMUP_ITERATIONS", "1")) """ The first escrow iterations where the annotation speed is checked to be big enough. + Set to 0 to disable. """ min_warmup_progress = float(getenv("MIN_WARMUP_PROGRESS", "10")) @@ -234,6 +235,7 @@ class CvatConfig: quality_retrieval_timeout = int(getenv("CVAT_QUALITY_RETRIEVAL_TIMEOUT", 60 * 60)) quality_check_interval = int(getenv("CVAT_QUALITY_CHECK_INTERVAL", 5)) + quality_reports_page_size = int(getenv("CVAT_QUALITY_REPORTS_PAGE_SIZE", 100)) class Config: diff --git a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py index 0b51e7986c..574e44c918 100644 --- a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py @@ -134,7 +134,10 @@ def get_jobs_quality_reports(parent_id: int) -> list[models.QualityReport]: with get_api_client() as api_client: try: return get_paginated_collection( - api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job" + api_client.quality_api.list_reports_endpoint, + parent_id=parent_id, + target="job", + page_size=Config.cvat_config.quality_reports_page_size, ) except exceptions.ApiException as e: