I primarily used TensorFlow for large-scale recommendation tasks when in big company, but PyTorch could be more efficient for smaller tasks in a smaller company.
This directory aims to train a Click-Through Rate (CTR) model using PyTorch. It's a simple example, seeking to keep everything minimal. While the model is straightforward, the data preprocessing pipeline is more complex due to a variety of inputs.
Supported features include:
- Both numerical and categorical input features
- Categorical: automatic vocabulary extraction, low-frequency filtering, dynamic embedding, hash embedding
- Numerical: standard or 0-1 normalization, automatic discretization, automatic update of statistical number for standard or 0-1 normalization if new data is fed in
- Variable-length sequence feature support, if there's order in the sequence, please put the latest data before the oldest data as it may pads at the end of the sequence
- Sequence features support weights by setting the weight column
- Implemented DataFrameDataset for straightforward training with input data of pandas/polars DataFrame format
- Implemented a common Trainer for training pytorch models, and save/load the results
- Basic FastAPI for Model API Serving
Not supported:
- Distribution training, as target of this tool is for small companies
pip install git+https://github.com/xiahouzuoxin/torchctr
- Using DataFrameDataset to load the raw data as pytorch Dataset format
- Create a model definition file in
torchctr/models
, and implement the model by inherit from nn.Module but with some extra member methods,- required:
- training_step
- validation_step
- optional:
- configure_optimizers
- configure_lr_scheduler
- required:
- Using Trainer to train the model
- Serving the model by Model API Serving
-
[Optional] According to your model and data processing, maybe need create a new ServingModel like BaseServingModel
-
Set up the service:
- Debuging: Given service name and model path from command line
cd $torchctr_root python -m torchctr.serving.serve --name [name] --path [path/to/model or path/to/ckpt] --serving_class BaseServingModel
- Production: write the command line parameters to
serving_models
variable in torchctr/serving/serve.py
- Debuging: Given service name and model path from command line
-
Test the service: reference
test_predict
in example