Skip to content

Linting and Formatting: add a couple more Ruff checks #1023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions code_generation/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def render_dataset_class_maps(self, template_path: Path, data_path: Path, output
# create list
all_map = {}
for dataset in dataset_meta_data:
if dataset.is_template:
prefixes = ["sym_", "asym_"]
else:
prefixes = [""]
prefixes = ["sym_", "asym_"] if dataset.is_template else [""]
for prefix in prefixes:
all_components = {}
for component in dataset.components:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _MetaEnum(EnumMeta):
Returns:
bool: True if the member is part of the Enum, False otherwise.
"""
return member in cls.__members__.keys()
return member in cls.__members__


class DatasetType(str, Enum, metaclass=_MetaEnum):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ select = [
"FURB",
"FLY",
"SLOT",
"NPY",
]
ignore = ["SIM108", "SIM118", "SIM110", "SIM211"]

[tool.ruff.lint.isort]
# Imports that are imported using keyword "as" and are from the same source - are combined.
Expand Down
10 changes: 2 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def get_tag(self):
class MyBuildExt(build_ext):
def build_extensions(self):
if not if_win:
if "CXX" in os.environ:
cxx = os.environ["CXX"]
else:
cxx = self.compiler.compiler_cxx[0]
cxx = os.environ["CXX"] if "CXX" in os.environ else self.compiler.compiler_cxx[0]
# check setuptools has an update change in the version 72.2 about cxx compiler options
# to be compatible with both version, we check if compiler_so_cxx exists
if not hasattr(self.compiler, "compiler_so_cxx"):
Expand All @@ -93,10 +90,7 @@ def build_extensions(self):
linker_so_cxx[0] = cxx
self.compiler.compiler_cxx = [cxx]
# add link time optimization
if "clang" in cxx:
lto_flag = "-flto=thin"
else:
lto_flag = "-flto"
lto_flag = "-flto=thin" if "clang" in cxx else "-flto"
compiler_so_cxx += [lto_flag]
linker_so_cxx += [lto_flag]
# remove debug and optimization flags
Expand Down
2 changes: 1 addition & 1 deletion src/power_grid_model/_core/dataset_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __contains__(cls, member):
Returns:
bool: True if the member is part of the Enum, False otherwise.
"""
return member in cls.__members__.keys()
return member in cls.__members__


class DatasetType(str, Enum, metaclass=_MetaEnum):
Expand Down
10 changes: 2 additions & 8 deletions src/power_grid_model/_core/power_grid_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ def _load_core() -> CDLL:

"""
# first try to find the DLL local
if platform.system() == "Windows":
dll_file = "_power_grid_core.dll"
else:
dll_file = "_power_grid_core.so"
dll_file = "_power_grid_core.dll" if platform.system() == "Windows" else "_power_grid_core.so"
dll_path = Path(__file__).parent / dll_file

# if local DLL is not found, try to find the DLL from conda environment
Expand Down Expand Up @@ -192,10 +189,7 @@ def make_c_binding(func: Callable):

# binding function
def cbind_func(self, *args, **kwargs):
if "destroy" in name:
c_inputs = []
else:
c_inputs = [self._handle]
c_inputs = [] if "destroy" in name else [self._handle]
args = chain(args, (kwargs[key] for key in py_argnames[len(args) :]))
for arg in args:
if isinstance(arg, str):
Expand Down
5 changes: 1 addition & 4 deletions src/power_grid_model/_core/power_grid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,7 @@ def _get_output_component_count(self, calculation_type: CalculationType):
}.get(calculation_type, [])

def include_type(component_type: ComponentType):
for exclude_type in exclude_types:
if exclude_type.value in component_type.value:
return False
return True
return all(exclude_type.value not in component_type.value for exclude_type in exclude_types)

return {ComponentType[k]: v for k, v in self.all_component_count.items() if include_type(k)}

Expand Down
12 changes: 4 additions & 8 deletions src/power_grid_model/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,8 @@ def _extract_columnar_data(
"""
not_columnar_data_message = "Expected columnar data"

if is_batch is not None:
allowed_dims = [2, 3] if is_batch else [1, 2]
else:
allowed_dims = [1, 2, 3]
if_is_batch = [2, 3] if is_batch else [1, 2]
allowed_dims = (if_is_batch) if is_batch is not None else [1, 2, 3]

sub_data = data["data"] if is_sparse(data) else data

