diff --git a/bertviz/transformers_neuron_view/modeling_utils.py b/bertviz/transformers_neuron_view/modeling_utils.py index 6c45a0c..e5e72d3 100644 --- a/bertviz/transformers_neuron_view/modeling_utils.py +++ b/bertviz/transformers_neuron_view/modeling_utils.py @@ -137,11 +137,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') - config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) - assert config.output_attention == True - config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + config = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False) + assert config.output_attentions == True + config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True) - assert config.output_attention == True + assert config.output_attentions == True assert unused_kwargs == {'foo': False} """ @@ -405,7 +405,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: - Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. @@ -414,8 +414,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` - model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading - assert model.config.output_attention == True + model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading + assert model.config.output_attentions == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)