How to Run Your Agatha Model

The first set to run your agatha model is to either train your own, or download a pretrained model. Other help pages describe how to do this in more detail.

To begin, lets assume you downloaded Agatha 2020. When extracted, you should see the following directory structure. This is consistent with other pretrained models.

model_release/
  model.pt
  predicate_embeddings/
    embeddings_*.v5.h5
  predicate_entities.sqlite3
  predicate_graph.sqlite3

TL;DR

Once you have the necessary files, you can run the model with the following snippet:

# Example paths, change according to your download location
model_path = "model_release/model.pt"
embedding_dir = "model_release/predicate_embeddings"
entity_db_path = "model_release/predicate_entities.sqlite3"
graph_db_path = "model_release/predicate_graph.sqlite3"

# Load the model
import torch
model = torch.load(model_path)

# Configure auxilary data paths
model.configure_paths(
  embedding_dir=embedding_dir
  entity_db=entity_db_path,
  graph_db=graph_db_path,
)

# Now you're ready to run some predictions!
# C0006826 : Cancer
# C0040329 : Tobacco
model.predict_from_terms([("C0006826", "C0040329")])
>>> [0.9946276545524597]

# Speed up by moving the model to GPU
model = model.cuda()

# If we would like to run thousands of queries, we want to load everything
# before the query process. This takes a while, and is optional.
model.preload()

# Get a list of valid terms (slow if preload not called beforehand)
from agatha.util.entity_types import is_umls_term_type
valid_terms = list(filter(is_umls_term_type, model.graph.keys()))
len(valid_terms)
>>> 278139

# Run a big batch of queries
from random import choice
rand_term_pairs =  [
 (choice(valid_terms), choice(valid_terms))
 for _ in range(100)
]
scores = model.predict_from_terms(rand_term_pairs, batch_size=64)
# Now, scores[i] corresponds to rand_term_pairs[i]

Required Files

model.pt : This is the model file. You can load this with torch.load. Behind the scenes, this will startup the appropriate Agatha module.

predicate_embeddings : This directory contains graph embeddings for each entity needed to make predicate predictions using the Agatha model.

predicate_entities.sqlite3 : This database contains embedding metadata for each entity managed by the Agatha model. This database is loaded with agatha.util.sqlite3_lookup.Sqlite3LookupTable.

predicate_graph.sqlite3 : This database contains term-predicate relationships for each entity managed by the Agatha model. This database is loaded with agatha.util.sqlite3_lookup.Sqlite3LookupTable.

Bulk queries

In order to run bulk queries efficiently, you will want to run:

model.cuda()
model.preload()

The first command, model.cuda() moves the weight matrices to the GPU. The second command, model.preload() modes all graph and embedding information into RAM. This way, each request for an embedding, of which we will do tens of times per query, can be handled without a slow lookup in the storage system. Warning, expect this to take around 30 GB of RAM to start. Additionally, Agatha employs caching intermediate values that will increase the memory usage as the query process goes on.

Batch size

When running model.predict_from_terms the optional batch_size parameter can be used to improve GPU usage. Set this value to an integer greater than one to pack more than one query within each call to the GPU. You may need to experiment to find a value that is large, but doesn’t exceed GPU memory.