Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion duo-client/src/main/java/com/duosecurity/client/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class Http {
private Headers.Builder headers;
private SortedMap<String, Object> params = new TreeMap<String, Object>();
protected int sigVersion = 5;
private long maxBackoffMs = MAX_BACKOFF_MS;
private Random random = new Random();
private OkHttpClient httpClient;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
Expand Down Expand Up @@ -314,7 +315,7 @@ private Response executeRequest(Request request) throws Exception {
long backoffMs = INITIAL_BACKOFF_MS;
while (true) {
Response response = httpClient.newCall(request).execute();
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > MAX_BACKOFF_MS) {
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > maxBackoffMs) {
return response;
}

Expand All @@ -327,6 +328,13 @@ protected void sleep(long ms) throws Exception {
Thread.sleep(ms);
}

protected void setMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;
}

public void signRequest(String ikey, String skey)
throws UnsupportedEncodingException {
signRequest(ikey, skey, sigVersion);
Expand Down Expand Up @@ -529,6 +537,7 @@ protected abstract static class ClientBuilder<T extends Http> {
private final String uri;

private int timeout = DEFAULT_TIMEOUT_SECS;
private long maxBackoffMs = MAX_BACKOFF_MS;
private String[] caCerts = null;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
private Map<String, String> headers = new HashMap<String, String>();
Expand Down Expand Up @@ -558,6 +567,26 @@ public ClientBuilder<T> useTimeout(int timeout) {
return this;
}

/**
* Set the maximum base backoff time in milliseconds for rate limit (429) retries.
* When a request receives a 429 response, the client retries with exponential
* backoff until the base backoff exceeds this threshold. Note that actual sleep
* time includes up to 1000ms of random jitter on top of the base backoff.
* Setting to 0 disables retries. Default is 32000ms (32 seconds).
*
* @param maxBackoffMs the maximum base backoff in milliseconds (must be >= 0)
* @return the Builder
* @throws IllegalArgumentException if maxBackoffMs is negative
*/
public ClientBuilder<T> useMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;

return this;
}

/**
* Provide custom CA certificates for certificate pinning.
*
Expand Down Expand Up @@ -604,6 +633,7 @@ public ClientBuilder<T> addHeader(String name, String value) {
*/
public T build() {
T duoClient = createClient(method, host, uri, timeout);
duoClient.setMaxBackoffMs(maxBackoffMs);
if (caCerts != null) {
duoClient.useCustomCertificates(caCerts);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ public class HttpRateLimitRetryTest {

private final int RANDOM_INT = 234;

@Before
public void before() throws Exception {
http = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
http = Mockito.spy(http);
private void setupHttp(Http client) throws Exception {
http = Mockito.spy(client);

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
Expand All @@ -39,6 +37,12 @@ public void before() throws Exception {
Mockito.doNothing().when(http).sleep(Mockito.any(Long.class));
}

@Before
public void before() throws Exception {
Http client = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
setupHttp(client);
}

@Test
public void testSingleRateLimitRetry() throws Exception {
final List<Response> responses = new ArrayList<Response>();
Expand Down Expand Up @@ -128,4 +132,87 @@ public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
assertEquals(16000L + RANDOM_INT, (long) sleepTimes.get(4));
assertEquals(32000L + RANDOM_INT, (long) sleepTimes.get(5));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(0)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

Response actualRes = http.executeHttpRequest();
assertEquals(1, responses.size());
assertEquals(429, actualRes.code());

// Verify no sleep was called
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffCustomLimit() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(4000)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

// With maxBackoff=4000, retries at 1000, 2000, 4000, then 8000 > 4000 exits
// That's 4 total requests (1 initial + 3 retries)
Response actualRes = http.executeHttpRequest();
assertEquals(4, responses.size());
assertEquals(429, actualRes.code());

ArgumentCaptor<Long> sleepCapture = ArgumentCaptor.forClass(Long.class);
Mockito.verify(http, Mockito.times(3)).sleep(sleepCapture.capture());
List<Long> sleepTimes = sleepCapture.getAllValues();
assertEquals(1000L + RANDOM_INT, (long) sleepTimes.get(0));
assertEquals(2000L + RANDOM_INT, (long) sleepTimes.get(1));
assertEquals(4000L + RANDOM_INT, (long) sleepTimes.get(2));
}

@Test(expected = IllegalArgumentException.class)
public void testMaxBackoffNegativeThrows() {
new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(-1)
.build();
}
}