redis/jvm-redis-authx-entraid

Using AzureIdentityProvider with WorkloadIdentityCredential causes high CPU usage and object allocations in a token refresh busy-wait

Opened this issue · 1 comments

nluk commented

Hi,
We've been using this extension since beta-1, before the support for DefaultAzureCredential was introduced, to authenticate to an Azure Managed Redis instance via an AKS workload identity. For this specific scenario, there is a discrepancy between built-in MSAL caching, and the TokenManager + RenewalScheduler renewal process, that causes excessive CPU usage (busy-wait) and rapid object allocation. The scenario in detail:

  1. Your AzureIdentityProvider delegates to the DefaultAzureCredential::getToken
    accessTokenSupplier = () -> defaultAzureCredential.getToken(ctx).block(Duration.ofMillis(timeout));
  2. DefaultAzureCredential selects WorkloadIdentityCredential from the credentials chain
  3. WorkloadIdentityCredential invokes an IdentityClient::authenticateWithWorkloadIdentityConfidentialClient:
    https://github.com/Azure/azure-sdk-for-java/blob/f39bb5bf7bbea85b4591a01fa4a99c2fcec94072/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java#L113
  4. Down the invocation chain that translates to MSAL's AquireTokenByClientCredentialSupplier:
    https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/5a4f9fcffb9d0bf9d8c2c15e29a056213a967d32/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByClientCredentialSupplier.java#L9
  5. That delegates to a AcquireTokenSilentSupplier, that has a built-in cache of tokens.
    https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/5a4f9fcffb9d0bf9d8c2c15e29a056213a967d32/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java#L35

