Skip to content

Commit ae56785

Browse files
authored
Merge pull request #960 from AzureAD/avdunn/retry-behavior
Refactor retry logic and add special behavior for IMDS
2 parents 1930c20 + eb10360 commit ae56785

30 files changed

+517
-143
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL authorityUr
235235

236236
AadInstanceDiscoveryResponse response = JsonHelper.convertJsonStringToJsonSerializableObject(httpResponse.body(), AadInstanceDiscoveryResponse::fromJson);
237237

238-
if (httpResponse.statusCode() != HttpHelper.HTTP_STATUS_200) {
239-
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_400 && response.error().equals("invalid_instance")) {
238+
if (httpResponse.statusCode() != HttpStatus.HTTP_OK) {
239+
if (httpResponse.statusCode() == HttpStatus.HTTP_BAD_REQUEST && response.error().equals("invalid_instance")) {
240240
// instance discovery failed due to an invalid authority, throw an exception.
241241
throw MsalServiceExceptionFactory.fromHttpResponse(httpResponse);
242242
}
@@ -310,7 +310,7 @@ static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundl
310310
log.info("Starting call to IMDS endpoint.");
311311
IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
312312
//If call to IMDS endpoint was successful, return region from response body
313-
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
313+
if (httpResponse.statusCode() == HttpStatus.HTTP_OK && !httpResponse.body().isEmpty()) {
314314
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
315315
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);
316316

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractApplicationBase.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public abstract class AbstractApplicationBase implements IApplicationBase {
3232
private IHttpClient httpClient;
3333
private Integer connectTimeoutForDefaultHttpClient;
3434
private Integer readTimeoutForDefaultHttpClient;
35+
private boolean retryDisabled;
3536
String tenant;
3637

3738
//The following fields are set in only some applications and/or set internally by the library. To avoid excessive
@@ -150,6 +151,10 @@ public Integer readTimeoutForDefaultHttpClient() {
150151
return this.readTimeoutForDefaultHttpClient;
151152
}
152153

154+
boolean isRetryDisabled() {
155+
return this.retryDisabled;
156+
}
157+
153158
String tenant() {
154159
return this.tenant;
155160
}
@@ -190,6 +195,7 @@ public abstract static class Builder<T extends Builder<T>> {
190195
Boolean onlySendFailureTelemetry = false;
191196
Integer connectTimeoutForDefaultHttpClient;
192197
Integer readTimeoutForDefaultHttpClient;
198+
boolean disableInternalRetries;
193199
private String clientId;
194200
private Authority authenticationAuthority = createDefaultAADAuthority();
195201

@@ -319,6 +325,18 @@ public T readTimeoutForDefaultHttpClient(Integer val) {
319325
return self();
320326
}
321327

328+
/**
329+
* The library has a number of policies for retrying HTTP calls in different scenarios.
330+
* <p>
331+
* This will disable all internal retry behavior, allowing customized retry behavior via your own implementation of {@link IHttpClient}
332+
*
333+
* @return instance of the Builder on which method was called
334+
*/
335+
public T disableInternalRetries() {
336+
disableInternalRetries = true;
337+
return self();
338+
}
339+
322340
T telemetryConsumer(Consumer<List<HashMap<String, String>>> val) {
323341
validateNotNull("telemetryConsumer", val);
324342

@@ -356,5 +374,16 @@ private static Authority createDefaultAADAuthority() {
356374
readTimeoutForDefaultHttpClient = builder.readTimeoutForDefaultHttpClient;
357375
authenticationAuthority = builder.authenticationAuthority;
358376
clientId = builder.clientId;
377+
retryDisabled = builder.disableInternalRetries;
378+
379+
if (builder.httpClient == null) {
380+
httpClient = new DefaultHttpClient(
381+
builder.proxy,
382+
builder.sslSocketFactory,
383+
builder.connectTimeoutForDefaultHttpClient,
384+
builder.readTimeoutForDefaultHttpClient);
385+
} else {
386+
httpClient = builder.httpClient;
387+
}
359388
}
360389
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractClientApplicationBase.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,7 @@ public T correlationId(String val) {
566566
super.serviceBundle = new ServiceBundle(
567567
builder.executorService,
568568
new TelemetryManager(telemetryConsumer, builder.onlySendFailureTelemetry),
569-
new HttpHelper(builder.httpClient == null ?
570-
new DefaultHttpClient(builder.proxy, builder.sslSocketFactory, builder.connectTimeoutForDefaultHttpClient, builder.readTimeoutForDefaultHttpClient) :
571-
builder.httpClient)
569+
new HttpHelper(this, new DefaultRetryPolicy())
572570
);
573571

574572
if (aadAadInstanceDiscoveryResponse != null) {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthorizationResponseHandler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class AuthorizationResponseHandler implements HttpHandler {
4141
public void handle(HttpExchange httpExchange) throws IOException {
4242
try {
4343
if (!httpExchange.getRequestURI().getPath().equalsIgnoreCase("/")) {
44-
httpExchange.sendResponseHeaders(200, 0);
44+
httpExchange.sendResponseHeaders(HttpStatus.HTTP_OK, 0);
4545
return;
4646
}
4747
String responseBody = new BufferedReader(new InputStreamReader(
@@ -92,13 +92,13 @@ private void sendErrorResponse(HttpExchange httpExchange, String response) throw
9292
private void send302Response(HttpExchange httpExchange, String redirectUri) throws IOException {
9393
Headers responseHeaders = httpExchange.getResponseHeaders();
9494
responseHeaders.set("Location", redirectUri);
95-
httpExchange.sendResponseHeaders(302, 0);
95+
httpExchange.sendResponseHeaders(HttpStatus.HTTP_FOUND, 0);
9696
}
9797

9898
private void send200Response(HttpExchange httpExchange, String response) throws IOException {
9999
byte[] responseBytes = response.getBytes("UTF-8");
100100
httpExchange.getResponseHeaders().set("Content-Type", "text/html; charset=UTF-8");
101-
httpExchange.sendResponseHeaders(200, responseBytes.length);
101+
httpExchange.sendResponseHeaders(HttpStatus.HTTP_OK, responseBytes.length);
102102
OutputStream os = httpExchange.getResponseBody();
103103
os.write(responseBytes);
104104
os.close();
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
/**
7+
* Default retry policy for most MSAL Java flows
8+
*/
9+
class DefaultRetryPolicy implements IRetryPolicy {
10+
private static final int RETRY_NUM = 1;
11+
private static final int RETRY_DELAY_MS = 1000;
12+
13+
@Override
14+
public boolean isRetryable(IHttpResponse httpResponse) {
15+
return HttpStatus.isServerError(httpResponse.statusCode()) &&
16+
HttpHelper.getRetryAfterHeader(httpResponse) == null;
17+
}
18+
19+
@Override
20+
public int getMaxRetryCount(IHttpResponse httpResponse) {
21+
return RETRY_NUM;
22+
}
23+
24+
@Override
25+
public int getRetryDelayMs(IHttpResponse httpResponse) {
26+
return RETRY_DELAY_MS;
27+
}
28+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/DeviceCodeFlowRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ DeviceCode acquireDeviceCode(String url,
4646
this.requestContext(),
4747
serviceBundle);
4848

49-
if (response.statusCode() != HttpHelper.HTTP_STATUS_200) {
49+
if (response.statusCode() != HttpStatus.HTTP_OK) {
5050
throw MsalServiceExceptionFactory.fromHttpResponse(response);
5151
}
5252

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/HttpHelper.java

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@ class HttpHelper implements IHttpHelper {
1818

1919
private static final Logger log = LoggerFactory.getLogger(HttpHelper.class);
2020
public static final String RETRY_AFTER_HEADER = "Retry-After";
21-
public static final int RETRY_NUM = 2;
22-
public static final int RETRY_DELAY_MS = 1000;
23-
24-
public static final int HTTP_STATUS_200 = 200;
25-
public static final int HTTP_STATUS_400 = 400;
26-
public static final int HTTP_STATUS_429 = 429;
27-
public static final int HTTP_STATUS_500 = 500;
2821

2922
private IHttpClient httpClient;
23+
private IRetryPolicy retryPolicy;
24+
private boolean retryDisabled;
3025

31-
HttpHelper(IHttpClient httpClient) {
26+
HttpHelper(IHttpClient httpClient, IRetryPolicy retryPolicy) {
3227
this.httpClient = httpClient;
28+
this.retryPolicy = retryPolicy != null ? retryPolicy : new DefaultRetryPolicy();
29+
}
30+
31+
HttpHelper(AbstractApplicationBase application, IRetryPolicy retryPolicy) {
32+
this.httpClient = application.httpClient();
33+
this.retryDisabled = application.isRetryDisabled();
34+
this.retryPolicy = retryPolicy != null ? retryPolicy : new DefaultRetryPolicy();
3335
}
3436

3537
public IHttpResponse executeHttpRequest(HttpRequest httpRequest,
@@ -141,20 +143,23 @@ private String getRequestThumbprint(RequestContext requestContext) {
141143
return StringHelper.createSha256Hash(sb.toString());
142144
}
143145

144-
boolean isRetryable(IHttpResponse httpResponse) {
145-
return httpResponse.statusCode() >= HTTP_STATUS_500 &&
146-
getRetryAfterHeader(httpResponse) == null;
147-
}
148-
149146
IHttpResponse executeHttpRequestWithRetries(HttpRequest httpRequest, IHttpClient httpClient)
150147
throws Exception {
151-
IHttpResponse httpResponse = null;
152-
for (int i = 0; i < RETRY_NUM; i++) {
148+
IHttpResponse httpResponse = httpClient.send(httpRequest);
149+
150+
if (retryDisabled) {
151+
return httpResponse;
152+
}
153+
154+
int retryCount = 0;
155+
int maxRetries = retryPolicy.getMaxRetryCount(httpResponse);
156+
157+
while (retryPolicy.isRetryable(httpResponse) && retryCount < maxRetries) {
158+
Thread.sleep(retryPolicy.getRetryDelayMs(httpResponse));
159+
160+
retryCount++;
161+
153162
httpResponse = httpClient.send(httpRequest);
154-
if (!isRetryable(httpResponse)) {
155-
break;
156-
}
157-
Thread.sleep(RETRY_DELAY_MS);
158163
}
159164

160165
return httpResponse;
@@ -180,8 +185,8 @@ private void processThrottlingInstructions(IHttpResponse httpResponse, RequestCo
180185
Integer retryAfterHeaderVal = getRetryAfterHeader(httpResponse);
181186
if (retryAfterHeaderVal != null) {
182187
expirationTimestamp = System.currentTimeMillis() + retryAfterHeaderVal * 1000;
183-
} else if (httpResponse.statusCode() == HTTP_STATUS_429 ||
184-
(httpResponse.statusCode() >= HTTP_STATUS_500)) {
188+
} else if (httpResponse.statusCode() == HttpStatus.HTTP_TOO_MANY_REQUESTS ||
189+
(httpResponse.statusCode() >= HttpStatus.HTTP_INTERNAL_ERROR)) {
185190

186191
expirationTimestamp = System.currentTimeMillis() + ThrottlingCache.DEFAULT_THROTTLING_TIME_SEC * 1000;
187192
}
@@ -191,7 +196,7 @@ private void processThrottlingInstructions(IHttpResponse httpResponse, RequestCo
191196
}
192197
}
193198

194-
private Integer getRetryAfterHeader(IHttpResponse httpResponse) {
199+
static Integer getRetryAfterHeader(IHttpResponse httpResponse) {
195200

196201
if (httpResponse.headers() != null) {
197202
TreeMap<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
@@ -279,4 +284,8 @@ private static void verifyReturnedCorrelationId(final HttpRequest httpRequest,
279284
log.info(msg);
280285
}
281286
}
287+
288+
void setRetryPolicy(IRetryPolicy retryPolicy) {
289+
this.retryPolicy = retryPolicy;
290+
}
282291
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/HttpHelperManagedIdentity.java

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
class HttpStatus {
7+
8+
static final int HTTP_OK = 200;
9+
static final int HTTP_FOUND = 302;
10+
static final int HTTP_BAD_REQUEST = 400;
11+
static final int HTTP_UNAUTHORIZED = 401;
12+
static final int HTTP_NOT_FOUND = 404;
13+
static final int HTTP_REQUEST_TIMEOUT = 408;
14+
static final int HTTP_GONE = 410;
15+
static final int HTTP_TOO_MANY_REQUESTS = 429;
16+
static final int HTTP_INTERNAL_ERROR = 500;
17+
static final int HTTP_UNAVAILABLE = 503;
18+
static final int HTTP_GATEWAY_TIMEOUT = 504;
19+
20+
/**
21+
* Determines if the status code represents a server error (5xx).
22+
*
23+
* @param code The HTTP status code
24+
* @return true if the status code is between 500 and 599, inclusive
25+
*/
26+
static boolean isServerError(int code) {
27+
return code >= 500 && code < 600;
28+
}
29+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
3636
super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS);
3737
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
3838

39+
//IMDS uses a different retry policy than the default used in other MI flows
40+
IHttpHelper httpHelper = serviceBundle.getHttpHelper();
41+
if (httpHelper instanceof HttpHelper) {
42+
((HttpHelper) httpHelper).setRetryPolicy(new IMDSRetryPolicy());
43+
}
44+
3945
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
4046
LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST)));
4147
try {

0 commit comments

Comments
 (0)