Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5576839
add option for passing mol object, gentle error handling
sfluegel05 Feb 25, 2026
95adba2
update reader for mol objects
sfluegel05 Feb 25, 2026
ad4432f
fix read property
sfluegel05 Feb 25, 2026
66c9659
make read_properties more flexible, don't recalculate properties
sfluegel05 Feb 25, 2026
be7278d
avoid redundant property calculation
sfluegel05 Feb 25, 2026
a50ce8f
fix function call
sfluegel05 Feb 25, 2026
b629f9c
catch mol processing errors
sfluegel05 Feb 25, 2026
430c72c
make sure that property cache is mutable
sfluegel05 Feb 25, 2026
2b9c1e5
fix data loading
sfluegel05 Feb 28, 2026
3a53cd2
add chebi25 dataset
sfluegel05 Feb 28, 2026
db3d429
use chebi-utils
sfluegel05 Feb 28, 2026
e5d870b
fix function name
sfluegel05 Mar 1, 2026
315f49c
fix function name
sfluegel05 Mar 2, 2026
de3193e
fix import
sfluegel05 Mar 2, 2026
ea77f36
add new tokens
Mar 2, 2026
354225b
catch none
sfluegel05 Mar 2, 2026
e78183e
Merge branch 'feature/sdf-support' of https://github.com/ChEB-AI/pyth…
sfluegel05 Mar 2, 2026
d2bbad5
fix assertion error
sfluegel05 Mar 2, 2026
0ba4f11
only calculate entended molecule graph if needed, sanitize molecule w…
sfluegel05 Mar 2, 2026
6a85dd2
update default values in configs
sfluegel05 Mar 2, 2026
578c2ef
reformat w/ ruff
sfluegel05 Mar 2, 2026
4abdcf3
update fg tokens
sfluegel05 Mar 6, 2026
885a405
property failure handling
sfluegel05 Mar 6, 2026
7f6cb23
assert idents match property values
sfluegel05 Mar 6, 2026
d2522f7
format
sfluegel05 Mar 6, 2026
4b8124b
use ruff for precommit
sfluegel05 Mar 6, 2026
9ba3db5
ruff format
sfluegel05 Mar 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions .github/workflows/lint.yml

This file was deleted.

19 changes: 19 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Pre-commit Check

on:
push:
branches: [main, master]
pull_request:

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Run pre-commit
uses: pre-commit/action@v3.0.1
Comment on lines +12 to +19
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workflow YAML is invalid due to indentation: steps: must contain a properly-indented list (the - uses: entries should be indented under steps:). As written, GitHub Actions will fail to parse this workflow.

Suggested change
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run pre-commit
uses: pre-commit/action@v3.0.1
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run pre-commit
uses: pre-commit/action@v3.0.1

Copilot uses AI. Check for mistakes.
44 changes: 16 additions & 28 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
repos:
- repo: https://github.com/psf/black
rev: "25.1.0"
hooks:
- id: black
- id: black-jupyter # for formatting jupyter-notebook
# Use `pre-commit autoupdate` to update all the hook.
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor grammar: "update all the hook" should be "update all the hooks".

Suggested change
# Use `pre-commit autoupdate` to update all the hook.
# Use `pre-commit autoupdate` to update all the hooks.

Copilot uses AI. Check for mistakes.

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile=black"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. https://docs.astral.sh/ruff/integrations/#pre-commit
rev: v0.14.11
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
18 changes: 9 additions & 9 deletions chebai_graph/models/dynamic_gni.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
print("Using complete randomness: ", self.complete_randomness)

if not self.complete_randomness:
assert (
"pad_node_features" in config or "pad_edge_features" in config
), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
assert "pad_node_features" in config or "pad_edge_features" in config, (
"Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
)
self.pad_node_features = (
int(config["pad_node_features"])
if config.get("pad_node_features") is not None
Expand All @@ -112,9 +112,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
f"in each forward pass."
)

assert (
self.pad_node_features > 0 or self.pad_edge_features > 0
), "'pad_node_features' or 'pad_edge_features' must be positive integers"
assert self.pad_node_features > 0 or self.pad_edge_features > 0, (
"'pad_node_features' or 'pad_edge_features' must be positive integers"
)

self.resgated: BasicGNN = ResGatedModel(
in_channels=self.in_channels,
Expand Down Expand Up @@ -182,9 +182,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
)
new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1)

