From af2a74cbbbd9bc8c51e54008bbf9faef3bf9ecb5 Mon Sep 17 00:00:00 2001 From: akallabeth Date: Wed, 8 Mar 2023 16:33:49 +0100 Subject: [PATCH] [core,aad] refactor aad parser * split functions into smaller elements * improve return code checks * add log messages for error results --- libfreerdp/core/aad.c | 956 +++++++++++++++++++++++++++++------------- 1 file changed, 661 insertions(+), 295 deletions(-) diff --git a/libfreerdp/core/aad.c b/libfreerdp/core/aad.c index 24d7612b0..bbaa6c5fa 100644 --- a/libfreerdp/core/aad.c +++ b/libfreerdp/core/aad.c @@ -20,6 +20,7 @@ #include #include +#include #include @@ -37,25 +38,6 @@ #include "aad.h" -#define LOG_ERROR_AND_GOTO(wlog, label, ...) \ - do \ - { \ - WLog_Print(wlog, WLOG_ERROR, __VA_ARGS__); \ - goto label; \ - } while (0); -#define LOG_ERROR_AND_RETURN(wlog, ret, ...) \ - do \ - { \ - WLog_Print(wlog, WLOG_ERROR, __VA_ARGS__); \ - return ret; \ - } while (0); -#define XFREE(x) \ - do \ - { \ - free(x); \ - x = NULL; \ - } while (0); - #define OAUTH2_CLIENT_ID "5177bc73-fd99-4c77-a90c-76844c9b6999" static const char* auth_server = "login.microsoftonline.com"; @@ -99,12 +81,60 @@ struct rdp_aad wLog* log; }; -static int alloc_sprintf(char** s, const char* template, ...); -static BOOL get_encoded_rsa_params(EVP_PKEY* pkey, char** e, char** n); +static BOOL get_encoded_rsa_params(wLog* wlog, EVP_PKEY* pkey, char** e, char** n); static BOOL generate_pop_key(rdpAad* aad); static BOOL read_http_message(rdpAad* aad, BIO* bio, long* status_code, char** content, size_t* content_length); +static int alloc_sprintf(char** s, size_t* slen, const char* template, ...) +{ + va_list ap; + + WINPR_ASSERT(s); + WINPR_ASSERT(slen); + *s = NULL; + *slen = 0; + + va_start(ap, template); + const int length = vsnprintf(NULL, 0, template, ap); + va_end(ap); + + char* str = calloc(length + 1, sizeof(char)); + if (!str) + return -1; + + va_start(ap, template); + const int plen = vsprintf(str, template, ap); + va_end(ap); + + WINPR_ASSERT(length == plen); + *s = str; + *slen = length; + return length; +} + +static SSIZE_T stream_sprintf(wStream* s, const char* fmt, ...) +{ + va_list ap; + va_start(ap, fmt); + const int rc = vsnprintf(NULL, 0, fmt, ap); + va_end(ap); + + if (rc < 0) + return rc; + + if (!Stream_EnsureRemainingCapacity(s, (size_t)rc)) + return -1; + + char* ptr = Stream_PointerAs(s, char); + const int rc2 = vsnprintf(ptr, rc, fmt, ap); + if (rc != rc2) + return -23; + if (!Stream_SafeSeek(s, (size_t)rc2)) + return -3; + return rc2; +} + static int print_error(const char* str, size_t len, void* u) { wLog* wlog = (wLog*)u; @@ -194,21 +224,163 @@ static BOOL json_get_string_alloc(wLog* wlog, cJSON* json, const char* key, char return FALSE; free(*result); *result = _strdup(str); + if (!*result) + WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' strdup is NULL", key); return *result != NULL; } +static BIO* aad_connect_https(rdpAad* aad, SSL_CTX* ssl_ctx) +{ + WINPR_ASSERT(aad); + WINPR_ASSERT(ssl_ctx); + + const int vprc = SSL_CTX_set_default_verify_paths(ssl_ctx); + const long mrc = SSL_CTX_set_mode(ssl_ctx, SSL_MODE_AUTO_RETRY); + + BIO* bio = BIO_new_ssl_connect(ssl_ctx); + if (!bio) + { + WLog_Print(aad->log, WLOG_ERROR, "Error setting up connection"); + return NULL; + } + const long chrc = BIO_set_conn_hostname(bio, auth_server); + const long cprc = BIO_set_conn_port(bio, "https"); + return bio; +} + +static BOOL aad_logging_bio_write(rdpAad* aad, BIO* bio, const char* str) +{ + WINPR_ASSERT(aad); + WINPR_ASSERT(bio); + WINPR_ASSERT(str); + + ERR_clear_error(); + if (BIO_write(bio, str, strlen(str)) < 0) + { + ERR_print_errors_cb(print_error, aad->log); + return FALSE; + } + return TRUE; +} + +static char* aad_read_response(rdpAad* aad, BIO* bio, size_t* plen, const char* what) +{ + WINPR_ASSERT(plen); + + long status_code; + char* buffer = NULL; + size_t length = 0; + + *plen = 0; + if (!read_http_message(aad, bio, &status_code, &buffer, &length)) + { + WLog_Print(aad->log, WLOG_ERROR, "Unable to read %s HTTP response", what); + return NULL; + } + WLog_Print(aad->log, WLOG_DEBUG, "%s HTTP response: %s", buffer); + + if (status_code != 200) + { + WLog_Print(aad->log, WLOG_ERROR, "%s HTTP status code: %li", status_code); + free(buffer); + return NULL; + } + *plen = length; + return buffer; +} + +static BOOL aad_read_and_extract_token_from_json(rdpAad* aad, BIO* bio) +{ + BOOL rc = FALSE; + size_t blen = 0; + char* buffer = aad_read_response(aad, bio, &blen, "access token"); + if (!buffer) + return FALSE; + + cJSON* json = cJSON_ParseWithLength(buffer, blen); + if (!json) + { + WLog_Print(aad->log, WLOG_ERROR, "Failed to parse JSON response"); + goto fail; + } + + if (!json_get_string_alloc(aad->log, json, "access_token", &aad->access_token)) + { + WLog_Print(aad->log, WLOG_ERROR, + "Could not find \"access_token\" property in JSON response"); + goto fail; + } + + rc = TRUE; +fail: + free(buffer); + cJSON_free(json); + return rc; +} + +static BOOL aad_read_and_extrace_nonce_from_json(rdpAad* aad, BIO* bio) +{ + BOOL rc = FALSE; + size_t blen = 0; + char* buffer = aad_read_response(aad, bio, &blen, "Nonce"); + if (!buffer) + return FALSE; + + /* Extract the nonce from the response */ + cJSON* json = cJSON_ParseWithLength(buffer, blen); + if (!json) + { + WLog_Print(aad->log, WLOG_ERROR, "Failed to parse JSON response"); + goto fail; + } + + if (!json_get_string_alloc(aad->log, json, "Nonce", &aad->nonce)) + { + WLog_Print(aad->log, WLOG_ERROR, "Could not find \"Nonce\" property in JSON response"); + goto fail; + } + rc = TRUE; +fail: + free(buffer); + cJSON_free(json); + return rc; +} + +static BOOL aad_send_token_request(rdpAad* aad, BIO* bio, const char* auth_code) +{ + BOOL rc = FALSE; + + char* req_body = NULL; + char* req_header = NULL; + size_t req_body_len = 0; + size_t req_header_len = 0; + const int trc = alloc_sprintf(&req_body, &req_body_len, token_http_request_body, auth_code, + aad->hostname, aad->kid); + if (trc < 0) + goto fail; + const int trh = alloc_sprintf(&req_header, &req_header_len, token_http_request_header, trc); + if (trh < 0) + goto fail; + + WLog_Print(aad->log, WLOG_DEBUG, "HTTP access token request: %s%s", req_header, req_body); + + if (!aad_logging_bio_write(aad, bio, req_header)) + goto fail; + if (!aad_logging_bio_write(aad, bio, req_body)) + goto fail; + rc = TRUE; +fail: + free(req_body); + free(req_header); + return rc; +} + int aad_client_begin(rdpAad* aad) { int ret = -1; SSL_CTX* ssl_ctx = NULL; BIO* bio = NULL; char* auth_code = NULL; - char *buffer = NULL, *req_header = NULL, *req_body = NULL; - size_t length = 0; - const char* hostname = NULL; - char* p = NULL; - long status_code; - cJSON* json = NULL; WINPR_ASSERT(aad); WINPR_ASSERT(aad->rdpcontext); @@ -220,209 +392,273 @@ int aad_client_begin(rdpAad* aad) WINPR_ASSERT(instance); /* Get the host part of the hostname */ - hostname = freerdp_settings_get_string(settings, FreeRDP_ServerHostname); - if (!hostname || !(aad->hostname = _strdup(hostname))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to get hostname"); - if ((p = strchr(aad->hostname, '.'))) + const char* hostname = freerdp_settings_get_string(settings, FreeRDP_ServerHostname); + if (!hostname) + { + WLog_Print(aad->log, WLOG_ERROR, "FreeRDP_ServerHostname == NULL"); + return -1; + } + + aad->hostname = _strdup(hostname); + if (aad->hostname) + { + WLog_Print(aad->log, WLOG_ERROR, "_strdup(FreeRDP_ServerHostname) == NULL"); + return -1; + } + + char* p = strchr(aad->hostname, '.'); + if (p) *p = '\0'; if (!generate_pop_key(aad)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to generate pop key"); + goto fail; /* Obtain an oauth authorization code */ - if (!instance->GetAadAuthCode || !instance->GetAadAuthCode(instance, aad->hostname, &auth_code)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to obtain authorization code"); + if (!instance->GetAadAuthCode) + { + WLog_Print(aad->log, WLOG_ERROR, "instance->GetAadAuthCode == NULL"); + goto fail; + } + const BOOL arc = instance->GetAadAuthCode(instance, aad->hostname, &auth_code); + if (!arc) + { + WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain authorization code"); + goto fail; + } /* Set up an ssl connection to the authorization server */ - if (!(ssl_ctx = SSL_CTX_new(TLS_client_method()))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error setting up SSL context"); - SSL_CTX_set_default_verify_paths(ssl_ctx); - SSL_CTX_set_mode(ssl_ctx, SSL_MODE_AUTO_RETRY); + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) + { + WLog_Print(aad->log, WLOG_ERROR, "Error setting up SSL context"); + goto fail; + } - if (!(bio = BIO_new_ssl_connect(ssl_ctx))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error setting up connection"); - BIO_set_conn_hostname(bio, auth_server); - BIO_set_conn_port(bio, "https"); + bio = aad_connect_https(aad, ssl_ctx); + if (!bio) + goto fail; /* Construct and send the token request message */ - length = alloc_sprintf(&req_body, token_http_request_body, auth_code, aad->hostname, aad->kid); - if (length < 0) + if (aad_send_token_request(aad, bio, auth_code)) goto fail; - if (alloc_sprintf(&req_header, token_http_request_header, length) < 0) - goto fail; - - WLog_Print(aad->log, WLOG_DEBUG, "HTTP access token request: %s%s", req_header, req_body); - - ERR_clear_error(); - if (BIO_write(bio, req_header, strlen(req_header)) < 0) - { - ERR_print_errors_cb(print_error, aad->log); - goto fail; - } - - ERR_clear_error(); - if (BIO_write(bio, req_body, strlen(req_body)) < 0) - { - ERR_print_errors_cb(print_error, aad->log); - goto fail; - } - - /* Read in the response */ - if (!read_http_message(aad, bio, &status_code, &buffer, &length)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to read access token HTTP response"); - WLog_Print(aad->log, WLOG_DEBUG, "HTTP access token response: %s", buffer); - - if (status_code != 200) - LOG_ERROR_AND_GOTO(aad->log, fail, "Received status code: %li", status_code); /* Extract the access token from the JSON response */ - json = cJSON_ParseWithLength(buffer, length); - if (!json) - LOG_ERROR_AND_GOTO(aad->log, fail, "Failed to parse JSON response"); - - if (!json_get_string_alloc(aad->log, json, "access_token", &aad->access_token)) - LOG_ERROR_AND_GOTO(aad->log, fail, - "Could not find \"access_token\" property in JSON response"); - - XFREE(buffer); - cJSON_free(json); - json = NULL; + if (!aad_read_and_extract_token_from_json(aad, bio)) + goto fail; /* Send the nonce request message */ WLog_Print(aad->log, WLOG_DEBUG, "HTTP nonce request: %s", nonce_http_request); - ERR_clear_error(); - if (BIO_write(bio, nonce_http_request, strlen(nonce_http_request)) < 0) - { - ERR_print_errors_cb(print_error, aad->log); + if (!aad_logging_bio_write(aad, bio, nonce_http_request)) goto fail; - } /* Read in the response */ - if (!read_http_message(aad, bio, &status_code, &buffer, &length)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to read HTTP response"); - WLog_Print(aad->log, WLOG_DEBUG, "HTTP nonce response: %s", buffer); - - if (status_code != 200) - LOG_ERROR_AND_GOTO(aad->log, fail, "Received status code: %li", status_code); - - /* Extract the nonce from the response */ - json = cJSON_ParseWithLength(buffer, length); - if (!json) - LOG_ERROR_AND_GOTO(aad->log, fail, "Failed to parse JSON response"); - - if (!json_get_string_alloc(aad->log, json, "Nonce", &aad->nonce)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Could not find \"Nonce\" property in JSON response"); + if (!aad_read_and_extrace_nonce_from_json(aad, bio)) + goto fail; ret = 1; fail: - cJSON_free(json); - free(buffer); - free(req_body); - free(req_header); BIO_free_all(bio); SSL_CTX_free(ssl_ctx); free(auth_code); return ret; } +static char* aad_create_jws_header(rdpAad* aad) +{ + WINPR_ASSERT(aad); + + /* Construct the base64url encoded JWS header */ + char* buffer = NULL; + size_t bufferlen = 0; + const int length = + alloc_sprintf(&buffer, &bufferlen, "{\"alg\":\"RS256\",\"kid\":\"%s\"}", aad->kid); + if (length < 0) + return NULL; + + char* jws_header = crypto_base64url_encode((BYTE*)buffer, bufferlen); + free(buffer); + return jws_header; +} + +static char* aad_create_jws_payload(rdpAad* aad, const char* ts_nonce) +{ + const time_t ts = time(NULL); + + WINPR_ASSERT(aad); + + char* e = NULL; + char* n = NULL; + if (!get_encoded_rsa_params(aad->log, aad->pop_key, &e, &n)) + return NULL; + + /* Construct the base64url encoded JWS payload */ + char* buffer = NULL; + size_t bufferlen = 0; + const int length = + alloc_sprintf(&buffer, &bufferlen, + "{" + "\"ts\":\"%li\"," + "\"at\":\"%s\"," + "\"u\":\"ms-device-service://termsrv.wvd.microsoft.com/name/%s\"," + "\"nonce\":\"%s\"," + "\"cnf\":{\"jwk\":{\"kty\":\"RSA\",\"e\":\"%s\",\"n\":\"%s\"}}," + "\"client_claims\":\"{\\\"aad_nonce\\\":\\\"%s\\\"}\"" + "}", + ts, aad->access_token, aad->hostname, ts_nonce, e, n, aad->nonce); + free(e); + free(n); + + if (length < 0) + return NULL; + + char* jws_payload = crypto_base64url_encode((BYTE*)buffer, bufferlen); + free(buffer); + return jws_payload; +} + +static BOOL aad_update_digest(rdpAad* aad, EVP_MD_CTX* ctx, const char* what) +{ + WINPR_ASSERT(aad); + WINPR_ASSERT(ctx); + WINPR_ASSERT(what); + + const int dsu1 = EVP_DigestSignUpdate(ctx, what, strlen(what)); + if (dsu1 <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_DigestSignUpdate [%s] failed with %d", what, dsu1); + return FALSE; + } + return TRUE; +} + +static char* aad_final_digest(rdpAad* aad, EVP_MD_CTX* ctx) +{ + char* jws_signature = NULL; + + WINPR_ASSERT(aad); + WINPR_ASSERT(ctx); + + size_t siglen = 0; + const int dsf = EVP_DigestSignFinal(ctx, NULL, &siglen); + if (dsf <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_DigestSignFinal failed with %d", dsf); + return FALSE; + } + + char* buffer = calloc(siglen + 1, sizeof(char)); + if (!buffer) + { + WLog_Print(aad->log, WLOG_ERROR, "calloc %" PRIuz " bytes failed", siglen + 1); + goto fail; + } + + size_t fsiglen = 0; + const int dsf2 = EVP_DigestSignFinal(ctx, (BYTE*)buffer, &fsiglen); + if (dsf2 <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_DigestSignFinal failed with %d", dsf2); + goto fail; + } + + if (siglen != fsiglen) + { + WLog_Print(aad->log, WLOG_ERROR, + "EVP_DigestSignFinal returned different sizes, first %" PRIuz " then %" PRIuz, + siglen, fsiglen); + goto fail; + } + jws_signature = crypto_base64url_encode((BYTE*)buffer, fsiglen); +fail: + free(buffer); + return jws_signature; +} + +static char* aad_create_jws_signature(rdpAad* aad, const char* jws_header, const char* jws_payload) +{ + char* jws_signature = NULL; + + WINPR_ASSERT(aad); + + EVP_MD_CTX* md_ctx = EVP_MD_CTX_new(); + if (!md_ctx) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_MD_CTX_new failed"); + goto fail; + } + + const int rdsi = EVP_DigestSignInit(md_ctx, NULL, EVP_sha256(), NULL, aad->pop_key); + if (rdsi <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_DigestSignInit failed with %d", rdsi); + goto fail; + } + + if (!aad_update_digest(aad, md_ctx, jws_header)) + goto fail; + if (!aad_update_digest(aad, md_ctx, ".")) + goto fail; + if (!aad_update_digest(aad, md_ctx, jws_payload)) + goto fail; + + jws_signature = aad_final_digest(aad, md_ctx); +fail: + EVP_MD_CTX_free(md_ctx); + return jws_signature; +} + static int aad_send_auth_request(rdpAad* aad, const char* ts_nonce) { int ret = -1; - char* jws_header = NULL; char* jws_payload = NULL; char* jws_signature = NULL; - char* buffer = NULL; - wStream* s = NULL; - time_t ts = time(NULL); - char *e = NULL, *n = NULL; - size_t length = 0; - EVP_MD_CTX* md_ctx = NULL; WINPR_ASSERT(aad); WINPR_ASSERT(ts_nonce); - /* Construct the base64url encoded JWS header */ - if ((length = alloc_sprintf(&buffer, "{\"alg\":\"RS256\",\"kid\":\"%s\"}", aad->kid)) < 0) + wStream* s = Stream_New(NULL, 1024); + if (!s) goto fail; - if (!(jws_header = crypto_base64url_encode((BYTE*)buffer, strlen(buffer)))) - goto fail; - XFREE(buffer); - if (!get_encoded_rsa_params(aad->pop_key, &e, &n)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error getting RSA key params"); + /* Construct the base64url encoded JWS header */ + char* jws_header = aad_create_jws_header(aad); + if (!jws_header) + goto fail; /* Construct the base64url encoded JWS payload */ - length = alloc_sprintf(&buffer, - "{" - "\"ts\":\"%li\"," - "\"at\":\"%s\"," - "\"u\":\"ms-device-service://termsrv.wvd.microsoft.com/name/%s\"," - "\"nonce\":\"%s\"," - "\"cnf\":{\"jwk\":{\"kty\":\"RSA\",\"e\":\"%s\",\"n\":\"%s\"}}," - "\"client_claims\":\"{\\\"aad_nonce\\\":\\\"%s\\\"}\"" - "}", - ts, aad->access_token, aad->hostname, ts_nonce, e, n, aad->nonce); - if (length < 0) + jws_payload = aad_create_jws_payload(aad, ts_nonce); + if (!jws_payload) goto fail; - if (!(jws_payload = crypto_base64url_encode((BYTE*)buffer, strlen(buffer)))) - goto fail; - XFREE(buffer); /* Sign the JWS with the pop key */ - if (!(md_ctx = EVP_MD_CTX_new())) - goto fail; - - if (!(EVP_DigestSignInit(md_ctx, NULL, EVP_sha256(), NULL, aad->pop_key))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while initializing signature context"); - - if (!(EVP_DigestSignUpdate(md_ctx, jws_header, strlen(jws_header)))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while signing data"); - if (!(EVP_DigestSignUpdate(md_ctx, ".", 1))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while signing data"); - if (!(EVP_DigestSignUpdate(md_ctx, jws_payload, strlen(jws_payload)))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while signing data"); - - if (!(EVP_DigestSignFinal(md_ctx, NULL, &length))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while signing data"); - - if (!(buffer = malloc(length))) - goto fail; - if (!(EVP_DigestSignFinal(md_ctx, (BYTE*)buffer, &length))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error while signing data"); - - if (!(jws_signature = crypto_base64url_encode((BYTE*)buffer, length))) + jws_signature = aad_create_jws_signature(aad, jws_header, jws_payload); + if (!jws_signature) goto fail; /* Construct the Authentication Request PDU with the JWS as the RDP Assertion */ - length = _snprintf(NULL, 0, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header, jws_payload, - jws_signature) + - 1; - if (length < 0) + if (stream_sprintf(s, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header, jws_payload, + jws_signature) < 0) goto fail; - - if (!(s = Stream_New(NULL, length))) - goto fail; - _snprintf(Stream_PointerAs(s, char), length, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header, - jws_payload, jws_signature); - Stream_Seek(s, length); + Stream_SealLength(s); if (transport_write(aad->transport, s) < 0) - LOG_ERROR_AND_GOTO(aad->log, fail, "Failed to send Authentication Request PDU"); - - ret = 1; - aad->state = AAD_STATE_AUTH; - + { + WLog_Print(aad->log, WLOG_ERROR, "transport_write [%" PRIdz " bytes] failed", + Stream_Length(s)); + } + else + { + ret = 1; + aad->state = AAD_STATE_AUTH; + } fail: Stream_Free(s, TRUE); - free(e); - free(n); - free(buffer); free(jws_header); free(jws_payload); free(jws_signature); - EVP_MD_CTX_free(md_ctx); + return ret; } @@ -468,36 +704,38 @@ static int aad_parse_state_auth(rdpAad* aad, wStream* s) if (!json_get_number(aad->log, json, "authentication_result", &result)) goto fail; + if (result != 0.0) + { + WLog_Print(aad->log, WLOG_ERROR, "Authentication result: %lf", result); + goto fail; + } + aad->state = AAD_STATE_FINAL; rc = 1; fail: cJSON_free(json); - - if (result != 0.0) - LOG_ERROR_AND_RETURN(aad->log, -1, "Authentication result: %d", (int)result); - - aad->state = AAD_STATE_FINAL; return rc; } + int aad_recv(rdpAad* aad, wStream* s) { - cJSON* json; - cJSON* prop; - WINPR_ASSERT(aad); WINPR_ASSERT(s); - if (aad->state == AAD_STATE_INITIAL) - return aad_parse_state_initial(aad, s); - else if (aad->state == AAD_STATE_AUTH) - return aad_parse_state_auth(aad, s); - else - LOG_ERROR_AND_RETURN(aad->log, -1, "Invalid state"); + switch (aad->state) + { + case AAD_STATE_INITIAL: + return aad_parse_state_initial(aad, s); + case AAD_STATE_AUTH: + return aad_parse_state_auth(aad, s); + default: + WLog_Print(aad->log, WLOG_ERROR, "Invalid AAD_STATE %d", aad->state); + return -1; + } } AAD_STATE aad_get_state(rdpAad* aad) { - if (!aad) - return AAD_STATE_FINAL; + WINPR_ASSERT(aad); return aad->state; } @@ -525,187 +763,315 @@ static BOOL read_http_message(rdpAad* aad, BIO* bio, long* status_code, char** c WINPR_ASSERT(content); WINPR_ASSERT(content_length); - if (BIO_get_line(bio, buffer, sizeof(buffer)) <= 0) - LOG_ERROR_AND_RETURN(aad->log, FALSE, "Error reading HTTP response"); + const int rb = BIO_get_line(bio, buffer, sizeof(buffer)); + if (rb <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "Error reading HTTP response"); + return FALSE; + } if (sscanf(buffer, "HTTP/%*u.%*u %li %*[^\r\n]\r\n", status_code) < 1) - LOG_ERROR_AND_RETURN(aad->log, FALSE, "Invalid HTTP response status line"); + { + WLog_Print(aad->log, WLOG_ERROR, "Invalid HTTP response status line"); + return FALSE; + } do { - char* name = NULL; + const int rb = BIO_get_line(bio, buffer, sizeof(buffer)); + if (rb <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "Error reading HTTP response"); + return FALSE; + } + char* val = NULL; - - if (BIO_get_line(bio, buffer, sizeof(buffer)) <= 0) - LOG_ERROR_AND_RETURN(aad->log, FALSE, "Error reading HTTP response"); - - name = strtok_r(buffer, ":", &val); - if (name && _stricmp(name, "content-length") == 0) + char* name = strtok_r(buffer, ":", &val); + if (name && (_stricmp(name, "content-length") == 0)) + { + errno = 0; *content_length = strtoul(val, NULL, 10); + switch (errno) + { + case 0: + break; + default: + WLog_Print(aad->log, WLOG_ERROR, "strtoul(%s) returned %s [%d]", val, + strerror(errno), errno); + return FALSE; + } + } } while (strcmp(buffer, "\r\n") != 0); if (*content_length == 0) return TRUE; - if (!(*content = malloc(*content_length + 1))) + *content = calloc(*content_length + 1, sizeof(char)); + if (!*content) return FALSE; - (*content)[*content_length] = '\0'; - if (BIO_read(bio, *content, *content_length) < *content_length) + const int brc = BIO_read(bio, *content, *content_length); + if (brc < *content_length) { free(*content); - LOG_ERROR_AND_RETURN(aad->log, FALSE, "Error reading HTTP response body"); + WLog_Print(aad->log, WLOG_ERROR, "Error reading HTTP response body (BIO_read returned %d)", + brc); + return FALSE; } return TRUE; } -static BOOL generate_pop_key(rdpAad* aad) +static BOOL generate_rsa_2048(rdpAad* aad) { - EVP_PKEY_CTX* ctx = NULL; - BOOL ret = FALSE; - size_t length = 0; - char* buffer = NULL; - char *e = NULL, *n = NULL; - WINPR_DIGEST_CTX* digest = NULL; + BOOL rc = FALSE; + WINPR_ASSERT(aad); + + EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL); + if (!aad) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL) failed"); + goto fail; + } + + const int rki = EVP_PKEY_keygen_init(ctx); + if (rki <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_PKEY_keygen_init failed with %d", rki); + goto fail; + } + + const int key_bits = 2048; + const int rkb = EVP_PKEY_CTX_set_rsa_keygen_bits(ctx, key_bits); + if (rkb <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_PKEY_CTX_set_rsa_keygen_bits(%d) failed with %d", + key_bits, rkb); + goto fail; + } + + const int rkg = EVP_PKEY_keygen(ctx, &aad->pop_key); + if (rkg <= 0) + { + WLog_Print(aad->log, WLOG_ERROR, "EVP_PKEY_keygen failed with %d", rkg); + goto fail; + } + + rc = TRUE; +fail: + + EVP_PKEY_CTX_free(ctx); + return rc; +} + +static char* generate_rsa_digest_base64_str(rdpAad* aad, const char* input, size_t ilen) +{ + char* b64 = NULL; + WINPR_DIGEST_CTX* digest = winpr_Digest_New(); + if (!digest) + { + WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_New failed"); + goto fail; + } + + if (!winpr_Digest_Init(digest, WINPR_MD_SHA256)) + { + WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Init(WINPR_MD_SHA256) failed"); + goto fail; + } + + if (!winpr_Digest_Update(digest, (const BYTE*)input, ilen)) + { + WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Update(%" PRIuz ") failed", ilen); + goto fail; + } + BYTE hash[WINPR_SHA256_DIGEST_LENGTH] = { 0 }; + if (!winpr_Digest_Final(digest, hash, sizeof(hash))) + { + WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Final(%" PRIuz ") failed", sizeof(hash)); + goto fail; + } + + /* Base64url encode the hash */ + b64 = crypto_base64url_encode(hash, sizeof(hash)); + +fail: + winpr_Digest_Free(digest); + return b64; +} + +static BOOL generate_json_base64_str(rdpAad* aad, const char* b64_hash) +{ + WINPR_ASSERT(aad); + + char* buffer = NULL; + size_t blen = 0; + const int length = alloc_sprintf(&buffer, &blen, "{\"kid\":\"%s\"}", b64_hash); + if (length < 0) + return FALSE; + + /* Finally, base64url encode the JSON text to form the kid */ + free(aad->kid); + aad->kid = crypto_base64url_encode((BYTE*)buffer, length); + free(buffer); + + if (!aad->kid) + { + return FALSE; + } + return TRUE; +} + +BOOL generate_pop_key(rdpAad* aad) +{ + BOOL ret = FALSE; + char* buffer = NULL; + char* b64_hash = NULL; + char *e = NULL, *n = NULL; WINPR_ASSERT(aad); /* Generate a 2048-bit RSA key pair */ - if (!(ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL))) - return FALSE; - - if (EVP_PKEY_keygen_init(ctx) <= 0) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error initializing keygen"); - if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx, 2048) <= 0) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error setting RSA keygen bits"); - if (EVP_PKEY_keygen(ctx, &aad->pop_key) <= 0) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error generating RSA pop token key"); + if (!generate_rsa_2048(aad)) + goto fail; /* Encode the public key as a JWK */ - if (!get_encoded_rsa_params(aad->pop_key, &e, &n)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error getting RSA key params"); + if (!get_encoded_rsa_params(aad->log, aad->pop_key, &e, &n)) + goto fail; - if ((length = alloc_sprintf(&buffer, "{\"e\":\"%s\",\"kty\":\"RSA\",\"n\":\"%s\"}", e, n)) < 0) + size_t blen = 0; + const int alen = + alloc_sprintf(&buffer, &blen, "{\"e\":\"%s\",\"kty\":\"RSA\",\"n\":\"%s\"}", e, n); + if (alen < 0) goto fail; /* Hash the encoded public key */ - if (!(digest = winpr_Digest_New())) - goto fail; - - if (!winpr_Digest_Init(digest, WINPR_MD_SHA256)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error initializing SHA256 digest"); - if (!winpr_Digest_Update(digest, (BYTE*)buffer, length)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to get hash of JWK"); - if (!winpr_Digest_Final(digest, hash, WINPR_SHA256_DIGEST_LENGTH)) - LOG_ERROR_AND_GOTO(aad->log, fail, "Unable to get hash of JWK"); - - XFREE(buffer); - - /* Base64url encode the hash */ - if (!(buffer = crypto_base64url_encode(hash, WINPR_SHA256_DIGEST_LENGTH))) + b64_hash = generate_rsa_digest_base64_str(aad, buffer, blen); + if (!b64_hash) goto fail; /* Encode a JSON object with a single property "kid" whose value is the encoded hash */ - { - char* buf2 = NULL; - if ((length = alloc_sprintf(&buf2, "{\"kid\":\"%s\"}", buffer)) < 0) - goto fail; - free(buffer); - buffer = buf2; - } - - /* Finally, base64url encode the JSON text to form the kid */ - if (!(aad->kid = crypto_base64url_encode((BYTE*)buffer, length))) - LOG_ERROR_AND_GOTO(aad->log, fail, "Error base64url encoding kid"); - - ret = TRUE; + ret = generate_json_base64_str(aad, b64_hash); fail: + free(b64_hash); free(buffer); free(e); free(n); - winpr_Digest_Free(digest); - EVP_PKEY_CTX_free(ctx); return ret; } -static BOOL get_encoded_rsa_params(EVP_PKEY* pkey, char** e, char** n) +static char* bn_to_base64_url(wLog* wlog, BIGNUM* bn) { + WINPR_ASSERT(wlog); + WINPR_ASSERT(bn); + + const int length = BN_num_bytes(bn); + if (length < 0) + { + WLog_Print(wlog, WLOG_ERROR, "BN_num_bytes failed with %d", length); + return NULL; + } + + const size_t alloc_size = (size_t)length + 1ull; + BYTE* buf = calloc(alloc_size, sizeof(BYTE)); + if (!buf) + { + return NULL; + } + + const int bnlen = BN_bn2bin(bn, buf); + if (bnlen != length) + { + free(buf); + WLog_Print(wlog, WLOG_ERROR, "BN_bn2bin returned %d, expected result %d", bnlen, length); + return NULL; + } + char* b64 = crypto_base64url_encode(buf, length); + free(buf); + + if (!b64) + WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode BIGNUM"); + + return b64; +} + +BOOL get_encoded_rsa_params(wLog* wlog, EVP_PKEY* pkey, char** pe, char** pn) +{ + BOOL rc = FALSE; BIGNUM *bn_e = NULL, *bn_n = NULL; - BYTE buf[2048] = { 0 }; - size_t length = 0; + char* e = NULL; + char* n = NULL; + WINPR_ASSERT(wlog); WINPR_ASSERT(pkey); - WINPR_ASSERT(e); - WINPR_ASSERT(n); + WINPR_ASSERT(pe); + WINPR_ASSERT(pn); - *e = NULL; - *n = NULL; + *pe = NULL; + *pn = NULL; #if OPENSSL_VERSION_NUMBER >= 0x30000000L if (!EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &bn_e)) + { + WLog_Print(wlog, WLOG_ERROR, "failed to get RSA E"); goto fail; + } if (!EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &bn_n)) + { + WLog_Print(wlog, WLOG_ERROR, "failed to get RSA N"); goto fail; + } #else { const RSA* rsa = NULL; if (!(rsa = EVP_PKEY_get0_RSA(pkey))) + { + WLog_Print(wlog, WLOG_ERROR, "failed to get RSA"); goto fail; + } if (!(bn_e = BN_dup(RSA_get0_e(rsa)))) + { + WLog_Print(wlog, WLOG_ERROR, "failed to get RSA E"); goto fail; + } if (!(bn_n = BN_dup(RSA_get0_n(rsa)))) + { + WLog_Print(wlog, WLOG_ERROR, "failed to get RSA N"); goto fail; + } } #endif - length = BN_num_bytes(bn_e); - if (length > sizeof(buf)) + e = bn_to_base64_url(wlog, bn_e); + if (!e) + { + WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode RSA E"); goto fail; - if (BN_bn2bin(bn_e, buf) < length) - goto fail; - *e = crypto_base64url_encode(buf, length); + } - length = BN_num_bytes(bn_n); - if (length > sizeof(buf)) + n = bn_to_base64_url(wlog, bn_n); + if (!n) + { + WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode RSA N"); goto fail; - if (BN_bn2bin(bn_n, buf) < length) - goto fail; - *n = crypto_base64url_encode(buf, length); + } + rc = TRUE; fail: BN_free(bn_e); BN_free(bn_n); - if (!(*e) || !(*n)) + if (!rc) { - free(*e); - free(*n); - return FALSE; + free(e); + free(n); } - return TRUE; -} - -static int alloc_sprintf(char** s, const char* template, ...) -{ - int length; - va_list ap; - - WINPR_ASSERT(s); - *s = NULL; - - va_start(ap, template); - length = vsnprintf(NULL, 0, template, ap); - va_end(ap); - - if (!(*s = calloc(length + 1, sizeof(char)))) - return -1; - - va_start(ap, template); - vsprintf(*s, template, ap); - va_end(ap); - - return length; + else + { + *pe = e; + *pn = n; + } + return rc; }