Skip to content

Commit b9055b9

Browse files
committed
Implement TransmitUnexpectedExceptionInterceptor
1 parent c79814c commit b9055b9

File tree

3 files changed

+366
-0
lines changed

3 files changed

+366
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright (c) 2019, Salesforce.com, Inc.
3+
* All rights reserved.
4+
* Licensed under the BSD 3-Clause license.
5+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
package com.salesforce.grpc.contrib.interceptor;
9+
10+
import io.grpc.*;
11+
12+
import java.io.PrintWriter;
13+
import java.io.StringWriter;
14+
import java.util.Collection;
15+
import java.util.HashSet;
16+
import java.util.Set;
17+
18+
/**
19+
* A class that intercepts uncaught exceptions of all types and handles them by closing the {@link ServerCall}, and
20+
* transmitting the exception's description and stack trace to the client. This class is a complement to gRPC's
21+
* {@code TransmitStatusRuntimeExceptionInterceptor}.
22+
*
23+
* <p>Without this interceptor, gRPC will strip all details and close the {@link ServerCall} with
24+
* a generic {@link Status#UNKNOWN} code.
25+
*
26+
* <p>Security warning: the exception description and stack trace may contain sensitive server-side
27+
* state information, and generally should not be sent to clients. Only install this interceptor
28+
* if all clients are trusted.
29+
*/
30+
// Heavily inspired by https://github.com/saturnism/grpc-java-by-example/blob/master/error-handling-example/error-server/src/main/java/com/example/grpc/server/UnknownStatusDescriptionInterceptor.java
31+
public class TransmitUnexpectedExceptionInterceptor implements ServerInterceptor {
32+
33+
private final Set<Class<? extends Throwable>> exactTypes = new HashSet<>();
34+
private final Set<Class<? extends Throwable>> parentTypes = new HashSet<>();
35+
36+
/**
37+
* Allows this interceptor to match on an exact exception type.
38+
* @param exactType The exact type to match on.
39+
* @return this
40+
*/
41+
public TransmitUnexpectedExceptionInterceptor forExactType(Class<? extends Throwable> exactType) {
42+
this.exactTypes.add(exactType);
43+
return this;
44+
}
45+
46+
/**
47+
* Allows this interceptor to match on a set of exact exception type.
48+
* @param exactTypes The set of exact types to match on.
49+
* @return this
50+
*/
51+
public TransmitUnexpectedExceptionInterceptor forExactTypes(Collection<Class<? extends Throwable>> exactTypes) {
52+
this.exactTypes.addAll(exactTypes);
53+
return this;
54+
}
55+
56+
/**
57+
* Allows this interceptor to match on any exception type deriving from {@code parentType}.
58+
* @param parentType The parent type to match on.
59+
* @return this
60+
*/
61+
public TransmitUnexpectedExceptionInterceptor forParentType(Class<? extends Throwable> parentType) {
62+
this.parentTypes.add(parentType);
63+
return this;
64+
}
65+
66+
/**
67+
* Allows this interceptor to match on any exception type deriving from any element of {@code parentTypes}.
68+
* @param parentTypes The set of parent types to match on.
69+
* @return this
70+
*/
71+
public TransmitUnexpectedExceptionInterceptor forParentTypes(Collection<Class<? extends Throwable>> parentTypes) {
72+
this.parentTypes.addAll(parentTypes);
73+
return this;
74+
}
75+
76+
/**
77+
* Allows this interceptor to match all exceptions. Use with caution!
78+
* @return this
79+
*/
80+
public TransmitUnexpectedExceptionInterceptor forAllExceptions() {
81+
return forParentType(Throwable.class);
82+
}
83+
84+
@Override
85+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
86+
ServerCall<ReqT, RespT> wrappedCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
87+
@Override
88+
public void sendMessage(RespT message) {
89+
super.sendMessage(message);
90+
}
91+
92+
@Override
93+
public void close(Status status, Metadata trailers) {
94+
95+
if (status.getCode() == Status.Code.UNKNOWN &&
96+
status.getDescription() == null &&
97+
status.getCause() != null &&
98+
exceptionTypeIsAllowed(status.getCause().getClass())) {
99+
Throwable e = status.getCause();
100+
status = Status.INTERNAL
101+
.withDescription(e.getMessage())
102+
.augmentDescription(stacktraceToString(e));
103+
}
104+
super.close(status, trailers);
105+
}
106+
};
107+
108+
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(wrappedCall, headers)) {
109+
@Override
110+
public void onHalfClose() {
111+
try {
112+
super.onHalfClose();
113+
} catch (Throwable e) {
114+
if (exceptionTypeIsAllowed(e.getClass())) {
115+
call.close(Status.INTERNAL
116+
.withDescription(e.getMessage())
117+
.augmentDescription(stacktraceToString(e)), new Metadata());
118+
} else {
119+
throw e;
120+
}
121+
}
122+
}
123+
};
124+
}
125+
126+
private boolean exceptionTypeIsAllowed(Class<? extends Throwable> exceptionClass) {
127+
// exact matches
128+
for (Class<? extends Throwable> clazz : exactTypes) {
129+
if (clazz.equals(exceptionClass)) {
130+
return true;
131+
}
132+
}
133+
134+
// parent type matches
135+
for (Class<? extends Throwable> clazz : parentTypes) {
136+
if (clazz.isAssignableFrom(exceptionClass)) {
137+
return true;
138+
}
139+
}
140+
141+
// no match
142+
return false;
143+
}
144+
145+
private String stacktraceToString(Throwable e) {
146+
StringWriter stringWriter = new StringWriter();
147+
PrintWriter printWriter = new PrintWriter(stringWriter);
148+
e.printStackTrace(printWriter);
149+
return stringWriter.toString();
150+
}
151+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/*
2+
* Copyright (c) 2019, Salesforce.com, Inc.
3+
* All rights reserved.
4+
* Licensed under the BSD 3-Clause license.
5+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
package com.salesforce.grpc.contrib.interceptor;
9+
10+
import com.salesforce.grpc.contrib.GreeterGrpc;
11+
import com.salesforce.grpc.contrib.HelloRequest;
12+
import com.salesforce.grpc.contrib.HelloResponse;
13+
import io.grpc.ServerInterceptor;
14+
import io.grpc.ServerInterceptors;
15+
import io.grpc.Status;
16+
import io.grpc.StatusRuntimeException;
17+
import io.grpc.stub.StreamObserver;
18+
import io.grpc.testing.GrpcServerRule;
19+
import org.junit.Rule;
20+
import org.junit.Test;
21+
22+
import java.util.Iterator;
23+
24+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
25+
26+
public class TransmitUnexpectedExceptionInterceptorTest {
27+
@Rule
28+
public final GrpcServerRule serverRule = new GrpcServerRule();
29+
30+
@Test
31+
public void noExceptionDoesNotInterfere() {
32+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
33+
@Override
34+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
35+
responseObserver.onNext(HelloResponse.newBuilder().setMessage("Hello " + request.getName()).build());
36+
responseObserver.onCompleted();
37+
}
38+
};
39+
40+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor();
41+
42+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
43+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
44+
45+
stub.sayHello(HelloRequest.newBuilder().setName("World").build());
46+
}
47+
48+
@Test
49+
public void exactTypeMatches() {
50+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
51+
@Override
52+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
53+
responseObserver.onError(new ArithmeticException("Divide by zero"));
54+
}
55+
};
56+
57+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);
58+
59+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
60+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
61+
62+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
63+
.isInstanceOf(StatusRuntimeException.class)
64+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
65+
.hasMessageContaining("Divide by zero");
66+
}
67+
68+
@Test
69+
public void parentTypeMatches() {
70+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
71+
@Override
72+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
73+
responseObserver.onError(new ArithmeticException("Divide by zero"));
74+
}
75+
};
76+
77+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class);
78+
79+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
80+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
81+
82+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
83+
.isInstanceOf(StatusRuntimeException.class)
84+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
85+
.hasMessageContaining("Divide by zero");
86+
}
87+
88+
@Test
89+
public void parentTypeMatchesExactly() {
90+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
91+
@Override
92+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
93+
responseObserver.onError(new RuntimeException("Divide by zero"));
94+
}
95+
};
96+
97+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class);
98+
99+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
100+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
101+
102+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
103+
.isInstanceOf(StatusRuntimeException.class)
104+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
105+
.hasMessageContaining("Divide by zero");
106+
}
107+
108+
@Test
109+
public void alleMatches() {
110+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
111+
@Override
112+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
113+
responseObserver.onError(new ArithmeticException("Divide by zero"));
114+
}
115+
};
116+
117+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forAllExceptions();
118+
119+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
120+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
121+
122+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
123+
.isInstanceOf(StatusRuntimeException.class)
124+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
125+
.hasMessageContaining("Divide by zero");
126+
}
127+
128+
@Test
129+
public void unknownTypeDoesNotMatch() {
130+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
131+
@Override
132+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
133+
responseObserver.onError(new NullPointerException("Bananas!"));
134+
}
135+
};
136+
137+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);
138+
139+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
140+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
141+
142+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
143+
.isInstanceOf(StatusRuntimeException.class)
144+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN")
145+
.hasMessageContaining("UNKNOWN");
146+
}
147+
148+
@Test
149+
public void unexpectedExceptionCanMatch() {
150+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
151+
@Override
152+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
153+
throw new ArithmeticException("Divide by zero");
154+
}
155+
};
156+
157+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);
158+
159+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
160+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
161+
162+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
163+
.isInstanceOf(StatusRuntimeException.class)
164+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
165+
.hasMessageContaining("Divide by zero");
166+
}
167+
168+
@Test
169+
public void unexpectedExceptionCanNotMatch() {
170+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
171+
@Override
172+
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
173+
throw new ArithmeticException("Divide by zero");
174+
}
175+
};
176+
177+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(NullPointerException.class);
178+
179+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
180+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
181+
182+
assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
183+
.isInstanceOf(StatusRuntimeException.class)
184+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN")
185+
.hasMessageContaining("UNKNOWN");
186+
}
187+
188+
@Test
189+
public void unexpectedExceptionCanMatchStreaming() {
190+
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
191+
@Override
192+
public void sayHelloStream(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
193+
responseObserver.onNext(HelloResponse.getDefaultInstance());
194+
responseObserver.onNext(HelloResponse.getDefaultInstance());
195+
throw new ArithmeticException("Divide by zero");
196+
}
197+
};
198+
199+
ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);
200+
201+
serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
202+
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());
203+
204+
Iterator<HelloResponse> it = stub.sayHelloStream(HelloRequest.newBuilder().setName("World").build());
205+
it.next();
206+
it.next();
207+
assertThatThrownBy(it::next)
208+
.isInstanceOf(StatusRuntimeException.class)
209+
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
210+
.hasMessageContaining("Divide by zero");
211+
}
212+
}

contrib/grpc-contrib/src/test/proto/helloworld.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ service Greeter {
1414
// Sends a greeting
1515
rpc SayHello (HelloRequest) returns (HelloResponse) {}
1616

17+
// Sends many greetings
18+
rpc SayHelloStream (HelloRequest) returns (stream HelloResponse) {}
19+
1720
// Sends the current time
1821
rpc SayTime (google.protobuf.Empty) returns (TimeResponse) {}
1922
}

0 commit comments

Comments
 (0)