This project develops a system that uses two specialized LLMs trained on different topics (sleep science and car history) and implements a router to direct queries to the appropriate model based on the input question. The system utilizes FastAPI for the backend, Gradio for the user interface, and implements asynchronous processing for improved performance.
multi_model_llm_system/
├── data/
│ ├── raw/
│ │ ├── sleep_science_qa.csv
│ │ ├── car_history_qa.csv
│ ├── processed/
│ ├── train_sleep.csv
│ ├── train_car.csv
│ ├── test_sleep.csv
│ ├── test_car.csv
├── scripts/
│ ├── data_preparation.py
│ ├── fine_tune_model.py
│ ├── query_router.py
│ ├── inference.py
│ ├── evaluate.py
├── notebooks/
│ ├── data_exploration.ipynb
│ ├── model_training.ipynb
│ ├── router_training.ipynb
│ ├── inference_tests.ipynb
├── requirements.txt
├── environment.yml
├── README.md
├── .env # Add your HF_TOKEN here
-
Clone the repository
git clone https://github.com/lucky-verma/LLM-Router.git cd LLM-Router
-
Install dependencies Make sure you have conda installed My CUDA version is 12.2 on Ubuntu 22.04
conda env create -f environment.yml conda activate webai
-
Prepare data Place your raw datasets in the data/raw/ directory. Run the data preparation script:
python -m scripts.data_preparation
-
Fine-tune models Run the fine-tuning script:
python -m scripts.fine_tune_model
-
Run inference Test the inference pipeline:
python -m scripts.inference
-
Run evaluation Test the evaluation pipeline:
python -m scripts.evaluate
-
Start the FastAPI backend:
python main.py
-
Start the Gradio user interface:
python gradio_app.py
Open the provided URL in your web browser to interact with the chat interface.
- Base Model: We selected the Mistral 7B model, implemented via Unsloth, as our foundation. This choice offers an optimal balance between performance and efficiency, providing robust natural language understanding while maintaining reasonable computational requirements.
- Query Router: For query classification, we employed a zero-shot classification model (BART-large-mnli). This approach allows for flexible and accurate routing of queries to the appropriate domain-specific model without requiring extensive labeled training data for each new domain.
- Domain Specialization: We fine-tuned separate models on domain-specific datasets:
- Sleep Science Model: Trained on a comprehensive dataset of sleep-related research, studies, and expert knowledge.
- Car History Model: Fine-tuned using a rich dataset encompassing automotive history, technological advancements, and industry developments. This specialization ensures high-quality, domain-specific responses.
- Backend Framework: We chose FastAPI for our backend due to its:
- Asynchronous request handling capabilities, enabling efficient processing of multiple queries.
- Built-in support for API documentation and validation.
- Ease of integration with machine learning models and other Python libraries.
- Frontend Interface: Gradio was selected to create our user interface because it offers:
- A simple yet powerful framework for building interactive AI applications.
- Model Optimization: We utilized quantization techniques to reduce model size and inference time, allowing for more efficient deployment and faster response times.
- Scalability Considerations: The architecture is designed to easily accommodate additional domain-specific models, allowing for future expansion of the system's knowledge base.
- synchronous Processing: Implemented async functions in FastAPI to handle concurrent requests more efficiently.
- Model Caching: Implemented lazy loading and caching of models to reduce startup time and memory usage.
- Gradio Interface: Created a user-friendly chat interface that displays which model (Sleep or Car) is responding to each query.
- Model Optimization: Further optimize the models using techniques like qLoRA and pruning to reduce inference time.
- Distributed Computing: Implement a distributed system to handle model inference across multiple GPUs or machines.
- Caching Mechanism: Implement a response cache for frequent queries to reduce unnecessary model inference.
- Advanced Router: Develop a more sophisticated routing mechanism that can handle multi-topic queries or ambiguous cases. Train a router on sleep and car datasets or create a Synthetic Dataset to train the router.
- Performance Profiling: Use detailed profiling tools to identify and address specific bottlenecks in the system.
- Load Balancing: Introduce a load balancer to distribute requests across multiple worker processes or servers.
- Streaming Responses: Implement streaming responses to improve perceived responsiveness for users.
- Monitoring and Logging: Add comprehensive logging and monitoring to track system performance and identify issues in real-time. Known Issues
- High latency: The current system has a high average response time(~2000 ms), which needs to be addressed for real-time applications.
- Limited scalability: The system doesn't show significant performance improvements with increased concurrency.
- Containerization: The system should be containerized to provide scalability and robustness.