Skip to content

sbonner0/kge-jaxed

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

58 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

KGE-JAXed

Python JAX uv Ruff

This project started as a way for me to learn JAX by building something I already understood: Knowledge Graph Embedding models.

It is not intended to be a production-ready library, and I would not recommend it for serious work. If you want a mature and battle tested package with broader coverage, better tooling, and a much more complete feature set, use PyKEEN.

KGE-JAXed is a small JAX/Flax NNX knowledge graph embedding library inspired by PyKEEN. The main goal is to keep the code easy to read and extend, while still following JAX-friendly patterns.

Project Overview

KGE-JAXed is a small, learning-oriented knowledge graph embedding package built with JAX and Flax NNX. It provides a compact implementation of classic KGE training workflows while keeping the codebase readable enough to inspect, modify, and extend.

The main entry point is KGEPipeline, which handles dataset loading, model construction, training, evaluation, and checkpointing. Datasets are resolved through PyKEEN, so the project can focus on model, loss, sampling, and evaluation logic rather than dataset packaging.

Currently, the project supports:

  • Training classic KGE models such as TransE, DistMult, ComplEx, and RotatE
  • Running link-prediction evaluation with MR, MRR, and Hits@K metrics
  • Inspecting per-triple ranking results directly through ranks_df, rather than only seeing aggregated metrics
  • Saving and resuming training with checkpoints
  • Customizing embedding initializers, constrainers, and regularizers

The project is best suited for learning JAX, experimenting with KGE model internals, and working with a smaller codebase where the full training and evaluation path is easy to follow.

Installation

This repository uses uv. To install the package and associated dependencies, run:

uv sync

Quick start

The simplest way to train a model is:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="transe",
    dataset="nations",
    loss_name="mrl",
    embedding_dim=128,
    negative_samples=1,
    learning_rate=1e-2,
    optimizer_name="adam",
    seed=42,
)

pipeline.train(
    epochs=100,
    log_every=10,
)

That will train a TransE model on the PyKEEN nations dataset.

Recommended first runs

  • Start with transe + mrl on nations if you want the quickest first run
  • Try rotate + nssa once you want something a bit heavier
  • Use smaller embedding sizes and fewer epochs first if you are just checking that everything works

Train and evaluate a model

Here is a fuller example that trains a model and then evaluates it on the test split:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="transe",
    dataset="nations",
    loss_name="mrl",
    embedding_dim=128,
    negative_samples=1,
    learning_rate=1e-2,
    optimizer_name="adam",
    seed=42,
)

train_summary = pipeline.train(
    epochs=100,
    log_every=10,
)

metrics_df, ranks_df = pipeline.evaluate(
    split="test",
    filtered=True,
)

print(train_summary["train_losses"][-5:])
print(metrics_df)

evaluate() returns two dataframes:

  • metrics_df gives you the usual aggregate ranking metrics such as MRR, MR, and Hits@K.
  • ranks_df gives you the actual ranks and scores for each evaluated triple.

Per-triple ranking results

One of the main reasons I like this project is that evaluation does not stop at a single aggregate score. Aggregate metrics are useful for comparing runs, but they hide which triples the model gets right, which ones it struggles with, and whether the errors are mostly coming from head prediction or tail prediction.

ranks_df keeps that information visible. It gives you a row-by-row view of the evaluation results, with one row per evaluated triple.

It contains:

  • head, relation, tail for the evaluated triple
  • rank_head for head prediction rank
  • rank_tail for tail prediction rank
  • score_head for the score assigned to the true head query
  • score_tail for the score assigned to the true tail query

That makes it easy to answer more diagnostic questions:

  • Which triples are hardest for the model?
  • Is the model much better at head prediction or tail prediction?
  • Which examples are getting rank 1?
  • Which examples are failing badly even when the aggregate MRR looks reasonable?
  • Which relations seem to be causing the worst errors?

First, you can map the integer IDs back to readable entity and relation labels:

ranks_with_labels = ranks_df.assign(
    head_label=lambda df: df["head"].map(pipeline.dataset.id_to_entity),
    relation_label=lambda df: df["relation"].map(pipeline.dataset.id_to_relation),
    tail_label=lambda df: df["tail"].map(pipeline.dataset.id_to_entity),
)

print(ranks_with_labels.head())

Then you can inspect the worst-ranked examples directly:

# Hardest triples by tail rank
print(
    ranks_with_labels.sort_values("rank_tail", ascending=False)[
        ["head_label", "relation_label", "tail_label", "rank_tail", "score_tail"]
    ].head(10)
)

# Hardest triples by head rank
print(
    ranks_with_labels.sort_values("rank_head", ascending=False)[
        ["head_label", "relation_label", "tail_label", "rank_head", "score_head"]
    ].head(10)
)

Or summarize the errors by relation:

relation_summary = (
    ranks_with_labels.assign(avg_rank=lambda df: (df["rank_head"] + df["rank_tail"]) / 2)
    .groupby("relation_label")["avg_rank"]
    .mean()
    .sort_values(ascending=False)
)

print(relation_summary.head(10))

This makes evaluation much easier to debug. You can look beyond a single MRR number and inspect where the model is doing well or badly.

