From 1882cebbce3032d1dc05edb8067be233b1b5ac5f Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Wed, 21 May 2025 10:01:37 +0200 Subject: [PATCH] [core,aad] Split GetAccessToken callback To allow client-common library to override the GetAccessToken callback introduce a new GetCommonAccessToken callback. This callback defaults to call the existing GetAccessToken callback, but client-common library can override if desired, so that a common token retrieval method is executed before a client UI is invoked. --- include/freerdp/freerdp.h | 45 ++++++++++++++++++++ libfreerdp/core/aad.c | 13 +++--- libfreerdp/core/aad.h | 3 +- libfreerdp/core/freerdp.c | 79 ++++++++++++++++++++++++++++++++++- libfreerdp/core/gateway/arm.c | 6 ++- libfreerdp/core/rdp.c | 2 +- libfreerdp/core/rdp.h | 1 + 7 files changed, 138 insertions(+), 11 deletions(-) diff --git a/include/freerdp/freerdp.h b/include/freerdp/freerdp.h index eba065c5f..6a32e2cdf 100644 --- a/include/freerdp/freerdp.h +++ b/include/freerdp/freerdp.h @@ -134,9 +134,35 @@ extern "C" ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */ } AccessTokenType; + /** @brief A function to be implemented by a client. It is called whenever the connection + * requires an access token. + * @param instance The instance the function is called for + * @param tokenType The type of token requested + * @param token A pointer that will hold the (allocated) token string + * @param count The number of arguments following + * + * @return \b TRUE for success, \b FALSE otherwise + * @since version 3.0.0 + */ typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token, size_t count, ...); + /** @brief The function is called whenever the connection requires an access token. + * It differs from \ref pGetAccessToken and is not meant to be implemented by a client + * directly. The client-common library will use this to provide common means to retrieve a token + * and only if that fails the instanc->GetAccessToken callback will be called. + * + * @param context The context the function is called for + * @param tokenType The type of token requested + * @param token A pointer that will hold the (allocated) token string + * @param count The number of arguments following + * + * @return \b TRUE for success, \b FALSE otherwise + * @since version 3.16.0 + */ + typedef BOOL (*pGetCommonAccessToken)(rdpContext* context, AccessTokenType tokenType, + char** token, size_t count, ...); + /** @brief Callback used to inform about a reconnection attempt * * @param instance The instance the information is for @@ -769,6 +795,25 @@ owned by rdpRdp */ */ FREERDP_API BOOL freerdp_persist_credentials(rdpContext* context); + /** @brief set a new function to be called when an access token is requested. + * + * @param context The rdp context to set the function for. Must not be \b NULL + * @param GetCommonAccessToken The function pointer to set, \b NULL to disable + * + * @return \b TRUE for success, \b FALSE otherwise + * @since version 3.16.0 + */ + FREERDP_API BOOL freerdp_set_common_access_token(rdpContext* context, + pGetCommonAccessToken GetCommonAccessToken); + + /** @brief get the current function pointer set as GetCommonAccessToken + * + * @param context The rdp context to set the function for. Must not be \b NULL + * @return The current function pointer set or \b NULL + * @since version 3.16.0 + */ + FREERDP_API pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context); + #ifdef __cplusplus } #endif diff --git a/libfreerdp/core/aad.c b/libfreerdp/core/aad.c index bc7511360..92c0d0943 100644 --- a/libfreerdp/core/aad.c +++ b/libfreerdp/core/aad.c @@ -48,6 +48,7 @@ struct rdp_aad char* hostname; char* scope; wLog* log; + pGetCommonAccessToken GetCommonAccessToken; }; #ifdef WITH_AAD @@ -303,17 +304,17 @@ int aad_client_begin(rdpAad* aad) return -1; /* Obtain an oauth authorization code */ - if (!instance->GetAccessToken) + if (!aad->GetCommonAccessToken) { - WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL"); + WLog_Print(aad->log, WLOG_ERROR, "aad->rdpcontext->GetCommonAccessToken == NULL"); return -1; } if (!aad_fetch_wellknown(aad->log, aad->rdpcontext)) return -1; - const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token, - 2, aad->scope, aad->kid); + const BOOL arc = aad->GetCommonAccessToken(aad->rdpcontext, ACCESS_TOKEN_TYPE_AAD, + &aad->access_token, 2, aad->scope, aad->kid); if (!arc) { WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token"); @@ -788,7 +789,8 @@ static BOOL ensure_wellknown(WINPR_ATTR_UNUSED rdpContext* context) #endif -rdpAad* aad_new(rdpContext* context, rdpTransport* transport) +rdpAad* aad_new(rdpContext* context, rdpTransport* transport, + pGetCommonAccessToken GetCommonAccessToken) { WINPR_ASSERT(transport); WINPR_ASSERT(context); @@ -799,6 +801,7 @@ rdpAad* aad_new(rdpContext* context, rdpTransport* transport) return NULL; aad->log = WLog_Get(FREERDP_TAG("aad")); + aad->GetCommonAccessToken = GetCommonAccessToken; aad->key = freerdp_key_new(); if (!aad->key) goto fail; diff --git a/libfreerdp/core/aad.h b/libfreerdp/core/aad.h index 797043ced..912d1821f 100644 --- a/libfreerdp/core/aad.h +++ b/libfreerdp/core/aad.h @@ -42,6 +42,7 @@ FREERDP_LOCAL AAD_STATE aad_get_state(rdpAad* aad); FREERDP_LOCAL void aad_free(rdpAad* aad); WINPR_ATTR_MALLOC(aad_free, 1) -FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport); +FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport, + pGetCommonAccessToken GetCommonAccessToken); #endif /* FREERDP_LIB_CORE_AAD_H */ diff --git a/libfreerdp/core/freerdp.c b/libfreerdp/core/freerdp.c index 1f1436ccd..08f4c6720 100644 --- a/libfreerdp/core/freerdp.c +++ b/libfreerdp/core/freerdp.c @@ -789,6 +789,56 @@ BOOL freerdp_context_new(freerdp* instance) return freerdp_context_new_ex(instance, NULL); } +static BOOL freerdp_common_context(rdpContext* context, AccessTokenType tokenType, char** token, + size_t count, ...) +{ + BOOL rc = FALSE; + + WINPR_ASSERT(context); + if (!context->instance || !context->instance->GetAccessToken) + return TRUE; + + va_list ap; + va_start(ap, count); + switch (tokenType) + { + case ACCESS_TOKEN_TYPE_AAD: + if (count != 2) + { + WLog_ERR(TAG, + "ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz + ", aborting", + count); + } + else + { + const char* scope = va_arg(ap, const char*); + const char* req_cnf = va_arg(ap, const char*); + rc = context->instance->GetAccessToken(context->instance, tokenType, token, count, + scope, req_cnf); + } + break; + case ACCESS_TOKEN_TYPE_AVD: + if (count != 0) + { + WLog_WARN(TAG, + "ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz + ", ignoring", + count); + } + else + { + rc = context->instance->GetAccessToken(context->instance, tokenType, token, count); + } + break; + default: + break; + } + va_end(ap); + + return rc; +} + BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings) { rdpRdp* rdp = NULL; @@ -869,10 +919,19 @@ BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings) if (!context->dump) goto fail; + /* Fallback: + * Client common library might set a function pointer to handle this, but here we provide a + * default implementation that simply calls instance->GetAccessToken. + */ + if (!freerdp_set_common_access_token(context, freerdp_common_context)) + goto fail; + IFCALLRET(instance->ContextNew, ret, instance, context); - if (ret) - return TRUE; + if (!ret) + goto fail; + + return TRUE; fail: freerdp_context_free(instance); @@ -1507,3 +1566,19 @@ const char* freerdp_disconnect_reason_string(int reason) return "rn-unknown"; } } + +BOOL freerdp_set_common_access_token(rdpContext* context, + pGetCommonAccessToken GetCommonAccessToken) +{ + WINPR_ASSERT(context); + WINPR_ASSERT(context->rdp); + context->rdp->GetCommonAccessToken = GetCommonAccessToken; + return TRUE; +} + +pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context) +{ + WINPR_ASSERT(context); + WINPR_ASSERT(context->rdp); + return context->rdp->GetCommonAccessToken; +} diff --git a/libfreerdp/core/gateway/arm.c b/libfreerdp/core/gateway/arm.c index 5eabbb01f..aceca53f8 100644 --- a/libfreerdp/core/gateway/arm.c +++ b/libfreerdp/core/gateway/arm.c @@ -194,6 +194,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method, WINPR_ASSERT(content_type); WINPR_ASSERT(arm->context); + WINPR_ASSERT(arm->context->rdp); freerdp* instance = arm->context->instance; WINPR_ASSERT(instance); @@ -211,7 +212,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method, { char* token = NULL; - if (!instance->GetAccessToken) + if (!arm->context->rdp->GetCommonAccessToken) { WLog_Print(arm->log, WLOG_ERROR, "No authorization token provided"); goto out; @@ -220,7 +221,8 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method, if (!arm_fetch_wellknown(arm)) goto out; - if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0)) + if (!arm->context->rdp->GetCommonAccessToken(arm->context, ACCESS_TOKEN_TYPE_AVD, &token, + 0)) { WLog_Print(arm->log, WLOG_ERROR, "Unable to obtain access token"); goto out; diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index 6d3b1ead6..a1eb83e3d 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -2318,7 +2318,7 @@ static bool rdp_new_common(rdpRdp* rdp) goto fail; } - rdp->aad = aad_new(rdp->context, rdp->transport); + rdp->aad = aad_new(rdp->context, rdp->transport, rdp->GetCommonAccessToken); if (!rdp->aad) goto fail; diff --git a/libfreerdp/core/rdp.h b/libfreerdp/core/rdp.h index cfa8995d8..e66d005b5 100644 --- a/libfreerdp/core/rdp.h +++ b/libfreerdp/core/rdp.h @@ -209,6 +209,7 @@ struct rdp_rdp char log_context[64]; WINPR_JSON* wellknown; FreeRDPTimer* timer; + pGetCommonAccessToken GetCommonAccessToken; }; FREERDP_LOCAL BOOL rdp_read_security_header(rdpRdp* rdp, wStream* s, UINT16* flags, UINT16* length);