Skip to content

Commit 0794f34

Browse files
Format: ruff format
1 parent d0ecd51 commit 0794f34

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

src/transformers/models/interns1/modular_interns1.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@
2929
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
3030
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3131
from ...processing_utils import Unpack
32-
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int, ModelOutput, TransformersKwargs
32+
from ...utils import (
33+
auto_docstring,
34+
can_return_tuple,
35+
is_torchdynamo_compiling,
36+
logging,
37+
torch_int,
38+
ModelOutput,
39+
TransformersKwargs,
40+
)
3341
from ..clip.modeling_clip import CLIPMLP
3442
from ..janus.modeling_janus import JanusVisionAttention
3543
from ..llama.modeling_llama import LlamaRMSNorm
@@ -411,7 +419,6 @@ def forward(
411419

412420
@auto_docstring
413421
class InternS1VisionModel(InternS1VisionPreTrainedModel):
414-
415422
def __init__(self, config: InternS1VisionConfig) -> None:
416423
super().__init__(config)
417424
self.config = config
@@ -545,6 +552,7 @@ class InternS1ModelOutputWithPast(ModelOutput):
545552

546553
class InternS1Model(LlavaModel):
547554
_checkpoint_conversion_mapping = {}
555+
548556
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
549557
"""Perform pixel shuffle downsampling on vision features.
550558

src/transformers/models/interns1/tokenization_interns1.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class InternS1CheckModuleMixin(ABC):
6161
6262
Note that short strings are ignored by this module.
6363
"""
64+
6465
def __init__(self, *, min_length: int):
6566
self.min_length = min_length
6667
self.REGEX = self._build_regex()
@@ -123,6 +124,7 @@ class FastaCheckModule(InternS1CheckModuleMixin):
123124
124125
Automatically detects protein sequence using regex patterns.
125126
"""
127+
126128
def __init__(self, *, min_length: int = 27):
127129
super().__init__(min_length=min_length)
128130
self.auto_detect_token = ["<FASTA_AUTO_DETECT>", "</FASTA_AUTO_DETECT>"]
@@ -135,6 +137,7 @@ def check_legitimacy(self, candidate: str):
135137
return True
136138

137139

140+
# fmt: off
138141
bonds = ["-", "=", "#", ":", "/", "\\", ".", "$"]
139142
organic_symbols = ["B", "C", "N", "O", "P", "S", "F", "Cl", "Br", "I"]
140143
other_allows = bonds + ["[", "]", "(", ")", ";"]
@@ -153,6 +156,7 @@ def check_legitimacy(self, candidate: str):
153156
"Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
154157
"Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
155158
]
159+
# fmt: on
156160

157161

158162
class SmilesCheckModule(InternS1CheckModuleMixin):
@@ -163,13 +167,15 @@ class SmilesCheckModule(InternS1CheckModuleMixin):
163167
or chemical syntax rules. Uses RDKit for precise validation when available,
164168
otherwise falls back to rule-based validation.
165169
"""
170+
166171
def __init__(self, *, min_length: int = 10):
167172
super().__init__(min_length=min_length)
168173
self.auto_detect_token = ["<SMILES_AUTO_DETECT>", "</SMILES_AUTO_DETECT>"]
169-
self._SQ_BRACKET_BAN_1 = re.compile(r'(?:[A-GI-Z]|[a-z]){3,}')
170-
self._SQ_BRACKET_BAN_2 = re.compile(r'\d{4,}')
174+
self._SQ_BRACKET_BAN_1 = re.compile(r"(?:[A-GI-Z]|[a-z]){3,}")
175+
self._SQ_BRACKET_BAN_2 = re.compile(r"\d{4,}")
171176

172177
def _build_regex(self):
178+
# fmt: off
173179
_two_letter_elements = [
174180
'Ac', 'Ag', 'Al', 'Am', 'Ar', 'As', 'At', 'Au', 'Ba', 'Be', 'Bh', 'Bi', 'Bk', 'Br', 'Ca', 'Cd',
175181
'Ce', 'Cf', 'Cl', 'Cm', 'Cn', 'Co', 'Cr', 'Cs', 'Cu', 'Db', 'Ds', 'Dy', 'Er', 'Es', 'Eu', 'Fe',
@@ -182,6 +188,7 @@ def _build_regex(self):
182188
_single_letter_elements = [
183189
"B", "C", "F", "H", "I", "K", "N", "O", "P", "S", "U", "V", "W", "Y", 'b', 'c', 'n', 'o', 'p', 's'
184190
]
191+
# fmt: on
185192
all_elements_sorted = sorted(_two_letter_elements + _single_letter_elements, key=lambda x: (-len(x), x))
186193
elements_pattern_str = "|".join(all_elements_sorted)
187194

@@ -263,17 +270,17 @@ def check_rings_and_brackets(self, text):
263270
left_sq_bracket += 1
264271
if left_sq_bracket > right_sq_bracket + 1:
265272
return False
266-
if pos == len(text)-1:
273+
if pos == len(text) - 1:
267274
return False
268-
if ']' not in text[pos+1:]:
275+
if "]" not in text[pos + 1 :]:
269276
return False
270-
bracket_span = text[pos+1:text.find(']')]
277+
bracket_span = text[pos + 1 : text.find("]")]
271278

272279
if self._SQ_BRACKET_BAN_1.search(bracket_span) or self._SQ_BRACKET_BAN_2.search(bracket_span):
273280
return False
274281

275-
matches = re.findall(r'\d+', bracket_span)
276-
if len(matches)>2:
282+
matches = re.findall(r"\d+", bracket_span)
283+
if len(matches) > 2:
277284
return False
278285
if c == "]":
279286
step = 1
@@ -477,7 +484,9 @@ def __init__(
477484
for token in self.protect_end_sp_tokens:
478485
self.tokens_trie.add(token)
479486

480-
self.new_sp_token_offset.append(len(self._added_tokens_decoder) - sum(self.new_sp_token_offset) + len(self._extra_special_tokens))
487+
self.new_sp_token_offset.append(
488+
len(self._added_tokens_decoder) - sum(self.new_sp_token_offset) + len(self._extra_special_tokens)
489+
)
481490
self.check_module_list = [SmilesCheckModule(), FastaCheckModule()]
482491

483492
@property

0 commit comments

Comments
 (0)