4
4
from graphql import format_error , graphql
5
5
6
6
from .constants import (
7
+ GQL_COMPLETE ,
7
8
GQL_CONNECTION_ERROR ,
8
9
GQL_CONNECTION_INIT ,
9
10
GQL_CONNECTION_TERMINATE ,
10
11
GQL_DATA ,
11
12
GQL_ERROR ,
13
+ GQL_NEXT ,
12
14
GQL_START ,
13
15
GQL_STOP ,
16
+ GQL_SUBSCRIBE ,
17
+ TRANSPORT_WS_PROTOCOL ,
14
18
)
15
19
16
20
@@ -23,6 +27,9 @@ def __init__(self, ws, request_context=None):
23
27
self .ws = ws
24
28
self .operations = {}
25
29
self .request_context = request_context
30
+ self .transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
31
+ request_context .get ("subprotocols" ) or []
32
+ )
26
33
27
34
def has_operation (self , op_id ):
28
35
return op_id in self .operations
@@ -41,7 +48,7 @@ def remove_operation(self, op_id):
41
48
42
49
def unsubscribe (self , op_id ):
43
50
async_iterator = self .remove_operation (op_id )
44
- if hasattr (async_iterator , ' dispose' ):
51
+ if hasattr (async_iterator , " dispose" ):
45
52
async_iterator .dispose ()
46
53
return async_iterator
47
54
@@ -84,12 +91,16 @@ def process_message(self, connection_context, parsed_message):
84
91
elif op_type == GQL_CONNECTION_TERMINATE :
85
92
return self .on_connection_terminate (connection_context , op_id )
86
93
87
- elif op_type == GQL_START :
94
+ elif op_type == (
95
+ GQL_SUBSCRIBE if connection_context .transport_ws_protocol else GQL_START
96
+ ):
88
97
assert isinstance (payload , dict ), "The payload must be a dict"
89
98
params = self .get_graphql_params (connection_context , payload )
90
99
return self .on_start (connection_context , op_id , params )
91
100
92
- elif op_type == GQL_STOP :
101
+ elif op_type == (
102
+ GQL_COMPLETE if connection_context .transport_ws_protocol else GQL_STOP
103
+ ):
93
104
return self .on_stop (connection_context , op_id )
94
105
95
106
else :
@@ -142,7 +153,12 @@ def build_message(self, id, op_type, payload):
142
153
143
154
def send_execution_result (self , connection_context , op_id , execution_result ):
144
155
result = self .execution_result_to_dict (execution_result )
145
- return self .send_message (connection_context , op_id , GQL_DATA , result )
156
+ return self .send_message (
157
+ connection_context ,
158
+ op_id ,
159
+ GQL_NEXT if connection_context .transport_ws_protocol else GQL_DATA ,
160
+ result ,
161
+ )
146
162
147
163
def execution_result_to_dict (self , execution_result ):
148
164
result = OrderedDict ()
0 commit comments