Skip to content

Commit ef27874

Browse files
authored
feat(relay): Allow to customize max_results per connection in relay (#3746)
1 parent 1e0e1ef commit ef27874

File tree

6 files changed

+122
-3
lines changed

6 files changed

+122
-3
lines changed

RELEASE.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Release type: minor
2+
3+
Add the ability to override the "max results" a relay's connection can return on
4+
a per-field basis.
5+
6+
The default value for this is defined in the schema's config, and set to `100`
7+
unless modified by the user. Now, that per-field value will take precedence over
8+
it.
9+
10+
For example:
11+
12+
```python
13+
@strawerry.type
14+
class Query:
15+
# This will still use the default value in the schema's config
16+
fruits: ListConnection[Fruit] = relay.connection()
17+
18+
# This will reduce the maximum number of results to 10
19+
limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10)
20+
21+
# This will increase the maximum number of results to 10
22+
higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
23+
```
24+
25+
Note that this only affects `ListConnection` and subclasses. If you are
26+
implementing your own connection resolver, there's an extra keyword named
27+
`max_results: int | None` that will be passed to it.

docs/guides/relay.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,24 @@ It can be defined in the `Query` objects in 4 ways:
205205
- `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list
206206
can contain `null` values if the given objects don't exist.
207207

208+
### Max results for connections
209+
210+
The implementation of `relay.ListConnection` will limit the number of results to
211+
the `relay_max_results` configuration in the
212+
[schema's config](../types/schema-configurations.md) (which defaults to `100`).
213+
214+
That can also be configured on a per-field basis by passing `max_results` to the
215+
`@connection` decorator. For example:
216+
217+
```python
218+
@strawerry.type
219+
class Query:
220+
fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
221+
```
222+
208223
### Custom connection pagination
209224

210-
The default `relay.Connection` class don't implement any pagination logic, and
225+
The default `relay.Connection` class doesn't implement any pagination logic, and
211226
should be used as a base class to implement your own pagination logic. All you
212227
need to do is implement the `resolve_connection` classmethod.
213228

strawberry/relay/fields.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
207207
class ConnectionExtension(FieldExtension):
208208
connection_type: type[Connection[Node]]
209209

210+
def __init__(self, max_results: Optional[int] = None) -> None:
211+
self.max_results = max_results
212+
210213
def apply(self, field: StrawberryField) -> None:
211214
field.arguments = [
212215
*field.arguments,
@@ -313,6 +316,7 @@ def resolve(
313316
after=after,
314317
first=first,
315318
last=last,
319+
max_results=self.max_results,
316320
)
317321

318322
async def resolve_async(
@@ -341,6 +345,7 @@ async def resolve_async(
341345
after=after,
342346
first=first,
343347
last=last,
348+
max_results=self.max_results,
344349
)
345350

346351
# If nodes was an AsyncIterable/AsyncIterator, resolve_connection
@@ -382,6 +387,7 @@ def connection(
382387
metadata: Optional[Mapping[Any, Any]] = None,
383388
directives: Optional[Sequence[object]] = (),
384389
extensions: list[FieldExtension] = (), # type: ignore
390+
max_results: Optional[int] = None,
385391
# This init parameter is used by pyright to determine whether this field
386392
# is added in the constructor or not. It is not used to change
387393
# any behaviour at the moment.
@@ -414,6 +420,9 @@ def connection(
414420
metadata: The metadata of the field.
415421
directives: The directives to apply to the field.
416422
extensions: The extensions to apply to the field.
423+
max_results: The maximum number of results this connection can return.
424+
Can be set to override the default value of 100 defined in the
425+
schema configuration.
417426
init: Used only for type checking purposes.
418427
419428
Examples:
@@ -476,7 +485,7 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
476485
default_factory=default_factory,
477486
metadata=metadata,
478487
directives=directives or (),
479-
extensions=[*extensions, ConnectionExtension()],
488+
extensions=[*extensions, ConnectionExtension(max_results=max_results)],
480489
)
481490
if resolver is not None:
482491
f = f(resolver)

strawberry/relay/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def resolve_connection(
717717
after: Optional[str] = None,
718718
first: Optional[int] = None,
719719
last: Optional[int] = None,
720+
max_results: Optional[int] = None,
720721
**kwargs: Any,
721722
) -> AwaitableOrValue[Self]:
722723
"""Resolve a connection from nodes.
@@ -731,6 +732,7 @@ def resolve_connection(
731732
after: Returns the items in the list that come after the specified cursor.
732733
first: Returns the first n items from the list.
733734
last: Returns the items in the list that come after the specified cursor.
735+
max_results: The maximum number of results to resolve.
734736
kwargs: Additional arguments passed to the resolver.
735737
736738
Returns:
@@ -767,6 +769,7 @@ def resolve_connection( # noqa: PLR0915
767769
after: Optional[str] = None,
768770
first: Optional[int] = None,
769771
last: Optional[int] = None,
772+
max_results: Optional[int] = None,
770773
**kwargs: Any,
771774
) -> AwaitableOrValue[Self]:
772775
"""Resolve a connection from the list of nodes.
@@ -780,6 +783,7 @@ def resolve_connection( # noqa: PLR0915
780783
after: Returns the items in the list that come after the specified cursor.
781784
first: Returns the first n items from the list.
782785
last: Returns the items in the list that come after the specified cursor.
786+
max_results: The maximum number of results to resolve.
783787
kwargs: Additional arguments passed to the resolver.
784788
785789
Returns:
@@ -794,6 +798,7 @@ def resolve_connection( # noqa: PLR0915
794798
after=after,
795799
first=first,
796800
last=last,
801+
max_results=max_results,
797802
)
798803

799804
type_def = get_object_definition(cls)

strawberry/relay/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,16 @@ def from_arguments(
131131
after: str | None = None,
132132
first: int | None = None,
133133
last: int | None = None,
134+
max_results: int | None = None,
134135
) -> Self:
135136
"""Get the slice metadata to use on ListConnection."""
136137
from strawberry.relay.types import PREFIX
137138

138-
max_results = info.schema.config.relay_max_results
139+
max_results = (
140+
max_results
141+
if max_results is not None
142+
else info.schema.config.relay_max_results
143+
)
139144
start = 0
140145
end: int | None = None
141146

tests/relay/test_connection.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import strawberry
99
from strawberry.permission import BasePermission
1010
from strawberry.relay import Connection, Node
11+
from strawberry.relay.types import ListConnection
12+
from strawberry.schema.config import StrawberryConfig
1113

1214

1315
@strawberry.type
@@ -34,6 +36,8 @@ def resolve_connection(
3436
before: Optional[str] = None,
3537
first: Optional[int] = None,
3638
last: Optional[int] = None,
39+
max_results: Optional[int] = None,
40+
**kwargs: Any,
3741
) -> Optional[Self]:
3842
return None
3943

@@ -124,3 +128,57 @@ def users(self) -> Optional[list[User]]: # pragma: no cover
124128
result = schema.execute_sync(query)
125129
assert result.data == {"users": None}
126130
assert result.errors[0].message == "Not allowed"
131+
132+
133+
@pytest.mark.parametrize(
134+
("field_max_results", "schema_max_results", "results", "expected"),
135+
[
136+
(5, 100, 5, 5),
137+
(5, 2, 5, 5),
138+
(5, 100, 10, 5),
139+
(5, 2, 10, 5),
140+
(5, 100, 0, 0),
141+
(5, 2, 0, 0),
142+
(None, 100, 5, 5),
143+
(None, 2, 5, 2),
144+
],
145+
)
146+
def test_max_results(
147+
field_max_results: Optional[int],
148+
schema_max_results: int,
149+
results: int,
150+
expected: int,
151+
):
152+
@strawberry.type
153+
class User(Node):
154+
id: strawberry.relay.NodeID[str]
155+
156+
@strawberry.type
157+
class Query:
158+
@strawberry.relay.connection(
159+
ListConnection[User],
160+
max_results=field_max_results,
161+
)
162+
def users(self) -> list[User]:
163+
return [User(id=str(i)) for i in range(results)]
164+
165+
schema = strawberry.Schema(
166+
query=Query,
167+
config=StrawberryConfig(relay_max_results=schema_max_results),
168+
)
169+
query = """
170+
query {
171+
users {
172+
edges {
173+
node {
174+
id
175+
}
176+
}
177+
}
178+
}
179+
"""
180+
181+
result = schema.execute_sync(query)
182+
assert result.data is not None
183+
assert isinstance(result.data["users"]["edges"], list)
184+
assert len(result.data["users"]["edges"]) == expected

0 commit comments

Comments
 (0)