diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py index 9d4f90a..09da233 100644 --- a/InstructorEmbedding/instructor.py +++ b/InstructorEmbedding/instructor.py @@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str): return batch -class InstructorPooling(nn.Module): +class INSTRUCTOR_Pooling(nn.Module): """Performs pooling (max or mean) on the token embeddings. Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. @@ -245,7 +245,7 @@ def load(input_path): ) as config_file: config = json.load(config_file) - return InstructorPooling(**config) + return INSTRUCTOR_Pooling(**config) def import_from_string(dotted_path): @@ -271,7 +271,7 @@ def import_from_string(dotted_path): raise ImportError(msg) -class InstructorTransformer(Transformer): +class INSTRUCTORTransformer(Transformer): def __init__( self, model_name_or_path: str, @@ -378,7 +378,7 @@ def load(input_path: str): with open(sbert_config_path, encoding="UTF-8") as config_file: config = json.load(config_file) - return InstructorTransformer(model_name_or_path=input_path, **config) + return INSTRUCTORTransformer(model_name_or_path=input_path, **config) def tokenize(self, texts): """ @@ -395,7 +395,7 @@ def tokenize(self, texts): input_features = self.tokenizer( *to_tokenize, - padding="max_length", + padding=True, truncation="longest_first", return_tensors="pt", max_length=self.max_seq_length, @@ -420,7 +420,7 @@ def tokenize(self, texts): input_features = self.tokenize(instruction_prepended_input_texts) instruction_features = self.tokenize(instructions) - input_features = Instructor.prepare_input_features( + input_features = INSTRUCTOR.prepare_input_features( input_features, instruction_features ) else: @@ -430,7 +430,7 @@ def tokenize(self, texts): return output -class Instructor(SentenceTransformer): +class INSTRUCTOR(SentenceTransformer): @staticmethod def prepare_input_features( input_features, instruction_features, return_data_type: str = "pt" @@ -510,27 +510,39 @@ def smart_batching_collate(self, batch): input_features = self.tokenize(instruction_prepended_input_texts) instruction_features = self.tokenize(instructions) - input_features = Instructor.prepare_input_features( + input_features = INSTRUCTOR.prepare_input_features( input_features, instruction_features ) batched_input_features.append(input_features) return batched_input_features, labels - def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False): + def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False, local_files_only=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None): """ Loads a full sentence-transformers model """ - # Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544 - download_kwargs = { - "repo_id": model_path, - "revision": revision, - "library_name": "sentence-transformers", - "token": token, - "cache_dir": cache_folder, - "tqdm_class": disabled_tqdm, - } - model_path = snapshot_download(**download_kwargs) + # copied from https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L559 + # because we need to get files outside of the allow_patterns too + # If file is local + if os.path.isdir(model_path): + model_path = str(model_path) + else: + # If model_path is a Hugging Face repository ID, download the model + download_kwargs = { + "repo_id": model_path, + "revision": revision, + "library_name": "InstructorEmbedding", + "token": token, + "cache_dir": cache_folder, + "tqdm_class": disabled_tqdm, + } + # Try to download from the remote + try: + model_path = snapshot_download(**download_kwargs) + except Exception: + # Otherwise, try local (i.e. cache) only + download_kwargs["local_files_only"] = True + model_path = snapshot_download(**download_kwargs) # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) config_sentence_transformers_json_path = os.path.join( @@ -559,9 +571,9 @@ def _load_sbert_model(self, model_path, token = None, cache_folder = None, revis modules = OrderedDict() for module_config in modules_config: if module_config["idx"] == 0: - module_class = InstructorTransformer + module_class = INSTRUCTORTransformer elif module_config["idx"] == 1: - module_class = InstructorPooling + module_class = INSTRUCTOR_Pooling else: module_class = import_from_string(module_config["type"]) module = module_class.load(os.path.join(model_path, module_config["path"])) @@ -619,7 +631,7 @@ def encode( input_was_string = True if device is None: - device = self._target_device + device = self.device self.to(device) diff --git a/requirements.txt b/requirements.txt index a7fe466..410b6c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ numpy requests>=2.26.0 scikit_learn>=1.0.2 scipy -sentence_transformers>=2.2.0 +sentence_transformers>=2.3.0 torch tqdm rich diff --git a/setup.py b/setup.py index 311e732..6f17fb0 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='InstructorEmbedding', packages=['InstructorEmbedding'], - version='1.0.1', + version='1.0.2', license='Apache License 2.0', description='Text embedding tool', long_description=readme,