gflownet is a library built upon PyTorch to easily train and extend GFlowNets, also known as GFN or generative flow networks. GFlowNets are a machine learning framework for probabilistic and generative modelling, with a wide range of applications, especially in scientific discovery problems.
In a nutshell, GFlowNets can be regarded as a generative model designed to sample objects
GFlowNets rely on the principle of compositionality to generate samples. A meaningful decomposition of samples
Consider the problem of generating Tetris-like boards. A natural decomposition of the sample generation process would be to add one piece at a time, starting from an empty board. For any state representing a board with pieces, we could identify its valid parents and children, as illustrated in the figure below.
We could define a reward function
The GFlowNet library comprises four core components: environment, proxy, policy models (forward and backward), and GFlowNet agent.
The environment defines the state space
The Scrabble environment simulates a simple letter arrangement game where words are constructed by adding one letter at a time, up to a maximum sequence length (typically 7). Therefore, the action space is the set of all English letters plus a special end-of-sequence (EOS) action; and the state space is the set of all possible words with up to 7 letters. We can represent each state
as a list of indices corresponding to the letters, padded with zeroes to the maximum length. For example, the state for the word "CAT" would be represented as [3, 1, 20, 0, 0, 0, 0]
. Actions in the Scrabble environment are single-element tuples containing the index of the letter, plus the end-of-sequence (EOS) action (-1,)
.
Using the gflownet library for a new task will typically require implementing your own environment. The library is particularly designed to make such extensions as easy as possible. In the documentation, we show how to do it step by step. You can also watch this live-coding tutorial on how to code the Scrabble environment.
We use the term "proxy" to refer to the function or model that provides the rewards for the states of an environment. In other words, In the context of GFlowNets, the proxy can be thought of as a function ScrabbleScorer
proxy computes the sum of the score of each letter of a word. For the word "CAT" that is
Adapting the gflownet library for a new task will also likely require implementing your own proxy, which is usually fairly simple, as illustrated in the documentation.
The policy models are neural networks that model the forward and backward transitions between states,
The GFlowNet Agent is the central component that ties all others together. It orchestrates the interaction between the environment, policies, and proxy, as well as other auxiliary components such as the Evaluator and the Logger. The GFlowNet can construct training batches by sampling trajectories, optimise the policy models via gradient descent, compute evaluation metrics, log data to Weights & Biases, etc. The agent can be configured to optimise any of the following loss functions implemented in the library: flow matching (FM), trajectory balance (TB), and detailed balance (TB) and forward-looking (FL).
To better understand the GFlowNet components, let us explore the Scrabble environment in more detail below.
When initializing any GFlowNet agent, it's useful to explore the properties of the environment. The library offers various functionalities for this purpose. Below are some detailed examples, among others:
- Checking the Initial State
You can observe the initial state of the environment. For Scrabble environment, this would be an empty board or sequence:
env.state
>>> [0, 0, 0, 0, 0, 0, 0]
- Exploring the Action Space
env.get_action_space()
>>> [(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (-1,)]
For Scrabble environment, the action space is all english alphabet letters indexed from 1 to 26. The action (-1,) represents the end-of-sequence (EOS) action, indicating the termination of word formation.
- Taking a Random Step
new_state, action_taken, valid = env.step_random()
print("New State:", new_state)
print("Action Taken:", action_taken)
print("Action Valid:", valid)
>>> New State: [24, 0, 0, 0, 0, 0, 0]
>>> Action Taken: (24,)
>>> Action Valid: True
This function randomly selects a valid action (adding a letter or ending the sequence) and applies it to the environment. The output shows the new state, the action taken, and whether the action was valid.
- Performing a Specific Action
action = (1,) # Action to add 'A'
new_state, performed_action, is_valid = env.step(action)
print("Updated State:", new_state)
print("Performed Action:", performed_action)
print("Was the Action Valid:", is_valid)
>>> Updated State: [24, 1, 0, 0, 0, 0, 0]
>>> Performed Action: (1,)
>>> Was the Action Valid: True
- Displaying the State as a human readable
env.state2readable(env.state)
>>> 'X A'
- Interpreting Actions as a human readable
print("Action Meaning:", env.idx2token[action[0]])
>>> Action Meaning: A
- Sampling a Random Trajectory
new_state, action_sequence = env.trajectory_random()
print("New State:", new_state)
print("Action Sequence:" action_sequence)
>>> New State: [16, 16, 17, 20, 11, 16, 0]
>>> Action Sequence: [(16,), (16,), (17,), (20,), (11,), (16,), (-1,)]
- Reset enviroment
env.reset()
env.state
>>> [0, 0, 0, 0, 0, 0, 0]
So far, we've discussed how to manually set actions or use random actions in the GFlowNet environment. This approach is useful for testing or understanding the basic mechanics of the environment. However, in practice, the goal of a GFlowNet agent is to learn from its experiences to take increasingly effective actions that are driven by a learned policy.
As the agent interacts with the environment, it collects data about the outcomes of its actions. This data is used to train a policy network, which models the probability distribution of possible actions given the current state. Over time, the policy network learns to favor actions that lead to more successful outcomes with higher reward, optimizing the agent's performance.
- Sample a batch of trajectories from a trained agent
batch, _ = gflownet.sample_batch(n_forward=3, train=False)
batch.states
>>> [[20, 20, 21, 3, 0, 0, 0], [12, 16, 8, 6, 14, 11, 20], [17, 17, 16, 23, 20, 16, 24]]
We can convert the first state to human readable:
env.state2readable(batch.states[0])
>>> 'T T U C'
We can also compute the rewards and the proxy for all states or single state.
proxy(env.states2proxy(batch.states))
>>> tensor([ 6., 19., 39.])
Or single state
proxy(env.state2proxy(batch.states[0]))
>>> tensor([6.])
The state2proxy
and states2proxy
are helper functions that transform the input to appropriate format. For example to tensor.
We can also compute the rewards, and since our transformation function g
is the identity, the rewards should be equal to the proxy directly.
proxy.rewards(env.states2proxy(batch.states))
>>> tensor([ 6., 19., 39.])
Quickstart: If you simply want to install everything, run setup_all.sh
.
- This project requires
python 3.10
andcuda 11.8
. - Setup is currently only supported on Ubuntu. It should also work on OSX, but you will need to handle the package dependencies.
- The recommend installation is as follows:
python3.10 -m venv ~/envs/gflownet # Initalize your virtual env.
source ~/envs/gflownet/bin/activate # Activate your environment.
./prereq_ubuntu.sh # Installs some packages required by dependencies.
./prereq_python.sh # Installs python packages with specific wheels.
./prereq_geometric.sh # OPTIONAL - for the molecule environment.
pip install .[all] # Install the remaining elements of this package.
Aside from the base packages, you can optionally install dev
tools using this tag, materials
dependencies using this tag, or molecules
packages using this tag. The simplest option is to use the all
tag, as above, which installs all dependencies.
The configuration is handled via the use of Hydra. To train a GFlowNet model with the default configuration, simply run
python main.py user.logdir.root=<path/to/log/files/>
Alternatively, you can create a user configuration file in config/user/<username>.yaml
specifying a logdir.root
and run
python main.py user=<username>
Using Hydra, you can easily specify any variable of the configuration in the command line. For example, to train GFlowNet with the trajectory balance loss, on the continuous torus (ctorus
) environment and the corresponding proxy:
python main.py gflownet=trajectorybalance env=ctorus proxy=torus
The above command will overwrite the env
and proxy
default configuration with the configuration files in config/env/ctorus.yaml
and config/proxy/torus.yaml
respectively.
Hydra configuration is hierarchical. For instance, You can seamlessly modify exisiting flag or variable in the configuration by setting logger.do.online=False
. For more, feel free to read the Hydra documentation.
Note that by default, PyTorch will operate on the CPU because we have not observed performance improvements by running on the GPU. You may run on GPU with device=cuda
.
Currently, the implementation includes the following GFlowNet losses:
- Flow-matching (FM):
gflownet=flowmatch
- Trajectory balance (TB):
gflownet=trajectorybalance
- Detailed balance (DB):
gflownet=detailedbalance
- Forward-looking (FL):
gflownet=forwardlooking
The library also has Logger class which helps to manage all logging activities during the training and evaluation of the network. It captures and stores logs to track the model's performance and debugging information. For instance, it logs details such as training progress, performance metrics, and any potential errors or warnings that occur. It also integrates to wandb.ai providing a cloud-based platform for logging the train and evaluation metrics to wandb.ai. The WandB is disabled by default. In order to enable it, set the configuration variable logger.do.online
to True
.
Many wonderful scientists and developers have contributed to this repository: Alex Hernandez-Garcia, Nikita Saxena, Alexandra Volokhova, Michał Koziarski, Divya Sharma, Pierre Luc Carrier and Victor Schmidt.
This repository has been used in at least the following research articles:
- Lahlou et al. A theory of continuous generative flow networks. ICML, 2023.
- Hernandez-Garcia, Saxena et al. Multi-fidelity active learning with GFlowNets. RealML at NeurIPS 2023.
- Mila AI4Science et al. Crystal-GFN: sampling crystals with desirable properties and constraints. AI4Mat at NeurIPS 2023 (spotlight).
- Volokhova, Koziarski et al. Towards equilibrium molecular conformation generation with GFlowNets. AI4Mat at NeurIPS 2023.
Bibtex Format
@misc{hernandez-garcia2024,
author = {Hernandez-Garcia, Alex and Saxena, Nikita and Volokhova, Alexandra and Koziarski, Michał and Sharma, Divya and Viviano, Joseph D and Carrier, Pierre Luc and Schmidt, Victor},
title = {gflownet},
url = {https://github.com/alexhernandezgarcia/gflownet},
year = {2024},
}
Or CFF file