/textgenrnn-api

A Flask API for training and querying text-generating RNNs with the textgenrnn module using Google App Engine and Google Cloud Storage.

Primary LanguagePython

textgenrnn-api

Description

A lightweight Python API, designed to run on Google Cloud, which allows clients to train RNNs on arbitirary strings and then generate output. Uses the phenomenal textgenrnn module for text generation and Flask as a web framework. textgenrnn-api will not output anything worthy of a NLP paper, but it's still pretty fun.

Routes

textgenrnn-api has two POST routes:

  • /train:
    • supply a list of training_strings
    • get back model_id
  • /generate:
    • supply a model_id, and optionally, a prompt, max_length, and/or temperature
    • get back output

Setup

  1. Clone this repository.

    git clone https://github.com/jkatofsky/textgenrnn-api.git
    cd textgenrnn-api
  2. Create a python venv (optional, but good practice).

    python3 -m venv env
    source env/bin/activate
  3. Install the required modules.

    pip3 install -r requirements.txt
  4. Create a Google Cloud project and set the PROJECT_NAME variable appropriately in settings.py.

  5. Create a Google Cloud Storage Bucket in your project to store the models and set the MODEL_BUCKET_NAME variable appropriately in settings.py. You can optionally set a lifespan for the models using a delete lifecycle rule.

  6. Download a service account credentials JSON for your project with permissions for the model bucket and set the CREDENTIALS_JSON_PATH variable appropriately in settings.py.

  7. To test the server locally (with convenient hot reload), use the following command.

    python3 -m flask run --reload
  8. Assuming you have the gcloud SDK installed, you can deploy this repo right to App Engine.

    gcloud app deploy

    For more information, here is Google's guide for deploying a Flask project to App Engine.

TODOs

  • Investigate feasability of running on Compute Engine or a GC AI offering, could really let me increase the speed/efficacy of the ML.
  • Investigate other packages for textgen?
  • Way of remembering clients?
    • Only allow N trains/N generates by a given client?
  • Route for testing model existance?
  • Specify as JSON routes & provide CURL example on README.
  • Play with default textgen parameters (# epochs, # training chars, word-level vs. char-level - could have a flag for this in settings.py).
  • If my PR is accepted, use the proper fork of textgenrnn again.
  • Return the expiration time of the model with every response?
  • Memory usage issues:
    • Server is using ~2 times more memory than local is.
      • Use less workers?
        • Still using 1.9 GB even with 3 workers…so it’s one process that’s using it all.
      • Try loading and deleting tensorflow dynamically…?
      • It’s not the memory from training a model, I think, but the memory from loading tensorflow into memory when the app engine loads from sleep.
    • The issue appears this way locally - but doesn't seem to be the same issue as described above?