Skip to content

Commit 347c90b

Browse files
authored
Merge pull request #63 from salesforce/feature/default-options-interceptor
Implemented DefaultCallOptionsClientInterceptor
2 parents 9076300 + 98c1352 commit 347c90b

File tree

3 files changed

+279
-0
lines changed

3 files changed

+279
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright (c) 2017, salesforce.com, inc.
3+
* All rights reserved.
4+
* Licensed under the BSD 3-Clause license.
5+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
package com.salesforce.grpc.contrib.interceptor;
9+
10+
import com.google.common.annotations.VisibleForTesting;
11+
import io.grpc.*;
12+
13+
import java.lang.reflect.Field;
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
import java.util.function.BiFunction;
17+
import java.util.function.Function;
18+
19+
import static com.google.common.base.Preconditions.checkNotNull;
20+
21+
/**
22+
* {@code DefaultCallOptionsClientInterceptor} applies specified gRPC {@code CallOptions} to every outbound request.
23+
* By default, {@code DefaultCallOptionsClientInterceptor} will not overwrite {@code CallOptions} already set on the
24+
* outbound request.
25+
*
26+
* <p>Example uses include:
27+
* <ul>
28+
* <li>Applying a set of {@code CallCredentials} to every request from any stub.</li>
29+
* <li>Applying a compression strategy to every request.</li>
30+
* <li>Attaching a custom {@code CallOptions.Key<T>} to every request.</li>
31+
* <li>Setting the {@code WaitForReady} bit on every request.</li>
32+
* <li>Preventing upstream users from tweaking {@code CallOptions} values by forcibly overwriting the value with a
33+
* specific default.</li>
34+
* </ul>
35+
*/
36+
public class DefaultCallOptionsClientInterceptor implements ClientInterceptor {
37+
private static final Field CUSTOM_OPTIONS_FIELD = getCustomOptionsField();
38+
39+
private static Field getCustomOptionsField() {
40+
try {
41+
Field f;
42+
f = CallOptions.class.getDeclaredField("customOptions");
43+
f.setAccessible(true);
44+
return f;
45+
} catch (NoSuchFieldException e) {
46+
throw new RuntimeException(e);
47+
}
48+
}
49+
50+
private CallOptions defaultOptions;
51+
private boolean overwrite = false;
52+
53+
/**
54+
* Constructs a {@code DefaultCallOptionsClientInterceptor}.
55+
* @param options the set of {@code CallOptions} to apply to every call
56+
*/
57+
public DefaultCallOptionsClientInterceptor(CallOptions options) {
58+
this.defaultOptions = checkNotNull(options, "defaultOptions");
59+
}
60+
61+
/**
62+
* Instructs the interceptor to overwrite {@code CallOptions} values even if they are already present on the
63+
* outbound request.
64+
*
65+
* @return this
66+
*/
67+
public DefaultCallOptionsClientInterceptor overwriteExistingValues() {
68+
this.overwrite = true;
69+
return this;
70+
}
71+
72+
public CallOptions getDefaultOptions() {
73+
return defaultOptions;
74+
}
75+
76+
public void setDefaultOptions(CallOptions options) {
77+
this.defaultOptions = options;
78+
}
79+
80+
@Override
81+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
82+
return next.newCall(method, patchOptions(callOptions));
83+
}
84+
85+
@VisibleForTesting
86+
CallOptions patchOptions(CallOptions baseOptions) {
87+
CallOptions patchedOptions = baseOptions;
88+
89+
patchedOptions = patchOption(patchedOptions, CallOptions::getAuthority, CallOptions::withAuthority);
90+
patchedOptions = patchOption(patchedOptions, CallOptions::getCredentials, CallOptions::withCallCredentials);
91+
patchedOptions = patchOption(patchedOptions, CallOptions::getCompressor, CallOptions::withCompression);
92+
patchedOptions = patchOption(patchedOptions, CallOptions::getDeadline, CallOptions::withDeadline);
93+
patchedOptions = patchOption(patchedOptions, CallOptions::isWaitForReady, (callOptions, waitForReady) -> waitForReady ? callOptions.withWaitForReady() : callOptions.withoutWaitForReady());
94+
patchedOptions = patchOption(patchedOptions, CallOptions::getMaxInboundMessageSize, CallOptions::withMaxInboundMessageSize);
95+
patchedOptions = patchOption(patchedOptions, CallOptions::getMaxOutboundMessageSize, CallOptions::withMaxOutboundMessageSize);
96+
patchedOptions = patchOption(patchedOptions, CallOptions::getExecutor, CallOptions::withExecutor);
97+
98+
for (ClientStreamTracer.Factory factory : defaultOptions.getStreamTracerFactories()) {
99+
patchedOptions = patchedOptions.withStreamTracerFactory(factory);
100+
}
101+
102+
for (CallOptions.Key<Object> key : customOptionKeys(defaultOptions)) {
103+
patchedOptions = patchOption(patchedOptions, co -> co.getOption(key), (co, o) -> co.withOption(key, o));
104+
}
105+
106+
return patchedOptions;
107+
}
108+
109+
private <T> CallOptions patchOption(CallOptions baseOptions, Function<CallOptions, T> getter, BiFunction<CallOptions, T, CallOptions> setter) {
110+
T baseValue = getter.apply(baseOptions);
111+
if (baseValue == null || overwrite) {
112+
T patchValue = getter.apply(defaultOptions);
113+
if (patchValue != null) {
114+
return setter.apply(baseOptions, patchValue);
115+
}
116+
}
117+
118+
return baseOptions;
119+
}
120+
121+
@SuppressWarnings("unchecked")
122+
private List<CallOptions.Key<Object>> customOptionKeys(CallOptions callOptions) {
123+
try {
124+
Object[][] customOptions = (Object[][]) CUSTOM_OPTIONS_FIELD.get(callOptions);
125+
List<CallOptions.Key<Object>> keys = new ArrayList<>(customOptions.length);
126+
for (Object[] arr : customOptions) {
127+
keys.add((CallOptions.Key<Object>) arr[0]);
128+
}
129+
return keys;
130+
} catch (IllegalAccessException e) {
131+
throw new RuntimeException(e);
132+
}
133+
}
134+
}