assert (
new_x is not None and new_edge_attr is not None
), "Feature initialization failed"
assert new_x is not None and new_edge_attr is not None, (
"Feature initialization failed"
)
out = self.resgated(
x=new_x.float(),
edge_index=graph_data.edge_index.long(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ RING_164
RING_71
RING_46
orthoester
RING_55
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
1
5
6
8
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ SINGLE
AROMATIC
TRIPLE
DOUBLE
UNSPECIFIED
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
7
10
12
11
9
124 changes: 87 additions & 37 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import pandas as pd
from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader
import torch
import tqdm
from chebai.preprocessing.datasets.chebi import (
Expand All @@ -15,6 +16,7 @@
)
from lightning_utilities.core.rank_zero import rank_zero_info
from torch_geometric.data.data import Data as GeomData
from rdkit import Chem

from chebai_graph.preprocessing.properties import (
AllNodeTypeProperty,
Expand All @@ -40,7 +42,7 @@
RandomFeatureInitializationReader,
)

from .utils import resolve_property
from chebai_graph.preprocessing.datasets.utils import resolve_property


class ChEBI50GraphData(ChEBIOver50):
Expand Down Expand Up @@ -126,31 +128,53 @@ def enc_if_not_none(encode, value):
else None
)

for property in self.properties:
if not os.path.isfile(self.get_property_path(property)):
rank_zero_info(f"Processing property {property.name}")
# read all property values first, then encode
rank_zero_info(f"\tReading property values of {property.name}...")
property_values = [
self.reader.read_property(feat, property)
for feat in tqdm.tqdm(features)
]
rank_zero_info(f"\tEncoding property values of {property.name}...")
property.encoder.on_start(property_values=property_values)
encoded_values = [
enc_if_not_none(property.encoder.encode, value)
for value in tqdm.tqdm(property_values)
if any(
not os.path.isfile(self.get_property_path(property))
for property in self.properties
):
# augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy)
if isinstance(self.reader, _AugmentorReader):
returned_results = []
for mol in features:
try:
r = self.reader._create_augmented_graph(mol)
except Exception:
r = None
returned_results.append(r)
mols = [
augmented_mol[1] if augmented_mol is not None else None
for augmented_mol in returned_results
]
Comment on lines +135 to 147
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _setup_properties, when self.reader is an _AugmentorReader you call self.reader._create_augmented_graph(mol) for each entry in features. For existing SMILES-based datasets features is typically a str, but _create_augmented_graph expects an RDKit Chem.Mol and will fail (caught and turned into None), resulting in mols being all None and property extraction being skipped. Consider using self.reader.read_property(...) directly, or convert SMILES to Chem.Mol via the reader’s _smiles_to_mol before calling _create_augmented_graph.

Copilot uses AI. Check for mistakes.

torch.save(
[
{property.name: torch.cat(feat), "ident": id}
for feat, id in zip(encoded_values, idents)
if feat is not None
],
self.get_property_path(property),
)
property.on_finish()
else:
mols = features

for property in self.properties:
if not os.path.isfile(self.get_property_path(property)):
rank_zero_info(f"Processing property {property.name}")
# read all property values first, then encode
rank_zero_info(f"\tReading property values of {property.name}...")
property_values = [
self.reader.read_property(mol, property)
if mol is not None
else None
for mol in tqdm.tqdm(mols)
]
rank_zero_info(f"\tEncoding property values of {property.name}...")
property.encoder.on_start(property_values=property_values)
encoded_values = [
enc_if_not_none(property.encoder.encode, value)
for value in tqdm.tqdm(property_values)
]
assert len(encoded_values) == len(idents) == len(features)
torch.save(
[
{property.name: torch.cat(feat), "ident": id}
for feat, id in zip(encoded_values, idents)
if feat is not None
],
self.get_property_path(property),
)
property.on_finish()

@property
def processed_properties_dir(self) -> str:
Expand Down Expand Up @@ -185,20 +209,23 @@ def _after_setup(self, **kwargs) -> None:
super()._after_setup(**kwargs)

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None
) -> Optional[dict]:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
result = self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
{"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]}
)
# _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object
if isinstance(result["features"], tuple):
result["features"], raw_data = result["features"]
if result is None or result["features"] is None:
return None
Comment on lines +222 to 225
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_preprocess_smiles_for_pred checks isinstance(result["features"], tuple) before verifying that result is not None. Since the next line explicitly handles result is None, this can currently raise a TypeError/AttributeError when to_data(...) returns None. Move the result is None / result["features"] is None guard before dereferencing result["features"].

Suggested change
if isinstance(result["features"], tuple):
result["features"], raw_data = result["features"]
if result is None or result["features"] is None:
return None
if result is None or result["features"] is None:
return None
if isinstance(result["features"], tuple):
result["features"], raw_data = result["features"]

Copilot uses AI. Check for mistakes.
for property in self.properties:
property.encoder.eval = True
property_value = self.reader.read_property(smiles, property)
property_value = self.reader.read_property(raw_data, property)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
Expand Down Expand Up @@ -250,7 +277,9 @@ def __init__(
assert (
distribution is not None
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
), "When using padding for features, a valid distribution must be specified."
), (
"When using padding for features, a valid distribution must be specified."
)
self.distribution = distribution
if self.pad_node_features:
print(
Expand Down Expand Up @@ -278,7 +307,12 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
Returns:
A GeomData object with merged features.
"""
geom_data = row["features"]
if isinstance(row["features"], tuple):
geom_data, _ = row[
"features"
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
else:
geom_data = row["features"]
Comment on lines +310 to +315
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphPropertiesMixIn._merge_props_into_base now supports row["features"] being a tuple, but other mixins in this same module (e.g., the augmented-graph property mixins) still treat row["features"] as a GeomData and read masks from it. If the processed dataset stores (GeomData, augmented_molecule) (as _AugmentorReader._read_data now returns), those asserts/mask accesses will break at load time. Consider normalizing row["features"] to GeomData earlier (or updating the remaining call sites to unwrap tuples consistently).

Copilot uses AI. Check for mistakes.
assert isinstance(geom_data, GeomData)
edge_attr = geom_data.edge_attr
x = geom_data.x
Expand Down Expand Up @@ -538,6 +572,10 @@ def _merge_props_into_base(
geom_data = row["features"]
if geom_data is None:
return None
if isinstance(geom_data, tuple):
geom_data = geom_data[
0
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
assert isinstance(geom_data, GeomData)

is_atom_node = geom_data.is_atom_node
Expand All @@ -550,9 +588,9 @@ def _merge_props_into_base(
edge_attr = geom_data.edge_attr

# Initialize node feature matrix
assert (
max_len_node_properties is not None
), "Maximum len of node properties should not be None"
assert max_len_node_properties is not None, (
"Maximum len of node properties should not be None"
)
x = torch.zeros((num_nodes, max_len_node_properties))

# Track column offsets for each node type
Expand Down Expand Up @@ -607,9 +645,9 @@ def _merge_props_into_base(
raise TypeError(f"Unsupported property type: {type(property).__name__}")

total_used_columns = max(atom_offset, fg_offset, graph_offset)
assert (
total_used_columns <= max_len_node_properties
), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
assert total_used_columns <= max_len_node_properties, (
f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
)

return GeomData(
x=x,
Expand Down Expand Up @@ -805,3 +843,15 @@ class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50):

class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100):
READER = AtomFGReader_WithFGEdges_WithGraphNode


class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX):
READER = AtomFGReader_WithFGEdges_WithGraphNode

THRESHOLD = 25


if __name__ == "__main__":
dataset = ChEBI25_WFGE_WGN_AsPerNodeType(chebi_version=248, subset="3_STAR")
dataset.prepare_data()
dataset.setup()
Comment on lines +852 to +857
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file appears to include an executable __main__ block that runs dataset preparation/setup. For a library/module file this is easy to trigger accidentally and can lead to long-running side effects when users run the module directly. Consider removing it from the PR, or moving this into a dedicated CLI/script (or a reproducible notebook) if it’s meant as a debugging entrypoint.

Suggested change
if __name__ == "__main__":
dataset = ChEBI25_WFGE_WGN_AsPerNodeType(chebi_version=248, subset="3_STAR")
dataset.prepare_data()
dataset.setup()

Copilot uses AI. Check for mistakes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rdkit.Chem import AllChem
from rdkit.Chem import MolToSmiles as m2s

from chebi_utils.sdf_extractor import _sanitize_molecule

from .fg_constants import ELEMENTS, FLAG_NO_FG


Expand Down Expand Up @@ -1911,7 +1913,11 @@ def get_structure(mol):
structure[frag] = {"atom": atom_idx, "is_ring_fg": False}

# Convert fragment SMILES back to mol to match with fused ring atom indices
frag_mol = Chem.MolFromSmiles(frag)
frag_mol = Chem.MolFromSmiles(frag, sanitize=False)
try:
frag_mol = _sanitize_molecule(frag_mol)
except Exception:
pass
frag_rings = frag_mol.GetRingInfo().AtomRings()
Comment on lines +1917 to 1921
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frag_mol can be None when Chem.MolFromSmiles(frag, sanitize=False) fails, or it can become None if sanitization fails. In both cases the next line calls frag_mol.GetRingInfo(), which will raise an AttributeError. Add a None check (and decide whether to mark is_ring_fg as False or skip the fragment) before accessing ring info.

Suggested change
try:
frag_mol = _sanitize_molecule(frag_mol)
except Exception:
pass
frag_rings = frag_mol.GetRingInfo().AtomRings()
if frag_mol is not None:
try:
frag_mol = _sanitize_molecule(frag_mol)
except Exception:
# If sanitization fails, fall back to the original fragment molecule
pass
# If molecule creation or sanitization failed, treat as having no rings
if frag_mol is None:
frag_rings = ()
else:
try:
frag_rings = frag_mol.GetRingInfo().AtomRings()
except Exception:
frag_rings = ()

Copilot uses AI. Check for mistakes.
if len(frag_rings) >= 1:
structure[frag]["is_ring_fg"] = True
Expand Down
Loading