Expand Down Expand Up @@ -683,10 +681,8 @@ def _extract_row_based_data(
Returns:
SingleArray | DenseBatchArray: the contents of row based data
"""
if is_batch is not None:
allowed_dims = [2] if is_batch else [1]
else:
allowed_dims = [1, 2]
if_is_batch = [2] if is_batch else [1]
allowed_dims = if_is_batch if is_batch is not None else [1, 2]

sub_data = data["data"] if is_sparse(data) else data

Expand Down
5 changes: 1 addition & 4 deletions src/power_grid_model/validation/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,10 +892,7 @@ def none_missing(data: SingleDataset, component: ComponentType, fields: str | li
fields = [fields]
for field in fields:
nan = _nan_type(component, field)
if np.isnan(nan):
invalid = np.isnan(data[component][field])
else:
invalid = np.equal(data[component][field], nan)
invalid = np.isnan(data[component][field]) if np.isnan(nan) else np.equal(data[component][field], nan)

if invalid.any():
# handle both symmetric and asymmetric values
Expand Down
7 changes: 2 additions & 5 deletions src/power_grid_model/validation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _update_input_data(input_data: SingleDataset, update_data: SingleDataset):
"""

merged_data = {component: array.copy() for component, array in input_data.items()}
for component in update_data.keys():
for component in update_data:
_update_component_data(component, merged_data[component], update_data[component])
return merged_data

Expand Down Expand Up @@ -140,10 +140,7 @@ def _update_component_array_data(
if field == "id":
continue
nan = _nan_type(component, field, DatasetType.update)
if np.isnan(nan):
mask = ~np.isnan(update_data[field])
else:
mask = np.not_equal(update_data[field], nan)
mask = ~np.isnan(update_data[field]) if np.isnan(nan) else np.not_equal(update_data[field], nan)

if mask.ndim == 2:
for phase in range(mask.shape[1]):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_0Z_model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def test_single_validation(
# test get indexer
for component_name, input_array in case_data["input"].items():
ids_array = input_array["id"].copy()
np.random.shuffle(ids_array)
rng = np.random.default_rng(3)
rng.shuffle(ids_array)
indexer_array = model.get_indexer(component_name, ids_array)
# check
assert np.all(input_array["id"][indexer_array] == ids_array)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_create_output_data(output_component_types, expected_fns, batch_size):
output_component_types=output_component_types,
output_type=DT.sym_output,
all_component_count=all_component_count,
is_batch=False if batch_size == 1 else True,
is_batch=batch_size != 1,
batch_size=batch_size,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def row_data(request):
def compare_row_data(actual_row_data, desired_row_data):
assert actual_row_data.keys() == desired_row_data.keys()

for comp_name in actual_row_data.keys():
for comp_name in actual_row_data:
actual_component = actual_row_data[comp_name]
desired_component = desired_row_data[comp_name]
if is_sparse(actual_component):
Expand Down
10 changes: 2 additions & 8 deletions tests/unit/test_meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,9 @@ def test_sensor_meta_data():
assert "id" in attr_names
# check specific attributes
if "voltage" in sensor:
if "output" in meta_type:
expected_attrs = output_voltage
else:
expected_attrs = input_voltage
expected_attrs = output_voltage if "output" in meta_type else input_voltage
else:
if "output" in meta_type:
expected_attrs = output_power
else:
expected_attrs = input_power
expected_attrs = output_power if "output" in meta_type else input_power

for name in expected_attrs:
assert name in attr_names
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def assert_individual_data_entry(serialized_dataset, data_filter, component, ser
if is_attribute_filtered_out(data_filter, component, attr):
assert attr not in deserialized_output
continue
assert attr in deserialized_output.keys()
assert attr in deserialized_output
assert_almost_equal(
deserialized_output[attr][comp_idx],
serialized_input[comp_idx][attr],
Expand All @@ -530,7 +530,7 @@ def assert_individual_data_entry(serialized_dataset, data_filter, component, ser
if is_attribute_filtered_out(data_filter, component, attr):
assert attr not in deserialized_output
continue
assert attr in deserialized_output.keys()
assert attr in deserialized_output
assert_almost_equal(
deserialized_output[attr][comp_idx],
serialized_input[comp_idx][attr_idx],
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ def _add_cases(case_dir: Path, calculation_type: str, **kwargs):


def pytest_cases(get_batch_cases: bool = False, data_dir: str | None = None, test_cases: list[str] | None = None):
if data_dir is not None:
relevant_calculations = [data_dir]
else:
relevant_calculations = ["power_flow", "state_estimation", "short_circuit"]
relevant_calculations = [data_dir] if data_dir is not None else ["power_flow", "state_estimation", "short_circuit"]

for calculation_type in relevant_calculations:
test_case_paths = get_test_case_paths(calculation_type=calculation_type, test_cases=test_cases)
Expand Down
Loading