From 65479cdd7aaf60b99e954b9c7dea59f770f64b42 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 25 Jun 2025 15:04:30 +0100 Subject: [PATCH] CU-8699hj2dx Revamp component initialisation (CogStack/MedCAT2#95) * CU-8699hj2dx: Initial changes to remove config-based init args and hardcode it (WIP) * CU-8699hj2dx: Update/fix a registration test * CU-8699hj2dx: Some minor keyword argument renaming * CU-8699hj2dx: Fix RelCAT tests (init) * CU-8699hj2dx: Update Transformers NER to work when loading models * CU-8699hj2dx: Fix DeID deserialising test * CU-8699hj2dx: Fix MeaCAT init * CU-8699hj2dx: Fix RelCAT init/load * CU-8699hj2dx: Remove unused import * CU-8699hj2dx: Add doc string regarding keyword arguments when manually deserialising * CU-8699hj2dx: Update pipeline with notes regarding keyword arguments for manual deserialisation --- medcat-v2/medcat/components/addons/addons.py | 46 +++++-- .../components/addons/meta_cat/meta_cat.py | 37 +++-- .../addons/relation_extraction/rel_cat.py | 46 ++++--- .../linking/context_based_linker.py | 17 +-- .../components/linking/no_action_linker.py | 16 +-- .../linking/two_step_context_based_linker.py | 15 +- .../medcat/components/ner/dict_based_ner.py | 15 +- .../components/ner/trf/transformers_ner.py | 41 ++++-- .../medcat/components/ner/vocab_based_ner.py | 15 +- .../components/normalizing/normalizer.py | 17 +-- medcat-v2/medcat/components/tagging/tagger.py | 15 +- medcat-v2/medcat/components/types.py | 87 ++++++------ medcat-v2/medcat/config/config.py | 28 ---- medcat-v2/medcat/pipeline/pipeline.py | 79 +++++------ medcat-v2/medcat/storage/serialisables.py | 21 +++ medcat-v2/medcat/storage/serialisers.py | 7 + .../medcat/tokenizing/regex_impl/tokenizer.py | 8 +- .../tokenizing/spacy_impl/tokenizers.py | 13 +- medcat-v2/medcat/tokenizing/tokenizers.py | 27 ++-- medcat-v2/medcat/utils/default_args.py | 130 ------------------ .../medcat/utils/legacy/convert_rel_cat.py | 10 +- .../relation_extraction/test_rel_cat.py | 8 +- .../test_rel_cat_in_model_pack.py | 8 +- .../tests/components/addons/test_addons.py | 36 +++-- medcat-v2/tests/components/helper.py | 14 +- .../linking/test_context_based_linker.py | 1 + .../components/ner/test_vocab_based_ner.py | 1 + .../ner/trf/test_transformers_ner.py | 3 +- .../components/normalizing/test_normalizer.py | 1 + .../tests/components/tagging/test_tagger.py | 1 + .../tests/components/test_registration.py | 51 +++---- medcat-v2/tests/components/test_types.py | 16 ++- .../tokenizing/spacy_impl/test_tokenizers.py | 10 +- medcat-v2/tests/utils/ner/test_deid.py | 4 +- 34 files changed, 354 insertions(+), 490 deletions(-) delete mode 100644 medcat-v2/medcat/utils/default_args.py diff --git a/medcat-v2/medcat/components/addons/addons.py b/medcat-v2/medcat/components/addons/addons.py index 0b3f05524..d32c4ce73 100644 --- a/medcat-v2/medcat/components/addons/addons.py +++ b/medcat-v2/medcat/components/addons/addons.py @@ -1,8 +1,11 @@ -from typing import Callable, Protocol, Any, runtime_checkable +from typing import Callable, Protocol, Any, runtime_checkable, Optional from medcat.components.types import BaseComponent, MutableEntity from medcat.utils.registry import Registry from medcat.config.config import ComponentConfig +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.tokenizing.tokenizers import BaseTokenizer @runtime_checkable @@ -19,9 +22,15 @@ def addon_type(self) -> str: def is_core(self) -> bool: return False + @classmethod + def get_folder_name_for_addon_and_name( + cls, addon_type: str, name: str) -> str: + return (cls.NAME_PREFIX + addon_type + + cls.NAME_SPLITTER + name) + def get_folder_name(self) -> str: - return (self.NAME_PREFIX + self.addon_type + - self.NAME_SPLITTER + self.name) + return self.get_folder_name_for_addon_and_name( + self.addon_type, self.name) @property def full_name(self) -> str: @@ -36,11 +45,15 @@ def get_output_key_val(self, ent: MutableEntity pass +AddonClass = Callable[[ComponentConfig, BaseTokenizer, + CDB, Vocab, Optional[str]], AddonComponent] + + _DEFAULT_ADDONS: dict[str, tuple[str, str]] = { 'meta_cat': ('medcat.components.addons.meta_cat.meta_cat', - 'MetaCATAddon.create_new'), + 'MetaCATAddon.create_new_component'), 'rel_cat': ('medcat.components.addons.relation_extraction.rel_cat', - 'RelCATAddon.create_new') + 'RelCATAddon.create_new_component') } # NOTE: type error due to non-concrete type @@ -48,30 +61,32 @@ def get_output_key_val(self, ent: MutableEntity def register_addon(addon_name: str, - addon_cls: Callable[..., AddonComponent]) -> None: + addon_cls: AddonClass) -> None: """Register a new addon. Args: addon_name (str): The addon name. - addon_cls (Callable[..., AddonComponent]): The addon creator. + addon_cls (AddonClass): The addon creator. """ _ADDON_REGISTRY.register(addon_name, addon_cls) -def get_addon_creator(addon_name: str) -> Callable[..., AddonComponent]: +def get_addon_creator(addon_name: str) -> AddonClass: """Get the creator for an addon. Args: addon_name (str): The name of the addonl Returns: - Callable[..., AddonComponent]: The creator of the addon. + AddonClass: The creator of the addon. """ return _ADDON_REGISTRY.get_component(addon_name) -def create_addon(addon_name: str, cnf: ComponentConfig, - *args, **kwargs) -> AddonComponent: +def create_addon( + addon_name: str, cnf: ComponentConfig, + tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, + model_load_path: Optional[str]) -> AddonComponent: """Create an addon of the specified name with the specified arguments. All the `*args`, and `**kwrags` are passed to the creator. @@ -79,8 +94,15 @@ def create_addon(addon_name: str, cnf: ComponentConfig, Args: addon_name (str): The name of the addon. cnf (ComponentConfig): The addon config. + tokenizer (BaseTokenizer): The base tokenizer to be passed to creator. + cdb (CDB): The CDB to be passed to creator. + vocab (Vocab): The Vocab to be passed to creator. + model_load_path (Optional[str]): The optional model load path to be + passed to creator. + Returns: AddonComponent: The resulting / created addon. """ - return get_addon_creator(addon_name)(cnf, *args, **kwargs) + return get_addon_creator(addon_name)( + cnf, tokenizer, cdb, vocab, model_load_path) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index b0625870c..cac3e3caa 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -11,6 +11,7 @@ import torch from torch import nn, Tensor from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.config.config import ComponentConfig from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.components.addons.meta_cat.ml_utils import ( predict, train_model, set_all_seeds, eval_model) @@ -25,6 +26,7 @@ from medcat.tokenizing.tokens import MutableDocument, MutableEntity from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.utils.defaults import COMPONENTS_FOLDER from peft import get_peft_model, LoraConfig, TaskType # It should be safe to do this always, as all other multiprocessing @@ -84,6 +86,23 @@ def create_new(cls, config: ConfigMetaCAT, base_tokenizer: BaseTokenizer, meta_cat = MetaCAT(tokenizer, embeddings=None, config=config) return cls(config, base_tokenizer, meta_cat) + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'MetaCATAddon': + if not isinstance(cnf, ConfigMetaCAT): + raise ValueError(f"Incompatible config: {cnf}") + if model_load_path is not None: + components_folder = os.path.join( + model_load_path, COMPONENTS_FOLDER) + folder_name = cls.get_folder_name_for_addon_and_name( + cls.addon_type, str(cnf.general.category_name)) + load_path = os.path.join(components_folder, folder_name) + return cls.load_existing(cnf, tokenizer, load_path) + # TODO: tokenizer preprocessing for (e.g) BPE tokenizer (see PR #67) + return cls.create_new(cnf, tokenizer, None) + @classmethod def load_existing(cls, cnf: ConfigMetaCAT, base_tokenizer: BaseTokenizer, @@ -100,18 +119,6 @@ def name(self) -> str: def __call__(self, doc: MutableDocument) -> MutableDocument: return self.mc(doc) - @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - # NOTE: cnf is silent init parameter - return [] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - # cls.init_tokenizer(cnf, model_load_path) - return {'base_tokenizer': tokenizer} - def load(self, folder_path: str) -> 'MetaCAT': mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths( folder_path) @@ -169,8 +176,10 @@ def serialise_to(self, folder_path: str) -> None: @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'MetaCATAddon': - # NOTE: model load path sent by kwargs - return cls.load_existing(load_path=folder_path, **init_kwargs) + return cls.load_existing( + load_path=folder_path, + cnf=init_kwargs['cnf'], + base_tokenizer=init_kwargs['tokenizer']) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index bb6902a61..37d91f253 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -2,7 +2,7 @@ import logging import os import random -from typing import Optional, Any +from typing import Optional from sklearn.utils import compute_class_weight import torch @@ -18,7 +18,7 @@ from medcat.cdb import CDB from medcat.vocab import Vocab -from medcat.config import Config +from medcat.config.config import Config, ComponentConfig from medcat.config.config_rel_cat import ConfigRelCAT from medcat.storage.serialisers import deserialise from medcat.storage.serialisables import SerialisingStrategy @@ -32,6 +32,7 @@ from medcat.components.addons.relation_extraction.rel_dataset import RelData from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer from medcat.tokenizing.tokens import MutableDocument +from medcat.utils.defaults import COMPONENTS_FOLDER logger = logging.getLogger(__name__) @@ -54,6 +55,20 @@ def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer, return cls(config, RelCAT(base_tokenizer, cdb, config=config, init_model=True)) + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'RelCATAddon': + if not isinstance(cnf, ConfigRelCAT): + raise ValueError(f"Incompatible config: {cnf}") + config = cnf + if model_load_path is not None: + load_path = os.path.join(model_load_path, COMPONENTS_FOLDER, + cls.NAME_PREFIX + cls.addon_type) + return cls.load_existing(config, tokenizer, cdb, load_path) + return cls.create_new(config, tokenizer, cdb) + @classmethod def load_existing(cls, cnf: ConfigRelCAT, base_tokenizer: BaseTokenizer, @@ -70,21 +85,6 @@ def serialise_to(self, folder_path: str) -> None: os.mkdir(folder_path) self._rel_cat.save(folder_path) - @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - # NOTE: cnf is silent init parameter - return [] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - # cls.init_tokenizer(cnf, model_load_path) - return { - 'base_tokenizer': tokenizer, - "cdb": cdb - } - @property def name(self) -> str: return str(self.addon_type) @@ -95,7 +95,12 @@ def name(self) -> str: def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'RelCATAddon': # NOTE: model load path sent by kwargs - return cls.load_existing(load_path=folder_path, **init_kwargs) + return cls.load_existing( + load_path=folder_path, + base_tokenizer=init_kwargs['tokenizer'], + cnf=init_kwargs['cnf'], + cdb=init_kwargs['cdb'], + ) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL @@ -232,7 +237,7 @@ def load(cls, load_path: str = "./") -> "RelCAT": rel_cat = RelCAT( # NOTE: this is a throaway tokenizer just for registrations - create_tokenizer(cdb.config.general.nlp.provider), + create_tokenizer(cdb.config.general.nlp.provider, cdb.config), cdb=cdb, config=component.relcat_config, task=component.task) rel_cat.device = device rel_cat.component = component @@ -883,7 +888,8 @@ def predict_text_with_anns(self, text: str, annotations: list[dict] Doc: spacy doc with the relations. """ # NOTE: This runs not an empty language, but the specified one - base_tokenizer = create_tokenizer(self.cdb.config.general.nlp.provider) + base_tokenizer = create_tokenizer( + self.cdb.config.general.nlp.provider, self.cdb.config) doc = base_tokenizer(text) for ann in annotations: diff --git a/medcat-v2/medcat/components/linking/context_based_linker.py b/medcat-v2/medcat/components/linking/context_based_linker.py index 82f3db08e..7142a24fd 100644 --- a/medcat-v2/medcat/components/linking/context_based_linker.py +++ b/medcat-v2/medcat/components/linking/context_based_linker.py @@ -1,6 +1,6 @@ import random import logging -from typing import Iterator, Optional, Union, Any +from typing import Iterator, Optional, Union from medcat.components.types import CoreComponentType, AbstractCoreComponent from medcat.tokenizing.tokens import MutableEntity, MutableDocument @@ -8,7 +8,7 @@ ContextModel, PerDocumentTokenCache) from medcat.cdb import CDB from medcat.vocab import Vocab -from medcat.config import Config +from medcat.config.config import Config, ComponentConfig from medcat.utils.defaults import StatusTypes as ST from medcat.utils.postprocessing import create_main_ann from medcat.tokenizing.tokenizers import BaseTokenizer @@ -245,11 +245,8 @@ def train(self, cui: str, cui, entity, doc, per_doc_valid_token_cache, negative, names) @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [cdb, vocab, cdb.config] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'Linker': + return cls(cdb, vocab, cdb.config) diff --git a/medcat-v2/medcat/components/linking/no_action_linker.py b/medcat-v2/medcat/components/linking/no_action_linker.py index 14a599ba2..fe14cce86 100644 --- a/medcat-v2/medcat/components/linking/no_action_linker.py +++ b/medcat-v2/medcat/components/linking/no_action_linker.py @@ -1,10 +1,11 @@ -from typing import Any, Optional +from typing import Optional from medcat.components.types import CoreComponentType, AbstractCoreComponent from medcat.tokenizing.tokens import MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.cdb.cdb import CDB from medcat.vocab import Vocab +from medcat.config.config import ComponentConfig class NoActionLinker(AbstractCoreComponent): @@ -17,11 +18,8 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: return doc @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'NoActionLinker': + return cls() diff --git a/medcat-v2/medcat/components/linking/two_step_context_based_linker.py b/medcat-v2/medcat/components/linking/two_step_context_based_linker.py index e5be1410e..005c01e20 100644 --- a/medcat-v2/medcat/components/linking/two_step_context_based_linker.py +++ b/medcat-v2/medcat/components/linking/two_step_context_based_linker.py @@ -7,7 +7,7 @@ from medcat.cdb.cdb import CDB from medcat.vocab import Vocab -from medcat.config.config import Config, SerialisableBaseModel +from medcat.config.config import Config, SerialisableBaseModel, ComponentConfig from medcat.utils.defaults import StatusTypes as ST from medcat.utils.matutils import sigmoid from medcat.utils.config_utils import temp_changed_config @@ -255,14 +255,11 @@ def train(self, cui: str, per_doc_valid_token_cache=pdc) @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [cdb, vocab, cdb.config] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'TwoStepLinker': + return cls(cdb, vocab, cdb.config) @property def two_step_config(self) -> 'TwoStepLinkerConfig': diff --git a/medcat-v2/medcat/components/ner/dict_based_ner.py b/medcat-v2/medcat/components/ner/dict_based_ner.py index 1d5ee0643..83a041e9d 100644 --- a/medcat-v2/medcat/components/ner/dict_based_ner.py +++ b/medcat-v2/medcat/components/ner/dict_based_ner.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional import logging from medcat.tokenizing.tokens import MutableDocument @@ -8,6 +8,7 @@ from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.vocab import Vocab from medcat.cdb import CDB +from medcat.config.config import ComponentConfig from ahocorasick import Automaton import medcat @@ -100,11 +101,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: return doc @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [tokenizer, cdb] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'NER': + return cls(tokenizer, cdb) diff --git a/medcat-v2/medcat/components/ner/trf/transformers_ner.py b/medcat-v2/medcat/components/ner/trf/transformers_ner.py index bae6ea1f7..459ca88b2 100644 --- a/medcat-v2/medcat/components/ner/trf/transformers_ner.py +++ b/medcat-v2/medcat/components/ner/trf/transformers_ner.py @@ -4,7 +4,7 @@ import datasets import torch from datetime import datetime -from typing import Iterable, Iterator, Optional, Union, Callable, Any +from typing import Iterable, Iterator, Optional, Union, Callable from typing import cast import inspect from functools import partial @@ -15,6 +15,7 @@ from medcat.utils.postprocessing import create_main_ann from medcat.utils.hasher import Hasher from medcat.config.config_transformers_ner import ConfigTransformersNER +from medcat.config.config import ComponentConfig from medcat.components.ner.trf.tokenizer import ( TransformersTokenizer) from medcat.utils.ner.metrics import metrics @@ -27,6 +28,7 @@ from medcat.preprocessors.cleaners import NameDescriptor from medcat.components.types import CoreComponentType, AbstractCoreComponent from medcat.vocab import Vocab +from medcat.utils.defaults import COMPONENTS_FOLDER from transformers import ( Trainer, AutoModelForTokenClassification, AutoTokenizer) @@ -63,6 +65,25 @@ def create_new(cls, cdb: CDB, base_tokenizer: BaseTokenizer, config=config, training_arguments=training_arguments, component=comp) + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'TransformersNER': + config = cdb.config.components.ner.custom_cnf + if not isinstance(config, ConfigTransformersNER): + raise ValueError( + "Did not find correct Transformers NER config. " + f"Found: {config}") + # TODO: anywhere to get these? + training_arguments = None + if model_load_path is not None: + load_path = os.path.join( + model_load_path, COMPONENTS_FOLDER, cls.NAME_PREFIX + "ner") + return cls.load_existing(cdb, tokenizer, load_path, + training_arguments, config) + return cls.create_new(cdb, tokenizer, config, training_arguments) + @classmethod def load_existing(cls, cdb: CDB, base_tokenizer: BaseTokenizer, load_path: str, training_arguments=None, @@ -76,17 +97,6 @@ def load_existing(cls, cdb: CDB, base_tokenizer: BaseTokenizer, def get_type(self): return CoreComponentType.ner - @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - # NOTE: TrfNER-specific config is at config.components.ner.custom_cnf - return [] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {'cdb': cdb, 'base_tokenizer': tokenizer} - @property def should_save(self) -> bool: return True @@ -110,7 +120,12 @@ def serialise_to(self, folder_path: str) -> None: @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'TransformersNER': - return cls.load_existing(load_path=folder_path, **init_kwargs) + return cls.load_existing( + load_path=folder_path, + cdb=init_kwargs['cdb'], + base_tokenizer=init_kwargs['tokenizer'], + # from Config.components.ner (of type Ner) + config=init_kwargs['cnf'].custom_cnf) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL diff --git a/medcat-v2/medcat/components/ner/vocab_based_ner.py b/medcat-v2/medcat/components/ner/vocab_based_ner.py index 190de667c..afd12e41e 100644 --- a/medcat-v2/medcat/components/ner/vocab_based_ner.py +++ b/medcat-v2/medcat/components/ner/vocab_based_ner.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional import logging from medcat.tokenizing.tokens import MutableDocument @@ -7,6 +7,7 @@ from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.vocab import Vocab from medcat.cdb import CDB +from medcat.config.config import ComponentConfig logger = logging.getLogger(__name__) @@ -108,11 +109,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: return doc @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [tokenizer, cdb] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'NER': + return cls(tokenizer, cdb) diff --git a/medcat-v2/medcat/components/normalizing/normalizer.py b/medcat-v2/medcat/components/normalizing/normalizer.py index bae6e0334..ce32e8457 100644 --- a/medcat-v2/medcat/components/normalizing/normalizer.py +++ b/medcat-v2/medcat/components/normalizing/normalizer.py @@ -1,9 +1,9 @@ -from typing import Optional, Iterable, Iterator, Any, Union, overload, Literal +from typing import Optional, Iterable, Iterator, Union, overload, Literal import re from medcat.tokenizing.tokens import MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer -from medcat.config.config import Config +from medcat.config.config import Config, ComponentConfig from medcat.vocab import Vocab from medcat.cdb import CDB from medcat.components.types import CoreComponentType, AbstractCoreComponent @@ -222,11 +222,8 @@ def __call__(self, doc: MutableDocument): return doc @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [tokenizer, cdb.config, cdb.token_counts, vocab] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'TokenNormalizer': + return cls(tokenizer, cdb.config, cdb.token_counts, vocab) diff --git a/medcat-v2/medcat/components/tagging/tagger.py b/medcat-v2/medcat/components/tagging/tagger.py index 71f9ac319..66a6af7a1 100644 --- a/medcat-v2/medcat/components/tagging/tagger.py +++ b/medcat-v2/medcat/components/tagging/tagger.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional import re from medcat.config.config import Preprocessing @@ -7,6 +7,7 @@ from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.config.config import ComponentConfig class TagAndSkipTagger(AbstractCoreComponent): @@ -37,12 +38,10 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: return doc - @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [cdb.config.preprocessing] @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'TagAndSkipTagger': + return cls(cdb.config.preprocessing) diff --git a/medcat-v2/medcat/components/types.py b/medcat-v2/medcat/components/types.py index 5112a7cea..77c53c8ed 100644 --- a/medcat-v2/medcat/components/types.py +++ b/medcat-v2/medcat/components/types.py @@ -1,11 +1,13 @@ -from typing import Optional, Protocol, Callable, runtime_checkable, Union, Any +from typing import Optional, Protocol, Callable, runtime_checkable, Union +from typing_extensions import Self from enum import Enum, auto -from medcat.utils.registry import Registry +from medcat.utils.registry import Registry, MedCATRegistryException from medcat.tokenizing.tokens import MutableDocument, MutableEntity from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.config.config import ComponentConfig class CoreComponentType(Enum): @@ -40,34 +42,22 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: pass @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - """Get the init arguments for the component. + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> Self: + """Create a new component or load one off disk if load path presented. - Args: - tokenizer (BaseTokenizer): The tokenizer. - cdb (CDB): The CDB. - vocab (Vocab): The Vocab. - model_load_path (Optional[str]): The model load path (or None). - - Returns: - list[Any]: The list of init arguments. - """ - pass - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - """Get init keyword arguments for the component. + This may raise an exception if the wrong type of config is provided. Args: - tokenizer (BaseTokenizer): The tokenizer. + cnf (ComponentConfig): The config relevant to this components. + tokenizer (BaseTokenizer): The base tokenizer. cdb (CDB): The CDB. vocab (Vocab): The Vocab. - model_load_path (Optional[str]): The model load path (or None). + model_load_path (Optional[str]): Model load path (if present). Returns: - dict[str, Any]: The keywrod arguments. + Self: The new components. """ pass @@ -123,25 +113,29 @@ def train(self, cui: str, _DEFAULT_TAGGERS: dict[str, tuple[str, str]] = { - "default": ("medcat.components.tagging.tagger", "TagAndSkipTagger"), + "default": ("medcat.components.tagging.tagger", + "TagAndSkipTagger.create_new_component"), } _DEFAULT_NORMALIZERS: dict[str, tuple[str, str]] = { "default": ("medcat.components.normalizing.normalizer", - "TokenNormalizer"), + "TokenNormalizer.create_new_component"), } _DEFAULT_NER: dict[str, tuple[str, str]] = { - "default": ("medcat.components.ner.vocab_based_ner", "NER"), - "dict": ("medcat.components.ner.dict_based_ner", "NER"), + "default": ("medcat.components.ner.vocab_based_ner", + "NER.create_new_component"), + "dict": ("medcat.components.ner.dict_based_ner", + "NER.create_new_component"), "transformers_ner": ("medcat.components.ner.trf.transformers_ner", - "TransformersNER.create_new"), + "TransformersNER.create_new_component"), } _DEFAULT_LINKING: dict[str, tuple[str, str]] = { - "default": ("medcat.components.linking.context_based_linker", "Linker"), + "default": ("medcat.components.linking.context_based_linker", + "Linker.create_new_component"), "no_action": ("medcat.components.linking.no_action_linker", - "NoActionLinker"), + "NoActionLinker.create_new_component"), "medcat2_two_step_linker": ( "medcat.components.linking.two_step_context_based_linker", - "TwoStepLinker") + "TwoStepLinker.create_new_component") } @@ -156,16 +150,19 @@ def train(self, cui: str, lazy_defaults=_DEFAULT_LINKING), } +CompClass = Callable[[ComponentConfig, BaseTokenizer, + CDB, Vocab, Optional[str]], CoreComponent] + def register_core_component(comp_type: CoreComponentType, comp_name: str, - comp_clazz: Callable[..., CoreComponent]) -> None: + comp_clazz: CompClass) -> None: """Register a new core component. Args: comp_type (CoreComponentType): The component type. comp_name (str): The component name. - comp_clazz (Callable[..., CoreComponent]): The component creator. + comp_clazz (ComplClass): The component creator. """ _CORE_REGISTRIES[comp_type].register(comp_name, comp_clazz) @@ -183,7 +180,7 @@ def get_core_registry(comp_type: CoreComponentType) -> Registry[CoreComponent]: def get_component_creator(comp_type: CoreComponentType, - comp_name: str) -> Callable[..., CoreComponent]: + comp_name: str) -> CompClass: """Get the component creator. Args: @@ -196,22 +193,30 @@ def get_component_creator(comp_type: CoreComponentType, return get_core_registry(comp_type).get_component(comp_name) -def create_core_component(comp_type: CoreComponentType, - comp_name: str, - *args, **kwargs) -> CoreComponent: +def create_core_component( + comp_type: CoreComponentType, comp_name: str, cnf: ComponentConfig, + tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, + model_load_path: Optional[str]) -> CoreComponent: """Creat a core component. - All `*args` and `**kwrags` are passed directly to the component creator. - Args: comp_type (CoreComponentType): The component type. comp_name (str): The name of the component. + cnf (ComponentConfig): The config to be passed to creator. + tokenizer (BaseTokenizer): The tokenizer to be passed to creator. + cdb (CDB): The CDB to be passed to creator. + vocab (Vocab): The vocab to be passed to the creator. + model_load_path (Optional[str]): The optional load path to be passed + to the creators. Returns: CoreComponent: The resulting / created component. """ - comp_getter = get_core_registry(comp_type).get_component(comp_name) - return comp_getter(*args, **kwargs) + try: + comp_getter = get_core_registry(comp_type).get_component(comp_name) + except MedCATRegistryException as err: + raise MedCATRegistryException(f"With comp type '{comp_type}'") from err + return comp_getter(cnf, tokenizer, cdb, vocab, model_load_path) def get_registered_components(comp_type: CoreComponentType diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 75bde587f..5fcce13b3 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -140,22 +140,6 @@ class ComponentConfig(DirtiableBaseModel): , , ) By default, only the 'default' component is registered. """ - init_args: list = Field(default_factory=list, exclude=True) - """These are the positional arguments required to construct the component. - - For default components, these will be automatically filled. However, if a - custom component is used, these would need to be set manually. - """ - init_kwargs: dict = Field(default_factory=dict, exclude=True) - """These are the keyword arguments required to construct the component. - - For default components, these will be automatically filled. However, if a - custom component is used, these would need to be set manually. - """ - - @classmethod - def ignore_attrs(cls): - return ["init_args", "init_kwargs"] class NLPConfig(SerialisableBaseModel): @@ -179,18 +163,6 @@ class NLPConfig(SerialisableBaseModel): NB! For these changes to take effect, the pipe would need to be recreated. """ - init_args: list = Field(default_factory=list, exclude=True) - """These are the positional arguments required to construct the component. - - For default components, these will be automatically filled. However, if a - custom component is used, these would need to be set manually. - """ - init_kwargs: dict = Field(default_factory=dict, exclude=True) - """These are the keyword arguments required to construct the component. - - For default components, these will be automatically filled. However, if a - custom component is used, these would need to be set manually. - """ # NOTE: this will allow for more config entries # since we don't know what other implementations may require diff --git a/medcat-v2/medcat/pipeline/pipeline.py b/medcat-v2/medcat/pipeline/pipeline.py index b431a1755..12f3b430c 100644 --- a/medcat-v2/medcat/pipeline/pipeline.py +++ b/medcat-v2/medcat/pipeline/pipeline.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, Iterable, Union +from typing import Optional, Iterable, Union import logging import os @@ -19,9 +19,6 @@ from medcat.config.config import ComponentConfig from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.config.config_rel_cat import ConfigRelCAT -from medcat.utils.default_args import (set_tokenizer_defaults, - set_components_defaults, - set_addon_defaults) logger = logging.getLogger(__name__) @@ -55,12 +52,8 @@ def __call__(self, text: str) -> MutableDocument: return doc @classmethod - def get_init_args(cls, config: Config) -> list[Any]: - return [] - - @classmethod - def get_init_kwargs(cls, config: Config) -> dict[str, Any]: - return {} + def create_new_tokenizer(cls, config: Config) -> 'DelegatingTokenizer': + raise ValueError("Initialise the delegating tokenizer with its initialiser") def get_doc_class(self) -> type[MutableDocument]: return self.tokenizer.get_doc_class() @@ -80,18 +73,14 @@ def __init__(self, cdb: CDB, vocab: Optional[Vocab], model_load_path: Optional[str], # NOTE: upon reload, old pipe can be useful old_pipe: Optional['Pipeline'] = None): - # NOTE: this only sets the default arguments if the - # default tokenizer is used - set_tokenizer_defaults(cdb.config) self.cdb = cdb + # NOTE: Vocab is None in case of DeID models and thats fine then, + # but it should be non-None otherwise + self.vocab: Vocab = vocab # type: ignore self.config = self.cdb.config self._tokenizer = self._init_tokenizer() self._components: list[CoreComponent] = [] self._addons: list[AddonComponent] = [] - set_components_defaults(cdb, vocab, self._tokenizer, model_load_path) - set_addon_defaults(cdb, vocab, self._tokenizer, model_load_path) - # NOTE: this only sets the default arguments if the - # a specific default component is used self._init_components(model_load_path, old_pipe) @property @@ -108,22 +97,22 @@ def tokenizer_with_tag(self) -> BaseTokenizer: def _init_tokenizer(self) -> BaseTokenizer: nlp_cnf = self.config.general.nlp try: - return create_tokenizer(nlp_cnf.provider, *nlp_cnf.init_args, - **nlp_cnf.init_kwargs) + return create_tokenizer(nlp_cnf.provider, self.config) except TypeError as type_error: if nlp_cnf.provider == 'spacy': raise type_error raise IncorrectArgumentsForTokenizer( nlp_cnf.provider) from type_error - def _init_component(self, comp_type: CoreComponentType) -> CoreComponent: + def _init_component(self, comp_type: CoreComponentType, + model_load_path: Optional[str]) -> CoreComponent: comp_config: ComponentConfig = getattr(self.config.components, comp_type.name) comp_name = comp_config.comp_name try: - comp = create_core_component(comp_type, comp_name, - *comp_config.init_args, - **comp_config.init_kwargs) + comp = create_core_component( + comp_type, comp_name, comp_config, self.tokenizer, self.cdb, + self.vocab, model_load_path) except TypeError as type_error: if comp_name == 'default': raise type_error @@ -164,11 +153,13 @@ def _load_saved_core_component(self, cct_name: str, comp_folder_path: str ) -> CoreComponent: logger.info("Using loaded component for '%s' for", cct_name) cnf: ComponentConfig = getattr(self.config.components, cct_name) - if cnf.init_args: - raise IncorrectCoreComponent( - "Manually serialisable core components need to define all " - "their arguments as keyword arguments") - comp = deserialise(comp_folder_path, **cnf.init_kwargs) + comp = deserialise( + comp_folder_path, + # NOTE: the following are keyword arguments used + # for manual deserialisation + cnf=cnf, tokenizer=self.tokenizer, cdb=self.cdb, + vocab=self.vocab, model_load_path=os.path.dirname( + os.path.dirname(comp_folder_path))) if not isinstance(comp, CoreComponent): raise IncorrectFolderUponLoad( f"Did not find a CoreComponent at {comp_folder_path} " @@ -191,7 +182,8 @@ def _init_components(self, model_load_path: Optional[str], comp = self._load_saved_core_component( cct_name, loaded_core_component_paths.pop(cct_name)) else: - comp = self._init_component(CoreComponentType[cct_name]) + comp = self._init_component( + CoreComponentType[cct_name], model_load_path) self._components.append(comp) for addon_cnf in self.config.components.addons: addon = self._init_addon( @@ -224,12 +216,14 @@ def _get_loaded_addon_path( def _load_addon(self, cnf: ComponentConfig, load_from: str ) -> AddonComponent: - if cnf.init_args: - raise IncorrectAddonLoaded( - "Manually serialisable addons need to define all their init " - "arguments as keyword arguments") # config is implicitly required argument - addon = deserialise(load_from, **cnf.init_kwargs, cnf=cnf) + model_load_path = os.path.dirname(os.path.dirname(load_from)) + addon = deserialise( + load_from, + # NOTE: the following are keyword arguments used + # for manual deserialisation + cnf=cnf, tokenizer=self.tokenizer, cdb=self.cdb, + vocab=self.vocab, model_load_path=model_load_path) if not isinstance(addon, AddonComponent): raise IncorrectAddonLoaded( f"Expected {AddonComponent.__name__}, but goet " @@ -258,8 +252,9 @@ def _init_addon( cnf, loaded_addon_component_paths) if loaded_path: return self._load_addon(cnf, loaded_path) - return create_addon(cnf.comp_name, cnf, - *cnf.init_args, **cnf.init_kwargs) + return create_addon( + cnf.comp_name, cnf=cnf, tokenizer=self.tokenizer, cdb=self.cdb, + vocab=self.vocab, model_load_path=None) def get_doc(self, text: str) -> MutableDocument: """Get the document for this text. @@ -357,11 +352,7 @@ class IncorrectArgumentsForTokenizer(TypeError): def __init__(self, provider: str): super().__init__( - f"Incorrect arguments for tokenizer ({provider}). Did you forget " - "to set `config.general.nlp.init_args` or " - "`config.general.nlp.init_kwargs`? When using a custom tokenizer, " - "you need to specify the arguments required for construction " - "manually.") + f"Incorrect arguments for tokenizer ({provider}).") class IncorrectArgumentsForComponent(TypeError): @@ -369,11 +360,7 @@ class IncorrectArgumentsForComponent(TypeError): def __init__(self, comp_type: CoreComponentType, comp_name: str): super().__init__( f"Incorrect arguments for core component {comp_type.name} " - f"({comp_name}). Did you forget to set " - f"`config.components.{comp_type.name}.init_args` and/or " - f"`config.components.{comp_type.name}.init_kwargs`? " - "When using a custom component, you need to specify the arguments" - "required or construction manually.") + f"({comp_name}).") class IncorrectCoreComponent(ValueError): diff --git a/medcat-v2/medcat/storage/serialisables.py b/medcat-v2/medcat/storage/serialisables.py index 65e50c09d..e7e04492b 100644 --- a/medcat-v2/medcat/storage/serialisables.py +++ b/medcat-v2/medcat/storage/serialisables.py @@ -169,11 +169,32 @@ def __eq__(self, other: Any) -> bool: class ManualSerialisable(Serialisable, Protocol): def serialise_to(self, folder_path: str) -> None: + """Serialise to a folder. + + Args: + folder_path (str): The folder to serialise to. + """ pass @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'ManualSerialisable': + """Deserialise from a specifc path. + + The init keyword arguments are generally: + - cnf: The config relevant to the components + - tokenizer (BaseTokenizer): The base tokenizer for the model + - cdb (CDB): The CDB for the model + - vocab (Vocab): The Vocab for the model + - model_load_path (Optional[str]): The model load path, + but not the component load path + + Args: + folder_path (str): The path to deserialsie form. + + Returns: + ManualSerialisable: The deserialised object. + """ pass diff --git a/medcat-v2/medcat/storage/serialisers.py b/medcat-v2/medcat/storage/serialisers.py index 03077e002..176c2dea6 100644 --- a/medcat-v2/medcat/storage/serialisers.py +++ b/medcat-v2/medcat/storage/serialisers.py @@ -342,6 +342,13 @@ def deserialise(folder_path: str, """Deserialise contents of a folder. Extra init keyword arguments can be provided if needed. + These are generally: + - cnf: The config relevant to the components + - tokenizer (BaseTokenizer): The base tokenizer for the model + - cdb (CDB): The CDB for the model + - vocab (Vocab): The Vocab for the model + - model_load_path (Optional[str]): The model load path, + but not the component load path This method finds the serialiser to be used based on the files on disk. diff --git a/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py b/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py index 058e14290..874f51dcb 100644 --- a/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py +++ b/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py @@ -360,12 +360,8 @@ def __call__(self, text: str) -> MutableDocument: return doc @classmethod - def get_init_args(cls, config: Config) -> list[Any]: - return [] - - @classmethod - def get_init_kwargs(cls, config: Config) -> dict[str, Any]: - return {} + def create_new_tokenizer(cls, config: Config) -> 'RegexTokenizer': + return cls() def get_doc_class(self) -> Type[MutableDocument]: return Document diff --git a/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py b/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py index 1017e2270..71634b248 100644 --- a/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py +++ b/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, cast, Any, Type +from typing import Optional, Callable, cast, Type import re import os import shutil @@ -85,18 +85,13 @@ def __call__(self, text: str) -> MutableDocument: return Document(self._nlp(text)) @classmethod - def get_init_args(cls, config: Config) -> list[Any]: + def create_new_tokenizer(cls, config: Config) -> 'SpacyTokenizer': nlp_cnf = config.general.nlp - return [ - nlp_cnf.modelname, + return cls(nlp_cnf.modelname, nlp_cnf.disabled_components, config.general.diacritics, config.preprocessing.max_document_length, - ] - - @classmethod - def get_init_kwargs(cls, config: Config) -> dict[str, Any]: - return {"stopwords": config.preprocessing.stopwords} + stopwords=config.preprocessing.stopwords) def get_doc_class(self) -> Type[MutableDocument]: return Document diff --git a/medcat-v2/medcat/tokenizing/tokenizers.py b/medcat-v2/medcat/tokenizing/tokenizers.py index c359b7e70..834a7ec34 100644 --- a/medcat-v2/medcat/tokenizing/tokenizers.py +++ b/medcat-v2/medcat/tokenizing/tokenizers.py @@ -1,4 +1,5 @@ -from typing import Protocol, Type, Any, Callable, runtime_checkable +from typing import Protocol, Type, Callable, runtime_checkable +from typing_extensions import Self import logging from medcat.config import Config @@ -47,11 +48,7 @@ def __call__(self, text: str) -> MutableDocument: pass @classmethod - def get_init_args(cls, config: Config) -> list[Any]: - pass - - @classmethod - def get_init_kwargs(cls, config: Config) -> dict[str, Any]: + def create_new_tokenizer(cls, config: Config) -> Self: pass def get_doc_class(self) -> Type[MutableDocument]: @@ -106,15 +103,18 @@ def load_internals_from(self, folder_path: str) -> bool: _DEFAULT_TOKENIZING: dict[str, tuple[str, str]] = { - "regex": ("medcat.tokenizing.regex_impl.tokenizer", "RegexTokenizer"), - "spacy": ("medcat.tokenizing.spacy_impl.tokenizers", "SpacyTokenizer") + "regex": ("medcat.tokenizing.regex_impl.tokenizer", + "RegexTokenizer.create_new_tokenizer"), + "spacy": ("medcat.tokenizing.spacy_impl.tokenizers", + "SpacyTokenizer.create_new_tokenizer") } _TOKENIZERS_REGISTRY = Registry(BaseTokenizer, # type: ignore lazy_defaults=_DEFAULT_TOKENIZING) -def get_tokenizer_creator(tokenizer_name: str) -> Callable[..., BaseTokenizer]: +def get_tokenizer_creator(tokenizer_name: str + ) -> Callable[[Config], BaseTokenizer]: """Get the creator method for the tokenizer. While this is generally just the class instance (i.e refers @@ -124,23 +124,22 @@ def get_tokenizer_creator(tokenizer_name: str) -> Callable[..., BaseTokenizer]: tokenizer_name (str): The name of the tokenizer. Returns: - Callable[..., BaseTokenizer]: The creator for the tokenizer. + Callable[[Config], BaseTokenizer]: The creator for the tokenizer. """ return _TOKENIZERS_REGISTRY.get_component(tokenizer_name) -def create_tokenizer(tokenizer_name: str, *args, **kwargs) -> BaseTokenizer: +def create_tokenizer(tokenizer_name: str, config: Config) -> BaseTokenizer: """Create the tokenizer given the init arguments. - The `*args`, and `**kwargs` will be directly passed to the creator. - Args: tokenizer_name (str): The tokenizer name. + config (Config): The config to be passed to the constructor. Returns: BaseTokenizer: The created tokenizer. """ - return _TOKENIZERS_REGISTRY.get_component(tokenizer_name)(*args, **kwargs) + return _TOKENIZERS_REGISTRY.get_component(tokenizer_name)(config) def list_available_tokenizers() -> list[tuple[str, str]]: diff --git a/medcat-v2/medcat/utils/default_args.py b/medcat-v2/medcat/utils/default_args.py deleted file mode 100644 index 59a16d8d1..000000000 --- a/medcat-v2/medcat/utils/default_args.py +++ /dev/null @@ -1,130 +0,0 @@ -"""This module exists purely to set the default arguments -in the config for the default tokenizer and the default -components creation. -""" -from typing import Optional - -from medcat.components.types import get_component_creator, CoreComponentType -from medcat.components.addons.addons import get_addon_creator -from medcat.tokenizing.tokenizers import BaseTokenizer, get_tokenizer_creator -from medcat.config.config import ComponentConfig -from medcat.config import Config -from medcat.cdb import CDB -from medcat.vocab import Vocab - -import logging - - -logger = logging.getLogger(__name__) - - -def set_tokenizer_defaults(config: Config) -> None: - """Set the default init arguments for the tokenizer. - - This generally uses the `get_init_args` and `get_init_kwargs` - method bound to the tokenizer class. - - Args: - config (Config): The same (modified) config. - """ - nlp_cnf = config.general.nlp - tok_cls = get_tokenizer_creator(nlp_cnf.provider) - if hasattr(tok_cls, 'get_init_args'): - nlp_cnf.init_args = tok_cls.get_init_args(config) - else: - logger.warning( - "Could not set init arguments for tokenizer (%s). " - "You generally need to specify these with the class method " - "get_init_args(Config) -> list[Any].", nlp_cnf.provider) - if hasattr(tok_cls, 'get_init_kwargs'): - nlp_cnf.init_kwargs = tok_cls.get_init_kwargs(config) - else: - logger.warning( - "Could not set init keyword arguments for tokenizer (%s). " - "You generally need to specify these with the class method " - "get_init_kwargs(Config) -> dict[str, Any].", nlp_cnf.provider) - - -def set_components_defaults(cdb: CDB, vocab: Optional[Vocab], - tokenizer: BaseTokenizer, - model_load_path: Optional[str]): - """Set the default init arguments for the componts. - - This generally uses the `get_init_args` and `get_init_kwargs` - method bound to the tokenizer class. - - Args: - cdb (CDB): The CDB. - vocab (Optional[Vocab]): The Vocab. - tokenizer (BaseTokenizer): The tokenizer. - model_load_path (Optional[str]): The model load path. - """ - for comp_name, comp_cnf in cdb.config.components: - if not isinstance(comp_cnf, ComponentConfig): - # e.g ignore order - continue - comp_cls = get_component_creator(CoreComponentType[comp_name], - comp_cnf.comp_name) - if not isinstance(comp_cls, type): - # i.e get CompCls from CompCls.create_new - comp_cls = comp_cls.__self__ # type: ignore - if hasattr(comp_cls, 'get_init_args'): - comp_cnf.init_args = comp_cls.get_init_args(tokenizer, cdb, vocab, - model_load_path) - else: - logger.warning( - "The component %s (%s) does not define init arguments. " - "You generally need to specify these with the class method " - "get_init_args(BaseTokenizer, CDB, Vocab) -> list[Any]", - comp_name, comp_cnf.comp_name) - if hasattr(comp_cls, 'get_init_kwargs'): - comp_cnf.init_kwargs = comp_cls.get_init_kwargs( - tokenizer, cdb, vocab, model_load_path) - else: - logger.warning( - "The component %s (%s) does not define init keyword arguments." - " You generally need to specify these with the class method " - "get_init_kwargs(BaseTokenizer, CDB, Vocab) -> dict[str, Any]", - comp_name, comp_cnf.comp_name) - - -def set_addon_defaults(cdb: CDB, vocab: Optional[Vocab], - tokenizer: BaseTokenizer, - model_load_path: Optional[str]): - """Set default init arguments for addons. - - Args: - cdb (CDB): The CDB. - vocab (Optional[Vocab]): The Vocab. - tokenizer (BaseTokenizer): The tokenizer. - model_load_path (Optional[str]): The model load path. - """ - for addon_cnf in cdb.config.components.addons: - addon_cls = get_addon_creator(addon_cnf.comp_name) - if not isinstance(addon_cls, type): - # i.e get MetaCAT from MetaCAT.create_new - addon_cls = addon_cls.__self__ # type: ignore - if hasattr(addon_cls, 'get_init_args'): - addon_cnf.init_args = addon_cls.get_init_args( - tokenizer, cdb, vocab, model_load_path) - else: - logger.warning( - "The addon '%s' does not define init arguments. " - "You generally need to specify these with the class method " - "get_init_args(BaseTokenizer, CDB, Vocab) -> list[Any]", - addon_cnf.comp_name) - if hasattr(addon_cls, 'get_init_kwargs'): - addon_cnf.init_kwargs = addon_cls.get_init_kwargs( - tokenizer, cdb, vocab, model_load_path) - else: - logger.warning( - "The component '%s' does not define init keyword arguments." - " You generally need to specify these with the class method " - "get_init_kwargs(BaseTokenizer, CDB, Vocab) -> dict[str, Any]", - addon_cnf.comp_name) - - -class OptionalPartNotInstalledException(ValueError): - - def __init__(self, *args): - super().__init__(*args) diff --git a/medcat-v2/medcat/utils/legacy/convert_rel_cat.py b/medcat-v2/medcat/utils/legacy/convert_rel_cat.py index e4ddb45d5..4d077c6f4 100644 --- a/medcat-v2/medcat/utils/legacy/convert_rel_cat.py +++ b/medcat-v2/medcat/utils/legacy/convert_rel_cat.py @@ -71,11 +71,5 @@ def get_rel_cat_from_old(cdb: CDB, old_path: str, tokenizer: BaseTokenizer from medcat.config import Config cdb = CDB(Config()) cdb.config.general.nlp.provider = 'spacy' - rc = get_rel_cat_from_old(cdb, sys.argv[1], - create_tokenizer( - "spacy", - "en_core_web_md", # model name - cdb.config.general.nlp.disabled_components, - False, # diacritics - cdb.config.preprocessing.max_document_length - )) + rc = get_rel_cat_from_old( + cdb, sys.argv[1], create_tokenizer("spacy", cdb.config)) diff --git a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat.py b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat.py index cd78ae130..37075777b 100644 --- a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat.py +++ b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat.py @@ -68,13 +68,7 @@ def setUpClass(cls) -> None: cls.config_rel_cat: ConfigRelCAT = config cls.base_tokenizer = create_tokenizer( - cdb.config.general.nlp.provider, - # NOTE: the following only required for spacy models - # but the saved model should be a spacy model - cdb.config.general.nlp.modelname, - cdb.config.general.nlp.disabled_components, - cdb.config.general.diacritics, - cdb.config.preprocessing.max_document_length) + cdb.config.general.nlp.provider, cdb.config) cls.rel_cat: RelCAT = RelCAT(cls.base_tokenizer, cdb, config=config, init_model=True) diff --git a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py index 73975c916..89cb68c01 100644 --- a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py +++ b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py @@ -51,13 +51,7 @@ def setUpClass(cls): cnf = Config() cnf.general.nlp.provider = 'spacy' cdb = CDB(cnf) - tokenizer = create_tokenizer( - "spacy", - "en_core_web_md", # model name - cdb.config.general.nlp.disabled_components, - False, # diacritics - cdb.config.preprocessing.max_document_length - ) + tokenizer = create_tokenizer("spacy", cdb.config) rc = get_rel_cat_from_old(cdb, cls._unpacked_v1_rel_cat_path, tokenizer) # add to model cat = CAT.load_model_pack(UNPACKED_EXAMPLE_MODEL_PACK_PATH) diff --git a/medcat-v2/tests/components/addons/test_addons.py b/medcat-v2/tests/components/addons/test_addons.py index 72c8e5e48..1d4340dfe 100644 --- a/medcat-v2/tests/components/addons/test_addons.py +++ b/medcat-v2/tests/components/addons/test_addons.py @@ -45,6 +45,12 @@ def get_output_key_val(self, ent: MutableEntity ) -> tuple[str, dict[str, Any]]: return '', {} + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'FakeAddonNoInit': + return cls(cnf) + class FakeAddonWithInit: name = 'fake_addon_w_init' @@ -60,14 +66,10 @@ def __call__(self, doc): return doc @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [tokenizer, cdb] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'FakeAddonWithInit': + return cls(cnf, tokenizer, cdb) @property def should_save(self) -> bool: @@ -93,21 +95,26 @@ class AddonsRegistrationTests(unittest.TestCase): @classmethod def setUpClass(cls): - addons.register_addon(cls.addon_cls.name, cls.addon_cls) + cls.addon_creator = cls.addon_cls.create_new_component + addons.register_addon(cls.addon_cls.name, cls.addon_creator) @classmethod def tearDownClass(cls): addons._ADDON_REGISTRY.unregister_all_components() addons._ADDON_REGISTRY._lazy_defaults.update(addons._DEFAULT_ADDONS) + def creator_args(self): + return ( + ComponentConfig(comp_name=self.addon_cls.name), + None, None, None, None) + def test_has_registration(self): - addon_cls = addons.get_addon_creator(self.addon_cls.name) - self.assertIs(addon_cls, self.addon_cls) + addon_creator = addons.get_addon_creator(self.addon_cls.name) + self.assertIs(addon_creator, self.addon_creator) def test_can_create_empty_addon(self): addon = addons.create_addon( - self.addon_cls.name, ComponentConfig( - comp_name=self.addon_cls.name)) + self.addon_cls.name, *self.creator_args()) self.assertIsInstance(addon, self.addon_cls) @@ -117,7 +124,8 @@ class AddonUsageTests(unittest.TestCase): @classmethod def setUpClass(cls): - addons.register_addon(cls.addon_cls.name, cls.addon_cls) + cls.addon_creator = cls.addon_cls.create_new_component + addons.register_addon(cls.addon_cls.name, cls.addon_creator) cls.cnf = Config() cls.cdb = CDB(cls.cnf) cls.vocab = Vocab() diff --git a/medcat-v2/tests/components/helper.py b/medcat-v2/tests/components/helper.py index ed15329c7..368a91473 100644 --- a/medcat-v2/tests/components/helper.py +++ b/medcat-v2/tests/components/helper.py @@ -1,8 +1,7 @@ -from typing import runtime_checkable, Type +from typing import runtime_checkable, Type, Callable from medcat.components import types from medcat.config.config import Config, ComponentConfig -from medcat.utils.default_args import set_components_defaults class FakeCDB: @@ -31,10 +30,7 @@ class ComponentInitTests: # these need to be specified when overriding comp_type: types.CoreComponentType default_cls: Type[types.BaseComponent] - - @classmethod - def set_def_args(cls, cdb: FakeCDB, vocab: FVocab, tokenizer: FTokenizer): - set_components_defaults(cdb, vocab, tokenizer, None) + default_creator: Callable[..., types.BaseComponent] @classmethod def setUpClass(cls): @@ -42,7 +38,6 @@ def setUpClass(cls): cls.fcdb = FakeCDB(cls.cnf) cls.fvocab = FVocab() cls.vtokenizer = FTokenizer() - cls.set_def_args(cls.fcdb, cls.fvocab, cls.vtokenizer) cls.comp_cnf: ComponentConfig = getattr( cls.cnf.components, cls.comp_type.name) @@ -51,13 +46,12 @@ def test_has_default(self): self.assertEqual(len(avail_components), self.expected_def_components) name, cls_name = avail_components[0] self.assertEqual(name, self.default) - self.assertIs(cls_name, self.default_cls.__name__) + self.assertIs(cls_name, self.default_creator.__name__) def test_can_create_def_component(self): component = types.create_core_component( self.comp_type, - self.default, *self.comp_cnf.init_args, - **self.comp_cnf.init_kwargs) + self.default, self.cnf, self.vtokenizer, self.fcdb, self.fvocab, None) self.assertIsInstance(component, runtime_checkable(types.BaseComponent)) self.assertIsInstance(component, self.default_cls) diff --git a/medcat-v2/tests/components/linking/test_context_based_linker.py b/medcat-v2/tests/components/linking/test_context_based_linker.py index 5246b96cd..119cc7278 100644 --- a/medcat-v2/tests/components/linking/test_context_based_linker.py +++ b/medcat-v2/tests/components/linking/test_context_based_linker.py @@ -38,6 +38,7 @@ class LinkingInitTests(ComponentInitTests, unittest.TestCase): expected_def_components = 3 comp_type = types.CoreComponentType.linking default_cls = context_based_linker.Linker + default_creator = context_based_linker.Linker.create_new_component module = context_based_linker @classmethod diff --git a/medcat-v2/tests/components/ner/test_vocab_based_ner.py b/medcat-v2/tests/components/ner/test_vocab_based_ner.py index cab810f0d..5a971b8aa 100644 --- a/medcat-v2/tests/components/ner/test_vocab_based_ner.py +++ b/medcat-v2/tests/components/ner/test_vocab_based_ner.py @@ -29,6 +29,7 @@ class NerInitTests(ComponentInitTests, unittest.TestCase): expected_def_components = 3 comp_type = types.CoreComponentType.ner default_cls = vocab_based_ner.NER + default_creator = vocab_based_ner.NER.create_new_component module = vocab_based_ner @classmethod diff --git a/medcat-v2/tests/components/ner/trf/test_transformers_ner.py b/medcat-v2/tests/components/ner/trf/test_transformers_ner.py index d6655cc19..5fc11906e 100644 --- a/medcat-v2/tests/components/ner/trf/test_transformers_ner.py +++ b/medcat-v2/tests/components/ner/trf/test_transformers_ner.py @@ -236,8 +236,9 @@ def test_ignore_extra_labels(self): # Load the saved model loaded_ner = TransformersNER.deserialise_from( model_path, + cnf=self.cdb.config.components.ner, cdb=self.cdb, - base_tokenizer=self.base_tokenizer)._component + tokenizer=self.base_tokenizer)._component # Get initial number of labels initial_num_labels = len(loaded_ner.tokenizer.label_map) diff --git a/medcat-v2/tests/components/normalizing/test_normalizer.py b/medcat-v2/tests/components/normalizing/test_normalizer.py index 0631ad2e1..359dd5973 100644 --- a/medcat-v2/tests/components/normalizing/test_normalizer.py +++ b/medcat-v2/tests/components/normalizing/test_normalizer.py @@ -22,6 +22,7 @@ def __call__(selt, text: str) -> FakeDocument: class NormaliserInitTests(ComponentInitTests, unittest.TestCase): comp_type = types.CoreComponentType.token_normalizing default_cls = normalizer.TokenNormalizer + default_creator = normalizer.TokenNormalizer.create_new_component module = normalizer @classmethod diff --git a/medcat-v2/tests/components/tagging/test_tagger.py b/medcat-v2/tests/components/tagging/test_tagger.py index 7478924db..88615e14c 100644 --- a/medcat-v2/tests/components/tagging/test_tagger.py +++ b/medcat-v2/tests/components/tagging/test_tagger.py @@ -9,4 +9,5 @@ class TaggerInitTests(ComponentInitTests, unittest.TestCase): comp_type = types.CoreComponentType.tagging default_cls = tagger.TagAndSkipTagger + default_creator = tagger.TagAndSkipTagger.create_new_component module = tagger diff --git a/medcat-v2/tests/components/test_registration.py b/medcat-v2/tests/components/test_registration.py index 247b1c3d7..61496c292 100644 --- a/medcat-v2/tests/components/test_registration.py +++ b/medcat-v2/tests/components/test_registration.py @@ -1,5 +1,3 @@ -from typing import Any, Optional - from medcat.components import types from medcat.config.config import Config, ComponentConfig from medcat.cdb.cdb import CDB @@ -29,14 +27,8 @@ def get_type(self): return types.CoreComponentType.ner @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component(cls, cnf, tokenizer, cdb, vocab, model_load_path): + return cls() class WithInitNER(types.AbstractCoreComponent): @@ -55,30 +47,25 @@ def get_type(self): return types.CoreComponentType.ner @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - return [tokenizer, cdb] - - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - return {} + def create_new_component(cls, cnf, tokenizer, cdb, vocab, model_load_path): + return cls(cnf, tokenizer) class RegisteredCompBaseTests(unittest.TestCase): TYPE = types.CoreComponentType.ner - TO_REGISTR = NoInitNER + TO_REGISTR_CLS = NoInitNER @classmethod def setUpClass(cls): - types.register_core_component(cls.TYPE, cls.TO_REGISTR.name, + cls.TO_REGISTR = cls.TO_REGISTR_CLS.create_new_component + types.register_core_component(cls.TYPE, cls.TO_REGISTR_CLS.name, cls.TO_REGISTR) @classmethod def tearDownClass(cls): # unregister component types._CORE_REGISTRIES[cls.TYPE].unregister_component( - cls.TO_REGISTR.name) + cls.TO_REGISTR_CLS.name) class CoreCompNoInitRegistrationTests(RegisteredCompBaseTests): @@ -90,19 +77,17 @@ class CoreCompNoInitRegistrationTests(RegisteredCompBaseTests): @classmethod def setUpClass(cls): super().setUpClass() - cls.init_args = cls.TO_REGISTR.get_init_args( - cls.FTOK, cls.FCDB, cls.FVOCAB, None) - cls.init_kwargs = cls.TO_REGISTR.get_init_kwargs( - cls.FTOK, cls.FCDB, cls.FVOCAB, None) + + def register_args(self): + return None, None, None, None, None def test_can_create_component(self): - comp = types.create_core_component(self.TYPE, self.TO_REGISTR.name, - *self.init_args, **self.init_kwargs) - self.assertIsInstance(comp, self.TO_REGISTR) + comp = types.create_core_component(self.TYPE, self.TO_REGISTR_CLS.name, *self.register_args()) + self.assertIsInstance(comp, self.TO_REGISTR_CLS) class CoreCompWithInitRegistrationTests(CoreCompNoInitRegistrationTests): - TO_REGISTR = WithInitNER + TO_REGISTR_CLS = WithInitNER class CoreCompNoInitCATTests(RegisteredCompBaseTests): @@ -117,7 +102,7 @@ def setUpClass(cls): # set name in component config comp_cnf: ComponentConfig = getattr(cls.cdb.config.components, cls.TYPE.name) - comp_cnf.comp_name = cls.TO_REGISTR.name + comp_cnf.comp_name = cls.TO_REGISTR_CLS.name # NOTE: init arguments should be handled automatically cls.cat = CAT(cdb=cls.cdb, vocab=cls.vocab) @@ -125,7 +110,7 @@ def test_can_be_used_in_config(self): self.assertIsInstance(self.cat, CAT) def test_comp_runs(self): - with unittest.mock.patch.object(self.TO_REGISTR, "__call__", + with unittest.mock.patch.object(self.TO_REGISTR_CLS, "__call__", unittest.mock.MagicMock() ) as mock_call: self.cat.get_entities("Some text") @@ -142,8 +127,8 @@ def test_can_save_and_load(self): cat = CAT.load_model_pack(full_path) self.assertIsInstance(cat, CAT) comp = cat._pipeline.get_component(self.TYPE) - self.assertIsInstance(comp, self.TO_REGISTR) + self.assertIsInstance(comp, self.TO_REGISTR_CLS) class CoreCompWithInitCATTests(CoreCompNoInitCATTests): - TO_REGISTR = WithInitNER + TO_REGISTR_CLS = WithInitNER diff --git a/medcat-v2/tests/components/test_types.py b/medcat-v2/tests/components/test_types.py index 74236254b..f39fe231f 100644 --- a/medcat-v2/tests/components/test_types.py +++ b/medcat-v2/tests/components/test_types.py @@ -21,6 +21,11 @@ def __call__(self, raw: BaseDocument, mutable: MutableDocument ) -> MutableDocument: return mutable + @classmethod + def create_new_component(cls, cnf, tokenizer, + cdb, vocab, model_load_path) -> 'FakeCoreComponent': + return cls(model_load_path) + class TypesRegistrationTests(unittest.TestCase): # NOTE: if/when default commponents get added, this needs to change @@ -28,12 +33,15 @@ class TypesRegistrationTests(unittest.TestCase): COMP_TYPE = types.CoreComponentType.linking WRONG_TYPE = types.CoreComponentType.ner COMP_NAME = "test-linker" - BCC = FakeCoreComponent + BCC = FakeCoreComponent.create_new_component + + def creation_args(self): + return None, None, None, None, self.COMP_TYPE def setUp(self): types.register_core_component(self.COMP_TYPE, self.COMP_NAME, self.BCC) self.registered = types.create_core_component( - self.COMP_TYPE, self.COMP_NAME, self.COMP_TYPE) + self.COMP_TYPE, self.COMP_NAME, *self.creation_args()) def tearDown(self): for registry in types._CORE_REGISTRIES.values(): @@ -71,11 +79,11 @@ def test_registered_is_fake_component(self): def test_does_not_get_incorrect_type(self): with self.assertRaises(MedCATRegistryException): - types.create_core_component(self.WRONG_TYPE, self.COMP_NAME) + types.create_core_component(self.WRONG_TYPE, self.COMP_NAME, *self.creation_args()) def test_does_not_get_incorrect_name(self): with self.assertRaises(MedCATRegistryException): - types.create_core_component(self.COMP_TYPE, "#" + self.COMP_NAME) + types.create_core_component(self.COMP_TYPE, "#" + self.COMP_NAME, *self.creation_args()) def test_lists_registered_component(self): comps = types.get_registered_components(self.COMP_TYPE) diff --git a/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py b/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py index 87a80779d..1d9ebaef8 100644 --- a/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py +++ b/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py @@ -11,14 +11,12 @@ class DefaultTokenizerInitTests(unittest.TestCase): default_provider = 'spacy' default_cls = SpacyTokenizer + default_creator = SpacyTokenizer.create_new_tokenizer exp_num_def_tokenizers = 2 @classmethod def setUpClass(cls): cls.cnf = Config() - cls.cnf.general.nlp.init_args = cls.default_cls.get_init_args(cls.cnf) - cls.cnf.general.nlp.init_kwargs = cls.default_cls.get_init_kwargs( - cls.cnf) def test_has_default(self): avail_tokenizers = tokenizers.list_available_tokenizers() @@ -26,12 +24,11 @@ def test_has_default(self): name, cls_name = [(t_name, t_cls) for t_name, t_cls in avail_tokenizers if t_name == self.default_provider][0] self.assertEqual(name, self.default_provider) - self.assertIs(cls_name, self.default_cls.__name__) + self.assertIs(cls_name, self.default_creator.__name__) def test_can_create_def_tokenizer(self): tokenizer = tokenizers.create_tokenizer( - self.default_provider, *self.cnf.general.nlp.init_args, - **self.cnf.general.nlp.init_kwargs) + self.default_provider, self.cnf) self.assertIsInstance(tokenizer, runtime_checkable(tokenizers.BaseTokenizer)) self.assertIsInstance(tokenizer, self.default_cls) @@ -40,3 +37,4 @@ def test_can_create_def_tokenizer(self): class DefaultTokenizerInitTests2(DefaultTokenizerInitTests): default_provider = 'regex' default_cls = RegexTokenizer + default_creator = RegexTokenizer.create_new_tokenizer diff --git a/medcat-v2/tests/utils/ner/test_deid.py b/medcat-v2/tests/utils/ner/test_deid.py index 5d7b0eb4b..cc5a3e968 100644 --- a/medcat-v2/tests/utils/ner/test_deid.py +++ b/medcat-v2/tests/utils/ner/test_deid.py @@ -61,9 +61,7 @@ def test_can_create_model(self): self.assertIsNotNone(deid_model) -tokenizer = create_tokenizer( - 'spacy', 'en_core_web_md', cnf.general.nlp.disabled_components, - cnf.general.diacritics, cnf.preprocessing.max_document_length) +tokenizer = create_tokenizer('spacy', Config()) def _create_model() -> deid.DeIdModel: