diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c06c9c3..6b502ed 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -34,8 +34,8 @@ jobs: python -m mypy `git ls-tree --full-tree --name-only -r HEAD | grep ".py$" | grep -v "tests/"` --explicit-package-bases --follow-imports=normal - name: Test run: | - python -m unittest discover - python -m unittest discover -s medcat/compare_models + python tests/runner/custom_test_runner.py + python tests/runner/custom_test_runner.py -s medcat/compare_models # TODO - in the future, we might want to add automated tests for notebooks as well # though it's not really possible right now since the notebooks are designed # in a way that assumes interaction (i.e specifying model pack names) diff --git a/medcat/1_create_model/create_cdb/create_cdb.py b/medcat/1_create_model/create_cdb/create_cdb.py index b163422..11b895b 100644 --- a/medcat/1_create_model/create_cdb/create_cdb.py +++ b/medcat/1_create_model/create_cdb/create_cdb.py @@ -1,7 +1,8 @@ import os import pandas as pd from medcat.config import Config -from medcat.cdb_maker import CDBMaker +from medcat.model_creation.cdb_maker import CDBMaker +from medcat.storage.serialisers import serialise, AvailableSerialisers pd.options.mode.chained_assignment = None # type: ignore @@ -24,6 +25,10 @@ model_dir = os.path.join(BASE_PATH, "models", "cdb") output_cdb = os.path.join(model_dir, f"{release}_SNOMED_cdb.dat") +os.makedirs(output_cdb, exist_ok=True) +# NOTE: by default, new models creaeted at the same location will not be saved +# so here we allow overwrtiing +allow_overwrite = True csv = pd.read_csv(csv_path) # Remove null values @@ -50,9 +55,9 @@ # Setup config config = Config() -config.general['spacy_model'] = 'en_core_web_md' -config.cdb_maker['remove_parenthesis'] = 1 -config.general['cdb_source_name'] = f'SNOMED_{release}' +config.general.nlp.modelname = 'en_core_web_md' +config.cdb_maker.remove_parenthesis = 1 +# config.general.cdb_source_name = f'SNOMED_{release}' maker = CDBMaker(config) @@ -64,8 +69,8 @@ # Add type_id pretty names to cdb cdb.addl_info['type_id2name'] = pd.Series(csv.description_type_ids.values, index=csv.type_ids.astype(str)).to_dict() -cdb.config.linking['filters']['cuis'] = set(csv['cui'].tolist()) # Add all cuis to filter out legacy terms. +cdb.config.components.linking.filters.cuis = set(csv['cui'].tolist()) # Add all cuis to filter out legacy terms. # save model -cdb.save(output_cdb) +serialise(AvailableSerialisers.dill, cdb, output_cdb, overwrite=allow_overwrite) print(f"CDB Model saved successfully as: {output_cdb}") diff --git a/medcat/1_create_model/create_cdb/create_umls_cdb.py b/medcat/1_create_model/create_cdb/create_umls_cdb.py index f692024..73314d6 100644 --- a/medcat/1_create_model/create_cdb/create_umls_cdb.py +++ b/medcat/1_create_model/create_cdb/create_umls_cdb.py @@ -1,7 +1,8 @@ import os import pandas as pd from medcat.config import Config -from medcat.cdb_maker import CDBMaker +from medcat.model_creation.cdb_maker import CDBMaker +from medcat.storage.serialisers import serialise, AvailableSerialisers pd.options.mode.chained_assignment = None # type: ignore @@ -28,6 +29,10 @@ model_dir = os.path.join(BASE_PATH, "models", "cdb") output_cdb = os.path.join(model_dir, f"{release}_UMLS_cdb.dat") +os.makedirs(output_cdb, exist_ok=True) +# NOTE: by default, new models creaeted at the same location will not be saved +# so here we allow overwrtiing +allow_overwrite = True csv = pd.read_csv(csv_path) # Remove null values @@ -39,9 +44,9 @@ # Setup config config = Config() -config.general['spacy_model'] = 'en_core_web_md' -config.cdb_maker['remove_parenthesis'] = 1 -config.general['cdb_source_name'] = f'UMLS_{release}' +config.general.nlp.modelname = 'en_core_web_md' +config.cdb_maker.remove_parenthesis = 1 +# config.general.cdb_source_name = f'UMLS_{release}' maker = CDBMaker(config) @@ -52,8 +57,8 @@ cdb = maker.prepare_csvs(csv_paths, full_build=True) # Add type_id pretty names to cdb -cdb.config.linking['filters']['cuis'] = set(csv['cui'].tolist()) # Add all cuis to filter out legacy terms. +cdb.config.components.linking.filters.cuis = set(csv['cui'].tolist()) # Add all cuis to filter out legacy terms. # save model -cdb.save(output_cdb) +serialise(AvailableSerialisers.dill, cdb, output_cdb, overwrite=allow_overwrite) print(f"CDB Model saved successfully as: {output_cdb}") diff --git a/medcat/1_create_model/create_modelpack/create_modelpack.py b/medcat/1_create_model/create_modelpack/create_modelpack.py index 949e681..750af4b 100644 --- a/medcat/1_create_model/create_modelpack/create_modelpack.py +++ b/medcat/1_create_model/create_modelpack/create_modelpack.py @@ -39,27 +39,37 @@ def load_cdb_and_save_modelpack(cdb_path: str, str: The model pack path. """ # Load cdb - cdb = CDB.load(cdb_path) + cdb: CDB + try: + cdb = CDB.load(cdb_path) + except NotADirectoryError: + from medcat.utils.legacy.convert_cdb import get_cdb_from_old + cdb = get_cdb_from_old(cdb_path) # Set cdb configuration # technically we already created this during the cdb creation - cdb.config.ner['min_name_len'] = 2 - cdb.config.ner['upper_case_limit_len'] = 3 - cdb.config.general['spell_check'] = True - cdb.config.linking['train_count_threshold'] = 10 - cdb.config.linking['similarity_threshold'] = 0.3 - cdb.config.linking['train'] = True - cdb.config.linking['disamb_length_limit'] = 4 - cdb.config.general['full_unlink'] = True + cdb.config.components.ner.min_name_len = 2 + cdb.config.components.ner.upper_case_limit_len = 3 + cdb.config.general.spell_check = True + cdb.config.components.linking.train_count_threshold = 10 + cdb.config.components.linking.similarity_threshold = 0.3 + cdb.config.components.linking.train = True + cdb.config.components.linking.disamb_length_limit = 4 + cdb.config.general.full_unlink = True # Load vocab - vocab = Vocab.load(vocab_path) + vocab: Vocab + try: + vocab = Vocab.load(vocab_path) + except NotADirectoryError: + from medcat.utils.legacy.convert_vocab import get_vocab_from_old + vocab = get_vocab_from_old(vocab_path) # Initialise the model cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) # Create and save model pack - return cat.create_model_pack(save_dir_path=modelpack_path, model_pack_name=modelpack_name) + return cat.save_model_pack(modelpack_path, pack_name=modelpack_name) def load_cdb_and_save_modelpack_in_def_location(cdb_name: str, diff --git a/medcat/1_create_model/create_vocab/create_vocab.py b/medcat/1_create_model/create_vocab/create_vocab.py index 0d74668..5487622 100644 --- a/medcat/1_create_model/create_vocab/create_vocab.py +++ b/medcat/1_create_model/create_vocab/create_vocab.py @@ -1,4 +1,5 @@ from medcat.vocab import Vocab +from medcat.storage.serialisers import serialise, AvailableSerialisers import os vocab = Vocab() @@ -17,5 +18,6 @@ # embeddings of 300 dimensions is standard vocab.add_words(os.path.join(vocab_dir, 'vocab_data.txt'), replace=True) -vocab.make_unigram_table() -vocab.save(os.path.join(vocab_dir, "vocab.dat")) +vocab_folder = os.path.join(vocab_dir, "vocab.dat") +os.makedirs(vocab_folder, exist_ok=True) +serialise(AvailableSerialisers.dill, vocab, vocab_folder) diff --git a/medcat/2_train_model/1_unsupervised_training/unsupervised training.ipynb b/medcat/2_train_model/1_unsupervised_training/unsupervised training.ipynb index 77cd181..44f9771 100644 --- a/medcat/2_train_model/1_unsupervised_training/unsupervised training.ipynb +++ b/medcat/2_train_model/1_unsupervised_training/unsupervised training.ipynb @@ -55,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "cat.cdb.print_stats()" + "cat.cdb.get_basic_info()" ] }, { @@ -88,21 +88,12 @@ "outputs": [], "source": [ "# Print statistics on the CDB before training\n", - "cat.cdb.print_stats()\n", + "cat.cdb.get_basic_info()\n", "\n", "# Run the annotation procedure over all the documents we have,\n", "# given that we have a large number of documents this can take quite some time.\n", "\n", - "for i, text in enumerate(data['text'].values):\n", - " # This will now run the training in the background \n", - " try:\n", - " _ = cat(text, do_train=True)\n", - " except TypeError:\n", - " pass\n", - " \n", - " # So we know how things are moving\n", - " if i % 10000 == 0:\n", - " print(\"Finished {} - text blocks\".format(i))\n" + "cat.trainer.train_unsupervised(data.text)\n" ] }, { @@ -112,7 +103,7 @@ "outputs": [], "source": [ "# Print statistics on the CDB after training\n", - "cat.cdb.print_stats()" + "cat.cdb.get_basic_info()" ] }, { @@ -122,7 +113,8 @@ "outputs": [], "source": [ "# save modelpack\n", - "cat.create_model_pack(save_dir_path=model_dir, model_pack_name=output_modelpack)\n" + "\n", + "cat.save_model_pack(model_dir, pack_name=output_modelpack)\n" ] }, { @@ -135,7 +127,7 @@ ], "metadata": { "kernelspec": { - "display_name": "medcat", + "display_name": "venv_v2", "language": "python", "name": "python3" }, @@ -149,12 +141,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ]" - }, - "vscode": { - "interpreter": { - "hash": "4e4ccc64ca47f932c34194843713e175cf3a19af3798844e4190152d16ba61ca" - } + "version": "3.10.13" } }, "nbformat": 4, diff --git a/medcat/2_train_model/1_unsupervised_training/unsupervised_medcattraining.py b/medcat/2_train_model/1_unsupervised_training/unsupervised_medcattraining.py index 89c9b51..d82d7f6 100644 --- a/medcat/2_train_model/1_unsupervised_training/unsupervised_medcattraining.py +++ b/medcat/2_train_model/1_unsupervised_training/unsupervised_medcattraining.py @@ -28,13 +28,14 @@ df = cs.DataFrame(index=cogstack_indices, columns=text_columns) # type: ignore cat = CAT.load_model_pack(model_pack_path+model_pack_name) -cat.cdb.print_stats() -cat.train(data_iterator=df[text_columns].iterrows(), - nepochs=1, - fine_tune=True, - progress_print=10000, - is_resumed=False) +print(cat.cdb.get_basic_info()) +cat.trainer.train_unsupervised( + data_iterator=df[text_columns].iterrows(), + nepochs=1, + fine_tune=True, + progress_print=10000, + is_resumed=False) -cat.cdb.print_stats() +print(cat.cdb.get_basic_info()) -cat.create_model_pack(save_dir_path=model_pack_path, model_pack_name=output_modelpack_name) +cat.save_model_pack(target_folder=model_pack_path, pack_name=output_modelpack_name) diff --git a/medcat/2_train_model/1_unsupervised_training/unsupervised_training.py b/medcat/2_train_model/1_unsupervised_training/unsupervised_training.py index 9a202e0..b62967d 100644 --- a/medcat/2_train_model/1_unsupervised_training/unsupervised_training.py +++ b/medcat/2_train_model/1_unsupervised_training/unsupervised_training.py @@ -1,3 +1,4 @@ +from medcat.cat import logger as cat_logger from medcat.cat import CAT import pandas as pd import os @@ -44,10 +45,10 @@ # Load modelpack print('Loading modelpack') cat = CAT.load_model_pack(model_pack_path) -cat.log.addHandler(logging.StreamHandler()) # add console output +cat_logger.addHandler(logging.StreamHandler()) # add console output print('STATS:') -cat.cdb.print_stats() +print(cat.cdb.get_basic_info()) # CHANGE AS NEEDED - if the number of spligt files is different all_data_files = [f'split_notes_5M_{nr}.csv' for nr in range(1, 20)] # file containing training material. @@ -55,14 +56,14 @@ # Load training data print('Load data for', i, 'from', data_file) data = pd.read_csv(os.path.join(data_dir, data_file)) - cat.train(data.text.values, progress_print=100) + cat.trainer.train_unsupervised(data.text.values, progress_print=100) print('Stats now, after', i) - cat.cdb.print_stats() + print(cat.cdb.get_basic_info()) # save modelpack - cat.create_model_pack(save_dir_path=model_dir, model_pack_name=f"{output_modelpack}_{i}") + cat.save_model_pack(target_folder=model_dir, pack_name=f"{output_modelpack}_{i}") # save modelpack - ALL -cat.create_model_pack(save_dir_path=model_dir, model_pack_name=output_modelpack) +cat.save_model_pack(target_folder=model_dir, pack_name=output_modelpack) diff --git a/medcat/2_train_model/2_supervised_training/meta_annotation_training.ipynb b/medcat/2_train_model/2_supervised_training/meta_annotation_training.ipynb index 54a6c2c..af5d206 100644 --- a/medcat/2_train_model/2_supervised_training/meta_annotation_training.ipynb +++ b/medcat/2_train_model/2_supervised_training/meta_annotation_training.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d58c720d", "metadata": {}, "outputs": [], @@ -11,9 +11,9 @@ "import os\n", "from datetime import date\n", "from medcat.cat import CAT\n", - "from medcat.meta_cat import MetaCAT\n", - "from medcat.config_meta_cat import ConfigMetaCAT\n", - "from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT" + "from medcat.components.addons.meta_cat import MetaCAT, MetaCATAddon\n", + "from medcat.config.config_meta_cat import ConfigMetaCAT\n", + "from medcat.components.addons.meta_cat.mctokenizers.bert_tokenizer import TokenizerWrapperBERT" ] }, { @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "b205d51b", "metadata": {}, "outputs": [ @@ -73,14 +73,19 @@ } ], "source": [ - "\n", + "def get_meta_cats(cat: CAT) -> list[MetaCATAddon]:\n", + " return [\n", + " addon for addon in cat._pipeline.iter_addons()\n", + " if isinstance(addon, MetaCATAddon)\n", + " ]\n", + "meta_cats = get_meta_cats(cat)\n", "# Check what meta cat models are in this model pack.\n", - "print(f'There are: {len(cat._meta_cats)} meta cat models in this model pack.')" + "print(f'There are: {len(meta_cats)} meta cat models in this model pack.')" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "31d7632a", "metadata": {}, "outputs": [ @@ -102,12 +107,12 @@ } ], "source": [ - "print(cat._meta_cats[0])" + "print(meta_cats[0])" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "e9180c4c", "metadata": {}, "outputs": [ @@ -129,12 +134,12 @@ } ], "source": [ - "print(cat._meta_cats[1])" + "print(meta_cats[1])" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "275ca9ff", "metadata": {}, "outputs": [ @@ -156,7 +161,7 @@ } ], "source": [ - "print(cat._meta_cats[2])" + "print(meta_cats[2])" ] }, { @@ -181,7 +186,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(cat._meta_cats[0].config.general.alternative_category_names)" + "print(meta_cats[0].config.general.alternative_category_names)" ] }, { @@ -203,8 +208,8 @@ "category_name_mapping = [[\"Presence\"],[\"Temporality\",\"Time\"],[\"Experiencer\",\"Subject\"]]\n", "lookup = {item: group for group in category_name_mapping for item in group}\n", "\n", - "for meta_model in range(len(cat._meta_cats)):\n", - " cat._meta_cats[meta_model].config.general.alternative_category_names = lookup.get(cat._meta_cats[meta_model].config.general.category_name)" + "for meta_model in range(len(meta_cats)):\n", + " meta_cats[meta_model].config.general.alternative_category_names = lookup.get(meta_cats[meta_model].config.general.category_name)" ] }, { @@ -229,7 +234,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(cat._meta_cats[0].config.general.alternative_class_names)" + "print(meta_cats[0].config.general.alternative_class_names)" ] }, { @@ -256,8 +261,8 @@ " \"Presence\": [[\"Hypothetical (N/A)\", \"Hypothetical\"], [\"Not present (False)\", \"False\"], [\"Present (True)\", \"True\"]]\n", "}\n", "\n", - "for meta_model in range(len(cat._meta_cats)):\n", - " cat._meta_cats[meta_model].config.general.alternative_class_names = class_name_mapping[cat._meta_cats[meta_model].config.general.category_name]" + "for meta_model in range(len(meta_cats)):\n", + " meta_cats[meta_model].config.general.alternative_class_names = class_name_mapping[meta_cats[meta_model].config.general.category_name]" ] }, { @@ -276,7 +281,7 @@ "outputs": [], "source": [ "# Train the first meta cat model - 'Temporality' Task.\n", - "meta_cat = cat._meta_cats[0]\n", + "meta_cat: MetaCATAddon = meta_cats[0]\n", "\n", "# to overwrite the existing model, resave the fine-tuned model with the same model pack dir\n", "meta_cat_task = meta_cat.config.general.category_name\n", @@ -287,7 +292,7 @@ "#save_dir_path= \"test_meta_\"+meta_cat_task # Where to save the meta_model and results. \n", "\n", "# train the meta_model\n", - "results = meta_cat.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n", + "results = meta_cat.mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n", "\n", "# Save results\n", "json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_cat_task+'_results.json'), 'w'))" diff --git a/medcat/2_train_model/2_supervised_training/meta_annotation_training_advanced.ipynb b/medcat/2_train_model/2_supervised_training/meta_annotation_training_advanced.ipynb index d32d399..fb85f6c 100644 --- a/medcat/2_train_model/2_supervised_training/meta_annotation_training_advanced.ipynb +++ b/medcat/2_train_model/2_supervised_training/meta_annotation_training_advanced.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d58c720d", "metadata": {}, "outputs": [], @@ -11,9 +11,9 @@ "import os\n", "from datetime import date\n", "from medcat.cat import CAT\n", - "from medcat.meta_cat import MetaCAT\n", - "from medcat.config_meta_cat import ConfigMetaCAT\n", - "from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT" + "from medcat.components.addons.meta_cat import MetaCAT, MetaCATAddon\n", + "from medcat.config.config_meta_cat import ConfigMetaCAT\n", + "from medcat.components.addons.meta_cat.mctokenizers.bert_tokenizer import TokenizerWrapperBERT" ] }, { @@ -219,7 +219,7 @@ "\n", "synthetic_data_export = [[],[],[]]\n", "\n", - "results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path,data_oversampled=synthetic_data_export)\n", + "results = meta_model.train_from_json(mctrainer_export, save_dir_path=save_dir_path,data_oversampled=synthetic_data_export)\n", "\n", "# Save results\n", "json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))" diff --git a/medcat/2_train_model/2_supervised_training/supervised training.ipynb b/medcat/2_train_model/2_supervised_training/supervised training.ipynb index 3d634ca..2a1ac10 100644 --- a/medcat/2_train_model/2_supervised_training/supervised training.ipynb +++ b/medcat/2_train_model/2_supervised_training/supervised training.ipynb @@ -10,7 +10,8 @@ "import json\n", "import pandas as pd\n", "from datetime import date\n", - "from medcat.cat import CAT" + "from medcat.cat import CAT\n", + "from medcat.stats.stats import get_stats" ] }, { @@ -52,7 +53,7 @@ "source": [ "# Create CAT - the main class from medcat used for concept annotation\n", "cat = CAT.load_model_pack(model_pack_path)\n", - "cat.config.linking['filters'] = {'cuis':set()} # To remove exisitng filters" + "cat.config.components.linking.filters.cuis = set() # To remove exisitng filters" ] }, { @@ -74,7 +75,7 @@ "if snomed_filter_path:\n", " snomed_filter = set(json.load(open(snomed_filter_path)))\n", "else:\n", - " snomed_filter = set(cat.cdb.cui2preferred_name.keys())\n" + " snomed_filter = set(cat.cdb.cui2info.keys())\n" ] }, { @@ -90,13 +91,17 @@ "metadata": {}, "outputs": [], "source": [ - "cat.train_supervised_from_json(data_path=mctrainer_export_path, \n", - " nepochs=3,\n", - " reset_cui_count=False,\n", - " print_stats=True,\n", - " use_filters=True,\n", - " extra_cui_filter=snomed_filter, # If not filter is set remove this line\n", - " )\n" + "import json\n", + "with open(mctrainer_export_path) as f:\n", + " data = json.load(f)\n", + "cat.trainer.train_supervised_raw(\n", + " data=data, \n", + " nepochs=3,\n", + " reset_cui_count=False,\n", + " print_stats=True,\n", + " use_filters=True,\n", + " extra_cui_filter=snomed_filter, # If not filter is set remove this line\n", + ")\n" ] }, { @@ -106,15 +111,6 @@ "# Stats" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = json.load(open(mctrainer_export_path))" - ] - }, { "cell_type": "code", "execution_count": null, @@ -123,7 +119,7 @@ }, "outputs": [], "source": [ - "fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = cat._print_stats(data, use_filters=True)" + "fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = get_stats(cat, data, use_project_filters=True)" ] }, { @@ -159,7 +155,7 @@ "outputs": [], "source": [ "# save modelpack\n", - "cat.create_model_pack(os.path.join(model_dir, output_modelpack))" + "cat.save_model_pack(os.path.join(model_dir, output_modelpack))" ] }, { @@ -198,7 +194,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "venv_v2", "language": "python", "name": "python3" }, @@ -212,12 +208,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6 (default, Sep 26 2022, 11:37:49) \n[Clang 14.0.0 (clang-1400.0.29.202)]" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } + "version": "3.10.13" } }, "nbformat": 4, diff --git a/medcat/3_run_model/run_model.ipynb b/medcat/3_run_model/run_model.ipynb index 23964ed..11cb7b0 100755 --- a/medcat/3_run_model/run_model.ipynb +++ b/medcat/3_run_model/run_model.ipynb @@ -120,9 +120,9 @@ " snomed_filter = set(json.load(open(snomed_filter_path)))\n", "else:\n", " print('There is no concept filter set')\n", - " snomed_filter = set(cat.cdb.cui2preferred_name.keys())\n", + " snomed_filter = set(cat.cdb.cui2info.keys())\n", "\n", - "cat.config.linking['filters']['cuis'] = snomed_filter \n" + "cat.config.linking.filters.cuis = snomed_filter \n" ] }, { @@ -155,14 +155,26 @@ "outputs": [], "source": [ "batch_char_size = 50000 # Batch size (BS) in number of characters\n", - "cat.multiprocessing_batch_char_size(data_iterator(df, doc_id_column, doc_text_column),\n", - " batch_size_chars=batch_char_size,\n", - " only_cui=False,\n", - " nproc=8, # Number of processors\n", - " out_split_size_chars=20*batch_char_size,\n", - " save_dir_path=ann_folder_path,\n", - " min_free_memory=0.1,\n", - " )\n", + "for text_id, text in data_iterator(df, doc_id_column, doc_text_column):\n", + " # NOTE: get_entities_multi_text returns an generator\n", + " # so no work gets done until the generator use materialised\n", + " output = cat.get_entities_multi_texts(text,\n", + " only_cui=False,\n", + " # nproc=8, # Number of processors\n", + " # out_split_size_chars=20*batch_char_size,\n", + " save_dir_path=ann_folder_path,\n", + " # min_free_memory=0.1,\n", + " )\n", + " # so if we're doing a small amount of data and/or not saving it on disk\n", + " # we probably want to just convert it to a list\n", + " output = list(output)\n", + " # However, if we we're saving the data on disk and don't\n", + " # want to duplicate in memory (i.e there's a lot of data\n", + " # and it can't all be held in memory), we may want to\n", + " # just exhaust the generator\n", + " # NOTE: uncomment to use, but commnet the `ouput = list(ouput)`` line\n", + " # for _ in output:\n", + " # pass\n", "\n", "medcat_logger.warning(f'Annotation process complete!')\n" ] @@ -225,7 +237,7 @@ "source": [ "text = \"He was diagnosed with heart failure\"\n", "doc = cat(text)\n", - "print(doc.ents)" + "print(doc.linked_ents)" ] }, { @@ -235,8 +247,8 @@ "outputs": [], "source": [ "# Display Snomed codes\n", - "for ent in doc.ents:\n", - " print(ent, \" - \", ent._.cui, \" - \", cat.cdb.cui2preferred_name[ent._.cui])" + "for ent in doc.linked_ents:\n", + " print(ent, \" - \", ent.cui, \" - \", cat.cdb.cui2info[ent.cui]['preferred_name'])" ] }, { @@ -246,8 +258,8 @@ "outputs": [], "source": [ "# To show semantic types for each entity\n", - "for ent in doc.ents:\n", - " print(ent, \" - \", cat.cdb.cui2type_ids.get(ent._.cui))" + "for ent in doc.linked_ents:\n", + " print(ent, \" - \", cat.cdb.cui2info[ent.cui]['type_ids'])" ] }, { @@ -258,7 +270,7 @@ "source": [ "# Display\n", "from spacy import displacy\n", - "displacy.render(doc, style='ent', jupyter=True)" + "displacy.render(doc._delegate, style='ent', jupyter=True)" ] }, { @@ -298,7 +310,7 @@ "metadata": {}, "outputs": [], "source": [ - "cat.cdb.print_stats()" + "cat.cdb.get_basic_info()" ] }, { @@ -321,7 +333,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "venv_v2", "language": "python", "name": "python3" }, @@ -335,12 +347,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } + "version": "3.10.13" } }, "nbformat": 4, diff --git a/medcat/3_run_model/run_model.py b/medcat/3_run_model/run_model.py index 2a840ab..5e9ab0d 100644 --- a/medcat/3_run_model/run_model.py +++ b/medcat/3_run_model/run_model.py @@ -57,9 +57,9 @@ if snomed_filter_path: snomed_filter = set(json.load(open(snomed_filter_path))) else: - snomed_filter = set(cat.cdb.cui2preferred_name.keys()) + snomed_filter = set(cat.cdb.cui2info.keys()) -cat.config.linking['filters']['cuis'] = snomed_filter +cat.config.components.linking.filters.cuis = snomed_filter del snomed_filter # build query, change as appropriate @@ -82,14 +82,16 @@ def relevant_text_gen(generator, doc_id = '_id', text_col='body_analysed'): batch_char_size = 500000 # Batch size (BS) in number of characters -cat.multiprocessing_batch_char_size(relevant_text_gen(search_gen), - batch_size_chars=batch_char_size, - only_cui=False, - nproc=8, # Number of processors - out_split_size_chars=20*batch_char_size, - save_dir_path=ann_folder_path, - min_free_memory=0.1, - ) +# NOTE: no multiprocessing in v2 right now +for text in relevant_text_gen(search_gen): + cat.get_entities(text, + # batch_size_chars=batch_char_size, + # only_cui=False, + # nproc=8, # Number of processors + # out_split_size_chars=20*batch_char_size, + # save_dir_path=ann_folder_path, + # min_free_memory=0.1, + ) medcat_logger.warning(f'Annotation process complete!') diff --git a/medcat/compare_models/compare.py b/medcat/compare_models/compare.py index 5bdce6e..5110c92 100644 --- a/medcat/compare_models/compare.py +++ b/medcat/compare_models/compare.py @@ -1,6 +1,7 @@ -from typing import List, Tuple, Dict, Set, Optional, Union, Iterator +from typing import Tuple, Dict, Set, Optional, Union, Iterator from functools import partial import glob +import json from medcat.cat import CAT @@ -34,14 +35,12 @@ def do_counting(cat1: CAT, cat2: CAT, ann_diffs: PerAnnotationDifferences, doc_limit: int = -1) -> ResultsTally: def cui2name(cat, cui): - if cui in cat.cdb.cui2preferred_name: - return cat.cdb.cui2preferred_name[cui] - all_names = cat.cdb.cui2names[cui] - # longest anme - return sorted(all_names, key=lambda name: len(name), reverse=True)[0] - res1 = ResultsTally(pt2ch=_get_pt2ch(cat1), cat_data=cat1.cdb.make_stats(), + ci = cat.cdb.cui2info[cui] + # longest name + return ci['preferred_name'] or sorted(ci['names'], key=lambda name: len(name), reverse=True)[0] + res1 = ResultsTally(pt2ch=_get_pt2ch(cat1), cat_data=cat1.cdb.get_basic_info(), cui2name=partial(cui2name, cat1)) - res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(), + res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.get_basic_info(), cui2name=partial(cui2name, cat2)) total = doc_limit if doc_limit != -1 else None for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values(), total=total): @@ -65,8 +64,8 @@ def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str save_opts = SaveOptions(use_db=True, db_file_name=temp_file.name, clean_callback=temp_file.close) pad = PerAnnotationDifferences(pt2ch1=pt2ch1, pt2ch2=pt2ch2, - model1_cuis=set(cat1.cdb.cui2names), - model2_cuis=set(cat2.cdb.cui2names), + model1_cuis=set(cat1.cdb.cui2info), + model2_cuis=set(cat2.cdb.cui2info), keep_raw=keep_raw, save_options=save_opts) total = doc_limit if doc_limit != -1 else None @@ -99,10 +98,14 @@ def load_and_train(model_pack_path: str, mct_export_path: str) -> CAT: # NOTE: Allowing mct_export_path to contain wildcat ("*"). # And in such a case, iterating over all matching files if "*" not in mct_export_path: - cat.train_supervised_from_json(mct_export_path) + with open(mct_export_path) as f: + mct_export = json.load(f) + cat.trainer.train_supervised_raw(mct_export) else: for file in glob.glob(mct_export_path): - cat.train_supervised_from_json(file) + with open(file) as f: + mct_export = json.load(f) + cat.trainer.train_supervised_raw(mct_export) return cat @@ -151,8 +154,8 @@ def get_diffs_for(model_pack_path_1: str, if show_progress: print("After adding children from 2nd model have a total of", len(cui_filter), "CUIs") - cat1.config.linking.filters.cuis = cui_filter - cat2.config.linking.filters.cuis = cui_filter + cat1.config.components.linking.filters.cuis = cui_filter + cat2.config.components.linking.filters.cuis = cui_filter ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw, doc_limit=doc_limit) if show_progress: diff --git a/medcat/compare_models/compare_cdb.py b/medcat/compare_models/compare_cdb.py index 5f99574..bdaf890 100644 --- a/medcat/compare_models/compare_cdb.py +++ b/medcat/compare_models/compare_cdb.py @@ -1,6 +1,7 @@ from typing import Dict, Set, Tuple from medcat.cdb import CDB +from medcat.cdb.concepts import CUIInfo import tqdm from itertools import chain @@ -96,7 +97,7 @@ class DictComparisonResults(BaseModel): values: DictCompareValues @classmethod - def get(cls, d1: dict, d2: dict, progress: bool = True) -> "DictComparisonResults": + def get(cls, d1: dict[str, CUIInfo], d2: dict[str, CUIInfo], progress: bool = True) -> "DictComparisonResults": return cls(keys=DictCompareKeys.get(d1, d2), values=DictCompareValues.get(d1, d2, progress=progress)) @@ -119,6 +120,6 @@ def compare(cdb1: CDB, Returns: CDBCompareResults: _description_ """ - reg = DictComparisonResults.get(cdb1.cui2names, cdb2.cui2names, progress=show_progress) - snames = DictComparisonResults.get(cdb1.cui2snames, cdb2.cui2snames, progress=show_progress) + reg = DictComparisonResults.get(cdb1.cui2info, cdb2.cui2info, progress=show_progress) + snames = DictComparisonResults.get(cdb1.cui2info, cdb2.cui2info, progress=show_progress) return CDBCompareResults(names=reg, snames=snames) diff --git a/medcat/compare_models/tests/test_compare.py b/medcat/compare_models/tests/test_compare.py index 79f53e9..4a9f975 100644 --- a/medcat/compare_models/tests/test_compare.py +++ b/medcat/compare_models/tests/test_compare.py @@ -77,11 +77,17 @@ class TrainAndCompareTests(unittest.TestCase): # this tests that the training is called @classmethod - @unittest.mock.patch("medcat.cat.CAT.train_supervised_from_json") - def _get_diffs(cls, mct_export_path: str, method): + @unittest.mock.patch("medcat.trainer.Trainer.train_supervised_raw") + def _get_diffs(cls, mct_export_path: str, method: unittest.mock.MagicMock): + orig_load_method = CAT.load_model_pack + def _wrapped_load_method(*args, **kwargs): + cat = orig_load_method(*args, **kwargs) + cat.trainer.train_supervised_raw = method + return cat + CAT.load_model_pack = _wrapped_load_method diffs = get_diffs_for(cls.cat_path, mct_export_path, cls.docs_file, supervised_train_comparison_model=True) - cls.assertTrue(cls, method.called) + method.assert_called() return diffs diff --git a/medcat/compare_models/tests/test_compare_annotations.py b/medcat/compare_models/tests/test_compare_annotations.py index b2b6fa5..5fdf747 100644 --- a/medcat/compare_models/tests/test_compare_annotations.py +++ b/medcat/compare_models/tests/test_compare_annotations.py @@ -38,8 +38,9 @@ def _cui2name(self, cui: str) -> str: return self.cui2name[cui] def setUp(self) -> None: - self.res = compare_annotations.ResultsTally(cat_data={"stats": "don't matter"}, - cui2name=self._cui2name, pt2ch=None) + self.res = compare_annotations.ResultsTally( + pt2ch=None, cat_data={"stats": "don't matter"}, + cui2name=self._cui2name) for entities in self.entities: self.res.count(entities['entities']) diff --git a/medcat/evaluate_mct_export/mct_analysis.py b/medcat/evaluate_mct_export/mct_analysis.py index cc35707..a28cb1c 100644 --- a/medcat/evaluate_mct_export/mct_analysis.py +++ b/medcat/evaluate_mct_export/mct_analysis.py @@ -2,7 +2,9 @@ import plotly.graph_objects as go from medcat.cat import CAT from datetime import date +from typing import cast +import os import json import torch import math @@ -11,13 +13,17 @@ import pandas as pd from collections import Counter from typing import List, Dict, Iterator, Tuple, Optional, Union -from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase +from medcat.components.addons.meta_cat.mctokenizers.tokenizers import TokenizerWrapperBase -from medcat.utils.meta_cat.ml_utils import create_batch_piped_data +from medcat.components.addons.meta_cat.ml_utils import create_batch_piped_data -from medcat.meta_cat import MetaCAT -from medcat.config_meta_cat import ConfigMetaCAT -from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values +from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon, MetaCAT +from medcat.stats.stats import get_stats +from medcat.utils.legacy.identifier import is_legacy_model_pack +from medcat.utils.legacy.convert_meta_cat import get_meta_cat_from_old +from medcat.config.config_meta_cat import ConfigMetaCAT +from medcat.components.addons.meta_cat.data_utils import prepare_from_json, encode_category_values +from medcat.storage.serialisers import deserialise import warnings @@ -35,8 +41,11 @@ def __init__(self, mct_export_paths: List[str], model_pack_path: Optional[str] = :param model_pack_path: Path to medcat modelpack """ self.cat: Optional[CAT] = None + self.is_legacy_model_pack = False if model_pack_path: self.cat = CAT.load_model_pack(model_pack_path) + mpp = model_pack_path.removesuffix(".zip") + self.is_legacy_model_pack = is_legacy_model_pack(mpp) self.mct_export_paths = mct_export_paths self.mct_export = self._load_mct_exports(self.mct_export_paths) self.project_names: List[str] = [] @@ -103,7 +112,7 @@ def annotation_df(self) -> pd.DataFrame: """ annotation_df = pd.DataFrame(self.annotations) if self.cat: - annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(self.cat.cdb.cui2preferred_name)) + annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(lambda cui: cast(CAT, self.cat).cdb.get_name(cui))) exceptions: List[ValueError] = [] # try the default format as well as the format specified above for format in [None, DATETIME_FORMAT]: @@ -137,9 +146,11 @@ def concept_summary(self, extra_cui_filter: Optional[str] = None) -> pd.DataFram concept_count_df['count_variations_ratio'] = round(concept_count_df['concept_count'] / concept_count_df['variations'], 3) if self.cat: - fps,fns,tps,cui_prec,cui_rec,cui_f1,cui_counts,examples = self.cat._print_stats(data=self.mct_export, - use_project_filters=True, - extra_cui_filter=extra_cui_filter) + fps,fns,tps,cui_prec,cui_rec,cui_f1,cui_counts,examples = get_stats(self.cat, + data=self.mct_export, # type: ignore + use_project_filters=True, + # extra_cui_filter=extra_cui_filter + ) concept_count_df['fps'] = concept_count_df['cui'].map(fps) concept_count_df['fns'] = concept_count_df['cui'].map(fns) concept_count_df['tps'] = concept_count_df['cui'].map(tps) @@ -252,11 +263,11 @@ def rename_meta_anns(self, meta_anns2rename: dict = dict(), meta_ann_values2rena return def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict: - device = torch.device(config.general['device']) # Create a torch device - batch_size_eval = config.general['batch_size_eval'] - pad_id = config.model['padding_idx'] - ignore_cpos = config.model['ignore_cpos'] - class_weights = config.train['class_weights'] + device = torch.device(config.general.device) # Create a torch device + batch_size_eval = config.general.batch_size_eval + pad_id = config.model.padding_idx + ignore_cpos = config.model.ignore_cpos + class_weights = config.train.class_weights if class_weights is not None: class_weights = torch.FloatTensor(class_weights).to(device) @@ -336,11 +347,25 @@ def full_annotation_df(self) -> pd.DataFrame: & (anns_df['irrelevant'] != True)] meta_df = meta_df.reset_index(drop=True) - for meta_model_card in self.cat.get_model_card(as_dict=True)['MetaCAT models']: - meta_model = meta_model_card['Category Name'] + for meta_model_category in self.cat.get_model_card(as_dict=True)['MetaCAT models']: + meta_model = meta_model_category print(f'Checking metacat model: {meta_model}') - _meta_model = MetaCAT.load(self.model_pack_path + '/meta_' + meta_model) - meta_results = self._eval(_meta_model, self.mct_export) + if self.is_legacy_model_pack: + _meta_model = get_meta_cat_from_old( + self.model_pack_path + '/meta_' + meta_model, self.cat._pipeline._tokenizer) + else: + meta_model_path = os.path.join( + self.model_pack_path, "saved_components", f"addon_meta_cat.{meta_model}") + # NOTE: the expected workflow when loading the model + # is one where the config is stored as part of the overall config + # and thus using it for loading is trivial + # but here we need to manually load the config from disk + config_path = os.path.join(meta_model_path, "meta_cat", "config") + cnf: ConfigMetaCAT = deserialise(config_path) # type: ignore + _meta_model = MetaCATAddon.load_existing( + cnf, self.cat._pipeline._tokenizer, meta_model_path) + meta_cat = _meta_model.mc + meta_results = self._eval(meta_cat, self.mct_export) _meta_values = {v: k for k, v in meta_results['meta_values'].items()} pred_meta_values = [] counter = 0 @@ -368,12 +393,23 @@ def meta_anns_concept_summary(self) -> pd.DataFrame: for cui in meta_df.cui.unique(): temp_meta_df = meta_df[meta_df['cui'] == cui] meta_task_results = {} - for meta_model_card in self.cat.get_model_card(as_dict=True)['MetaCAT models']: - meta_task = meta_model_card['Category Name'] + for meta_task in self.cat.get_model_card(as_dict=True)['MetaCAT models']: list_meta_anns = list(zip(temp_meta_df[meta_task], temp_meta_df['predict_' + meta_task])) counter_meta_anns = Counter(list_meta_anns) - meta_value_results: Dict[Tuple[Dict, str, str], Union[int, float]] = {} - for meta_value in meta_model_card['Classes'].keys(): + meta_value_results: Dict[Tuple[str, str, str], Union[int, float]] = {} + # TODO: maybe make this easier? + meta_cats: list[MetaCATAddon] = [ + addon for addon in + self.cat._pipeline.iter_addons() + if (isinstance(addon, MetaCATAddon) and + addon.config.comp_name == meta_task) + ] + if len(meta_cats) != 1: + raise ValueError( + f"Unable to uniquely identify meta task {meta_task}. " + f"Found {len(meta_cats)} options") + meta_cat = meta_cats[0] + for meta_value in meta_cat.config.general.category_value2id.keys(): total = 0 fp = 0 fn = 0 @@ -409,7 +445,7 @@ def meta_anns_concept_summary(self) -> pd.DataFrame: meta_anns_df['total_anns'] = meta_anns_df[col_lst].sum(axis=1) meta_anns_df = meta_anns_df.sort_values(by='total_anns', ascending=False) meta_anns_df = meta_anns_df.rename_axis('cui').reset_index(drop=False) - meta_anns_df.insert(1, 'concept_name', meta_anns_df['cui'].map(self.cat.cdb.cui2preferred_name)) + meta_anns_df.insert(1, 'concept_name', meta_anns_df['cui'].map(lambda cui: cast(CAT, self.cat).cdb.get_name(cui))) return meta_anns_df def generate_report(self, path: str = 'mct_report.xlsx', meta_ann=False, concept_filter: Optional[List] = None): diff --git a/requirements.txt b/requirements.txt index f476450..a7cc1f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -spacy>=3.6.0,<4.0 -medcat~=1.16.0 +spacy>=3.8.0,<4.0 +medcat[meta-cat,spacy,deid,rel-cat]~=2.0.0b4 plotly~=5.19.0 -eland==8.12.1 +# eland~=8.18.1 # NOTE: there is no numpy2-compatible eland release as of 2025-05-13 en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl ipyfilechooser jupyter_contrib_nbextensions diff --git a/tests/runner/custom_test_runner.py b/tests/runner/custom_test_runner.py new file mode 100644 index 0000000..fc71cc9 --- /dev/null +++ b/tests/runner/custom_test_runner.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +""" +Custom test runner that ensures the compatibility layer is loaded +before unittest discovers and runs tests. +""" +import sys +import os +import unittest +import importlib +import argparse + +# First, ensure medcat compatibility is set up +import medcat # noqa + + +# Now run the tests +if __name__ == '__main__': + # Parse arguments to mimic unittest discover behavior + parser = argparse.ArgumentParser(description='Run tests with compatibility layer') + parser.add_argument('-s', '--start-directory', default='tests', + help='Directory to start discovery (default: tests)') + parser.add_argument('-p', '--pattern', default='test*.py', + help='Pattern to match test files (default: test*.py)') + parser.add_argument('-t', '--top-level-directory', default=None, + help='Top level directory of project (default: None)') + parser.add_argument('--verbosity', '-v', type=int, default=2, + help='Verbosity level (default: 2)') + + args = parser.parse_args() + + # Ensure the start directory exists + if not os.path.isdir(args.start_directory): + print(f"Error: Start directory '{args.start_directory}' does not exist") + sys.exit(1) + + # Get all tests using the specified parameters + test_loader = unittest.TestLoader() + test_suite = test_loader.discover( + start_dir=args.start_directory, + pattern=args.pattern, + top_level_dir=args.top_level_directory + ) + + # Run the tests + test_runner = unittest.TextTestRunner(verbosity=args.verbosity) + result = test_runner.run(test_suite) + + # Return non-zero exit code if tests failed + sys.exit(not result.wasSuccessful()) diff --git a/tests/medcat/1_create_model/__init__.py b/tests/tmedcat/1_create_model/__init__.py similarity index 100% rename from tests/medcat/1_create_model/__init__.py rename to tests/tmedcat/1_create_model/__init__.py diff --git a/tests/medcat/1_create_model/create_cdb/__init__.py b/tests/tmedcat/1_create_model/create_cdb/__init__.py similarity index 100% rename from tests/medcat/1_create_model/create_cdb/__init__.py rename to tests/tmedcat/1_create_model/create_cdb/__init__.py diff --git a/tests/medcat/1_create_model/create_cdb/test_create_cdb.py b/tests/tmedcat/1_create_model/create_cdb/test_create_cdb.py similarity index 87% rename from tests/medcat/1_create_model/create_cdb/test_create_cdb.py rename to tests/tmedcat/1_create_model/create_cdb/test_create_cdb.py index cd734c7..b115f52 100644 --- a/tests/medcat/1_create_model/create_cdb/test_create_cdb.py +++ b/tests/tmedcat/1_create_model/create_cdb/test_create_cdb.py @@ -1,6 +1,9 @@ import os import sys +import shutil + import medcat.cdb +from medcat.storage.serialisers import deserialise _FILE_DIR = os.path.dirname(__file__) @@ -18,8 +21,8 @@ from unittest.mock import patch # SNOMED pre-cdb csv -PRE_CDB_CSV_PATH_SNOMED = os.path.join(_WWC_BASE_FOLDER, "tests", "medcat", "resources", "example_cdb_input_snomed.csv") -PRE_CDB_CSV_PATH_UMLS = os.path.join(_WWC_BASE_FOLDER, "tests", "medcat", "resources", "example_cdb_input_umls.csv") +PRE_CDB_CSV_PATH_SNOMED = os.path.join(_WWC_BASE_FOLDER, "tests", "tmedcat", "resources", "example_cdb_input_snomed.csv") +PRE_CDB_CSV_PATH_UMLS = os.path.join(_WWC_BASE_FOLDER, "tests", "tmedcat", "resources", "example_cdb_input_umls.csv") def get_mock_input(output: str): @@ -35,12 +38,12 @@ def setUp(self) -> None: def tearDown(self) -> None: if self.output_cdb is not None and os.path.exists(self.output_cdb): - os.remove(self.output_cdb) + shutil.rmtree(self.output_cdb) def assertHasCDB(self, path: str): self.assertTrue(os.path.exists(path)) self.assertTrue(path.endswith(".dat")) - cdb = medcat.cdb.CDB.load(path) + cdb: CDB = deserialise(path) self.assertIsInstance(cdb, medcat.cdb.CDB) def test_snomed_cdb_creation(self): diff --git a/tests/medcat/1_create_model/create_modelpack/__init__.py b/tests/tmedcat/1_create_model/create_modelpack/__init__.py similarity index 100% rename from tests/medcat/1_create_model/create_modelpack/__init__.py rename to tests/tmedcat/1_create_model/create_modelpack/__init__.py diff --git a/tests/medcat/1_create_model/create_modelpack/test_create_modelpack.py b/tests/tmedcat/1_create_model/create_modelpack/test_create_modelpack.py similarity index 87% rename from tests/medcat/1_create_model/create_modelpack/test_create_modelpack.py rename to tests/tmedcat/1_create_model/create_modelpack/test_create_modelpack.py index 6e789a1..89dae1a 100644 --- a/tests/medcat/1_create_model/create_modelpack/test_create_modelpack.py +++ b/tests/tmedcat/1_create_model/create_modelpack/test_create_modelpack.py @@ -21,7 +21,7 @@ import create_modelpack -RESOURCES_FOLDER = os.path.join(_WWC_BASE_FOLDER, "tests", "medcat", "resources") +RESOURCES_FOLDER = os.path.join(_WWC_BASE_FOLDER, "tests", "tmedcat", "resources") DEFAULT_CDB_PATH = os.path.join(RESOURCES_FOLDER, "cdb.dat") DEFAULT_VOCAB_PATH = os.path.join(RESOURCES_FOLDER, "vocab.dat") @@ -39,9 +39,10 @@ def tearDownClass(cls): cls.tempfolder.cleanup() def test_a(self): - model_pack_name = create_modelpack.load_cdb_and_save_modelpack( + model_pack_name_full = create_modelpack.load_cdb_and_save_modelpack( DEFAULT_CDB_PATH, self.model_pack_name, self.tempfolder.name, DEFAULT_VOCAB_PATH) + model_pack_name = os.path.basename(model_pack_name_full) self.assertTrue(model_pack_name.startswith(self.model_pack_name)) model_pack_path = os.path.join(self.tempfolder.name, model_pack_name) self.assertTrue(os.path.exists(model_pack_path)) diff --git a/tests/medcat/1_create_model/create_vocab/__init__.py b/tests/tmedcat/1_create_model/create_vocab/__init__.py similarity index 100% rename from tests/medcat/1_create_model/create_vocab/__init__.py rename to tests/tmedcat/1_create_model/create_vocab/__init__.py diff --git a/tests/medcat/1_create_model/create_vocab/test_create_vocab.py b/tests/tmedcat/1_create_model/create_vocab/test_create_vocab.py similarity index 81% rename from tests/medcat/1_create_model/create_vocab/test_create_vocab.py rename to tests/tmedcat/1_create_model/create_vocab/test_create_vocab.py index b2b358c..7fea449 100644 --- a/tests/medcat/1_create_model/create_vocab/test_create_vocab.py +++ b/tests/tmedcat/1_create_model/create_vocab/test_create_vocab.py @@ -1,7 +1,9 @@ import os import sys +import shutil import medcat.vocab +from medcat.storage.serialisers import deserialise _FILE_DIR = os.path.dirname(__file__) @@ -40,16 +42,19 @@ class CreateVocabTest(unittest.TestCase): def setUp(self) -> None: if os.path.exists(VOCAB_OUTPUT_PATH): - os.rename(VOCAB_OUTPUT_PATH, self.temp_vocab_path) + # NOTE: this is a folder in v2 + shutil.move(VOCAB_OUTPUT_PATH, self.temp_vocab_path) self.moved = True else: self.moved = False def tearDown(self) -> None: if os.path.exists(VOCAB_OUTPUT_PATH): - os.remove(VOCAB_OUTPUT_PATH) + # NOTE: this is a folder in v2 + shutil.rmtree(VOCAB_OUTPUT_PATH) if self.moved: - os.rename(self.temp_vocab_path, VOCAB_OUTPUT_PATH) + # NOTE: this is a folder in v2 + shutil.move(self.temp_vocab_path, VOCAB_OUTPUT_PATH) def test_creating_vocab(self): with patch('builtins.open', side_effect=custom_open): @@ -57,5 +62,5 @@ def test_creating_vocab(self): vocab_path = os.path.join(create_vocab.vocab_dir, "vocab.dat") self.assertEqual(os.path.abspath(vocab_path), VOCAB_OUTPUT_PATH) self.assertTrue(os.path.exists(vocab_path)) - vocab = medcat.vocab.Vocab.load(vocab_path) + vocab: medcat.vocab.Vocab = deserialise(vocab_path) self.assertIsInstance(vocab, medcat.vocab.Vocab) diff --git a/tests/medcat/2_train_model/1_unsupervised_training/__init__.py b/tests/tmedcat/2_train_model/1_unsupervised_training/__init__.py similarity index 100% rename from tests/medcat/2_train_model/1_unsupervised_training/__init__.py rename to tests/tmedcat/2_train_model/1_unsupervised_training/__init__.py diff --git a/tests/medcat/2_train_model/1_unsupervised_training/test_splitter.py b/tests/tmedcat/2_train_model/1_unsupervised_training/test_splitter.py similarity index 95% rename from tests/medcat/2_train_model/1_unsupervised_training/test_splitter.py rename to tests/tmedcat/2_train_model/1_unsupervised_training/test_splitter.py index f336674..73dae89 100644 --- a/tests/medcat/2_train_model/1_unsupervised_training/test_splitter.py +++ b/tests/tmedcat/2_train_model/1_unsupervised_training/test_splitter.py @@ -19,7 +19,7 @@ import splitter -FILE_TO_SPLIT = os.path.join(_WWC_BASE_FOLDER, "tests", "medcat", "resources", "example_file_to_split.csv") +FILE_TO_SPLIT = os.path.join(_WWC_BASE_FOLDER, "tests", "tmedcat", "resources", "example_file_to_split.csv") NR_OF_LINES_IN_FILE = 125 NR_OF_COLUMNS_IN_FILE = 20 diff --git a/tests/medcat/2_train_model/__init__.py b/tests/tmedcat/2_train_model/__init__.py similarity index 100% rename from tests/medcat/2_train_model/__init__.py rename to tests/tmedcat/2_train_model/__init__.py diff --git a/tests/medcat/__init__.py b/tests/tmedcat/__init__.py similarity index 100% rename from tests/medcat/__init__.py rename to tests/tmedcat/__init__.py diff --git a/tests/medcat/evaluate_mct_export/__init__.py b/tests/tmedcat/evaluate_mct_export/__init__.py similarity index 100% rename from tests/medcat/evaluate_mct_export/__init__.py rename to tests/tmedcat/evaluate_mct_export/__init__.py diff --git a/tests/medcat/evaluate_mct_export/offline_test_mct_analysis.py b/tests/tmedcat/evaluate_mct_export/offline_test_mct_analysis.py similarity index 100% rename from tests/medcat/evaluate_mct_export/offline_test_mct_analysis.py rename to tests/tmedcat/evaluate_mct_export/offline_test_mct_analysis.py diff --git a/tests/medcat/evaluate_mct_export/test_mct_analysis.py b/tests/tmedcat/evaluate_mct_export/test_mct_analysis.py similarity index 100% rename from tests/medcat/evaluate_mct_export/test_mct_analysis.py rename to tests/tmedcat/evaluate_mct_export/test_mct_analysis.py diff --git a/tests/medcat/resources/MCT_export_example.json b/tests/tmedcat/resources/MCT_export_example.json similarity index 100% rename from tests/medcat/resources/MCT_export_example.json rename to tests/tmedcat/resources/MCT_export_example.json diff --git a/tests/medcat/resources/cdb.dat b/tests/tmedcat/resources/cdb.dat similarity index 100% rename from tests/medcat/resources/cdb.dat rename to tests/tmedcat/resources/cdb.dat diff --git a/tests/medcat/resources/example_cdb_input_snomed.csv b/tests/tmedcat/resources/example_cdb_input_snomed.csv similarity index 100% rename from tests/medcat/resources/example_cdb_input_snomed.csv rename to tests/tmedcat/resources/example_cdb_input_snomed.csv diff --git a/tests/medcat/resources/example_cdb_input_umls.csv b/tests/tmedcat/resources/example_cdb_input_umls.csv similarity index 100% rename from tests/medcat/resources/example_cdb_input_umls.csv rename to tests/tmedcat/resources/example_cdb_input_umls.csv diff --git a/tests/medcat/resources/example_file_to_split.csv b/tests/tmedcat/resources/example_file_to_split.csv similarity index 100% rename from tests/medcat/resources/example_file_to_split.csv rename to tests/tmedcat/resources/example_file_to_split.csv diff --git a/tests/medcat/resources/vocab.dat b/tests/tmedcat/resources/vocab.dat similarity index 100% rename from tests/medcat/resources/vocab.dat rename to tests/tmedcat/resources/vocab.dat