|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
| 14 | +import torch |
15 | 15 | from collections import defaultdict |
16 | 16 | from enum import Enum |
17 | 17 | from typing import Annotated, Any, Dict, List, Optional, Set, Union |
@@ -279,5 +279,27 @@ def requires_calibration_data(self): |
279 | 279 |
|
280 | 280 | return False |
281 | 281 |
|
| 282 | + def model_dump(self, *args, **kwargs): |
| 283 | + # Call the parent dump first |
| 284 | + data = super().model_dump(*args, **kwargs) |
| 285 | + |
| 286 | + # Convert any torch.dtype to string |
| 287 | + schemes = ["config_groups", "kv_cache_scheme"] |
| 288 | + for scheme in schemes: |
| 289 | + if data.get(scheme) is not None: |
| 290 | + for _, v in data[scheme].items(): |
| 291 | + weight = v.get("weights") |
| 292 | + input = v.get("input_activations") |
| 293 | + output = v.get("output_activations") |
| 294 | + |
| 295 | + args = [weight, input, output] |
| 296 | + for arg in args: |
| 297 | + for key, value in arg.items(): |
| 298 | + if isinstance(value, torch.dtype): |
| 299 | + data[key] = str(value).replace("torch.", "") |
| 300 | + |
| 301 | + breakpoint() |
| 302 | + return data |
| 303 | + |
282 | 304 | # TODO set `extra="forbid"` when upstream transformers is compatible |
283 | 305 | model_config = ConfigDict(extra="ignore") |
0 commit comments