Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vizier/_src/pyvizier/shared/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class Metric:
def _std_not_negative(self, _, stddev: Optional[float]) -> bool:
if (stddev is not None) and (not stddev >= 0):
raise ValueError(
'Standard deviation must be a non-negative finite number.')
'Standard deviation must be a non-negative finite number.'
) # pytype: disable=bad-return-type

value: float = attr.ib(
converter=float,
Expand Down Expand Up @@ -146,11 +147,11 @@ def cast_as_internal(self,
internal_type.assert_correct_type(self.value)

if internal_type in (ParameterType.DOUBLE, ParameterType.DISCRETE):
return self.as_float
return self.as_float # pytype: disable=bad-return-type
elif internal_type == ParameterType.INTEGER:
return self.as_int
return self.as_int # pytype: disable=bad-return-type
elif internal_type == ParameterType.CATEGORICAL:
return self.as_str
return self.as_str # pytype: disable=bad-return-type
else:
raise RuntimeError(f'Unknown type {internal_type}')

Expand All @@ -175,11 +176,11 @@ def cast(
if external_type == ExternalType.INTERNAL:
return self.value
elif external_type == ExternalType.BOOLEAN:
return self.as_bool
return self.as_bool # pytype: disable=bad-return-type
elif external_type == ExternalType.INTEGER:
return self.as_int
return self.as_int # pytype: disable=bad-return-type
elif external_type == ExternalType.FLOAT:
return self.as_float
return self.as_float # pytype: disable=bad-return-type
else:
raise ValueError(
'Unknown external type enum value: {}.'.format(external_type))
Expand Down Expand Up @@ -361,7 +362,7 @@ class ParameterDict(abc.MutableMapping):

def as_dict(self) -> Dict[str, ParameterValueTypes]:
"""Returns the dict of parameter names to raw values."""
return {k: self.get_value(k) for k in self._items}
return {k: self.get_value(k) for k in self._items} # pytype: disable=bad-return-type

def __init__(self, iterable: Any = tuple(), **kwargs):
self.__attrs_init__()
Expand Down
Loading