HumanSignal/label-studio-sdk

Attached is our (soon deprecated) internally developed LS client

ynouri opened this issue · 1 comments

As a follow-up to #12, here's our internally developed LS client (tested against LS 1.0). I understand this is unsolicited and you probably don't need that, but I thought it was worth sharing nonetheless, in case you are curious.

Obviously our plan is to migrate to label-studio-sdk whenever possible.

There is no actual issue here, just using GH Issues as a way to share this, so feel free to close :)

import logging
from functools import cached_property
from typing import Callable, Generator, List, Optional

from requests_toolbelt.sessions import BaseUrlSession
from tqdm.auto import tqdm

from .models import (
    LabelStudioExistingProject,
    LabelStudioExistingTask,
    LabelStudioNewProject,
    LabelStudioNewTask,
    LabelStudioProjectData,
    LabelStudioProjectsData,
    TaskDataType,
)

class LabelStudioClient:
    def __init__(self, env: str, token: str):
        self.base_url = LABEL_STUDIO_BASE_URL[env]
        self.token = token
        self.current_project: Optional[int] = None

    @cached_property
    def session(self) -> BaseUrlSession:
        logger.info("Creating Label Studio session.")
        session = BaseUrlSession(base_url=self.base_url)
        session.headers = {
            "Content-Type": "application/json",
            "Authorization": f"Token {self.token}",
        }
        return session

    def log_current_user(self):
        response = self.session.get("current-user/whoami")
        current_user = response.json()["email"]
        logger.info(f"Using user = {current_user}")

    def create_project(self, project: LabelStudioNewProject) -> int:
        """Create a project and returns its id."""
        logger.info("Creating project...")
        response = self.session.post("projects/", data=project.json())
        created_project = response.json()
        self.current_project = created_project["id"]
        logger.info(f"Created project with id = {self.current_project}")
        return self.current_project

    def get_project(
        self, project_id: Optional[int] = None
    ) -> LabelStudioExistingProject:
        project_id = self._default_to_current_project(project_id)
        logger.info(f"Getting project details for project ID = {project_id}")
        response = self.session.get(f"projects/{project_id}")
        raise_for_status_with_logging(response)
        project = LabelStudioExistingProject.parse_obj(response.json())
        return project

    def delete_project(self, project_id: Optional[int] = None):
        project_id = self._default_to_current_project(project_id)
        logger.info("Deleting project...")
        response = self.session.delete(f"projects/{project_id}")
        raise_for_status_with_logging(response)
        logger.info("Project deleted.")

    def _default_to_current_project(self, project_id: Optional[int]) -> int:
        if project_id is None:
            if self.current_project is None:
                msg = "Current project is not initialized, please specify a project ID"
                raise ValueError(msg)
            logger.debug(f"Using current project id = {self.current_project}")
            return self.current_project
        else:
            return project_id

    def create_task(self, task: LabelStudioNewTask):
        task.project = self._default_to_current_project(task.project)
        logger.debug("Creating task.")
        response = self.session.post("tasks/", data=task.json())
        raise_for_status_with_logging(response)

    def patch_task(self, task: LabelStudioExistingTask):
        task.project = self._default_to_current_project(task.project)
        logger.debug("Patching existing task.")
        response = self.session.patch(f"tasks/{task.id}", data=task.json())
        raise_for_status_with_logging(response)

    def import_tasks(
        self, tasks: LabelStudioTaskList, project_id: Optional[int] = None
    ):
        project_id = self._default_to_current_project(project_id)
        logger.info("Importing tasks in bulk...")
        route = f"projects/{project_id}/tasks/bulk/"
        response = self.session.post(route, data=tasks.json())
        raise_for_status_with_logging(response)
        logger.info("Done!")

    def export_annotated_tasks(
        self, project_id: Optional[int] = None, cancelled: bool = True
    ) -> Generator[LabelStudioExistingTask, None, None]:
        project_id = self._default_to_current_project(project_id)
        logger.info("Exporting annotated tasks in bulk...")
        if cancelled:
            route = (
                f"projects/{project_id}/export?exportType=JSON&download_all_tasks=true"
            )
        else:
            route = f"projects/{project_id}/export?exportType=JSON"
        response = self.session.get(route)
        raise_for_status_with_logging(response)
        tasks = response.json()
        for task in tasks:
            yield LabelStudioExistingTask.parse_obj(task)

    def get_tasks(
        self, project_id: Optional[int] = None, page_size: int = 100
    ) -> Generator[LabelStudioExistingTask, None, None]:
        project_id = self._default_to_current_project(project_id)
        page_id = 1  # Label studio seems to start pagination at 1.
        route = f"projects/{project_id}/tasks"
        logger.info(f"Using a page size = {page_size}")
        while True:
            params = dict(page=page_id, page_size=page_size)
            logger.debug(f"Retrieving tasks for page ID = {page_id}")
            response = self.session.get(route, params=params)
            # Stop if error code
            if response.status_code != 200:
                logger.info("Received non 200 error code, assuming no pages left.")
                break
            tasks = response.json()
            # Stop if tasks list empty
            if len(tasks) == 0:
                logger.info("Received empty list, assuming no pages left.")
                break
            # Else, yield tasks
            for task in tasks:
                yield LabelStudioExistingTask.parse_obj(task)
            # Go to next page
            page_id = page_id + 1

    def update_project_label_config(
        self, label_config_xml: str, project_id: Optional[int] = None
    ):
        project_id = self._default_to_current_project(project_id)
        payload = dict(label_config=label_config_xml)
        logger.info("Updating project label configuration...")
        response = self.session.patch(f"projects/{project_id}/", json=payload)
        raise_for_status_with_logging(response)

    def create_ml_backend(
        self,
        title: str,
        url: str,
        description: str = "",
        project_id: Optional[int] = None,
    ):
        project_id = self._default_to_current_project(project_id)
        payload = dict(
            project=project_id,
            title=title,
            url=url,
            description=description,
        )
        response = self.session.post("ml/", json=payload)
        raise_for_status_with_logging(response)

    def upgrade_project(
        self,
        upgrade_task_fun: Callable[[LabelStudioExistingTask], LabelStudioExistingTask],
        project_id: Optional[int] = None,
    ):
        project_id = self._default_to_current_project(project_id)
        logger.info(f"Upgrading project ID = {project_id}...")

        # Step 1: download all tasks
        logger.info("Downloading task data...")
        tasks = list(self.get_tasks(project_id=project_id))
        logger.info(f"{len(tasks):,} tasks downloaded!")

        # Step 2: upgrade and patch tasks
        logger.info("Patching tasks...")
        for task in tqdm(tasks):
            upgraded_task = upgrade_task_fun(task)
            self.patch_task(upgraded_task)
        logger.info("Patching successful.")
        logger.info(f"Project ID = {project_id} upgraded.")

    def export_annotated_tasks_for_project_ids(
        self,
        project_ids: List[int],
        task_data_type: TaskDataType,
        include_cancelled: bool = True,
    ) -> LabelStudioProjectsData:
        """
        Example with type annotation:
            data: LabelStudioProjectsData[Transaction]
            data = client.export_annotated_tasks_for_project_ids(
                project_ids=[10, 11], task_data_type=Transaction
            )
        """
        projects_data = LabelStudioProjectsData[task_data_type]()
        for project_id in project_ids:
            project = self.get_project(project_id)
            if include_cancelled:
                tasks = list(self.export_annotated_tasks(project_id, cancelled=True))
                tasks.extend(
                    list(self.export_annotated_tasks(project_id, cancelled=False))
                )
            else:
                tasks = list(self.export_annotated_tasks(project_id, cancelled=False))
            project_data = LabelStudioProjectData[task_data_type](
                project=project,
                tasks=tasks,
            )
            projects_data.projects.append(project_data)
        return projects_data

@ynouri Thank you very much for sharing this, we definitely look into this and possibly borrow some ideas.