Skip to content

Commit cdda9d4

Browse files
committed
Support graphql-transport-ws websocket subprotocol
1 parent 366d1aa commit cdda9d4

File tree

3 files changed

+39
-16
lines changed

3 files changed

+39
-16
lines changed

graphql_ws/base.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
from graphql import format_error, graphql
55

66
from .constants import (
7+
GQL_COMPLETE,
78
GQL_CONNECTION_ERROR,
89
GQL_CONNECTION_INIT,
910
GQL_CONNECTION_TERMINATE,
1011
GQL_DATA,
1112
GQL_ERROR,
13+
GQL_NEXT,
1214
GQL_START,
1315
GQL_STOP,
16+
GQL_SUBSCRIBE,
17+
TRANSPORT_WS_PROTOCOL,
1418
)
1519

1620

@@ -23,6 +27,9 @@ def __init__(self, ws, request_context=None):
2327
self.ws = ws
2428
self.operations = {}
2529
self.request_context = request_context
30+
self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
31+
request_context.get("subprotocols") or []
32+
)
2633

2734
def has_operation(self, op_id):
2835
return op_id in self.operations
@@ -41,7 +48,7 @@ def remove_operation(self, op_id):
4148

4249
def unsubscribe(self, op_id):
4350
async_iterator = self.remove_operation(op_id)
44-
if hasattr(async_iterator, 'dispose'):
51+
if hasattr(async_iterator, "dispose"):
4552
async_iterator.dispose()
4653
return async_iterator
4754

@@ -84,12 +91,16 @@ def process_message(self, connection_context, parsed_message):
8491
elif op_type == GQL_CONNECTION_TERMINATE:
8592
return self.on_connection_terminate(connection_context, op_id)
8693

87-
elif op_type == GQL_START:
94+
elif op_type == (
95+
GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START
96+
):
8897
assert isinstance(payload, dict), "The payload must be a dict"
8998
params = self.get_graphql_params(connection_context, payload)
9099
return self.on_start(connection_context, op_id, params)
91100

92-
elif op_type == GQL_STOP:
101+
elif op_type == (
102+
GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP
103+
):
93104
return self.on_stop(connection_context, op_id)
94105

95106
else:
@@ -142,7 +153,12 @@ def build_message(self, id, op_type, payload):
142153

143154
def send_execution_result(self, connection_context, op_id, execution_result):
144155
result = self.execution_result_to_dict(execution_result)
145-
return self.send_message(connection_context, op_id, GQL_DATA, result)
156+
return self.send_message(
157+
connection_context,
158+
op_id,
159+
GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA,
160+
result,
161+
)
146162

147163
def execution_result_to_dict(self, execution_result):
148164
result = OrderedDict()

graphql_ws/constants.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
GRAPHQL_WS = "graphql-ws"
22
WS_PROTOCOL = GRAPHQL_WS
3+
TRANSPORT_WS_PROTOCOL = "graphql-transport-ws"
34

45
GQL_CONNECTION_INIT = "connection_init" # Client -> Server
56
GQL_CONNECTION_ACK = "connection_ack" # Server -> Client
@@ -8,8 +9,10 @@
89
# NOTE: This one here don't follow the standard due to connection optimization
910
GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server
1011
GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client
11-
GQL_START = "start" # Client -> Server
12-
GQL_DATA = "data" # Server -> Client
12+
GQL_START = "start" # Client -> Server (graphql-ws)
13+
GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent)
14+
GQL_DATA = "data" # Server -> Client (graphql-ws)
15+
GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent)
1316
GQL_ERROR = "error" # Server -> Client
14-
GQL_COMPLETE = "complete" # Server -> Client
15-
GQL_STOP = "stop" # Client -> Server
17+
GQL_COMPLETE = "complete" # Server -> Client (and Client -> Server for graphql-transport-ws STOP equivalent)
18+
GQL_STOP = "stop" # Client -> Server (graphql-ws only)

graphql_ws/django/consumers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
from channels.generic.websocket import AsyncJsonWebsocketConsumer
44

5-
from ..constants import WS_PROTOCOL
5+
from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL
66
from .subscriptions import subscription_server
77

88

99
class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):
10-
1110
async def connect(self):
1211
self.connection_context = None
13-
if WS_PROTOCOL in self.scope["subprotocols"]:
14-
self.connection_context = await subscription_server.handle(
15-
ws=self, request_context=self.scope
16-
)
17-
await self.accept(subprotocol=WS_PROTOCOL)
18-
else:
12+
found_protocol = None
13+
for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]:
14+
if protocol in self.scope["subprotocols"]:
15+
found_protocol = protocol
16+
break
17+
if not found_protocol:
1918
await self.close()
19+
return
20+
self.connection_context = await subscription_server.handle(
21+
ws=self, request_context=self.scope
22+
)
23+
await self.accept(subprotocol=found_protocol)
2024

2125
async def disconnect(self, code):
2226
if self.connection_context:

0 commit comments

Comments
 (0)