diff --git a/CMakeLists.txt b/CMakeLists.txt index fde3f8446..328e1c8f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -744,6 +744,7 @@ find_feature(PCSC ${PCSC_FEATURE_TYPE} ${PCSC_FEATURE_PURPOSE} ${PCSC_FEATURE_DE option(WITH_AAD "Compile with support for Azure AD authentication" ON) if (WITH_AAD) find_package(CJSON REQUIRED) + include_directories(${CJSON_INCLUDE_DIRS}) endif() if (WITH_DSP_FFMPEG OR WITH_VIDEO_FFMPEG OR WITH_FFMPEG) diff --git a/client/SDL/sdl_freerdp.cpp b/client/SDL/sdl_freerdp.cpp index 836db5093..a6cba11de 100644 --- a/client/SDL/sdl_freerdp.cpp +++ b/client/SDL/sdl_freerdp.cpp @@ -1139,11 +1139,9 @@ static BOOL sdl_client_new(freerdp* instance, rdpContext* context) instance->VerifyChangedCertificateEx = client_cli_verify_changed_certificate_ex; instance->LogonErrorInfo = sdl_logon_error_info; #ifdef WITH_WEBVIEW - instance->GetRDSAADAccessToken = sdl_webview_get_rdsaad_access_token; - instance->GetAVDAccessToken = sdl_webview_get_avd_access_token; + instance->GetAccessToken = sdl_webview_get_access_token; #else - instance->GetRDSAADAccessToken = client_cli_get_rdsaad_access_token; - instance->GetAVDAccessToken = client_cli_get_avd_access_token; + instance->GetAccessToken = client_cli_get_access_token; #endif /* TODO: Client display set up */ diff --git a/client/SDL/sdl_webview.cpp b/client/SDL/sdl_webview.cpp index 4a840a9ce..53e704140 100644 --- a/client/SDL/sdl_webview.cpp +++ b/client/SDL/sdl_webview.cpp @@ -26,10 +26,14 @@ #include #include +#include #include +#include #include "sdl_webview.hpp" +#define TAG CLIENT_TAG("sdl.webview") + class SchemeHandler : public QWebEngineUrlSchemeHandler { public: @@ -91,8 +95,8 @@ static std::string sdl_webview_get_auth_code(QString url) return handler.code(); } -BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, const char* req_cnf, - char** token) +static BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, + const char* req_cnf, char** token) { WINPR_ASSERT(instance); WINPR_ASSERT(scope); @@ -121,7 +125,7 @@ BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, c return client_common_get_access_token(instance, token_request.c_str(), token); } -BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token) +static BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token) { WINPR_ASSERT(token); @@ -143,3 +147,46 @@ BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token) "&scope=" + scope + "&redirect_uri=" + redirect_uri; return client_common_get_access_token(instance, token_request.c_str(), token); } + +BOOL sdl_webview_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token, + size_t count, ...) +{ + WINPR_ASSERT(instance); + WINPR_ASSERT(token); + switch (tokenType) + { + case ACCESS_TOKEN_TYPE_AAD: + { + if (count < 2) + { + WLog_ERR(TAG, + "ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz + ", aborting", + count); + return FALSE; + } + else if (count > 2) + WLog_WARN(TAG, + "ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz + ", ignoring", + count); + va_list ap; + va_start(ap, count); + const char* scope = va_arg(ap, const char*); + const char* req_cnf = va_arg(ap, const char*); + const BOOL rc = sdl_webview_get_rdsaad_access_token(instance, scope, req_cnf, token); + va_end(ap); + return rc; + } + case ACCESS_TOKEN_TYPE_AVD: + if (count != 0) + WLog_WARN(TAG, + "ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz + ", ignoring", + count); + return sdl_webview_get_avd_access_token(instance, token); + default: + WLog_ERR(TAG, "Unexpected value for AccessTokenType [%" PRIuz "], aborting", tokenType); + return FALSE; + } +} diff --git a/client/SDL/sdl_webview.hpp b/client/SDL/sdl_webview.hpp index 93e229721..8ece62d25 100644 --- a/client/SDL/sdl_webview.hpp +++ b/client/SDL/sdl_webview.hpp @@ -26,9 +26,8 @@ extern "C" { #endif - BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token); - BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, - const char* req_cnf, char** token); + BOOL sdl_webview_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token, + size_t count, ...); #ifdef __cplusplus } diff --git a/client/X11/xf_client.c b/client/X11/xf_client.c index 8ca5df935..ec66c6d72 100644 --- a/client/X11/xf_client.c +++ b/client/X11/xf_client.c @@ -1976,8 +1976,7 @@ static BOOL xfreerdp_client_new(freerdp* instance, rdpContext* context) instance->PostDisconnect = xf_post_disconnect; instance->PostFinalDisconnect = xf_post_final_disconnect; instance->LogonErrorInfo = xf_logon_error_info; - instance->GetRDSAADAccessToken = client_cli_get_rdsaad_access_token; - instance->GetAVDAccessToken = client_cli_get_avd_access_token; + instance->GetAccessToken = client_cli_get_access_token; PubSub_SubscribeTerminate(context->pubSub, xf_TerminateEventHandler); #ifdef WITH_XRENDER PubSub_SubscribeZoomingChange(context->pubSub, xf_ZoomingChangeEventHandler); diff --git a/client/common/client.c b/client/common/client.c index 01d6f3682..e9f6731d7 100644 --- a/client/common/client.c +++ b/client/common/client.c @@ -964,8 +964,8 @@ static char* extract_authorization_code(char* url) return NULL; } -BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope, const char* req_cnf, - char** token) +static BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope, + const char* req_cnf, char** token) { size_t size = 0; char* url = NULL; @@ -1009,7 +1009,7 @@ cleanup: return (*token != NULL); } -BOOL client_cli_get_avd_access_token(freerdp* instance, char** token) +static BOOL client_cli_get_avd_access_token(freerdp* instance, char** token) { size_t size = 0; char* url = NULL; @@ -1052,6 +1052,49 @@ cleanup: return (*token != NULL); } +BOOL client_cli_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token, + size_t count, ...) +{ + WINPR_ASSERT(instance); + WINPR_ASSERT(token); + switch (tokenType) + { + case ACCESS_TOKEN_TYPE_AAD: + { + if (count < 2) + { + WLog_ERR(TAG, + "ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz + ", aborting", + count); + return FALSE; + } + else if (count > 2) + WLog_WARN(TAG, + "ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz + ", ignoring", + count); + va_list ap; + va_start(ap, count); + const char* scope = va_arg(ap, const char*); + const char* req_cnf = va_arg(ap, const char*); + const BOOL rc = client_cli_get_rdsaad_access_token(instance, scope, req_cnf, token); + va_end(ap); + return rc; + } + case ACCESS_TOKEN_TYPE_AVD: + if (count != 0) + WLog_WARN(TAG, + "ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz + ", ignoring", + count); + return client_cli_get_avd_access_token(instance, token); + default: + WLog_ERR(TAG, "Unexpected value for AccessTokenType [%" PRIuz "], aborting", tokenType); + return FALSE; + } +} + BOOL client_common_get_access_token(freerdp* instance, const char* request, char** token) { #ifdef WITH_AAD diff --git a/include/freerdp/client.h b/include/freerdp/client.h index 76bc33b19..217afdd84 100644 --- a/include/freerdp/client.h +++ b/include/freerdp/client.h @@ -170,9 +170,8 @@ extern "C" FREERDP_API int client_cli_logon_error_info(freerdp* instance, UINT32 data, UINT32 type); - FREERDP_API BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope, - const char* req_cnf, char** token); - FREERDP_API BOOL client_cli_get_avd_access_token(freerdp* instance, char** token); + FREERDP_API BOOL client_cli_get_access_token(freerdp* instance, AccessTokenType tokenType, + char** token, size_t count, ...); FREERDP_API BOOL client_common_get_access_token(freerdp* instance, const char* request, char** token); diff --git a/include/freerdp/freerdp.h b/include/freerdp/freerdp.h index 4758287f6..dbbb84edf 100644 --- a/include/freerdp/freerdp.h +++ b/include/freerdp/freerdp.h @@ -126,9 +126,15 @@ extern "C" char** domain, rdp_auth_reason reason); typedef BOOL (*pChooseSmartcard)(freerdp* instance, SmartcardCertInfo** cert_list, DWORD count, DWORD* choice, BOOL gateway); - typedef BOOL (*pGetRDSAADAccessToken)(freerdp* instance, const char* scope, const char* req_cnf, - char** token); - typedef BOOL (*pGetAVDAccessToken)(freerdp* instance, char** token); + + typedef enum + { + ACCESS_TOKEN_TYPE_AAD, /**!< oauth2 access token for RDS AAD authentication */ + ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */ + } AccessTokenType; + + typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token, + size_t count, ...); /** @brief Callback used if user interaction is required to accept * an unknown certificate. @@ -522,13 +528,10 @@ owned by rdpRdp */ Callback for choosing a smartcard for logon. Used when multiple smartcards are available. Returns an index into a list of SmartcardCertInfo pointers */ - ALIGN64 pGetRDSAADAccessToken GetRDSAADAccessToken; /* (offset 71) - Callback for obtaining an oauth2 access token - for RDS AAD authentication */ - ALIGN64 pGetAVDAccessToken GetAVDAccessToken; /* (offset 72) - Callback for obtaining an oauth2 access token - for Azure Virtual Desktop */ - UINT64 paddingE[80 - 73]; /* 73 */ + ALIGN64 pGetAccessToken GetAccessToken; /* (offset 71) + Callback for obtaining an access token + for \b AccessTokenType authentication */ + UINT64 paddingE[80 - 72]; /* 72 */ }; struct rdp_channel_handles diff --git a/libfreerdp/core/aad.c b/libfreerdp/core/aad.c index 8837730da..bc44111d7 100644 --- a/libfreerdp/core/aad.c +++ b/libfreerdp/core/aad.c @@ -282,13 +282,13 @@ int aad_client_begin(rdpAad* aad) return -1; /* Obtain an oauth authorization code */ - if (!instance->GetRDSAADAccessToken) + if (!instance->GetAccessToken) { - WLog_Print(aad->log, WLOG_ERROR, "instance->GetRDSAADAccessToken == NULL"); + WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL"); return -1; } - const BOOL arc = - instance->GetRDSAADAccessToken(instance, aad->scope, aad->kid, &aad->access_token); + const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token, + 2, aad->scope, aad->kid); if (!arc) { WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token"); diff --git a/libfreerdp/core/gateway/arm.c b/libfreerdp/core/gateway/arm.c index a1ac5d799..bb23d7a90 100644 --- a/libfreerdp/core/gateway/arm.c +++ b/libfreerdp/core/gateway/arm.c @@ -174,13 +174,13 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method, { char* token = NULL; - if (!instance->GetAVDAccessToken) + if (!instance->GetAccessToken) { WLog_ERR(TAG, "No authorization token provided"); goto out; } - if (!instance->GetAVDAccessToken(instance, &token)) + if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0)) { WLog_ERR(TAG, "Unable to obtain access token"); goto out;