How to Train Agatha¶
Training the Agatha deep learning model is the last step to generating
hypotheses after you’ve already processed all necessary information using
agatha.construct
. This process uses PyTorch and PyTorch-Lightning to
efficiently manage the distributed training of the predicate ranking model
stored in agatha.ml.hypothesis_predictor
.
tl:dr;¶
You will need the following files:
predicate_graph.sqlite3
predicate_entities.sqlite3
embeddings/predicate_subset/*.h5
You will need to run python3 -m agatha.ml.hypothesis_predictor
with the right
keyword arguments. If performing distributed training, you will need to run this
on each machine in your training cluster.
Take a look at [scripts/train_2020.sh][https://github.com/JSybrandt/agatha/blob/master/scripts/train_2020.sh] for how to train the agatha model.
If you are running the training process on one machine and only one gpu, you
will want to remove the distributed_backend
flag, and make sure num_nodes
is
set to one. If you are using multiple machines, or multiple gpus on one
machine, then you will want to make sure that distributed_backend="ddp"
and
you should take a look at setting the distributed environment variables if you
run into errors. In the multi-gpu one-machine case, these variables should be
set automatically.
Background¶
The Agatha deep learning model learns to rank entity-pairs. To learn this ranking, we will be comparing existing predicates found within our dataset against randomly sampled entity-pairs. Of course, if a predicate exists in our database, it should receive a higher model output than many random pairs.
A positive sample
is a entity-pair that actually occurs in our dataset. A
negative sample
is one of those non-existent randomly sampled pairs. We will
use the margin ranking loss criteria to learn to associate higher values
with positive samples. To do this, we will compare one positive sample to a high
number of negative samples. This is the negative-sampling rate
.
A single sample, be it positive or negative, is comprised of four parts:
Term 1 (the subject).
Term 2 (the object).
Predicates associated with term 1 (but not term 2).
Predicates associated with term 2 (but not term 1).
This as a whole is referred to as a sample
. Generating samples is the primary
bottleneck in the training process. This is because we have many millions of
terms and predicates. As a result, the Agatha deep learning framework comes
along with a number of utilities to make managing the large datasets easier.
Datasets¶
In order to begin training you will need the following data:
Embeddings for all entities and predicates, stored as a directory of
.h5
files.Entity metadata, stored as a
.sqlite3
file.The subgraph containing all entity-predicate edges, stored as a
.sqlite3
file.
The network construction process will produce these datasets as sqlite3
files.
Sqlite is an embedded database, meaning that we can load the database from
storage and don’t need to spin up a whole server. Additionally, because we are
only going to read and never going to write to these databases during
training, each machine in our distributed training cluster can have independent
access to the same data very efficiently.
All of the sqlite3 databases managed by Agatha are stored in a simple format
that enables easy python access through the
agatha.util.sqlite3_lookup.Sqlite3LookupTable
object. This provides read-only
access to the database through a dictionary-like interface.
For instance, if we want to get the neighbors for the node l:noun:cancer
, we
can simply write this code:
from agatha.util.sqlite3_lookup import Sqlite3LookupTable
graph = Sqlite3LookupTable("./data./releases/2020/graph.sqlite3")
graph["l:noun:cancer"]
# Returns:
# ... [
# ... < List of all neighbors >
# ... ]
This process works by first making an sqlite3 connection to the graph database
file. By default, we assume that this database contains a table called
lookup_table
that has the schema: (key:str, value:str)
. Values in this
database are all json-encoded. This means that calling graph[foo]
looks up
the value associated with “foo” in the database, and parses whatever it find
through json.loads(...)
.
This process is slow compared to most other operations in the training pipeline.
Each query has to check against the sqlite key
index, which is stored on disk,
load the value
, also stored on disk, and then parse the string. There are two
optimizations that make this faster: preloading and caching. Look into the API
documentation for more detail.
Installing Apex for AMP¶
Apex is a bit of a weird dependency, but it allows us to take advantage of some GPU optimizations that really cut back our memory footprint. Amp allows us to train using 16-bit precision, enabling more samples per batch, resulting in faster training times. However, note that if you install apex on a node that has one type of GPU, you will get an error if you try and train on another. This means that you need to install this dependency on a training node with the appropriate GPU.
Warning: Apex is going to require a different version of GCC than we typically
use. If you’re on palmetto, you can run: module rm gcc/8.1.0; module load gcc/6.3.0
To install apex, first select a location such as ~/software
to keep the files.
Next, download apex with git git clone https://github.com/NVIDIA/apex.git
.
Finally, install the dependency with: pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
In full, run this:
# SSH into one of your training nodes with the correct GPU configuration.
# Make sure the appropriate modules are loaded.
# Assuming you would like to install apex in ~/software/apex
# make software dir if its not present
mkdir -p ~/software
# Clone apex to ~/software/apex
git clone https://github.com/NVIDIA/apex.git ~/software/apex
# Enter Apex dir
cd ~/software/apex
# Run install
pip install -v \
--no-cache-dir \
--global-option="--cpp_ext" \
--global-option="--cuda_ext" \
./
Model Parameters¶
This is NOT an exhaustive list of the parameters present in the Agatha deep learning model, but is a full list of the parameters you need to know to train the model.
amp_level
: The optimization level used by NVIDIA. O1
works well. O2
causes some
convergence issues, so I would stay away from that.
default_root_dir
: The directory to store model training files.
dataloader-workers
: The number of processes used to generate predicate pairs, per-gpu. Too many
dataloader workers will cause an out-of-memory error. I’ve found 3 works well.
dim
: The number of dimensions of each input embedding. We use 512 in most cases.
This parameter effects the size of various internal parameters.
distributed_backend
: Used to specify how to communicate between GPUs. Ignored if using only one
GPU. Set to ddp
for distributed data parallel (even if only using gpus on
the same node).
embedding-dir
: The system path containing embedding HDF5
(*.h5
) files.
entity-db
: The system path to the entities .sqlite3
database.
gpus
: The specific GPUs enabled on this machine. GPUs are indexed starting from 0.
On a 2-GPU node, this should be set to 0,1
.
gradient_clip_val
: A single step of gradient decent cannot move a parameter more than this
amount. We find that setting this to 1.0
enables convergence.
graph-db
: The system path to the graph .sqlite3
database.
lr
: The learning rate. We use 0.02
because we’re cool.
margin
: The objective of the Agatha training procedure is Margin Ranking Loss.
This parameter determines how different a positive ranking criteria needs to
be from all negative ranking criteria. Setting this too high or low will cause
convergence issues. Remember that the model outputs in the [0,1]
interval.
We recommend 0.1
.
max_epochs
: The maximum number of times to go through the training set.
negative-scramble-rate
: For each positive sample, how many negative scrambles (easy negative samples).
negative-swap-rate
: For each positive sample, how many negative swaps (hard negative samples).
neighbor-sample-rate
: When sampling a term-pair, we also sample each pair’s disjoint neighborhood.
This determines the maximum number of neighbors to include.
num_nodes
: This determines the number of MACHINES used to train the model.
num_sanity_val_steps
: Before starting training in earnest, we can optionally take a few validation
steps just to make sure everything has been configured properly. If this is
set above zero, we will run multiple validation steps on the newly
instantiated model. Recommended to run around 3
just to make sure everything
is working.
positives-per-batch
: Number of positive samples per batch per machine. More results in faster
training. Keep in mind that the true batch size will be num_nodes * positives-per-batch * (negative-scramble-rate + negative-swap-rate)
. When
running with 16-bit precision on V100 gpus, we can handle around 80
positives per batch.
precision
: The number of bits per-float. Set to 16
for half-precision if you’ve
installed apex.
train_percent_check
: Limits the number of actual training examples per-epoch. If set to 0.1
then
one epoch will occur after every 10% of the training data. This is important
because we only checkpoint after every epoch, and don’t want to spend too
much time computing between checkpoints. We recommend that if you set this
value, you should increase max_epochs
accordingly.
transformer-dropout
: Within the transformer encoder of Agatha, there is a dropout parameter that
helps improve performance. Recommended you set this to 0.1
.
transformer-ff-dim
: The size fo the transformer-encoded feed-forward layer. Recommended you set
this to something between 2*dim
and 4*dim
.
transformer-heads
: The number of self-attention operations per self-attention block in the
transformer encoder. We use 16
.
transformer-layers
: The number of transformer encoder blocks. Each transformer-encoder contains
multi-headed self-attention and a feed-forward layer. More transformer encoder
layers should lead to higher quality, but will require additional training
time and memory.
val_percent_check
: Just like how train_percent_check
limits the number of training samples
per-epoch, val_percent_check
limits the number of validation samples
per-epoch. Recommended that if you set one, you set the other accordingly.
validation-fraction
: Before training, this parameter determines the training-validation split. A
higher value means less training data, but more consistent validation numbers.
Recommended you set to 0.2
.
warmup-steps
: Agatha uses a gradient warmup strategy to improve early convergence. This
parameter indicates the number of steps needed to reach the input learning
rate. For instance, if you specify a learning rate of 0.02
and 100
warmup
steps, at step 50
there will be an effective learning rate around 0.01
. We
set this to 100
, but higher can be better if you have the time.
weight-decay
: Each step, the weights of the agatha model will be moved towards zero at this
rate. This helps with latter convergence and encourages sparsity. We set to
0.01
.
weights_save_path
: The result root directory. Model checkpoints will be stored in
weights_save_path/checkpoints/version_X/
. Recommended that this is set to
the same value as default_root_dir
.
Subset Data with Percent Check Flags¶
In the list of model flags are two that deserve more explanation:
train_percent_check
, and val_percent_check
. When debugging the model
training process to ensure everything has been setup correctly, it is worthwhile
to run the training routine through a couple of epochs quickly. This will ensure
that the model output checkpoints are created properly. To do so, set
train_percent_check
and val_percent_check
to a very small value, such as
0.0001
. Preferably, this will be small enough to complete an epoch in a couple
of minutes. Warning, you set this value too low, you will filter out all of
the training data and will create problems.
When you actually want to train the model, you still might want a modest
train_percent_check
and val_percent_check
. For instance, if the estimated
time per epoch is greater than a couple of hours, you might want more frequent
check pointing. What we want to avoid is the amount of training time that is
lost when an unpredictable system failure causes an outage 40 hours into
training, and we haven’t created our first checkpoint yet. If this were to
happen, we would simply lose all of the progress we had made for nearly two days
worth of computational effort.
Therefore, I recommend setting these values to something that reduces the time per epoch to the single-digit hours. Keep in mind that when you reduce the training set, and especially when you reduce the validation set, you should expect poorer convergence in the final model. Therefore, if at all possible, it is recommend that you increase the number of training processes by adding more distributed workers. Once you have as many machines as you can afford, then tune this parameter.
Running Distributed Training¶
In order to preform distributed training, you will need to ensure that your training cluster is each configured with the same modules, libraries, and python versions.
On palmetto, and many HPC systems, this can be done with modules and Anaconda. I
recommend adding a section to your .bashrc
for the sake of training Agatha
that loads all necessary modules and activates the appropriate conda
environment. As part of this configuration, you will need to set some
environment variables on each machine that help coordinate training. These are MASER_ADDR
, MASTER_PORT
, and NODE_RANK
.
Distributed Training Env Variables¶
MASER_ADDR
: Needs to be set to the hostname of one of your training nodes. This node will
coordinate the others.
MASTER_PORT
: Needs to be set to an unused network port for each machine. Can be any large
number. We recommend: 12910
.
NODE_RANK
: If you have N machines, then each machine needs a unique NODE_RANK
value
between 0 and N-1.
We recommend setting these values automatically using a nodefile
. A nodefile
is just a text file containing the hostnames of each machine in your training
cluster. The first name will be the MASTER_ADDR
and the NODE_RANK
will
correspond to the order of names in the file.
If ~/.nodefile
is the path to your nodefile, then you can set these values
with:
export NODEFILE=$HOME/.nodefile
export NODE_RANK=$(grep -n $HOSTNAME $NODEFILE | awk 'BEGIN{FS=":"}{print $1-1}')
export MASTER_ADDR=$(head -1 $NODEFILE)
export MASTER_PORT=12910
If you’re on palmetto, you’ve already got access to the nodefile referenced by
PBS_NODEFILE
. However, only the first machine will have this variable set. I
recommend automatically copying this file to some shared location whenever it is
detected. You can do that in .bashrc
by putting the following lines BEFORE
setting the NODE_RANK
and MASER_ADDR
variables.
# If $PBS_NODEFILE is a file
if [[ -f $PBS_NODEFILE ]]; then
cp $PBS_NODEFILE ~/.nodefile
fi
Launching Training on Each Machine with Parallel¶
Once each machine is configured, you will then need to run the agatha training
module on each. We recommend parallel
to help you do this. Parallel runs a
given bash script multiple times simultaneously, and has some flags that let
us run a script on each machine in a nodefile.
Put simply, you can start distributed training with the following:
parallel \
--sshloginfile $NODEFILE \
--ungroup \
--nonall \
python3 -m agatha.ml.hypothesis_predictor \
... agatha args ...
To explain the parameters:
sshloginfile
: Specifies the set of machines to run training on. We use the NODEFILE
created in the previous step.
ungroup
: By default, parallel
will wait until a process exits to show us its output.
This flag gives us input every time a process writes the newline character.
nonall
: This specifies that the following command (python3
) will not need its
arguments set by parallel
, and that we would like to run the following
command as-is, once per machine in $NODEFILE
.
Palmetto-Specific Details¶
On palmetto, there are a number of modules that you will need to run Agatha. Here is what I load on every machine I use to train agatha:
# C++ compiler modules
module load gcc/8.3.0
module load mpc/0.8.1
# NVIDIA modules
module load cuda-toolkit/10.2.89
module load duDNN/10.2.v7.6.5
module load nccl/2.6.4-1
# Needed for parallel
module load gnu-parallel
# Needed to work with HDF5 files
module load hdf5/1.10.5
# Needed to work with sqlite
module load sqlite/3.21.0
conda activate agatha
# Copy PBS_NODEFILE if it exists
if [[ -f $PBS_NODEFILE ]]; then
cp $PBS_NODEFILE ~/.nodefile
fi
# Set distributed training variables
export NODEFILE="~/.nodefile"
export NODE_RANK=$(grep -n $HOSTNAME $NODEFILE | awk 'BEGIN{FS=":"}{print $1-1}')
export MASTER_ADDR=$(head -1 $NODEFILE)
export MASTER_PORT=12910
Loading the Trained Model¶
Once you’ve completed a few epochs of training, you will hopefully a see a file
appear in
{weights_save_path}/lightning_logs/version_{#}/checkpoints/epoch={#}.ckpt
If course, weights_save_path
refers to whatever directory you listed in
--weights_save_path
in the training command-line arguments. The version number
refers to the model version that pytorch-lightning deduces while training. Each
time you run the training script with the same checkpoint directory, this number
will increment. Then the epoch number will refer to whatever epoch this model
last updated its checkpoint. Note here that the epoch number might be less than
the number of epochs you’ve actually computed, because we will only update the
checkpoint when the validation loss is improved.
To load the checkpoint in python, use:
from agatha.ml.hypothesis_predictor import HypothesisPredictor
model = HypothesisPredictor.load_from_checkpoint( ... )
When you want to give this model to someone else, you often don’t want to give them the whole checkpoint. For this, you can use a simpler pytorch model format. The conversion is really simple:
checkpoint_path = ...
output_path = ...
import torch
from agatha.ml.hypothesis_predictor import HypothesisPredictor
# Load model from checkpoint
model = HypothesisPredictor.load_from_checkpoint(checkpoint_path)
# Save model in pytorch model format
torch.save(model, output_path)
The reason to do this is so future users can load your model with:
import torch
model = torch.load(...)
Running your new model.¶
Now that you have a model that you can load (either through
load_from_checkpoint
or torch.load
, you can run some examples to ensure that
everything has been configured properly. The simplest way to do this is to run a
little script like this in your python terminal:
from agatha.ml.hypothesis_predictor import HypothesisPredictor
model = HypothesisPredictor.load_from_checkpoint("...")
# - OR -
import torch
model = torch.load("...")
# Configure auxilary data paths
model.configure_paths(
embedding_dir="/path/to/embeddings",
entity_db="/path/to/entities.sqlite3",
graph_db="/path/to/graph.sqlite3",
)
# Optional, if you're going to do a lot of queries.
model = model.eval()
model.preload()
# C0006826 is the term for Tobacco
# C0040329 is the term for Cancer
print(model.predict_from_terms([("C0006826", "C0040329")]))
If this outputs something like [0.9]
(or any other float, if your model hasn’t
really been trained), then you’re good!