/restricttotopic

Validator for GuardrailsHub to check if a text is related with a topic.

Primary LanguagePythonApache License 2.0Apache-2.0

Overview

Developed by Tryolabs
Date of development Feb 15, 2024
Validator type Format
Blog -
License Apache 2
Input/Output Output

Description

This validator checks if a text is related with a topic.

Requirements

  • Dependencies:

    • guardrails-ai>=0.4.0
    • tenacity>=8.1.0
    • transformers>=4.11.3
    • torch>=2.1.1
  • Foundation model access keys:

    • OPENAI_API_KEY

Installation

guardrails hub install hub://tryolabs/restricttotopic

Usage Examples

Validating string output via Python

In this example, we apply the validator to a string output generated by an LLM.

# Import Guard and Validator
from guardrails.hub import RestrictToTopic
from guardrails import Guard

# Setup Guard
guard = Guard().use(
    RestrictToTopic(
        valid_topics=["sports"],
        invalid_topics=["music"],
        disable_classifier=True,
        disable_llm=False,
        on_fail="exception"
    )
)

guard.validate("""
In Super Bowl LVII in 2023, the Chiefs clashed with the Philadelphia Eagles in a fiercely contested battle, ultimately emerging victorious with a score of 38-35.
""")  # Validator passes

guard.validate("""
The Beatles were a charismatic English pop-rock band of the 1960s.
""")  # Validator fails

Validating JSON output via Python

In this example, we apply the validator to a string field of a JSON output generated by an LLM.

# Import Guard and Validator
from pydantic import BaseModel, Field
from guardrails.hub import RestrictToTopic
from guardrails import Guard

# Initialize Validator
val = RestrictToTopic(
    valid_topics=["sports"],
    disable_classifier=True,
    disable_llm=False,
    on_fail="exception"
)

# Create Pydantic BaseModel
class TopicSummary(BaseModel):
		topic: str
		summary: str = Field(validators=[val])

# Create a Guard to check for valid Pydantic output
guard = Guard.from_pydantic(output_class=TopicSummary)

# Run LLM output generating JSON through guard
guard.parse("""
{
	"topic": "Super Bowl LVII",
	"summary": "In Super Bowl LVII in 2023, the Chiefs clashed with the Philadelphia Eagles in a fiercely contested battle, ultimately emerging victorious with a score of 38-35."
}
""")

API Reference

__init__(self, on_fail="noop")

    Initializes a new instance of the RestrictToTopic class.

    Parameters

    • valid_topics (List[str]): topics that the text should be about (one or many).
    • invalid_topics (List[str]): topics that the text cannot be about. Defaults to [].
    • device (int): Device ordinal for CPU/GPU supports for Zero-Shot classifier. Setting this to -1 will leverage CPU, a positive will run the Zero-Shot model on the associated CUDA device id. Defaults to -1.
    • model (str): The Zero-Shot model that will be used to classify the topic. See a list of all models here: https://huggingface.co/models?pipeline_tag=zero-shot-classification. Defaults to facebook/bart-large-mnli.
    • llm_callable (Union[str, Callable, None]): Either the name of the OpenAI model, or a callable that takes a prompt and returns a response. Defaults to gpt-3.5-turbo.
    • disable_classifier (bool): Controls whether to use the Zero-Shot model. At least one of disable_classifier and disable_llm must be False. Defaults to False.
    • disable_llm (bool): Controls whether to use the LLM fallback. At least one of disable_classifier and disable_llm must be False. Defaults to False.
    • model_threshold (float): The threshold used to determine whether to accept a topic from the Zero-Shot model. Must be a number between 0 and 1. Defaults to 0.5.
    • on_fail (str, Callable): The policy to enact when a validator fails. If str, must be one of reask, fix, filter, refrain, noop, exception or fix_reask. Otherwise, must be a function that is called when the validator fails.

validate(self, value, metadata) → ValidationResult

    Validates the given `value` using the rules defined in this validator, relying on the `metadata` provided to customize the validation process. This method is automatically invoked by `guard.parse(...)`, ensuring the validation logic is applied to the input data.

    Note:

    1. This method should not be called directly by the user. Instead, invoke guard.parse(...) where this method will be called internally for each associated Validator.
    2. When invoking guard.parse(...), ensure to pass the appropriate metadata dictionary that includes keys and values required by this validator. If guard is associated with multiple validators, combine all necessary metadata into a single dictionary.

    Parameters

    • value (Any): The input value to validate.
    • metadata (dict): A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.