Graph-PFN is a Python project for graph-based PFNs (Prior-Fitted Networks).
This repository provides scripts for training and evaluating PFNs and other models such as Graph Neural Networks (GNNs) using torch
, torch_geometric
, and other dependencies.
First, clone the repository to your local machine:
git clone git@github.com:aron-bram/graph-pfn.git
cd graph-pfn
It is recommended to use a virtual environment to manage dependencies:
python3.10 -m venv venv # NOTE that I only tested the installation with python version 3.10
source venv/bin/activate
Ensure you have the latest version of pip, setuptools, and wheel:
pip install --upgrade pip setuptools wheel
When running on a CPU, install the following dependencies:
pip install --upgrade pip setuptools wheel
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
pip install torch_geometric==2.4.0
pip install torch_scatter==2.1.2 torch-sparse==0.6.17 -f https://data.pyg.org/whl/torch-2.1.1+cpu.html
pip install git+https://github.com/automl/PFNs
pip install matplotlib
Alternatively, when running on GPU with cuda 11.8 support, run these commands instead:
pip install --upgrade pip setuptools wheel
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install torch_geometric==2.4.0
pip install torch_scatter==2.1.2 torch-sparse==0.6.17 -f https://data.pyg.org/whl/torch-2.1.1+cu118.html
pip install git+https://github.com/automl/PFNs
pip install matplotlib
Optionally install dev tools:
pip install pytest ruff black isort
To ensure everything is set up correctly, run:
cd scripts
python train_pfn.py --pfn_epochs 2 --pfn_steps_per_epoch 1
You should see the following output at the end of the run:
Successfully trained PFN on 2 samples from the prior, and saved it under prior_fitted_model
Using the command above, this output means that you successfully trained a pfn on 2 datasets, which is saved
graph-pfn/
│── scripts/ # Source code
│── tests/ # Test scripts
│── README.md # Installation guide (this file)
│── requirements.txt # Alternative dependency file
│── venv/ # (Optional) Virtual environment
To start training a PFN model:
cd scripts
python scripts/train_pfn.py
To start evaluating a trained PFN model:
cd scripts
python evaluate_pfn.py
To train and evaluate other models on benchmarks:
cd scripts
python train_eval_baselines.py
Refer to the respective script's documentation at the top of the .py file for more detail on requirements, output, and explanation of what each script does. For example, a comprehensive explanation on how the PFN is trained on the prior can be found in train_pfn.py.
Each script accepts its own arguments to customize the run, and each argument is documented in the code.
To run all tests using pytest:
cd tests
pytest test_sampler.py
If torch_sparse or torch_scatter fails to install, ensure you're using the correct index when running on a CPU:
pip install torch_scatter==2.1.2 torch-sparse==0.6.16 -f https://data.pyg.org/whl/torch-2.1.1+cpu.html
Ensure you're using Python 3.10. The installation doesn't work with 3.13 as of now, but should but 3.11 and 3.12 should do, albeit I haven't tested with those versions.
This project is licensed under the MIT License.
Contributions are welcome! Please fork the repository and submit a pull request with your improvements.
For any issues, open an issue on GitHub.