A typical metrics_df looks like this:

         head   tail    avg
mrr      ...    ...    ...
mr       ...    ...    ...
hits@1   ...    ...    ...
hits@3   ...    ...    ...
hits@10  ...    ...    ...

And ranks_df looks like this:

   head  relation  tail  rank_head  rank_tail  score_head  score_tail
0     0         1     2          1          3      ...        ...
1     4         2     7          8          1      ...        ...

Train with a different model

You can swap the model and loss configuration directly in the pipeline:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="rotate",
    dataset="fb15k",
    loss_name="nssa",
    embedding_dim=256,
    negative_samples=16,
    learning_rate=1e-3,
    optimizer_name="adam",
    dataset_kwargs={
        "batch_size": 256,
        "shuffle": True,
    },
    loss_kwargs={
        "adversarial_temperature": 1.0,
        "margin": 6.0,
    },
    seed=42,
)

pipeline.train(
    epochs=50,
    log_every=5,
)

metrics_df, ranks_df = pipeline.evaluate(
    split="test",
    filtered=True,
    ks=(1, 3, 10),
)

print(metrics_df)

Save and resume training

You can also save checkpoints during training:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="transe",
    dataset="nations",
    loss_name="mrl",
    embedding_dim=128,
)

pipeline.train(
    epochs=50,
    log_every=5,
    save_checkpoint_dir="checkpoints/transe-nations",
    save_every=10,
)

Then later:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="transe",
    dataset="nations",
    loss_name="mrl",
    embedding_dim=128,
)

pipeline.load_checkpoint("checkpoints/transe-nations")
pipeline.train(epochs=20, log_every=5)

Datasets

This project uses PyKEEN datasets rather than re-implementing dataset download and packaging.

If you pass a dataset name such as "nations" or "fb15k" into the pipeline, it is resolved through PyKEEN. For the full list of available datasets, see the PyKEEN dataset documentation.

Supported models

These are the models currently registered in the library:

Model name Class Status
transe TransE Implemented
trans TranS Implemented
distmult DistMult Implemented
complex ComplEx Implemented
rotate RotatE Implemented

Supported losses

Loss name Meaning
mrl Margin ranking loss
bce Binary cross-entropy loss
softplus Softplus loss
nssa Self-adversarial negative sampling loss

Embedding initializers, constrainers, and regularizers

KGE models often need small operations around the embedding tables in addition to the interaction function itself.

  • Initializers choose the starting embedding values before training.
  • Constrainers project embedding values after initialization and after each optimizer step. They enforce hard constraints such as unit-length entity vectors or unit-modulus complex relation values.
  • Regularizers add a soft penalty to the training loss. They encourage smaller or simpler embeddings, but they do not strictly enforce a constraint.

These are configured separately for entity and relation embeddings through model_kwargs:

from kge_jaxed import KGEPipeline

pipeline = KGEPipeline(
    model="rotate",
    dataset="nations",
    loss_name="nssa",
    embedding_dim=128,
    model_kwargs={
        "entity_embedding_kwargs": {
            "embedding_init": "normal_norm",
            "embedding_init_kwargs": {"stddev": 0.1},
        },
        "relation_embedding_kwargs": {
            "embedding_init": "init_phases",
        },
        "relation_constrainer_kwargs": {
            "name": "unit_modulus",
        },
        "entity_regularizer_kwargs": {
            "name": "lp",
            "p": 2.0,
            "normalize": True,
            "weight": 0.01,
        },
    },
)

The most useful built-in initializer names are:

Name Meaning
uniform, normal Random uniform or normal initialization
uniform_norm, normal_norm Random initialization followed by row-wise unit normalization
xavier, xavier_uniform, glorot_uniform Glorot uniform initialization
xavier_uniform_norm, glorot_uniform_norm Glorot uniform initialization followed by row-wise unit normalization
xavier_normal, glorot_normal Glorot normal initialization
xavier_normal_norm, glorot_normal_norm Glorot normal initialization followed by row-wise unit normalization
complex_normal, complex_uniform Complex-valued initialization with independent real and imaginary parts
init_phases, complex_phases Complex unit-modulus values, useful for RotatE relations

The built-in constrainers are:

Name Meaning
unit_norm or normalize Normalize each embedding row to unit L2 norm
max_norm or clamp_norm Project rows with norm above max_value back to that norm
clip or clamp Clip individual values into [min_value, max_value]
non_negative Replace negative values with zero
unit_modulus or complex_normalize Project each complex value to magnitude one

The built-in regularizers are:

Name Meaning
lp Mean or sum of row-wise Lp norms
np, powersum, power_sum, n3 Mean or sum of sum(abs(x) ** p) per row; p=3 gives an N3-style penalty

Passing None for a model config uses that model's default. Passing an empty dict such as entity_constrainer_kwargs={} disables that default. This matters for models such as TransE, which constrains entity embeddings by default, and RotatE, which constrains relation embeddings to unit modulus by default.

Development

Linting and unit tests are configured through tox.

tox -e lint
tox -e test

About

KGE-JAXed: A simple knowledge graph embedding library created in JAX

Topics

Resources

Stars

Watchers

Forks

Contributors

Languages