The aim of this project is to reproduce the main results of the paper Discovering Symbolic Models from Deep Learning with Inductive Biases by Cranmer et al.
In this project, we train a different variants of a Graph Neural Network (GNN) (standard, bottleneck, L1, KL and pruning) on particle datasets with different interaction forces (charge,
We firstly validate whether the GNNs learn the true forces by performing a linear regression of the true forces on the most important messages. Seperately we use SymTorch to approximate the behaviour of the edge model to see if we can extract the true force laws. By combining deep learning with symbolic regression, this framework could be extended to search for new empirical laws in high-dimensional data.
We have extended on the original project by introducing a new model variant, the pruning model, where the dimensionality of the messages decreases throughout training.
We have also created a new demo Colab notebook, demo.ipynb, where the user can test the pipeline on any of the interaction forces or model variations, as an attempt to increase the reproducibility of the pipeline.
SymTorch_symbolic_distillation_GNNs
├── LICENSE
|
├── README.md
|
├── demo.ipynb #demo colab
|
├── linrepr_plots/ #linear combination plots and R2 scores stored here
|
├── media_for_readme/ #folder for the images in this README
|
├── model.py #GNN PyTorch Geometric models
|
├── plot_linear_rep.py #plot the linear combination of forces and calculate R2
|
├── pruning_experiments.py #run the pruning hyperparamter experiments
|
├── report_plots/ #plots for the report
|
├── report/ #report and executive summary in here
|
├── requirements.txt #install dependencies
|
├── simulations/ #code for simulations here
|
├── symbolic_reg.py #symbolic regression using PySR
|
├── test_models.py #get prediction losses on trained models
|
├── train_models_charge.py #train models on charge dataset
|
├── train_models_r1.py #train models on r1 dataset
|
├── train_models_r2.py #train models on r2 dataset
|
├── train_models_spring.py #train models on spring dataset
|
└── utils.py #contains code on making train/test/val data and loading it
We recommend you to view the accompanying Colab notebook, demo.ipynb, to recreate the pipeline. The notebook mounts to your Google Drive and saves all model weights to your drive directly.
To use this notebook, go to Colab and access the demo notebook from the following url: https://github.com/elizabethsztan/SymTorch_symbolic_distillation_GNNs/blob/main/demo.ipynb.
If you want to run the code locally on your system, please see below.
Ensure you have installed Python3.11 on your system.
- Clone the repository
git clone git@github.com:XXXXXXX/SymTorch_symbolic_distillation_GNNs.gitGo to the repository locally:
cd SymTorch_symbolic_distillation_GNNs- Install dependencies Create a new virtual environment and activate it.
python3.11 -m venv project_venv
source project_venv/bin/activateInstall requirements.txt.
pip install -r requirements.txtThe code in this repository uses an earlier version of SymTorch from commit 13b9925. This version predates the implementation of native PyTorch model serialization, so the saving and loading mechanisms for symbolic models are different from the current SymTorch release. Specifically:
- Symbolic models are saved using
save_model()withsave_pytorch=False, save_regressors=True - Models are loaded using
SymbolicModel.load_model()with the original callable function passed asmlp_architecture
We need to generate the datasets that we will use for training. For example, for spring:
python3 simulations/generate_data.py --sim spring --saveYou can make all four datasets required. Just replace spring with charge, r1, r2.
This will save the data in a new folder in the repository called datasets.
A fully populated datasets folder looks like
.
The dataset contains positions
There are specific training scripts for training all models on the different datasets: train_models_{sim} for the different simulations.
These scripts train all of the model variations for a specific simulation.
Optional:
If you want wandb experiment logging, you need to log in
wandb loginTo train the models with wandb logging:
python3 train_models_spring.py --save --epoch 100 --wandb_logTo train the models without wandb logging:
python3 train_models_spring.py --save --epoch 100The model weights are saved at model_weights\{sim}\{model_type}. There is also a metrics file which has the training configuration as well as the train and validation set loss saved in the same folder.
When running a training script for the first time, it automatically creates a train/val/test split of your data located in train_val_test_data/{sim}. But if this already exists, then the script will load in the training and validation data from this folder.
To verify that the GNNs have learnt the true forces, we can perform a linear regression on the true forces to fit the two most important messages. We pick the most important messages as the two that have the highest standard deviation over the test set (for standard and L1 models) or the highest KL divergence (for the KL model). The pruning and bottleneck model already have the message dimensions matching the dimensionality of the system.
To plot the linear combination of forces and calculate the
python3 plot_linear_rep.py --dataset_name spring --model_type bottleneck --num_epoch 100 - The
model_typearg can take either the model variation as an input (standard,bottleneck,L1,KL,pruning) orallif you want to do the analysis for all models at the same time. - You need to pass in the number of epochs you trained your model for in the
num_epochargument. - There is an extra argument,
--cutoff, which you can add. This can take a number and it will only plot a subset of the datapoints if you want a cleaner plot. By default, if you don't include this it will plot all the points in the test set. \
This populates the folder linrepr_plots.
- The linear plots are saved at
linrepr_plots\{sim}\{model_type}. This folder contains the plots for both the robust fit and the fit with all datapoints included (inwith_outliersfolder). - The
$R^2$ scores are saved inlinrepr_plots\{sim}\r2_scores_epoch_{epoch}.jsonfor all the model types trained on a specific dataset. This includes$R^2$ scores for both the robust and not robust fit.
Example populated folder:
We can perform SR on our trained models to see if we can reconstruct the force law from the edge model. To do this, we have written a script, symbolic_reg.py, which wraps the edge models with the SymTorch's MLP_SR class. To aid the efficiency of the SR, we used transformations of the input variables: the variables allowed in the symbolic regression are
We run symbolic regression for niterations = y000 to allow the Pareto front of equations to stabilise.
To perform the symbolic regression:
python3 symbolic_reg.py --dataset_name spring --model_type bottleneck Make sure to change num_epoch to the number of epochs that you trained your model for in the script if this wasn't 100. \
The data is saved in pysr_objects/{sim}/{model_type}/dim{0,1} for message 1 and similar for message 2. If you look at the hall_of_fame.csv file, you can see the full Pareto front of equations.
Example Pareto front of equations:
For spring, bottleneck;

The red box shows a successful reconstruction as we expect equations in the form of

for this sim.
We made a specific script to get the prediction losses on the test set. You need to have made all of the datasets and trained all of your models for 100 epochs (as this is hardcoded into this script).
Run:
python3 test_models.pyThe results will be saves at model_weights/test_results.json. \
As an extension to the original paper, we introduce a new model variation - pruning. This model variation is similar to bottleneck in the way that it restricts the dimensionality of the message vectors and does not add regularisation terms to the loss.
The pruning model decreases the dimensionality of the message elements throughout training. The message elements with the highest standard deviation are kept, as they have the biggest variation depending on input implying that these are the most important for the model.
The rate at which the pruning occurs, and the epoch where pruning finishes, are hyperparameters which we have tuned via a grid search procedure. In all cases, we begin pruning after the first epoch (approximately 11,000 optimiser steps).
There were three different pruning schedules that we trialed in the hyperparameter search:

We also varied the point in training at which pruning was completed as part of our hyperparameter search, trialing pruning end points at 65%, 75%, and 85% of the total training duration.
To choose the best hyperparameter combination, we plot the linear representation of true forces and choose the hyperparameter combination that provides the highest
We found that the best hyperparameter combination was to end pruning at 65% of the way through training and use a cosine decay schedule. This was the hyperparameter combination used for the rest of the project.
- Train the pruning models
If you want to run the pruning experiments yourself, run:
python3 pruning_experiments.py --epoch 100 --wandb_log --saveIf you don't want to log using wandb:
python3 pruning_experiments.py --epoch 100 --saveThe pruning experiments are run on the charge dataset.
- Calculate
$R^2$ scores of the different models to choose hyperparameters
python3 plot_linear_rep.py --dataset_name charge --model_type pruning_experiments --num-epoch 100Ensure you pass in the correct num_epoch corresponding to how long you trained your pruning experiments for, and the correct dataset_name for the dataset you trained the models on. \
The linrepr_plots\pruning_experiments\r2_scores_epoch_{epoch}.json. In the same folder, you can also find the individual linear regression plots if you're interested.