[aad,avc] unify callbacks to GetAccessToken

The AAD and AVD authentication mechanisms both need an OAuth2 token.
They only differ in the provided arguments, so unify the callbacks into
a single one with variable argument lists.
This commit is contained in:
akallabeth
2023-07-20 08:35:55 +02:00
committed by akallabeth
parent d309fcd6e8
commit 734117351d
10 changed files with 123 additions and 34 deletions

View File

@@ -744,6 +744,7 @@ find_feature(PCSC ${PCSC_FEATURE_TYPE} ${PCSC_FEATURE_PURPOSE} ${PCSC_FEATURE_DE
option(WITH_AAD "Compile with support for Azure AD authentication" ON)
if (WITH_AAD)
find_package(CJSON REQUIRED)
include_directories(${CJSON_INCLUDE_DIRS})
endif()
if (WITH_DSP_FFMPEG OR WITH_VIDEO_FFMPEG OR WITH_FFMPEG)

View File

@@ -1139,11 +1139,9 @@ static BOOL sdl_client_new(freerdp* instance, rdpContext* context)
instance->VerifyChangedCertificateEx = client_cli_verify_changed_certificate_ex;
instance->LogonErrorInfo = sdl_logon_error_info;
#ifdef WITH_WEBVIEW
instance->GetRDSAADAccessToken = sdl_webview_get_rdsaad_access_token;
instance->GetAVDAccessToken = sdl_webview_get_avd_access_token;
instance->GetAccessToken = sdl_webview_get_access_token;
#else
instance->GetRDSAADAccessToken = client_cli_get_rdsaad_access_token;
instance->GetAVDAccessToken = client_cli_get_avd_access_token;
instance->GetAccessToken = client_cli_get_access_token;
#endif
/* TODO: Client display set up */

View File

@@ -26,10 +26,14 @@
#include <string>
#include <stdlib.h>
#include <stdarg.h>
#include <winpr/string.h>
#include <freerdp/log.h>
#include "sdl_webview.hpp"
#define TAG CLIENT_TAG("sdl.webview")
class SchemeHandler : public QWebEngineUrlSchemeHandler
{
public:
@@ -91,8 +95,8 @@ static std::string sdl_webview_get_auth_code(QString url)
return handler.code();
}
BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, const char* req_cnf,
char** token)
static BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope,
const char* req_cnf, char** token)
{
WINPR_ASSERT(instance);
WINPR_ASSERT(scope);
@@ -121,7 +125,7 @@ BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope, c
return client_common_get_access_token(instance, token_request.c_str(), token);
}
BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token)
static BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token)
{
WINPR_ASSERT(token);
@@ -143,3 +147,46 @@ BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token)
"&scope=" + scope + "&redirect_uri=" + redirect_uri;
return client_common_get_access_token(instance, token_request.c_str(), token);
}
BOOL sdl_webview_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...)
{
WINPR_ASSERT(instance);
WINPR_ASSERT(token);
switch (tokenType)
{
case ACCESS_TOKEN_TYPE_AAD:
{
if (count < 2)
{
WLog_ERR(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", aborting",
count);
return FALSE;
}
else if (count > 2)
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", ignoring",
count);
va_list ap;
va_start(ap, count);
const char* scope = va_arg(ap, const char*);
const char* req_cnf = va_arg(ap, const char*);
const BOOL rc = sdl_webview_get_rdsaad_access_token(instance, scope, req_cnf, token);
va_end(ap);
return rc;
}
case ACCESS_TOKEN_TYPE_AVD:
if (count != 0)
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz
", ignoring",
count);
return sdl_webview_get_avd_access_token(instance, token);
default:
WLog_ERR(TAG, "Unexpected value for AccessTokenType [%" PRIuz "], aborting", tokenType);
return FALSE;
}
}

View File

@@ -26,9 +26,8 @@ extern "C"
{
#endif
BOOL sdl_webview_get_avd_access_token(freerdp* instance, char** token);
BOOL sdl_webview_get_rdsaad_access_token(freerdp* instance, const char* scope,
const char* req_cnf, char** token);
BOOL sdl_webview_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...);
#ifdef __cplusplus
}

View File

@@ -1976,8 +1976,7 @@ static BOOL xfreerdp_client_new(freerdp* instance, rdpContext* context)
instance->PostDisconnect = xf_post_disconnect;
instance->PostFinalDisconnect = xf_post_final_disconnect;
instance->LogonErrorInfo = xf_logon_error_info;
instance->GetRDSAADAccessToken = client_cli_get_rdsaad_access_token;
instance->GetAVDAccessToken = client_cli_get_avd_access_token;
instance->GetAccessToken = client_cli_get_access_token;
PubSub_SubscribeTerminate(context->pubSub, xf_TerminateEventHandler);
#ifdef WITH_XRENDER
PubSub_SubscribeZoomingChange(context->pubSub, xf_ZoomingChangeEventHandler);

View File

