[core,rdstls] prefer Stream_Get_* functions

This commit is contained in:
Armin Novak
2025-09-24 10:55:55 +02:00
parent 23b2c3bfae
commit 5ac53c1fd8

View File

@@ -280,6 +280,9 @@ static BOOL rdstls_write_data(wStream* s, UINT32 length, const BYTE* data)
static BOOL rdstls_write_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
{
WINPR_ASSERT(rdstls);
WINPR_ASSERT(rdstls->context);
rdpSettings* settings = rdstls->context->settings;
WINPR_ASSERT(settings);
@@ -313,6 +316,7 @@ static BOOL rdstls_write_authentication_request_with_cookie(WINPR_ATTR_UNUSED rd
static BOOL rdstls_write_authentication_response(rdpRdstls* rdstls, wStream* s)
{
WINPR_ASSERT(rdstls);
if (!Stream_EnsureRemainingCapacity(s, 8))
return FALSE;
@@ -325,13 +329,11 @@ static BOOL rdstls_write_authentication_response(rdpRdstls* rdstls, wStream* s)
static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s)
{
UINT16 dataType = 0;
UINT16 supportedVersions = 0;
WINPR_ASSERT(rdstls);
if (Stream_GetRemainingLength(s) < 4)
return FALSE;
Stream_Read_UINT16(s, dataType);
const UINT16 dataType = Stream_Get_UINT16(s);
if (dataType != RDSTLS_DATA_CAPABILITIES)
{
WLog_Print(rdstls->log, WLOG_ERROR,
@@ -340,7 +342,7 @@ static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s)
return FALSE;
}
Stream_Read_UINT16(s, supportedVersions);
const UINT16 supportedVersions = Stream_Get_UINT16(s);
if ((supportedVersions & RDSTLS_VERSION_1) == 0)
{
WLog_Print(rdstls->log, WLOG_ERROR,
@@ -354,14 +356,12 @@ static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s)
static BOOL rdstls_read_unicode_string(WINPR_ATTR_UNUSED wLog* log, wStream* s, char** str)
{
UINT16 length = 0;
WINPR_ASSERT(str);
if (Stream_GetRemainingLength(s) < 2)
return FALSE;
Stream_Read_UINT16(s, length);
const UINT16 length = Stream_Get_UINT16(s);
if (Stream_GetRemainingLength(s) < length)
return FALSE;
@@ -382,8 +382,6 @@ static BOOL rdstls_read_unicode_string(WINPR_ATTR_UNUSED wLog* log, wStream* s,
static BOOL rdstls_read_data(WINPR_ATTR_UNUSED wLog* log, wStream* s, UINT16* pLength,
const BYTE** pData)
{
UINT16 length = 0;
WINPR_ASSERT(pLength);
WINPR_ASSERT(pData);
@@ -392,7 +390,7 @@ static BOOL rdstls_read_data(WINPR_ATTR_UNUSED wLog* log, wStream* s, UINT16* pL
if (Stream_GetRemainingLength(s) < 2)
return FALSE;
Stream_Read_UINT16(s, length);
const UINT16 length = Stream_Get_UINT16(s);
if (Stream_GetRemainingLength(s) < length)
return FALSE;
@@ -457,6 +455,9 @@ static BOOL rdstls_cmp_str(wLog* log, const char* field, const char* serverStr,
static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
{
WINPR_ASSERT(rdstls);
WINPR_ASSERT(rdstls->context);
BOOL rc = FALSE;
const BYTE* clientRedirectionGuid = NULL;
@@ -465,12 +466,7 @@ static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstl
char* clientUsername = NULL;
char* clientDomain = NULL;
const BYTE* serverRedirectionGuid = NULL;
const char* serverPassword = NULL;
const char* serverUsername = NULL;
const char* serverDomain = NULL;
rdpSettings* settings = rdstls->context->settings;
const rdpSettings* settings = rdstls->context->settings;
WINPR_ASSERT(settings);
if (!rdstls_read_data(rdstls->log, s, &clientRedirectionGuidLength, &clientRedirectionGuid))
@@ -485,12 +481,13 @@ static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstl
if (!rdstls_read_unicode_string(rdstls->log, s, &clientPassword))
goto fail;
serverRedirectionGuid = freerdp_settings_get_pointer(settings, FreeRDP_RedirectionGuid);
const BYTE* serverRedirectionGuid =
freerdp_settings_get_pointer(settings, FreeRDP_RedirectionGuid);
const UINT32 serverRedirectionGuidLength =
freerdp_settings_get_uint32(settings, FreeRDP_RedirectionGuidLength);
serverUsername = freerdp_settings_get_string(settings, FreeRDP_Username);
serverDomain = freerdp_settings_get_string(settings, FreeRDP_Domain);
serverPassword = freerdp_settings_get_string(settings, FreeRDP_Password);
const char* serverUsername = freerdp_settings_get_string(settings, FreeRDP_Username);
const char* serverDomain = freerdp_settings_get_string(settings, FreeRDP_Domain);
const char* serverPassword = freerdp_settings_get_string(settings, FreeRDP_Password);
rdstls->resultCode = RDSTLS_RESULT_SUCCESS;
@@ -522,12 +519,10 @@ static BOOL rdstls_process_authentication_request_with_cookie(WINPR_ATTR_UNUSED
static BOOL rdstls_process_authentication_request(rdpRdstls* rdstls, wStream* s)
{
UINT16 dataType = 0;
if (Stream_GetRemainingLength(s) < 2)
return FALSE;
Stream_Read_UINT16(s, dataType);
const UINT16 dataType = Stream_Get_UINT16(s);
switch (dataType)
{
case RDSTLS_DATA_PASSWORD_CREDS:
@@ -552,13 +547,10 @@ static BOOL rdstls_process_authentication_request(rdpRdstls* rdstls, wStream* s)
static BOOL rdstls_process_authentication_response(rdpRdstls* rdstls, wStream* s)
{
UINT16 dataType = 0;
UINT32 resultCode = 0;
if (Stream_GetRemainingLength(s) < 6)
return FALSE;
Stream_Read_UINT16(s, dataType);
const UINT16 dataType = Stream_Get_UINT16(s);
if (dataType != RDSTLS_DATA_RESULT_CODE)
{
WLog_Print(rdstls->log, WLOG_ERROR,
@@ -567,7 +559,7 @@ static BOOL rdstls_process_authentication_response(rdpRdstls* rdstls, wStream* s
return FALSE;
}
Stream_Read_UINT32(s, resultCode);
const UINT32 resultCode = Stream_Get_UINT32(s);
if (resultCode != RDSTLS_RESULT_SUCCESS)
{
WLog_Print(rdstls->log, WLOG_ERROR, "resultCode: %s [0x%08" PRIX32 "]",
@@ -669,8 +661,6 @@ static BOOL rdstls_send(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, v
static int rdstls_recv(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, void* extra)
{
UINT16 version = 0;
UINT16 pduType = 0;
rdpRdstls* rdstls = (rdpRdstls*)extra;
WINPR_ASSERT(transport);
@@ -680,7 +670,7 @@ static int rdstls_recv(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, vo
if (Stream_GetRemainingLength(s) < 4)
return FALSE;
Stream_Read_UINT16(s, version);
const UINT16 version = Stream_Get_UINT16(s);
if (version != RDSTLS_VERSION_1)
{
WLog_Print(rdstls->log, WLOG_ERROR,
@@ -689,7 +679,7 @@ static int rdstls_recv(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, vo
return -1;
}
Stream_Read_UINT16(s, pduType);
const UINT16 pduType = Stream_Get_UINT16(s);
switch (pduType)
{
case RDSTLS_TYPE_CAPABILITIES:
@@ -722,6 +712,8 @@ static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE exp
if (current == expected)
return TRUE;
WINPR_ASSERT(rdstls);
const DWORD log_level = WLOG_ERROR;
if (WLog_IsLevelActive(rdstls->log, log_level))
WLog_PrintTextMessage(rdstls->log, log_level, line, file, fkt,
@@ -735,15 +727,15 @@ static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE exp
static BOOL rdstls_send_capabilities(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
wStream* s = NULL;
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
goto fail;
return FALSE;
s = Stream_New(NULL, 512);
wStream* s = Stream_New(NULL, 512);
if (!s)
goto fail;
WINPR_ASSERT(rdstls);
if (!rdstls_send(rdstls->transport, s, rdstls))
goto fail;
@@ -756,22 +748,21 @@ fail:
static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
int status = 0;
wStream* s = NULL;
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
goto fail;
return FALSE;
s = Stream_New(NULL, 4096);
wStream* s = Stream_New(NULL, 4096);
if (!s)
goto fail;
status = transport_read_pdu(rdstls->transport, s);
WINPR_ASSERT(rdstls);
const int res = transport_read_pdu(rdstls->transport, s);
if (status < 0)
if (res < 0)
goto fail;
status = rdstls_recv(rdstls->transport, s, rdstls);
const int status = rdstls_recv(rdstls->transport, s, rdstls);
if (status < 0)
goto fail;
@@ -785,15 +776,15 @@ fail:
static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
wStream* s = NULL;
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
goto fail;
return FALSE;
s = Stream_New(NULL, 512);
wStream* s = Stream_New(NULL, 512);
if (!s)
goto fail;
WINPR_ASSERT(rdstls);
if (!rdstls_send(rdstls->transport, s, rdstls))
goto fail;
@@ -806,22 +797,21 @@ fail:
static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
int status = 0;
wStream* s = NULL;
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
goto fail;
return FALSE;
s = Stream_New(NULL, 512);
wStream* s = Stream_New(NULL, 512);
if (!s)
goto fail;
status = transport_read_pdu(rdstls->transport, s);
WINPR_ASSERT(rdstls);
const int res = transport_read_pdu(rdstls->transport, s);
if (status < 0)
if (res < 0)
goto fail;
status = rdstls_recv(rdstls->transport, s, rdstls);
const int status = rdstls_recv(rdstls->transport, s, rdstls);
if (status < 0)
goto fail;
@@ -835,15 +825,15 @@ fail:
static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
wStream* s = NULL;
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
goto fail;
return FALSE;
s = Stream_New(NULL, 4096);
wStream* s = Stream_New(NULL, 4096);
if (!s)
goto fail;
WINPR_ASSERT(rdstls);
if (!rdstls_send(rdstls->transport, s, rdstls))
goto fail;
@@ -856,24 +846,22 @@ fail:
static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls)
{
BOOL rc = FALSE;
int status = 0;
wStream* s = NULL;
WINPR_ASSERT(rdstls);
if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
goto fail;
return FALSE;
s = Stream_New(NULL, 512);
wStream* s = Stream_New(NULL, 512);
if (!s)
goto fail;
status = transport_read_pdu(rdstls->transport, s);
const int res = transport_read_pdu(rdstls->transport, s);
if (status < 0)
if (res < 0)
goto fail;
status = rdstls_recv(rdstls->transport, s, rdstls);
const int status = rdstls_recv(rdstls->transport, s, rdstls);
if (status < 0)
goto fail;
@@ -886,6 +874,8 @@ fail:
static int rdstls_server_authenticate(rdpRdstls* rdstls)
{
WINPR_ASSERT(rdstls);
if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES))
return -1;
@@ -946,37 +936,35 @@ static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s
{
case RDSTLS_DATA_PASSWORD_CREDS:
{
UINT16 redirGuidLength = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, redirGuidLength);
const UINT16 redirGuidLength = Stream_Get_UINT16(s);
if (Stream_GetRemainingLength(s) < redirGuidLength)
return 0;
Stream_Seek(s, redirGuidLength);
UINT16 usernameLength = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, usernameLength);
const UINT16 usernameLength = Stream_Get_UINT16(s);
if (Stream_GetRemainingLength(s) < usernameLength)
return 0;
Stream_Seek(s, usernameLength);
UINT16 domainLength = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, domainLength);
const UINT16 domainLength = Stream_Get_UINT16(s);
if (Stream_GetRemainingLength(s) < domainLength)
return 0;
Stream_Seek(s, domainLength);
UINT16 passwordLength = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, passwordLength);
const UINT16 passwordLength = Stream_Get_UINT16(s);
pduLength = Stream_GetPosition(s) + passwordLength;
}
@@ -987,10 +975,9 @@ static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s
return 0;
Stream_Seek(s, 4);
UINT16 cookieLength = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, cookieLength);
const UINT16 cookieLength = Stream_Get_UINT16(s);
pduLength = Stream_GetPosition(s) + cookieLength;
}
@@ -1011,20 +998,20 @@ SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream)
wStream sbuffer = { 0 };
wStream* s = Stream_StaticConstInit(&sbuffer, Stream_Buffer(stream), Stream_Length(stream));
UINT16 version = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, version);
const UINT16 version = Stream_Get_UINT16(s);
if (version != RDSTLS_VERSION_1)
{
WLog_Print(log, WLOG_ERROR, "invalid RDSTLS version");
return -1;
}
UINT16 pduType = 0;
if (Stream_GetRemainingLength(s) < 2)
return 0;
Stream_Read_UINT16(s, pduType);
const UINT16 pduType = Stream_Get_UINT16(s);
switch (pduType)
{
case RDSTLS_TYPE_CAPABILITIES:
@@ -1033,8 +1020,8 @@ SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream)
case RDSTLS_TYPE_AUTHREQ:
if (Stream_GetRemainingLength(s) < 2)
return 0;
UINT16 dataType = 0;
Stream_Read_UINT16(s, dataType);
const UINT16 dataType = Stream_Get_UINT16(s);
pduLength = rdstls_parse_pdu_data_type(log, dataType, s);
break;