/RockyBot

AI intelligence bot from Modulus Labs

Primary LanguageTypeScriptMIT LicenseMIT

RockyBot

watercolor of rocky staying alive -- DALLE (original, digital ink)

"watercolor of rocky staying alive" -- DALLE (original, digital ink)

RockyBot is the first ever fully on-chain AI trading bot!! Features include

  • An L1 contract which holds funds and exchanges WEth / USDC on Uniswap.
  • An L2 contract implementing a simple (but flexible) 3-layer neural network for predicting future WEth prices.
  • A simple frontend for visualization and PyTorch code for training both regressors and classifiers.

Rocky is live at rockybot.app -- check out how he's doing!!

The Cairo neural net model can be found in the L2ContractHelper directory, under the L2RockafellerBot.cairo file (pardon our misspelling!).

To play with creating your own neural net, copy over all the code from this line onwards and follow the example given by the three_layer_nn function. For tips or tricks, feel free to hop on our Discord and reach out!

Getting Started

This directory is for all things training-related with respect to Rocky!

Installation

  • Install Conda from the official site.
  • Create Conda env: conda create --name rockybot-env --file pytorch-model-env.txt

(Note that you should run the above command from this directory!)

All of the below is with respect to the pytorch-model/ directory!

Note that the model type which is implemented in Cairo is the simple 3-layer neural net with ReLU activations between each layer (except the final layer, which outputs raw logits/softmax distribution), trained on the dataset derived from process_playground_task() function.

Data Generation

Simply run process_dataset.py with no arguments. This command generates .npy files for the playground_task task which will be used in classification eval/training.

The dataset currently generated for training/eval has as features the price (1 WEth --> USDC) difference between the current timestamp's price and 0-35 hours before. The label is the price bucket (defined here) for price difference between the next hour's price and the current price.

Model Training

This function makes liberal use of argparse! See the get_train_args() function for full details. Sample command is as follows:

python3 classification_train.py \
	--dataset playground_task \
	--model-type simple_3_layer_classifier \
	--num-epochs 100 \
	--model-name rockybot_sample_1

Results (saved model files, train stats, etc) are stored in playground_task_task/simple_3_layer_classifier/rockybot_sample_1.

Model Validation

Similarly to the train command, we use argparse here. Example command is as follows:

python3 classification_eval.py \
	--dataset playground_task \
	--model-name rockybot_sample_1 \
	--model-type simple_3_layer_classifier

Enter the corresponding model checkpoint to load and evaluate. This command outputs the val loss and accuracy (note that these models will grossly overfit the training set, since market data is noisy and learning is ungeneralizable, as far as we can tell), and generates a confusion matrix as well (see playground_task_viz/simple_3_layer_classifier/rockybot_sample_1/Confusion_Matrix.png).

Sample Directory

This is how your directory structure should look after running process_dataset.py, classification_train.py, and classification_eval.py!

├── classification_eval.py
├── classification_train.py
├── constants.py
├── datasets
│   └── playground_task
│       ├── feature_idx.json
│       ├── label_idx.json
│       ├── train.npy
│       ├── train_labels.npy
│       ├── val.npy
│       └── val_labels.npy
├── datasets.py
├── eth_btc_pricedata.csv
├── models.py
├── opts.py
├── playground_task_task
│   └── simple_3_layer_classifier
│       └── ryan_test_8
│           ├── final_model.pth
│           └── train_stats.json
├── playground_task_viz
│   └── simple_3_layer_classifier
│       └── ryan_test_8
│           └── Confusion_Matrix.png
├── process_dataset.py
├── pytorch-model-env.txt
├── regression_eval.py
├── regression_train.py
├── test.json
├── test.npy
├── tree.txt
├── utils.py
└── viz_utils.py