Skip to content

Commit 4105c61

Browse files
committed
Include usage of filter_size for importing logs, user, computer
Building the query for the matching nodes from the backend can for surpass the sqlite limit of allowed query parameters and thus result in failure when importing the archive. In this PR we impement the usage of the `filter_size` which limits the number of parameters in the query for importing logs, user and computers.
1 parent cf07e9f commit 4105c61

File tree

4 files changed

+130
-20
lines changed

4 files changed

+130
-20
lines changed

src/aiida/tools/archive/imports.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,18 @@ def _import_users(
262262
# get matching emails from the backend
263263
output_email_id: Dict[str, int] = {}
264264
if input_id_email:
265-
output_email_id = dict(
266-
orm.QueryBuilder(backend=backend_to)
267-
.append(orm.User, filters={'email': {'in': list(input_id_email.values())}}, project=['email', 'id'])
268-
.all(batch_size=query_params.batch_size)
269-
)
265+
output_email_id = {
266+
key: value
267+
for query_results in [
268+
dict(
269+
orm.QueryBuilder(backend=backend_to)
270+
.append(orm.User, filters={'email': {'in': chunk}}, project=['email', 'id'])
271+
.all(batch_size=query_params.batch_size)
272+
)
273+
for _, chunk in batch_iter(set(input_id_email.values()), query_params.filter_size)
274+
]
275+
for key, value in query_results.items()
276+
}
270277

271278
new_users = len(input_id_email) - len(output_email_id)
272279
existing_users = len(output_email_id)
@@ -300,11 +307,18 @@ def _import_computers(
300307
# get matching uuids from the backend
301308
backend_uuid_id: Dict[str, int] = {}
302309
if input_id_uuid:
303-
backend_uuid_id = dict(
304-
orm.QueryBuilder(backend=backend_to)
305-
.append(orm.Computer, filters={'uuid': {'in': list(input_id_uuid.values())}}, project=['uuid', 'id'])
306-
.all(batch_size=query_params.batch_size)
307-
)
310+
backend_uuid_id = {
311+
key: value
312+
for query_results in [
313+
dict(
314+
orm.QueryBuilder(backend=backend_to)
315+
.append(orm.Computer, filters={'uuid': {'in': chunk}}, project=['uuid', 'id'])
316+
.all(batch_size=query_params.batch_size)
317+
)
318+
for _, chunk in batch_iter(set(input_id_uuid.values()), query_params.filter_size)
319+
]
320+
for key, value in query_results.items()
321+
}
308322

309323
new_computers = len(input_id_uuid) - len(backend_uuid_id)
310324
existing_computers = len(backend_uuid_id)
@@ -460,17 +474,20 @@ def _import_nodes(
460474

461475
# get matching uuids from the backend
462476
backend_uuid_id: Dict[str, int] = {}
463-
input_id_uuid_uuids = list(input_id_uuid.values())
464477

465478
if input_id_uuid:
466-
for _, batch in batch_iter(input_id_uuid_uuids, query_params.filter_size):
467-
backend_uuid_id.update(
479+
backend_uuid_id = {
480+
key: value
481+
for query_results in [
468482
dict(
469483
orm.QueryBuilder(backend=backend_to)
470-
.append(orm.Node, filters={'uuid': {'in': batch}}, project=['uuid', 'id'])
484+
.append(orm.Node, filters={'uuid': {'in': chunk}}, project=['uuid', 'id'])
471485
.all(batch_size=query_params.batch_size)
472486
)
473-
)
487+
for _, chunk in batch_iter(set(input_id_uuid.values()), query_params.filter_size)
488+
]
489+
for key, value in query_results.items()
490+
}
474491

475492
new_nodes = len(input_id_uuid) - len(backend_uuid_id)
476493

@@ -544,12 +561,20 @@ def _import_logs(
544561

545562
# get matching uuids from the backend
546563
backend_uuid_id: Dict[str, int] = {}
564+
547565
if input_id_uuid:
548-
backend_uuid_id = dict(
549-
orm.QueryBuilder(backend=backend_to)
550-
.append(orm.Log, filters={'uuid': {'in': list(input_id_uuid.values())}}, project=['uuid', 'id'])
551-
.all(batch_size=query_params.batch_size)
552-
)
566+
backend_uuid_id = {
567+
key: value
568+
for query_results in [
569+
dict(
570+
orm.QueryBuilder(backend=backend_to)
571+
.append(orm.Log, filters={'uuid': {'in': chunk}}, project=['uuid', 'id'])
572+
.all(batch_size=query_params.batch_size)
573+
)
574+
for _, chunk in batch_iter(set(input_id_uuid.values()), query_params.filter_size)
575+
]
576+
for key, value in query_results.items()
577+
}
553578

554579
new_logs = len(input_id_uuid) - len(backend_uuid_id)
555580
existing_logs = len(backend_uuid_id)

tests/tools/archive/orm/test_computers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,36 @@ def test_import_of_django_sqla_export_file(aiida_localhost, backend):
328328
res = builder.dict()[0]
329329

330330
assert res['comp']['metadata'] == comp1_metadata
331+
332+
333+
def test_filter_size(tmp_path, aiida_profile_clean):
334+
"""Tests if the query still works when the number of computer is beyond the `filter_size limit."""
335+
nb_nodes = 5
336+
nodes = []
337+
for i in range(nb_nodes):
338+
node = orm.CalcJobNode()
339+
node.computer = orm.Computer(
340+
label=f'{i}', hostname='localhost', transport_type='core.local', scheduler_type='core.direct'
341+
).store()
342+
node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1})
343+
node.label = f'{i}'
344+
node.store()
345+
node.seal()
346+
nodes.append(node)
347+
348+
builder = orm.QueryBuilder().append(orm.Computer, project=['uuid', 'label'])
349+
builder = builder.all()
350+
351+
# Export DB
352+
export_file_existing = tmp_path.joinpath('export.aiida')
353+
create_archive(nodes, filename=export_file_existing)
354+
355+
# Clean database and reimport DB
356+
aiida_profile_clean.reset_storage()
357+
import_archive(export_file_existing, filter_size=2)
358+
359+
# Check correct import
360+
builder = orm.QueryBuilder().append(orm.Computer, project=['uuid', 'label'])
361+
builder = builder.all()
362+
363+
assert len(builder) == nb_nodes

tests/tools/archive/orm/test_logs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,28 @@ def test_reimport_of_logs_for_single_node(tmp_path, aiida_profile_clean):
378378
log_message = str(log[1])
379379
assert log_uuid in total_log_uuids
380380
assert log_message in log_msgs
381+
382+
383+
def test_filter_size(tmp_path, aiida_profile_clean):
384+
"""Tests if the query still works when the number of logs is beyond the `filter_size limit."""
385+
node = orm.CalculationNode().store()
386+
node.seal()
387+
388+
nb_nodes = 5
389+
for _ in range(nb_nodes):
390+
node.logger.critical('some')
391+
392+
# Export DB
393+
export_file_existing = tmp_path.joinpath('export.aiida')
394+
create_archive([node], filename=export_file_existing)
395+
396+
# Clean database and reimport DB
397+
aiida_profile_clean.reset_storage()
398+
import_archive(export_file_existing, filter_size=2)
399+
400+
# Check correct import
401+
builder = orm.QueryBuilder().append(orm.Node, tag='node', project=['uuid'])
402+
builder.append(orm.Log, with_node='node', project=['uuid'])
403+
builder = builder.all()
404+
405+
assert len(builder) == nb_nodes

tests/tools/archive/orm/test_users.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,30 @@ def test_non_default_user_nodes(aiida_profile_clean, tmp_path, aiida_localhost_f
161161
assert orm.load_node(uuid).user.email == new_email
162162
for uuid in uuids2:
163163
assert orm.load_node(uuid).user.email == manager.get_profile().default_user_email
164+
165+
166+
def test_filter_size(tmp_path, aiida_profile_clean):
167+
"""Tests if the query still works when the number of users is beyond the `filter_size limit."""
168+
nb_nodes = 5
169+
nodes = []
170+
# We need to attach a node to the user otherwise it is not exported
171+
for i in range(nb_nodes):
172+
node = orm.Int(5, user=orm.User(email=f'{i}').store())
173+
node.store()
174+
nodes.append(node)
175+
176+
# Export DB
177+
export_file_existing = tmp_path.joinpath('export.aiida')
178+
create_archive(nodes, filename=export_file_existing)
179+
180+
# Clean database and reimport DB
181+
aiida_profile_clean.reset_storage()
182+
import_archive(export_file_existing, filter_size=2)
183+
184+
# Check correct import
185+
builder = orm.QueryBuilder().append(orm.User, project=['id'])
186+
builder = builder.all()
187+
188+
# We need to add one because default profile is added by reset_storage
189+
# automatically
190+
assert len(builder) == nb_nodes + 1

0 commit comments

Comments
 (0)