In our experience, the issued access token is valid for 24 hours. The cache won't refresh the token unless it's about to expire - five minutes left (https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/5a4f9fcffb9d0bf9d8c2c15e29a056213a967d32/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java#L16).

With the default ratio and lower bound:

public static final float DEFAULT_EXPIRATION_REFRESH_RATIO = 0.75F;

public static final int DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 2 * 60 * 1000;

The calculation in:

public long calculateRenewalDelay(long expireDate, long issueDate) {

resolves to:

  1. Lower bound is 24h - 2m - elapsed 0h = ~23h58m
  2. Ratio is 0.75 * 24h - elapsed 0h = ~18h00m
  3. Min(bound,ratio) = 18h of delay.

That means, that in 18 hours an attempt to refresh the token will be made, with 6 hours left until token expiry. But that attempt will result in WorkloadIdentityCredential returning a cached token, since we're not in the 5-minute-long window of refresh in MSAL. That triggers next calculation:

  1. Lower bound is 24h - 2m - elapsed 18h = ~5h58m
  2. Ratio is 0.75 * 24h - elapsed 18 = ~0
  3. Min(bound, ratio) = 0, immediate refresh

The immediate refresh is submitted to an executor, that hits the cache and the entire process loops in a busy-wait.

JFR method profiling shows a busy-wait in TokenManager:
Image

And rapid allocation of renewed tokens:

Image

Minimal example without MSAL, with a provider that replicates the caching behaviour:

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.SimpleToken;
import redis.clients.authentication.core.Token;
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;
import redis.clients.authentication.core.TokenManagerConfig;
import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;

import java.time.Duration;
import java.time.Instant;
import java.util.Collections;

public class MinimalExample {

    private static final Logger log = LoggerFactory.getLogger(MinimalExample.class);

    static class IdentityProviderWith5MinuteBound implements IdentityProvider {

        private static final Logger log = LoggerFactory.getLogger(IdentityProviderWith5MinuteBound.class);

        private Token token = issueToken();

        @Override
        public Token requestToken() {
            if(System.currentTimeMillis() > (token.getExpiresAt() - (5 * 60 * 1000))) {
                token = issueToken();
            }
            log.info("Returning cached token");
            return token;
        }

        private static Token issueToken() {
            log.info("Issuing token");
            Instant now = Instant.now(); // needed as a reference point-in-time, because TokenManager uses System.currentTimeMillis
            Instant iat = now.minus(Duration.ofHours(19)); // Issued 19hrs ago = we're hitting the default 0.75 ratio of our 24h token
            Instant exp = iat.plus(Duration.ofHours(24));
            return new SimpleToken("user", "value", exp.toEpochMilli(), iat.toEpochMilli(), Collections.emptyMap());
        }
    }

    public static void main(String[] args) throws InterruptedException {
        TokenManagerConfig defaultEntraConfig = EntraIDTokenAuthConfigBuilder.builder().build().getTokenManagerConfig();
        IdentityProviderWith5MinuteBound identityProvider = new IdentityProviderWith5MinuteBound();
        TokenManager tokenManager = new TokenManager(identityProvider, defaultEntraConfig);
        Duration.ofMillis(tokenManager.calculateRenewalDelay(Instant.now().plus(Duration.ofHours(24)).toEpochMilli(), Instant.now().toEpochMilli()));
        tokenManager.start(new TokenListener() {
            @Override
            public void onTokenRenewed(Token newToken) {
                log.info("Token renewed");
            }

            @Override
            public void onError(Exception reason) {
                log.info("Error", reason);
            }
        }, false);
        Thread.sleep(Duration.ofDays(1));
    }
}

hi @nluk ,

thank you for letting us know the details of this case and for your patience as well.
trying to explore more on this CPU peaks you are having;
I wonder if you get profiling results above from your minimal example or from the actual environment when WorkloadIdentityCredential kicks in and triggers the calls to AcquireTokenSilentSupplier.

Asking since, for sure the calculations would go wrong, when we are breaking into the way library instantiates and interprets the internal objects. Like in minimal example, we are using a SimpleToken which is a general abstract implementation but its not meant to be for EntraID.

Provider implementation for EntraID only instantiates and uses JWToken rather than SimpleToken.
And JWToken always overwrites the receive timestamp here,

        this.receivedAt = System.currentTimeMillis();

The way to produce the similar behaviour of TokenManager that is provided a WorkloadIdentityCredential under the hood should align more with example below;

public class MinimalExample2 {

    private static final Logger log = LoggerFactory.getLogger(MinimalExample2.class);

    private static Instant now = Instant.now(); // needed as a reference point-in-time, because TokenManager uses
    // System.currentTimeMillis
    private static Instant iat = now.minus(Duration.ofHours(19)); // Issued 19hrs ago = we're hitting the default 0.75
                                                                  // ratio of
    // our 24h token
    private static Instant exp = iat.plus(Duration.ofHours(24));
    private static String cachedToken = JWT.create().withExpiresAt(new Date(exp.toEpochMilli()))
            .withClaim("oid", "user").sign(Algorithm.none());

    private static IAuthenticationResult cachedAuthenticationResult = new IAuthenticationResult() {
        @Override
        public String accessToken() {
            return cachedToken;
        }

        @Override
        public String idToken() {
            throw new UnsupportedOperationException("Unimplemented method 'idToken'");
        }

        @Override
        public IAccount account() {
            throw new UnsupportedOperationException("Unimplemented method 'account'");
        }

        @Override
        public ITenantProfile tenantProfile() {
            throw new UnsupportedOperationException("Unimplemented method 'tenantProfile'");
        }

        @Override
        public String environment() {
            throw new UnsupportedOperationException("Unimplemented method 'environment'");
        }

        @Override
        public String scopes() {
            throw new UnsupportedOperationException("Unimplemented method 'scopes'");
        }

        @Override
        public Date expiresOnDate() {
            throw new UnsupportedOperationException("Unimplemented method 'expiresOnDate'");
        }
    };

    public static void main(String[] args) throws InterruptedException {
        TokenAuthConfig config = EntraIDTokenAuthConfigBuilder.builder()
                .customEntraIdAuthenticationSupplier(getAuthenticationSupplier()).build();
        TokenManager tokenManager = new TokenManager(config.getIdentityProviderConfig().getProvider(),
                config.getTokenManagerConfig());

        tokenManager.start(new TokenListener() {
            @Override
            public void onTokenRenewed(Token newToken) {
                log.info("Token renewed");
            }

            @Override
            public void onError(Exception reason) {
                log.info("Error", reason);
            }
        }, true);
        Thread.sleep(60000);
    }

    private static Supplier<IAuthenticationResult> getAuthenticationSupplier() {
        return () -> {
            return cachedAuthenticationResult;
        };
    }
}

This will recalculate the remaining time and TTL regarding to new timestamps each time TokenManager receives a token regardless of its new or same cached one before. So CPU peaks are unlikely to develop in the same fashion with 'SimpleToken' case above.
Please let me know it makes sense.

Having said that, i am willing to explore what causes this CPU consumption! In a case like lowerRefreshBoundMillis threshold kicks in, TokenManager will try to obtain a new token in a very aggressive manner similar to what you suggested.
But this is intended since an expired token will cause to drop Redis connections(existing ones running on old token) immediately.

What i suspect is that TokenManager dont receive any result from defaultAzureCredential.getToken(ctx).block(Duration.ofMillis(timeout)), within given timeouts or an error is issued from provider during the call. Not even a cached/old one it receives!
Could you check if you can see any error logs right before or during CPU peaks, or any timeouts or rejections from provider side.
Also it would be nice if you can share any specific configuration and/or customization if you have.