[Feature]: Groq Agent
Arunprakaash opened this issue · 7 comments
Arunprakaash commented
Brief Description
. add support for chat groq agent
Rationale
- Faster streamin response
Suggested Implementation
vocode/streaming/agent/groq_agent.py
import logging
from typing import AsyncGenerator
from typing import Optional, Tuple
from langchain import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
HumanMessagePromptTemplate,
)
from langchain.schema import ChatMessage, AIMessage, HumanMessage
from langchain_groq import ChatGroq
from vocode import getenv
from vocode.streaming.agent.base_agent import RespondAgent
from vocode.streaming.agent.utils import get_sentence_from_buffer
from vocode.streaming.models.agent import ChatGroqAgentConfig
SENTENCE_ENDINGS = [".", "!", "?"]
class ChatGroqAgent(RespondAgent[ChatGroqAgentConfig]):
def __init__(
self,
agent_config: ChatGroqAgentConfig,
logger: Optional[logging.Logger] = None,
groq_api_key: Optional[str] = None,
):
super().__init__(agent_config=agent_config, logger=logger)
from groq import AsyncGroq
groq_api_key = groq_api_key or getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError(
"GROQ_API_KEY must be set in environment or passed in"
)
self.prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
self.llm = ChatGroq(
model_name=agent_config.model_name,
groq_api_key=groq_api_key,
)
self.groq_client = (
AsyncGroq(api_key=groq_api_key) if agent_config.generate_responses else None
)
self.memory = ConversationBufferMemory(return_messages=True)
self.memory.chat_memory.messages.append(
HumanMessage(content=self.agent_config.prompt_preamble)
)
if agent_config.initial_message:
self.memory.chat_memory.messages.append(
AIMessage(content=agent_config.initial_message.text)
)
self.conversation = ConversationChain(
memory=self.memory, prompt=self.prompt, llm=self.llm
)
async def respond(
self,
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[str, bool]:
text = await self.conversation.apredict(input=human_input)
self.logger.debug(f"LLM response: {text}")
return text, False
async def generate_response(
self,
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[Tuple[str, bool], None]:
self.memory.chat_memory.messages.append(HumanMessage(content=human_input))
bot_memory_message = AIMessage(content="")
self.memory.chat_memory.messages.append(bot_memory_message)
prompt = self.llm._create_message_dicts(self.memory.chat_memory.messages, None)[0]
if self.groq_client:
streamed_response = await self.groq_client.chat.completions.create(
messages=prompt,
model=self.agent_config.model_name,
stream=True,
max_tokens=self.agent_config.max_tokens_to_sample,
stop=None
)
buffer = ""
async for completion in streamed_response:
buffer += completion.choices[0].delta.content
sentence, remainder = get_sentence_from_buffer(buffer)
if sentence:
bot_memory_message.content = bot_memory_message.content + sentence
buffer = remainder
yield sentence, True
continue
def update_last_bot_message_on_cut_off(self, message: str):
for memory_message in self.memory.chat_memory.messages[::-1]:
if (
isinstance(memory_message, ChatMessage)
and memory_message.role == "assistant"
) or isinstance(memory_message, AIMessage):
memory_message.content = message
return
vocode/streaming/models/agent.py
from enum import Enum
from typing import List, Optional, Union
from langchain.prompts import PromptTemplate
from pydantic import validator
from vocode.streaming.models.actions import ActionConfig
from vocode.streaming.models.message import BaseMessage
from .model import TypedModel, BaseModel
from .vector_db import VectorDBConfig
FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS = 0.5
LLM_AGENT_DEFAULT_TEMPERATURE = 1.0
LLM_AGENT_DEFAULT_MAX_TOKENS = 256
LLM_AGENT_DEFAULT_MODEL_NAME = "text-curie-001"
CHAT_GPT_AGENT_DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
ACTION_AGENT_DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
CHAT_ANTHROPIC_DEFAULT_MODEL_NAME = "claude-v1"
CHAT_VERTEX_AI_DEFAULT_MODEL_NAME = "chat-bison@001"
AZURE_OPENAI_DEFAULT_API_TYPE = "azure"
AZURE_OPENAI_DEFAULT_API_VERSION = "2023-03-15-preview"
AZURE_OPENAI_DEFAULT_ENGINE = "gpt-35-turbo"
CHAT_GROQ_DEFAULT_MODEL_NAME = "mixtral-8x7b-32768"
class AgentType(str, Enum):
BASE = "agent_base"
LLM = "agent_llm"
CHAT_GPT_ALPHA = "agent_chat_gpt_alpha"
CHAT_GPT = "agent_chat_gpt"
CHAT_ANTHROPIC = "agent_chat_anthropic"
CHAT_GROQ = "agent_chat_groq"
CHAT_VERTEX_AI = "agent_chat_vertex_ai"
ECHO = "agent_echo"
GPT4ALL = "agent_gpt4all"
LLAMACPP = "agent_llamacpp"
INFORMATION_RETRIEVAL = "agent_information_retrieval"
RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented"
WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented"
ACTION = "agent_action"
class FillerAudioConfig(BaseModel):
silence_threshold_seconds: float = FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS
use_phrases: bool = True
use_typing_noise: bool = False
@validator("use_typing_noise")
def typing_noise_excludes_phrases(cls, v, values):
if v and values.get("use_phrases"):
values["use_phrases"] = False
if not v and not values.get("use_phrases"):
raise ValueError("must use either typing noise or phrases for filler audio")
return v
class WebhookConfig(BaseModel):
url: str
class AzureOpenAIConfig(BaseModel):
api_type: str = AZURE_OPENAI_DEFAULT_API_TYPE
api_version: Optional[str] = AZURE_OPENAI_DEFAULT_API_VERSION
engine: str = AZURE_OPENAI_DEFAULT_ENGINE
class AgentConfig(TypedModel, type=AgentType.BASE.value):
initial_message: Optional[BaseMessage] = None
generate_responses: bool = True
allowed_idle_time_seconds: Optional[float] = None
allow_agent_to_be_cut_off: bool = True
end_conversation_on_goodbye: bool = False
send_filler_audio: Union[bool, FillerAudioConfig] = False
webhook_config: Optional[WebhookConfig] = None
track_bot_sentiment: bool = False
actions: Optional[List[ActionConfig]] = None
class CutOffResponse(BaseModel):
messages: List[BaseMessage] = [BaseMessage(text="Sorry?")]
class LLMAgentConfig(AgentConfig, type=AgentType.LLM.value):
prompt_preamble: str
expected_first_prompt: Optional[str] = None
model_name: str = LLM_AGENT_DEFAULT_MODEL_NAME
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
cut_off_response: Optional[CutOffResponse] = None
class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value):
prompt_preamble: str
expected_first_prompt: Optional[str] = None
model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
cut_off_response: Optional[CutOffResponse] = None
azure_params: Optional[AzureOpenAIConfig] = None
vector_db_config: Optional[VectorDBConfig] = None
class ChatAnthropicAgentConfig(AgentConfig, type=AgentType.CHAT_ANTHROPIC.value):
prompt_preamble: str
model_name: str = CHAT_ANTHROPIC_DEFAULT_MODEL_NAME
max_tokens_to_sample: int = 200
class ChatGroqAgentConfig(AgentConfig, type=AgentType.CHAT_GROQ.value):
prompt_preamble: str
model_name: str = CHAT_GROQ_DEFAULT_MODEL_NAME
max_tokens_to_sample: int = 200
generate_responses: bool = True
class ChatVertexAIAgentConfig(AgentConfig, type=AgentType.CHAT_VERTEX_AI.value):
prompt_preamble: str
model_name: str = CHAT_VERTEX_AI_DEFAULT_MODEL_NAME
generate_responses: bool = False # Google Vertex AI doesn't support streaming
class LlamacppAgentConfig(AgentConfig, type=AgentType.LLAMACPP.value):
prompt_preamble: str
llamacpp_kwargs: dict = {}
prompt_template: Optional[Union[PromptTemplate, str]] = None
class InformationRetrievalAgentConfig(
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value
):
recipient_descriptor: str
caller_descriptor: str
goal_description: str
fields: List[str]
# TODO: add fields for IVR, voicemail
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO.value):
pass
class GPT4AllAgentConfig(AgentConfig, type=AgentType.GPT4ALL.value):
prompt_preamble: str
model_path: str
generate_responses: bool = False
class RESTfulUserImplementedAgentConfig(
AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED.value
):
class EndpointConfig(BaseModel):
url: str
method: str = "POST"
respond: EndpointConfig
generate_responses: bool = False
# generate_response: Optional[EndpointConfig]
class RESTfulAgentInput(BaseModel):
conversation_id: str
human_input: str
class RESTfulAgentOutputType(str, Enum):
BASE = "restful_agent_base"
TEXT = "restful_agent_text"
END = "restful_agent_end"
class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
pass
class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
response: str
class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
pass
streaming conversation usage of groq agent
import asyncio
import logging
import signal
from dotenv import load_dotenv
from vocode.streaming.agent.groq_agent import ChatGroqAgent
load_dotenv()
from vocode.streaming.streaming_conversation import StreamingConversation
from vocode.helpers import create_streaming_microphone_input_and_speaker_output
from vocode.streaming.transcriber import *
from vocode.streaming.agent import *
from vocode.streaming.synthesizer import *
from vocode.streaming.models.transcriber import *
from vocode.streaming.models.agent import *
from vocode.streaming.models.synthesizer import *
from vocode.streaming.models.message import BaseMessage
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
async def main():
(
microphone_input,
speaker_output,
) = create_streaming_microphone_input_and_speaker_output(
use_default_devices=False,
logger=logger,
use_blocking_speaker_output=True, # this moves the playback to a separate thread, set to False to use the main thread
)
conversation = StreamingConversation(
output_device=speaker_output,
transcriber=DeepgramTranscriber(
DeepgramTranscriberConfig.from_input_device(
microphone_input,
endpointing_config=PunctuationEndpointingConfig(),
)
),
agent=ChatGroqAgent(
ChatGroqAgentConfig(
initial_message=BaseMessage(text="What up"),
prompt_preamble="""The AI is having a pleasant conversation about life""",
)
),
synthesizer=AzureSynthesizer(
AzureSynthesizerConfig.from_output_device(speaker_output)
),
logger=logger,
)
await conversation.start()
print("Conversation started, press Ctrl+C to end")
signal.signal(
signal.SIGINT, lambda _0, _1: asyncio.create_task(conversation.terminate())
)
while conversation.is_active():
chunk = await microphone_input.get_audio()
conversation.receive_audio(chunk)
if __name__ == "__main__":
asyncio.run(main())
spikecodes commented
I'd recommend opening this as a PR so changes can be reviewed and merged into the project
Kevin7744 commented
Hey @Arunprakaash , are you not getting errors with ConversationChain(self.llm) BaseLanguageModel?
Arunprakaash commented
I have not tested that yet. I have just replaced the ChatGPT agent with this one and got it working. Once I've figured out everything, I'll create a pull request.
github-actions commented
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Scylla2020 commented
This gives an error
File "C:\Users\UserX\AppData\Local\Programs\Python\Python310\lib\site-packages\pydantic\deprecated\class_validators.py", line 249, in root_validator
raise PydanticUserError(
pydantic.errors.PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.