
ML model library for TrainDB

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0


You can train models in TrainDB using this ML model library. These models can be used to generate synopsis data or to estimate aggregate values in approximate query processing.


  • TrainDB
  • Python 3.8 or 3.9
  • Python virtual environment manager, such as pyenv (optional)
  • Packages used by ML models, such as pytorch - install requirements.txt
$> pip install --no-deps -r requirements.txt



You can download TrainDB and this model library in one step by running the following command:

$> git clone --recurse-submodules https://github.com/traindb-project/traindb.git


If you use traindb-model library with TrainDB, you can run SQL-like statements via trsql. Please refer to the README file in TrainDB.

You can also train models and generate synthetic data using the CLI model runner. For example, you can train a model on the test dataset as follows:

$> python tools/TrainDBCliModelRunner.py train TableGAN models/TableGAN.py \
       tests/test_dataset/instacart_small/data.csv \
       tests/test_dataset/instacart_small/metadata.json \
epoch 1 step 50 tensor(1.1035, grad_fn=<SubBackward0>) tensor(0.7770, grad_fn=<NegBackward>) None
epoch 1 step 100 tensor(0.8791, grad_fn=<SubBackward0>) tensor(0.9682, grad_fn=<NegBackward>) None

$> python tools/TrainDBCliModelRunner.py synopsis TableGAN models/TableGAN.py output 1000 sample.txt

Similarly, you can train inference models and run queries as follows:

$> python tools/TrainDBCliModelRunner.py train RSPN \
       models/RSPN.py \
       tests/test_dataset/instacart_small/data.csv \
       tests/test_dataset/instacart_small/metadata.json \

// SELECT COUNT(*) FROM order_products GROUP BY reordered WHERE add_to_cart_order < 12
$> python tools/TrainDBCliModelRunner.py infer RSPN models/RSPN.py output/ "COUNT(*)" "reordered" "add_to_cart_order < 12"


For demo and detailed explanation, see Notebook.

You can run the test codes directly from the GitHub codespaces. The instructions are the same as mentioned above.