-
Notifications
You must be signed in to change notification settings - Fork 2
SDF-based dataset support #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
5576839
95adba2
ad4432f
66c9659
be7278d
a50ce8f
b629f9c
430c72c
2b9c1e5
3a53cd2
db3d429
e5d870b
315f49c
de3193e
ea77f36
354225b
e78183e
d2bbad5
0ba4f11
6a85dd2
578c2ef
4abdcf3
885a405
7f6cb23
d2522f7
4b8124b
9ba3db5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| name: Pre-commit Check | ||
|
|
||
| on: | ||
| push: | ||
| branches: [main, master] | ||
| pull_request: | ||
|
|
||
| jobs: | ||
| pre-commit: | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
|
|
||
| - uses: actions/setup-python@v5 | ||
| with: | ||
| python-version: '3.10' | ||
|
|
||
| - name: Run pre-commit | ||
| uses: pre-commit/action@v3.0.1 | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,31 +1,19 @@ | ||||||
| repos: | ||||||
| - repo: https://github.com/psf/black | ||||||
| rev: "25.1.0" | ||||||
| hooks: | ||||||
| - id: black | ||||||
| - id: black-jupyter # for formatting jupyter-notebook | ||||||
| # Use `pre-commit autoupdate` to update all the hook. | ||||||
|
||||||
| # Use `pre-commit autoupdate` to update all the hook. | |
| # Use `pre-commit autoupdate` to update all the hooks. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,3 +156,4 @@ RING_164 | |
| RING_71 | ||
| RING_46 | ||
| orthoester | ||
| RING_55 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,3 +5,4 @@ | |
| 1 | ||
| 5 | ||
| 6 | ||
| 8 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,3 +3,4 @@ SINGLE | |
| AROMATIC | ||
| TRIPLE | ||
| DOUBLE | ||
| UNSPECIFIED | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,3 +9,5 @@ | |
| 7 | ||
| 10 | ||
| 12 | ||
| 11 | ||
| 9 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||||||||||||||||
| from typing import Optional | ||||||||||||||||||
|
|
||||||||||||||||||
| import pandas as pd | ||||||||||||||||||
| from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader | ||||||||||||||||||
| import torch | ||||||||||||||||||
| import tqdm | ||||||||||||||||||
| from chebai.preprocessing.datasets.chebi import ( | ||||||||||||||||||
|
|
@@ -15,6 +16,7 @@ | |||||||||||||||||
| ) | ||||||||||||||||||
| from lightning_utilities.core.rank_zero import rank_zero_info | ||||||||||||||||||
| from torch_geometric.data.data import Data as GeomData | ||||||||||||||||||
| from rdkit import Chem | ||||||||||||||||||
|
|
||||||||||||||||||
| from chebai_graph.preprocessing.properties import ( | ||||||||||||||||||
| AllNodeTypeProperty, | ||||||||||||||||||
|
|
@@ -40,7 +42,7 @@ | |||||||||||||||||
| RandomFeatureInitializationReader, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| from .utils import resolve_property | ||||||||||||||||||
| from chebai_graph.preprocessing.datasets.utils import resolve_property | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class ChEBI50GraphData(ChEBIOver50): | ||||||||||||||||||
|
|
@@ -126,31 +128,53 @@ def enc_if_not_none(encode, value): | |||||||||||||||||
| else None | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| for property in self.properties: | ||||||||||||||||||
| if not os.path.isfile(self.get_property_path(property)): | ||||||||||||||||||
| rank_zero_info(f"Processing property {property.name}") | ||||||||||||||||||
| # read all property values first, then encode | ||||||||||||||||||
| rank_zero_info(f"\tReading property values of {property.name}...") | ||||||||||||||||||
| property_values = [ | ||||||||||||||||||
| self.reader.read_property(feat, property) | ||||||||||||||||||
| for feat in tqdm.tqdm(features) | ||||||||||||||||||
| ] | ||||||||||||||||||
| rank_zero_info(f"\tEncoding property values of {property.name}...") | ||||||||||||||||||
| property.encoder.on_start(property_values=property_values) | ||||||||||||||||||
| encoded_values = [ | ||||||||||||||||||
| enc_if_not_none(property.encoder.encode, value) | ||||||||||||||||||
| for value in tqdm.tqdm(property_values) | ||||||||||||||||||
| if any( | ||||||||||||||||||
| not os.path.isfile(self.get_property_path(property)) | ||||||||||||||||||
| for property in self.properties | ||||||||||||||||||
| ): | ||||||||||||||||||
| # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) | ||||||||||||||||||
| if isinstance(self.reader, _AugmentorReader): | ||||||||||||||||||
| returned_results = [] | ||||||||||||||||||
| for mol in features: | ||||||||||||||||||
| try: | ||||||||||||||||||
| r = self.reader._create_augmented_graph(mol) | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| r = None | ||||||||||||||||||
| returned_results.append(r) | ||||||||||||||||||
| mols = [ | ||||||||||||||||||
| augmented_mol[1] if augmented_mol is not None else None | ||||||||||||||||||
| for augmented_mol in returned_results | ||||||||||||||||||
| ] | ||||||||||||||||||
|
Comment on lines
+135
to
147
|
||||||||||||||||||
|
|
||||||||||||||||||
| torch.save( | ||||||||||||||||||
| [ | ||||||||||||||||||
| {property.name: torch.cat(feat), "ident": id} | ||||||||||||||||||
| for feat, id in zip(encoded_values, idents) | ||||||||||||||||||
| if feat is not None | ||||||||||||||||||
| ], | ||||||||||||||||||
| self.get_property_path(property), | ||||||||||||||||||
| ) | ||||||||||||||||||
| property.on_finish() | ||||||||||||||||||
| else: | ||||||||||||||||||
| mols = features | ||||||||||||||||||
|
|
||||||||||||||||||
| for property in self.properties: | ||||||||||||||||||
| if not os.path.isfile(self.get_property_path(property)): | ||||||||||||||||||
| rank_zero_info(f"Processing property {property.name}") | ||||||||||||||||||
| # read all property values first, then encode | ||||||||||||||||||
| rank_zero_info(f"\tReading property values of {property.name}...") | ||||||||||||||||||
| property_values = [ | ||||||||||||||||||
| self.reader.read_property(mol, property) | ||||||||||||||||||
| if mol is not None | ||||||||||||||||||
| else None | ||||||||||||||||||
| for mol in tqdm.tqdm(mols) | ||||||||||||||||||
| ] | ||||||||||||||||||
| rank_zero_info(f"\tEncoding property values of {property.name}...") | ||||||||||||||||||
| property.encoder.on_start(property_values=property_values) | ||||||||||||||||||
| encoded_values = [ | ||||||||||||||||||
| enc_if_not_none(property.encoder.encode, value) | ||||||||||||||||||
| for value in tqdm.tqdm(property_values) | ||||||||||||||||||
| ] | ||||||||||||||||||
| assert len(encoded_values) == len(idents) == len(features) | ||||||||||||||||||
| torch.save( | ||||||||||||||||||
| [ | ||||||||||||||||||
| {property.name: torch.cat(feat), "ident": id} | ||||||||||||||||||
| for feat, id in zip(encoded_values, idents) | ||||||||||||||||||
| if feat is not None | ||||||||||||||||||
| ], | ||||||||||||||||||
| self.get_property_path(property), | ||||||||||||||||||
| ) | ||||||||||||||||||
| property.on_finish() | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def processed_properties_dir(self) -> str: | ||||||||||||||||||
|
|
@@ -185,20 +209,23 @@ def _after_setup(self, **kwargs) -> None: | |||||||||||||||||
| super()._after_setup(**kwargs) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _preprocess_smiles_for_pred( | ||||||||||||||||||
| self, idx, smiles: str, model_hparams: Optional[dict] = None | ||||||||||||||||||
| ) -> dict: | ||||||||||||||||||
| self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None | ||||||||||||||||||
| ) -> Optional[dict]: | ||||||||||||||||||
| """Preprocess prediction data.""" | ||||||||||||||||||
| # Add dummy labels because the collate function requires them. | ||||||||||||||||||
| # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, | ||||||||||||||||||
| # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. | ||||||||||||||||||
| result = self.reader.to_data( | ||||||||||||||||||
| {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} | ||||||||||||||||||
| {"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]} | ||||||||||||||||||
| ) | ||||||||||||||||||
| # _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object | ||||||||||||||||||
| if isinstance(result["features"], tuple): | ||||||||||||||||||
| result["features"], raw_data = result["features"] | ||||||||||||||||||
| if result is None or result["features"] is None: | ||||||||||||||||||
| return None | ||||||||||||||||||
|
Comment on lines
+222
to
225
|
||||||||||||||||||
| if isinstance(result["features"], tuple): | |
| result["features"], raw_data = result["features"] | |
| if result is None or result["features"] is None: | |
| return None | |
| if result is None or result["features"] is None: | |
| return None | |
| if isinstance(result["features"], tuple): | |
| result["features"], raw_data = result["features"] |
Copilot
AI
Mar 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphPropertiesMixIn._merge_props_into_base now supports row["features"] being a tuple, but other mixins in this same module (e.g., the augmented-graph property mixins) still treat row["features"] as a GeomData and read masks from it. If the processed dataset stores (GeomData, augmented_molecule) (as _AugmentorReader._read_data now returns), those asserts/mask accesses will break at load time. Consider normalizing row["features"] to GeomData earlier (or updating the remaining call sites to unwrap tuples consistently).
Copilot
AI
Mar 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file appears to include an executable __main__ block that runs dataset preparation/setup. For a library/module file this is easy to trigger accidentally and can lead to long-running side effects when users run the module directly. Consider removing it from the PR, or moving this into a dedicated CLI/script (or a reproducible notebook) if it’s meant as a debugging entrypoint.
| if __name__ == "__main__": | |
| dataset = ChEBI25_WFGE_WGN_AsPerNodeType(chebi_version=248, subset="3_STAR") | |
| dataset.prepare_data() | |
| dataset.setup() |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |||||||||||||||||||||||||||||||||||||||
| from rdkit.Chem import AllChem | ||||||||||||||||||||||||||||||||||||||||
| from rdkit.Chem import MolToSmiles as m2s | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from chebi_utils.sdf_extractor import _sanitize_molecule | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from .fg_constants import ELEMENTS, FLAG_NO_FG | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1911,7 +1913,11 @@ def get_structure(mol): | |||||||||||||||||||||||||||||||||||||||
| structure[frag] = {"atom": atom_idx, "is_ring_fg": False} | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Convert fragment SMILES back to mol to match with fused ring atom indices | ||||||||||||||||||||||||||||||||||||||||
| frag_mol = Chem.MolFromSmiles(frag) | ||||||||||||||||||||||||||||||||||||||||
| frag_mol = Chem.MolFromSmiles(frag, sanitize=False) | ||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| frag_mol = _sanitize_molecule(frag_mol) | ||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||||
| frag_rings = frag_mol.GetRingInfo().AtomRings() | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1917
to
1921
|
||||||||||||||||||||||||||||||||||||||||
| try: | |
| frag_mol = _sanitize_molecule(frag_mol) | |
| except Exception: | |
| pass | |
| frag_rings = frag_mol.GetRingInfo().AtomRings() | |
| if frag_mol is not None: | |
| try: | |
| frag_mol = _sanitize_molecule(frag_mol) | |
| except Exception: | |
| # If sanitization fails, fall back to the original fragment molecule | |
| pass | |
| # If molecule creation or sanitization failed, treat as having no rings | |
| if frag_mol is None: | |
| frag_rings = () | |
| else: | |
| try: | |
| frag_rings = frag_mol.GetRingInfo().AtomRings() | |
| except Exception: | |
| frag_rings = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The workflow YAML is invalid due to indentation:
steps:must contain a properly-indented list (the- uses:entries should be indented understeps:). As written, GitHub Actions will fail to parse this workflow.