This project started as a way for me to learn JAX by building something I already understood: Knowledge Graph Embedding models.
It is not intended to be a production-ready library, and I would not recommend it for serious work. If you want a mature and battle tested package with broader coverage, better tooling, and a much more complete feature set, use PyKEEN.
KGE-JAXed is a small JAX/Flax NNX knowledge graph embedding library inspired by PyKEEN. The main goal is to keep the code easy to read and extend, while still following JAX-friendly patterns.
KGE-JAXed is a small, learning-oriented knowledge graph embedding package built with JAX and Flax NNX. It provides a compact implementation of classic KGE training workflows while keeping the codebase readable enough to inspect, modify, and extend.
The main entry point is KGEPipeline, which handles dataset loading, model construction, training, evaluation, and checkpointing. Datasets are resolved through PyKEEN, so the project can focus on model, loss, sampling, and evaluation logic rather than dataset packaging.
Currently, the project supports:
- Training classic KGE models such as
TransE,DistMult,ComplEx, andRotatE - Running link-prediction evaluation with MR, MRR, and Hits@K metrics
- Inspecting per-triple ranking results directly through
ranks_df, rather than only seeing aggregated metrics - Saving and resuming training with checkpoints
- Customizing embedding initializers, constrainers, and regularizers
The project is best suited for learning JAX, experimenting with KGE model internals, and working with a smaller codebase where the full training and evaluation path is easy to follow.
This repository uses uv. To install the package and associated dependencies, run:
uv syncThe simplest way to train a model is:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="transe",
dataset="nations",
loss_name="mrl",
embedding_dim=128,
negative_samples=1,
learning_rate=1e-2,
optimizer_name="adam",
seed=42,
)
pipeline.train(
epochs=100,
log_every=10,
)That will train a TransE model on the PyKEEN nations dataset.
- Start with
transe+mrlonnationsif you want the quickest first run - Try
rotate+nssaonce you want something a bit heavier - Use smaller embedding sizes and fewer epochs first if you are just checking that everything works
Here is a fuller example that trains a model and then evaluates it on the test split:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="transe",
dataset="nations",
loss_name="mrl",
embedding_dim=128,
negative_samples=1,
learning_rate=1e-2,
optimizer_name="adam",
seed=42,
)
train_summary = pipeline.train(
epochs=100,
log_every=10,
)
metrics_df, ranks_df = pipeline.evaluate(
split="test",
filtered=True,
)
print(train_summary["train_losses"][-5:])
print(metrics_df)evaluate() returns two dataframes:
metrics_dfgives you the usual aggregate ranking metrics such as MRR, MR, and Hits@K.ranks_dfgives you the actual ranks and scores for each evaluated triple.
One of the main reasons I like this project is that evaluation does not stop at a single aggregate score. Aggregate metrics are useful for comparing runs, but they hide which triples the model gets right, which ones it struggles with, and whether the errors are mostly coming from head prediction or tail prediction.
ranks_df keeps that information visible. It gives you a row-by-row view of the evaluation results, with one row per evaluated triple.
It contains:
head,relation,tailfor the evaluated triplerank_headfor head prediction rankrank_tailfor tail prediction rankscore_headfor the score assigned to the true head queryscore_tailfor the score assigned to the true tail query
That makes it easy to answer more diagnostic questions:
- Which triples are hardest for the model?
- Is the model much better at head prediction or tail prediction?
- Which examples are getting rank 1?
- Which examples are failing badly even when the aggregate MRR looks reasonable?
- Which relations seem to be causing the worst errors?
First, you can map the integer IDs back to readable entity and relation labels:
ranks_with_labels = ranks_df.assign(
head_label=lambda df: df["head"].map(pipeline.dataset.id_to_entity),
relation_label=lambda df: df["relation"].map(pipeline.dataset.id_to_relation),
tail_label=lambda df: df["tail"].map(pipeline.dataset.id_to_entity),
)
print(ranks_with_labels.head())Then you can inspect the worst-ranked examples directly:
# Hardest triples by tail rank
print(
ranks_with_labels.sort_values("rank_tail", ascending=False)[
["head_label", "relation_label", "tail_label", "rank_tail", "score_tail"]
].head(10)
)
# Hardest triples by head rank
print(
ranks_with_labels.sort_values("rank_head", ascending=False)[
["head_label", "relation_label", "tail_label", "rank_head", "score_head"]
].head(10)
)Or summarize the errors by relation:
relation_summary = (
ranks_with_labels.assign(avg_rank=lambda df: (df["rank_head"] + df["rank_tail"]) / 2)
.groupby("relation_label")["avg_rank"]
.mean()
.sort_values(ascending=False)
)
print(relation_summary.head(10))This makes evaluation much easier to debug. You can look beyond a single MRR number and inspect where the model is doing well or badly.
A typical metrics_df looks like this:
head tail avg
mrr ... ... ...
mr ... ... ...
hits@1 ... ... ...
hits@3 ... ... ...
hits@10 ... ... ...
And ranks_df looks like this:
head relation tail rank_head rank_tail score_head score_tail
0 0 1 2 1 3 ... ...
1 4 2 7 8 1 ... ...
You can swap the model and loss configuration directly in the pipeline:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="rotate",
dataset="fb15k",
loss_name="nssa",
embedding_dim=256,
negative_samples=16,
learning_rate=1e-3,
optimizer_name="adam",
dataset_kwargs={
"batch_size": 256,
"shuffle": True,
},
loss_kwargs={
"adversarial_temperature": 1.0,
"margin": 6.0,
},
seed=42,
)
pipeline.train(
epochs=50,
log_every=5,
)
metrics_df, ranks_df = pipeline.evaluate(
split="test",
filtered=True,
ks=(1, 3, 10),
)
print(metrics_df)You can also save checkpoints during training:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="transe",
dataset="nations",
loss_name="mrl",
embedding_dim=128,
)
pipeline.train(
epochs=50,
log_every=5,
save_checkpoint_dir="checkpoints/transe-nations",
save_every=10,
)Then later:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="transe",
dataset="nations",
loss_name="mrl",
embedding_dim=128,
)
pipeline.load_checkpoint("checkpoints/transe-nations")
pipeline.train(epochs=20, log_every=5)This project uses PyKEEN datasets rather than re-implementing dataset download and packaging.
If you pass a dataset name such as "nations" or "fb15k" into the pipeline, it is resolved through PyKEEN. For the full list of available datasets, see the PyKEEN dataset documentation.
These are the models currently registered in the library:
| Model name | Class | Status |
|---|---|---|
transe |
TransE |
Implemented |
trans |
TranS |
Implemented |
distmult |
DistMult |
Implemented |
complex |
ComplEx |
Implemented |
rotate |
RotatE |
Implemented |
| Loss name | Meaning |
|---|---|
mrl |
Margin ranking loss |
bce |
Binary cross-entropy loss |
softplus |
Softplus loss |
nssa |
Self-adversarial negative sampling loss |
KGE models often need small operations around the embedding tables in addition to the interaction function itself.
- Initializers choose the starting embedding values before training.
- Constrainers project embedding values after initialization and after each optimizer step. They enforce hard constraints such as unit-length entity vectors or unit-modulus complex relation values.
- Regularizers add a soft penalty to the training loss. They encourage smaller or simpler embeddings, but they do not strictly enforce a constraint.
These are configured separately for entity and relation embeddings through
model_kwargs:
from kge_jaxed import KGEPipeline
pipeline = KGEPipeline(
model="rotate",
dataset="nations",
loss_name="nssa",
embedding_dim=128,
model_kwargs={
"entity_embedding_kwargs": {
"embedding_init": "normal_norm",
"embedding_init_kwargs": {"stddev": 0.1},
},
"relation_embedding_kwargs": {
"embedding_init": "init_phases",
},
"relation_constrainer_kwargs": {
"name": "unit_modulus",
},
"entity_regularizer_kwargs": {
"name": "lp",
"p": 2.0,
"normalize": True,
"weight": 0.01,
},
},
)The most useful built-in initializer names are:
| Name | Meaning |
|---|---|
uniform, normal |
Random uniform or normal initialization |
uniform_norm, normal_norm |
Random initialization followed by row-wise unit normalization |
xavier, xavier_uniform, glorot_uniform |
Glorot uniform initialization |
xavier_uniform_norm, glorot_uniform_norm |
Glorot uniform initialization followed by row-wise unit normalization |
xavier_normal, glorot_normal |
Glorot normal initialization |
xavier_normal_norm, glorot_normal_norm |
Glorot normal initialization followed by row-wise unit normalization |
complex_normal, complex_uniform |
Complex-valued initialization with independent real and imaginary parts |
init_phases, complex_phases |
Complex unit-modulus values, useful for RotatE relations |
The built-in constrainers are:
| Name | Meaning |
|---|---|
unit_norm or normalize |
Normalize each embedding row to unit L2 norm |
max_norm or clamp_norm |
Project rows with norm above max_value back to that norm |
clip or clamp |
Clip individual values into [min_value, max_value] |
non_negative |
Replace negative values with zero |
unit_modulus or complex_normalize |
Project each complex value to magnitude one |
The built-in regularizers are:
| Name | Meaning |
|---|---|
lp |
Mean or sum of row-wise Lp norms |
np, powersum, power_sum, n3 |
Mean or sum of sum(abs(x) ** p) per row; p=3 gives an N3-style penalty |
Passing None for a model config uses that model's default. Passing an empty
dict such as entity_constrainer_kwargs={} disables that default. This matters
for models such as TransE, which constrains entity embeddings by default, and
RotatE, which constrains relation embeddings to unit modulus by default.
Linting and unit tests are configured through tox.
tox -e lint
tox -e test