Attached is our (soon deprecated) internally developed LS client
ynouri opened this issue · 1 comments
ynouri commented
As a follow-up to #12, here's our internally developed LS client (tested against LS 1.0). I understand this is unsolicited and you probably don't need that, but I thought it was worth sharing nonetheless, in case you are curious.
Obviously our plan is to migrate to label-studio-sdk
whenever possible.
There is no actual issue here, just using GH Issues as a way to share this, so feel free to close :)
import logging
from functools import cached_property
from typing import Callable, Generator, List, Optional
from requests_toolbelt.sessions import BaseUrlSession
from tqdm.auto import tqdm
from .models import (
LabelStudioExistingProject,
LabelStudioExistingTask,
LabelStudioNewProject,
LabelStudioNewTask,
LabelStudioProjectData,
LabelStudioProjectsData,
TaskDataType,
)
class LabelStudioClient:
def __init__(self, env: str, token: str):
self.base_url = LABEL_STUDIO_BASE_URL[env]
self.token = token
self.current_project: Optional[int] = None
@cached_property
def session(self) -> BaseUrlSession:
logger.info("Creating Label Studio session.")
session = BaseUrlSession(base_url=self.base_url)
session.headers = {
"Content-Type": "application/json",
"Authorization": f"Token {self.token}",
}
return session
def log_current_user(self):
response = self.session.get("current-user/whoami")
current_user = response.json()["email"]
logger.info(f"Using user = {current_user}")
def create_project(self, project: LabelStudioNewProject) -> int:
"""Create a project and returns its id."""
logger.info("Creating project...")
response = self.session.post("projects/", data=project.json())
created_project = response.json()
self.current_project = created_project["id"]
logger.info(f"Created project with id = {self.current_project}")
return self.current_project
def get_project(
self, project_id: Optional[int] = None
) -> LabelStudioExistingProject:
project_id = self._default_to_current_project(project_id)
logger.info(f"Getting project details for project ID = {project_id}")
response = self.session.get(f"projects/{project_id}")
raise_for_status_with_logging(response)
project = LabelStudioExistingProject.parse_obj(response.json())
return project
def delete_project(self, project_id: Optional[int] = None):
project_id = self._default_to_current_project(project_id)
logger.info("Deleting project...")
response = self.session.delete(f"projects/{project_id}")
raise_for_status_with_logging(response)
logger.info("Project deleted.")
def _default_to_current_project(self, project_id: Optional[int]) -> int:
if project_id is None:
if self.current_project is None:
msg = "Current project is not initialized, please specify a project ID"
raise ValueError(msg)
logger.debug(f"Using current project id = {self.current_project}")
return self.current_project
else:
return project_id
def create_task(self, task: LabelStudioNewTask):
task.project = self._default_to_current_project(task.project)
logger.debug("Creating task.")
response = self.session.post("tasks/", data=task.json())
raise_for_status_with_logging(response)
def patch_task(self, task: LabelStudioExistingTask):
task.project = self._default_to_current_project(task.project)
logger.debug("Patching existing task.")
response = self.session.patch(f"tasks/{task.id}", data=task.json())
raise_for_status_with_logging(response)
def import_tasks(
self, tasks: LabelStudioTaskList, project_id: Optional[int] = None
):
project_id = self._default_to_current_project(project_id)
logger.info("Importing tasks in bulk...")
route = f"projects/{project_id}/tasks/bulk/"
response = self.session.post(route, data=tasks.json())
raise_for_status_with_logging(response)
logger.info("Done!")
def export_annotated_tasks(
self, project_id: Optional[int] = None, cancelled: bool = True
) -> Generator[LabelStudioExistingTask, None, None]:
project_id = self._default_to_current_project(project_id)
logger.info("Exporting annotated tasks in bulk...")
if cancelled:
route = (
f"projects/{project_id}/export?exportType=JSON&download_all_tasks=true"
)
else:
route = f"projects/{project_id}/export?exportType=JSON"
response = self.session.get(route)
raise_for_status_with_logging(response)
tasks = response.json()
for task in tasks:
yield LabelStudioExistingTask.parse_obj(task)
def get_tasks(
self, project_id: Optional[int] = None, page_size: int = 100
) -> Generator[LabelStudioExistingTask, None, None]:
project_id = self._default_to_current_project(project_id)
page_id = 1 # Label studio seems to start pagination at 1.
route = f"projects/{project_id}/tasks"
logger.info(f"Using a page size = {page_size}")
while True:
params = dict(page=page_id, page_size=page_size)
logger.debug(f"Retrieving tasks for page ID = {page_id}")
response = self.session.get(route, params=params)
# Stop if error code
if response.status_code != 200:
logger.info("Received non 200 error code, assuming no pages left.")
break
tasks = response.json()
# Stop if tasks list empty
if len(tasks) == 0:
logger.info("Received empty list, assuming no pages left.")
break
# Else, yield tasks
for task in tasks:
yield LabelStudioExistingTask.parse_obj(task)
# Go to next page
page_id = page_id + 1
def update_project_label_config(
self, label_config_xml: str, project_id: Optional[int] = None
):
project_id = self._default_to_current_project(project_id)
payload = dict(label_config=label_config_xml)
logger.info("Updating project label configuration...")
response = self.session.patch(f"projects/{project_id}/", json=payload)
raise_for_status_with_logging(response)
def create_ml_backend(
self,
title: str,
url: str,
description: str = "",
project_id: Optional[int] = None,
):
project_id = self._default_to_current_project(project_id)
payload = dict(
project=project_id,
title=title,
url=url,
description=description,
)
response = self.session.post("ml/", json=payload)
raise_for_status_with_logging(response)
def upgrade_project(
self,
upgrade_task_fun: Callable[[LabelStudioExistingTask], LabelStudioExistingTask],
project_id: Optional[int] = None,
):
project_id = self._default_to_current_project(project_id)
logger.info(f"Upgrading project ID = {project_id}...")
# Step 1: download all tasks
logger.info("Downloading task data...")
tasks = list(self.get_tasks(project_id=project_id))
logger.info(f"{len(tasks):,} tasks downloaded!")
# Step 2: upgrade and patch tasks
logger.info("Patching tasks...")
for task in tqdm(tasks):
upgraded_task = upgrade_task_fun(task)
self.patch_task(upgraded_task)
logger.info("Patching successful.")
logger.info(f"Project ID = {project_id} upgraded.")
def export_annotated_tasks_for_project_ids(
self,
project_ids: List[int],
task_data_type: TaskDataType,
include_cancelled: bool = True,
) -> LabelStudioProjectsData:
"""
Example with type annotation:
data: LabelStudioProjectsData[Transaction]
data = client.export_annotated_tasks_for_project_ids(
project_ids=[10, 11], task_data_type=Transaction
)
"""
projects_data = LabelStudioProjectsData[task_data_type]()
for project_id in project_ids:
project = self.get_project(project_id)
if include_cancelled:
tasks = list(self.export_annotated_tasks(project_id, cancelled=True))
tasks.extend(
list(self.export_annotated_tasks(project_id, cancelled=False))
)
else:
tasks = list(self.export_annotated_tasks(project_id, cancelled=False))
project_data = LabelStudioProjectData[task_data_type](
project=project,
tasks=tasks,
)
projects_data.projects.append(project_data)
return projects_data