@@ -964,8 +964,8 @@ static char* extract_authorization_code(char* url)
return NULL;
}
BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope, const char* req_cnf,
char** token)
static BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope,
const char* req_cnf, char** token)
{
size_t size = 0;
char* url = NULL;
@@ -1009,7 +1009,7 @@ cleanup:
return (*token != NULL);
}
BOOL client_cli_get_avd_access_token(freerdp* instance, char** token)
static BOOL client_cli_get_avd_access_token(freerdp* instance, char** token)
{
size_t size = 0;
char* url = NULL;
@@ -1052,6 +1052,49 @@ cleanup:
return (*token != NULL);
}
BOOL client_cli_get_access_token(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...)
{
WINPR_ASSERT(instance);
WINPR_ASSERT(token);
switch (tokenType)
{
case ACCESS_TOKEN_TYPE_AAD:
{
if (count < 2)
{
WLog_ERR(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", aborting",
count);
return FALSE;
}
else if (count > 2)
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", ignoring",
count);
va_list ap;
va_start(ap, count);
const char* scope = va_arg(ap, const char*);
const char* req_cnf = va_arg(ap, const char*);
const BOOL rc = client_cli_get_rdsaad_access_token(instance, scope, req_cnf, token);
va_end(ap);
return rc;
}
case ACCESS_TOKEN_TYPE_AVD:
if (count != 0)
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz
", ignoring",
count);
return client_cli_get_avd_access_token(instance, token);
default:
WLog_ERR(TAG, "Unexpected value for AccessTokenType [%" PRIuz "], aborting", tokenType);
return FALSE;
}
}
BOOL client_common_get_access_token(freerdp* instance, const char* request, char** token)
{
#ifdef WITH_AAD

View File

@@ -170,9 +170,8 @@ extern "C"
FREERDP_API int client_cli_logon_error_info(freerdp* instance, UINT32 data, UINT32 type);
FREERDP_API BOOL client_cli_get_rdsaad_access_token(freerdp* instance, const char* scope,
const char* req_cnf, char** token);
FREERDP_API BOOL client_cli_get_avd_access_token(freerdp* instance, char** token);
FREERDP_API BOOL client_cli_get_access_token(freerdp* instance, AccessTokenType tokenType,
char** token, size_t count, ...);
FREERDP_API BOOL client_common_get_access_token(freerdp* instance, const char* request,
char** token);

View File

@@ -126,9 +126,15 @@ extern "C"
char** domain, rdp_auth_reason reason);
typedef BOOL (*pChooseSmartcard)(freerdp* instance, SmartcardCertInfo** cert_list, DWORD count,
DWORD* choice, BOOL gateway);
typedef BOOL (*pGetRDSAADAccessToken)(freerdp* instance, const char* scope, const char* req_cnf,
char** token);
typedef BOOL (*pGetAVDAccessToken)(freerdp* instance, char** token);
typedef enum
{
ACCESS_TOKEN_TYPE_AAD, /**!< oauth2 access token for RDS AAD authentication */
ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */
} AccessTokenType;
typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...);
/** @brief Callback used if user interaction is required to accept
* an unknown certificate.
@@ -522,13 +528,10 @@ owned by rdpRdp */
Callback for choosing a smartcard for logon.
Used when multiple smartcards are available. Returns an index into a list
of SmartcardCertInfo pointers */
ALIGN64 pGetRDSAADAccessToken GetRDSAADAccessToken; /* (offset 71)
Callback for obtaining an oauth2 access token
for RDS AAD authentication */
ALIGN64 pGetAVDAccessToken GetAVDAccessToken; /* (offset 72)
Callback for obtaining an oauth2 access token
for Azure Virtual Desktop */
UINT64 paddingE[80 - 73]; /* 73 */
ALIGN64 pGetAccessToken GetAccessToken; /* (offset 71)
Callback for obtaining an access token
for \b AccessTokenType authentication */
UINT64 paddingE[80 - 72]; /* 72 */
};
struct rdp_channel_handles

View File

@@ -282,13 +282,13 @@ int aad_client_begin(rdpAad* aad)
return -1;
/* Obtain an oauth authorization code */
if (!instance->GetRDSAADAccessToken)
if (!instance->GetAccessToken)
{
WLog_Print(aad->log, WLOG_ERROR, "instance->GetRDSAADAccessToken == NULL");
WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL");
return -1;
}
const BOOL arc =
instance->GetRDSAADAccessToken(instance, aad->scope, aad->kid, &aad->access_token);
const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token,
2, aad->scope, aad->kid);
if (!arc)
{
WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token");

View File

@@ -174,13 +174,13 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
{
char* token = NULL;
if (!instance->GetAVDAccessToken)
if (!instance->GetAccessToken)
{
WLog_ERR(TAG, "No authorization token provided");
goto out;
}
if (!instance->GetAVDAccessToken(instance, &token))
if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0))
{
WLog_ERR(TAG, "Unable to obtain access token");
goto out;