|
| 1 | +/* |
| 2 | + * Copyright (c) 2017, salesforce.com, inc. |
| 3 | + * All rights reserved. |
| 4 | + * Licensed under the BSD 3-Clause license. |
| 5 | + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| 6 | + */ |
| 7 | + |
| 8 | +package com.salesforce.grpc.contrib.interceptor; |
| 9 | + |
| 10 | +import com.google.common.annotations.VisibleForTesting; |
| 11 | +import io.grpc.*; |
| 12 | + |
| 13 | +import java.lang.reflect.Field; |
| 14 | +import java.util.ArrayList; |
| 15 | +import java.util.List; |
| 16 | +import java.util.function.BiFunction; |
| 17 | +import java.util.function.Function; |
| 18 | + |
| 19 | +import static com.google.common.base.Preconditions.checkNotNull; |
| 20 | + |
| 21 | +/** |
| 22 | + * {@code DefaultCallOptionsClientInterceptor} applies specified gRPC {@code CallOptions} to every outbound request. |
| 23 | + * By default, {@code DefaultCallOptionsClientInterceptor} will not overwrite {@code CallOptions} already set on the |
| 24 | + * outbound request. |
| 25 | + * |
| 26 | + * <p>Example uses include: |
| 27 | + * <ul> |
| 28 | + * <li>Applying a set of {@code CallCredentials} to every request from any stub.</li> |
| 29 | + * <li>Applying a compression strategy to every request.</li> |
| 30 | + * <li>Attaching a custom {@code CallOptions.Key<T>} to every request.</li> |
| 31 | + * <li>Setting the {@code WaitForReady} bit on every request.</li> |
| 32 | + * <li>Preventing upstream users from tweaking {@code CallOptions} values by forcibly overwriting the value with a |
| 33 | + * specific default.</li> |
| 34 | + * </ul> |
| 35 | + */ |
| 36 | +public class DefaultCallOptionsClientInterceptor implements ClientInterceptor { |
| 37 | + private static final Field CUSTOM_OPTIONS_FIELD = getCustomOptionsField(); |
| 38 | + |
| 39 | + private static Field getCustomOptionsField() { |
| 40 | + try { |
| 41 | + Field f; |
| 42 | + f = CallOptions.class.getDeclaredField("customOptions"); |
| 43 | + f.setAccessible(true); |
| 44 | + return f; |
| 45 | + } catch (NoSuchFieldException e) { |
| 46 | + throw new RuntimeException(e); |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + private CallOptions defaultOptions; |
| 51 | + private boolean overwrite = false; |
| 52 | + |
| 53 | + /** |
| 54 | + * Constructs a {@code DefaultCallOptionsClientInterceptor}. |
| 55 | + * @param options the set of {@code CallOptions} to apply to every call |
| 56 | + */ |
| 57 | + public DefaultCallOptionsClientInterceptor(CallOptions options) { |
| 58 | + this.defaultOptions = checkNotNull(options, "defaultOptions"); |
| 59 | + } |
| 60 | + |
| 61 | + /** |
| 62 | + * Instructs the interceptor to overwrite {@code CallOptions} values even if they are already present on the |
| 63 | + * outbound request. |
| 64 | + * |
| 65 | + * @return this |
| 66 | + */ |
| 67 | + public DefaultCallOptionsClientInterceptor overwriteExistingValues() { |
| 68 | + this.overwrite = true; |
| 69 | + return this; |
| 70 | + } |
| 71 | + |
| 72 | + public CallOptions getDefaultOptions() { |
| 73 | + return defaultOptions; |
| 74 | + } |
| 75 | + |
| 76 | + public void setDefaultOptions(CallOptions options) { |
| 77 | + this.defaultOptions = options; |
| 78 | + } |
| 79 | + |
| 80 | + @Override |
| 81 | + public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { |
| 82 | + return next.newCall(method, patchOptions(callOptions)); |
| 83 | + } |
| 84 | + |
| 85 | + @VisibleForTesting |
| 86 | + CallOptions patchOptions(CallOptions baseOptions) { |
| 87 | + CallOptions patchedOptions = baseOptions; |
| 88 | + |
| 89 | + patchedOptions = patchOption(patchedOptions, CallOptions::getAuthority, CallOptions::withAuthority); |
| 90 | + patchedOptions = patchOption(patchedOptions, CallOptions::getCredentials, CallOptions::withCallCredentials); |
| 91 | + patchedOptions = patchOption(patchedOptions, CallOptions::getCompressor, CallOptions::withCompression); |
| 92 | + patchedOptions = patchOption(patchedOptions, CallOptions::getDeadline, CallOptions::withDeadline); |
| 93 | + patchedOptions = patchOption(patchedOptions, CallOptions::isWaitForReady, (callOptions, waitForReady) -> waitForReady ? callOptions.withWaitForReady() : callOptions.withoutWaitForReady()); |
| 94 | + patchedOptions = patchOption(patchedOptions, CallOptions::getMaxInboundMessageSize, CallOptions::withMaxInboundMessageSize); |
| 95 | + patchedOptions = patchOption(patchedOptions, CallOptions::getMaxOutboundMessageSize, CallOptions::withMaxOutboundMessageSize); |
| 96 | + patchedOptions = patchOption(patchedOptions, CallOptions::getExecutor, CallOptions::withExecutor); |
| 97 | + |
| 98 | + for (ClientStreamTracer.Factory factory : defaultOptions.getStreamTracerFactories()) { |
| 99 | + patchedOptions = patchedOptions.withStreamTracerFactory(factory); |
| 100 | + } |
| 101 | + |
| 102 | + for (CallOptions.Key<Object> key : customOptionKeys(defaultOptions)) { |
| 103 | + patchedOptions = patchOption(patchedOptions, co -> co.getOption(key), (co, o) -> co.withOption(key, o)); |
| 104 | + } |
| 105 | + |
| 106 | + return patchedOptions; |
| 107 | + } |
| 108 | + |
| 109 | + private <T> CallOptions patchOption(CallOptions baseOptions, Function<CallOptions, T> getter, BiFunction<CallOptions, T, CallOptions> setter) { |
| 110 | + T baseValue = getter.apply(baseOptions); |
| 111 | + if (baseValue == null || overwrite) { |
| 112 | + T patchValue = getter.apply(defaultOptions); |
| 113 | + if (patchValue != null) { |
| 114 | + return setter.apply(baseOptions, patchValue); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + return baseOptions; |
| 119 | + } |
| 120 | + |
| 121 | + @SuppressWarnings("unchecked") |
| 122 | + private List<CallOptions.Key<Object>> customOptionKeys(CallOptions callOptions) { |
| 123 | + try { |
| 124 | + Object[][] customOptions = (Object[][]) CUSTOM_OPTIONS_FIELD.get(callOptions); |
| 125 | + List<CallOptions.Key<Object>> keys = new ArrayList<>(customOptions.length); |
| 126 | + for (Object[] arr : customOptions) { |
| 127 | + keys.add((CallOptions.Key<Object>) arr[0]); |
| 128 | + } |
| 129 | + return keys; |
| 130 | + } catch (IllegalAccessException e) { |
| 131 | + throw new RuntimeException(e); |
| 132 | + } |
| 133 | + } |
| 134 | +} |
0 commit comments