grpc-contrib/src/test/java/com/salesforce/grpc/contrib/MoreMetadataTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ public boolean equals(Object o) {
6868
}
6969
}
7070

71+
@Test
72+
public void jsonMarshallerPrimitiveRoundtrip() {
73+
Metadata.AsciiMarshaller<Integer> marshaller = MoreMetadata.JSON_MARSHALLER(Integer.class);
74+
String s = marshaller.toAsciiString(42);
75+
assertThat(s).isEqualTo("42");
76+
77+
Integer l = marshaller.parseAsciiString(s);
78+
assertThat(l).isEqualTo(42);
79+
}
80+
7181
@Test
7282
public void protobufMarshallerRoundtrip() {
7383
HelloRequest request = HelloRequest.newBuilder().setName("World").build();
@@ -134,4 +144,18 @@ public void rawJsonToTypedJson() {
134144
assertThat(bar.cheese).isEqualTo("swiss");
135145
assertThat(bar.age).isEqualTo(42);
136146
}
147+
148+
@Test
149+
public void rawBytesToTypedProto() {
150+
Metadata.Key<byte[]> byteKey = Metadata.Key.of("key-bin", Metadata.BINARY_BYTE_MARSHALLER);
151+
Metadata.Key<HelloRequest> protoKey = Metadata.Key.of("key-bin", MoreMetadata.PROTOBUF_MARSHALLER(HelloRequest.class));
152+
153+
HelloRequest request = HelloRequest.newBuilder().setName("World").build();
154+
Metadata metadata = new Metadata();
155+
metadata.put(byteKey, request.toByteArray());
156+
157+
HelloRequest request2 = metadata.get(protoKey);
158+
assertThat(request2).isNotNull();
159+
assertThat(request2.getName()).isEqualTo("World");
160+
}
137161
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) 2017, salesforce.com, inc.
3+
* All rights reserved.
4+
* Licensed under the BSD 3-Clause license.
5+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
package com.salesforce.grpc.contrib.interceptor;
9+
10+
import io.grpc.CallOptions;
11+
import io.grpc.ClientStreamTracer;
12+
import org.junit.Test;
13+
14+
import static org.assertj.core.api.Java6Assertions.assertThat;
15+
16+
public class DefaultCallOptionsClientInterceptorTest {
17+
@Test
18+
public void simpleValueTransfers() {
19+
CallOptions baseOptions = CallOptions.DEFAULT;
20+
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("FOO");
21+
22+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);
23+
24+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
25+
26+
assertThat(patchedOptions.getAuthority()).isEqualTo("FOO");
27+
}
28+
29+
@Test
30+
public void clientStreamTracerTransfers() {
31+
ClientStreamTracer.Factory factory1 = new ClientStreamTracer.Factory() {};
32+
ClientStreamTracer.Factory factory2 = new ClientStreamTracer.Factory() {};
33+
34+
CallOptions baseOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1);
35+
CallOptions defaultOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory2);
36+
37+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);
38+
39+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
40+
41+
assertThat(patchedOptions.getStreamTracerFactories()).containsExactly(factory1, factory2);
42+
}
43+
44+
@Test
45+
public void customKeyTransfers() {
46+
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
47+
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);
48+
49+
CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO");
50+
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR");
51+
52+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);
53+
54+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
55+
56+
assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
57+
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
58+
}
59+
60+
@Test
61+
public void noOverwriteWorks() {
62+
CallOptions baseOptions = CallOptions.DEFAULT.withAuthority("FOO");
63+
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("BAR");
64+
65+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);
66+
67+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
68+
69+
assertThat(patchedOptions.getAuthority()).isEqualTo("FOO");
70+
}
71+
72+
@Test
73+
public void noOverwriteWorksCustomKeys() {
74+
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
75+
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);
76+
CallOptions.Key<String> k3 = CallOptions.Key.of("k3", null);
77+
78+
CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO").withOption(k3, "BAZ");
79+
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR").withOption(k3, "BOP");
80+
81+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);
82+
83+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
84+
85+
assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
86+
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
87+
assertThat(patchedOptions.getOption(k3)).isEqualTo("BAZ");
88+
}
89+
90+
@Test
91+
public void overwriteWorks() {
92+
CallOptions baseOptions = CallOptions.DEFAULT.withAuthority("FOO");
93+
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("BAR");
94+
95+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions)
96+
.overwriteExistingValues();
97+
98+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
99+
100+
assertThat(patchedOptions.getAuthority()).isEqualTo("BAR");
101+
}
102+
103+
@Test
104+
public void overwriteWorksCustomKeys() {
105+
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
106+
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);
107+
CallOptions.Key<String> k3 = CallOptions.Key.of("k3", null);
108+
109+
CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO").withOption(k3, "BAZ");
110+
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR").withOption(k3, "BOP");
111+
112+
DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions)
113+
.overwriteExistingValues();
114+
115+
CallOptions patchedOptions = interceptor.patchOptions(baseOptions);
116+
117+
assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
118+
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
119+
assertThat(patchedOptions.getOption(k3)).isEqualTo("BOP");
120+
}
121+
}

0 commit comments

Comments
 (0)