/** * FreeRDP: A Remote Desktop Protocol Implementation * RDSTLS Security protocol * * Copyright 2023 Joan Torres * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "settings.h" #include #include #include #include #include #include #include "rdstls.h" #include "transport.h" #include "utils.h" #define RDSTLS_VERSION_1 0x01 #define RDSTLS_TYPE_CAPABILITIES 0x01 #define RDSTLS_TYPE_AUTHREQ 0x02 #define RDSTLS_TYPE_AUTHRSP 0x04 #define RDSTLS_DATA_CAPABILITIES 0x01 #define RDSTLS_DATA_PASSWORD_CREDS 0x01 #define RDSTLS_DATA_AUTORECONNECT_COOKIE 0x02 #define RDSTLS_DATA_RESULT_CODE 0x01 typedef enum { RDSTLS_STATE_INITIAL, RDSTLS_STATE_CAPABILITIES, RDSTLS_STATE_AUTH_REQ, RDSTLS_STATE_AUTH_RSP, RDSTLS_STATE_FINAL, } RDSTLS_STATE; typedef enum { RDSTLS_RESULT_SUCCESS = 0x00000000, RDSTLS_RESULT_ACCESS_DENIED = 0x00000005, RDSTLS_RESULT_LOGON_FAILURE = 0x0000052e, RDSTLS_RESULT_INVALID_LOGON_HOURS = 0x00000530, RDSTLS_RESULT_PASSWORD_EXPIRED = 0x00000532, RDSTLS_RESULT_ACCOUNT_DISABLED = 0x00000533, RDSTLS_RESULT_PASSWORD_MUST_CHANGE = 0x00000773, RDSTLS_RESULT_ACCOUNT_LOCKED_OUT = 0x00000775 } RDSTLS_RESULT_CODE; struct rdp_rdstls { BOOL server; RDSTLS_STATE state; rdpContext* context; rdpTransport* transport; RDSTLS_RESULT_CODE resultCode; wLog* log; }; static const char* rdstls_result_code_str(UINT32 resultCode) { switch (resultCode) { case RDSTLS_RESULT_SUCCESS: return "RDSTLS_RESULT_SUCCESS"; case RDSTLS_RESULT_ACCESS_DENIED: return "RDSTLS_RESULT_ACCESS_DENIED"; case RDSTLS_RESULT_LOGON_FAILURE: return "RDSTLS_RESULT_LOGON_FAILURE"; case RDSTLS_RESULT_INVALID_LOGON_HOURS: return "RDSTLS_RESULT_INVALID_LOGON_HOURS"; case RDSTLS_RESULT_PASSWORD_EXPIRED: return "RDSTLS_RESULT_PASSWORD_EXPIRED"; case RDSTLS_RESULT_ACCOUNT_DISABLED: return "RDSTLS_RESULT_ACCOUNT_DISABLED"; case RDSTLS_RESULT_PASSWORD_MUST_CHANGE: return "RDSTLS_RESULT_PASSWORD_MUST_CHANGE"; case RDSTLS_RESULT_ACCOUNT_LOCKED_OUT: return "RDSTLS_RESULT_ACCOUNT_LOCKED_OUT"; default: return "RDSTLS_RESULT_UNKNOWN"; } } /** * Create new RDSTLS state machine. * * @param context A pointer to the rdp context to use * * @return new RDSTLS state machine. */ rdpRdstls* rdstls_new(rdpContext* context, rdpTransport* transport) { WINPR_ASSERT(context); WINPR_ASSERT(transport); rdpSettings* settings = context->settings; WINPR_ASSERT(settings); rdpRdstls* rdstls = (rdpRdstls*)calloc(1, sizeof(rdpRdstls)); if (!rdstls) return nullptr; rdstls->log = WLog_Get(FREERDP_TAG("core.rdstls")); rdstls->context = context; rdstls->transport = transport; rdstls->server = settings->ServerMode; rdstls->state = RDSTLS_STATE_INITIAL; return rdstls; } /** * Free RDSTLS state machine. * @param rdstls The RDSTLS instance to free */ void rdstls_free(rdpRdstls* rdstls) { free(rdstls); } static const char* rdstls_get_state_str(RDSTLS_STATE state) { switch (state) { case RDSTLS_STATE_INITIAL: return "RDSTLS_STATE_INITIAL"; case RDSTLS_STATE_CAPABILITIES: return "RDSTLS_STATE_CAPABILITIES"; case RDSTLS_STATE_AUTH_REQ: return "RDSTLS_STATE_AUTH_REQ"; case RDSTLS_STATE_AUTH_RSP: return "RDSTLS_STATE_AUTH_RSP"; case RDSTLS_STATE_FINAL: return "RDSTLS_STATE_FINAL"; default: return "UNKNOWN"; } } static RDSTLS_STATE rdstls_get_state(rdpRdstls* rdstls) { WINPR_ASSERT(rdstls); return rdstls->state; } static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expected, RDSTLS_STATE requested) { if (requested != expected) { WLog_Print(log, WLOG_ERROR, "Unexpected rdstls state transition from %s [%u] to %s [%u], expected %s [%u]", rdstls_get_state_str(current), current, rdstls_get_state_str(requested), requested, rdstls_get_state_str(expected), expected); return FALSE; } return TRUE; } static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state) { BOOL rc = FALSE; WINPR_ASSERT(rdstls); WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state), rdstls_get_state_str(state)); switch (rdstls->state) { case RDSTLS_STATE_INITIAL: rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); break; case RDSTLS_STATE_CAPABILITIES: rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state); break; case RDSTLS_STATE_AUTH_REQ: rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state); break; case RDSTLS_STATE_AUTH_RSP: rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_FINAL, state); break; case RDSTLS_STATE_FINAL: rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state); break; default: WLog_Print(rdstls->log, WLOG_ERROR, "Invalid rdstls state %s [%u], requested transition to %s [%u]", rdstls_get_state_str(rdstls->state), rdstls->state, rdstls_get_state_str(state), state); break; } if (rc) rdstls->state = state; return rc; } static BOOL rdstls_write_capabilities(WINPR_ATTR_UNUSED rdpRdstls* rdstls, wStream* s) { if (!Stream_EnsureRemainingCapacity(s, 6)) return FALSE; Stream_Write_UINT16(s, RDSTLS_TYPE_CAPABILITIES); Stream_Write_UINT16(s, RDSTLS_DATA_CAPABILITIES); Stream_Write_UINT16(s, RDSTLS_VERSION_1); return TRUE; } static SSIZE_T rdstls_write_string(wStream* s, const char* str) { const size_t pos = Stream_GetPosition(s); if (!Stream_EnsureRemainingCapacity(s, 2)) return -1; if (!str) { /* Write unicode null */ Stream_Write_UINT16(s, 2); if (!Stream_EnsureRemainingCapacity(s, 2)) return -1; Stream_Write_UINT16(s, 0); return (SSIZE_T)(Stream_GetPosition(s) - pos); } const size_t length = (strlen(str) + 1); Stream_Write_UINT16(s, (UINT16)length * sizeof(WCHAR)); if (!Stream_EnsureRemainingCapacity(s, length * sizeof(WCHAR))) return -1; if (Stream_Write_UTF16_String_From_UTF8(s, length, str, length, TRUE) < 0) return -1; return (SSIZE_T)(Stream_GetPosition(s) - pos); } static BOOL rdstls_write_data(wStream* s, UINT32 length, const BYTE* data) { WINPR_ASSERT(data || (length == 0)); if (!Stream_EnsureRemainingCapacity(s, 2) || (length > UINT16_MAX)) return FALSE; Stream_Write_UINT16(s, (UINT16)length); if (!Stream_EnsureRemainingCapacity(s, length)) return FALSE; Stream_Write(s, data, length); return TRUE; } static BOOL rdstls_write_cookie(wStream* s, const ARC_SC_PRIVATE_PACKET* cookie) { WINPR_ASSERT(cookie); const uint16_t length = 28; if (!Stream_EnsureRemainingCapacity(s, 2)) return FALSE; Stream_Write_UINT16(s, length); if (!Stream_EnsureRemainingCapacity(s, length)) return FALSE; Stream_Write_UINT32(s, cookie->cbLen); Stream_Write_UINT32(s, cookie->version); Stream_Write_UINT32(s, cookie->logonId); Stream_Write(s, cookie->arcRandomBits, sizeof(cookie->arcRandomBits)); return TRUE; } static BOOL rdstls_write_authentication_request_with_password(rdpRdstls* rdstls, wStream* s) { WINPR_ASSERT(rdstls); WINPR_ASSERT(rdstls->context); WLog_Print(rdstls->log, WLOG_DEBUG, "Writing RDSTLS password authentication message"); rdpSettings* settings = rdstls->context->settings; WINPR_ASSERT(settings); if (!Stream_EnsureRemainingCapacity(s, 4)) return FALSE; Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHREQ); Stream_Write_UINT16(s, RDSTLS_DATA_PASSWORD_CREDS); if (!rdstls_write_data(s, settings->RedirectionGuidLength, settings->RedirectionGuid)) return FALSE; if (rdstls_write_string(s, settings->Username) < 0) return FALSE; if (rdstls_write_string(s, settings->Domain) < 0) return FALSE; if (!rdstls_write_data(s, settings->RedirectionPasswordLength, settings->RedirectionPassword)) return FALSE; return TRUE; } static BOOL rdstls_write_authentication_request_with_cookie(WINPR_ATTR_UNUSED rdpRdstls* rdstls, WINPR_ATTR_UNUSED wStream* s) { WINPR_ASSERT(rdstls); WINPR_ASSERT(rdstls->context); WLog_Print(rdstls->log, WLOG_DEBUG, "Writing RDSTLS cookie authentication message"); rdpSettings* settings = rdstls->context->settings; WINPR_ASSERT(settings); if (!Stream_EnsureRemainingCapacity(s, 8)) return FALSE; Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHREQ); Stream_Write_UINT16(s, RDSTLS_DATA_AUTORECONNECT_COOKIE); Stream_Write_UINT32(s, settings->RedirectedSessionId); return (rdstls_write_cookie(s, settings->ServerAutoReconnectCookie)); } static BOOL rdstls_write_authentication_response(rdpRdstls* rdstls, wStream* s) { WINPR_ASSERT(rdstls); if (!Stream_EnsureRemainingCapacity(s, 8)) return FALSE; Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHRSP); Stream_Write_UINT16(s, RDSTLS_DATA_RESULT_CODE); Stream_Write_UINT32(s, rdstls->resultCode); return TRUE; } static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s) { WINPR_ASSERT(rdstls); if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4)) return FALSE; const UINT16 dataType = Stream_Get_UINT16(s); if (dataType != RDSTLS_DATA_CAPABILITIES) { WLog_Print(rdstls->log, WLOG_ERROR, "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX32, dataType, WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_DATA_CAPABILITIES)); return FALSE; } const UINT16 supportedVersions = Stream_Get_UINT16(s); if ((supportedVersions & RDSTLS_VERSION_1) == 0) { WLog_Print(rdstls->log, WLOG_ERROR, "received invalid supportedVersions=0x%04" PRIX16 ", expected 0x%04" PRIX32, supportedVersions, WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_VERSION_1)); return FALSE; } return TRUE; } static BOOL rdstls_read_unicode_string(WINPR_ATTR_UNUSED wLog* log, wStream* s, char** str) { WINPR_ASSERT(str); if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2)) return FALSE; const UINT16 length = Stream_Get_UINT16(s); if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length)) return FALSE; if (length <= 2) { Stream_Seek(s, length); return TRUE; } *str = Stream_Read_UTF16_String_As_UTF8(s, length / sizeof(WCHAR), nullptr); return (*str) != nullptr; } static BOOL rdstls_read_data(WINPR_ATTR_UNUSED wLog* log, wStream* s, UINT16* pLength, const BYTE** pData) { WINPR_ASSERT(pLength); WINPR_ASSERT(pData); *pData = nullptr; *pLength = 0; if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2)) return FALSE; const UINT16 length = Stream_Get_UINT16(s); if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length)) return FALSE; if (length <= 2) { Stream_Seek(s, length); return TRUE; } *pData = Stream_ConstPointer(s); *pLength = length; Stream_Seek(s, length); return TRUE; } static BOOL rdstls_cmp_data(wLog* log, const char* field, const BYTE* serverData, const UINT32 serverDataLength, const BYTE* clientData, const UINT16 clientDataLength) { if (serverDataLength > 0) { if (clientDataLength == 0) { WLog_Print(log, WLOG_ERROR, "expected %s", field); return FALSE; } if (serverDataLength > UINT16_MAX || serverDataLength != clientDataLength || memcmp(serverData, clientData, serverDataLength) != 0) { WLog_Print(log, WLOG_ERROR, "%s verification failed", field); return FALSE; } } return TRUE; } static BOOL rdstls_cmp_str(wLog* log, const char* field, const char* serverStr, const char* clientStr) { if (!utils_str_is_empty(serverStr)) { if (utils_str_is_empty(clientStr)) { WLog_Print(log, WLOG_ERROR, "expected %s", field); return FALSE; } WINPR_ASSERT(serverStr); WINPR_ASSERT(clientStr); if (strcmp(serverStr, clientStr) != 0) { WLog_Print(log, WLOG_ERROR, "%s verification failed", field); return FALSE; } } return TRUE; } static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstls, wStream* s) { WINPR_ASSERT(rdstls); WINPR_ASSERT(rdstls->context); BOOL rc = FALSE; const BYTE* clientRedirectionGuid = nullptr; UINT16 clientRedirectionGuidLength = 0; char* clientPassword = nullptr; char* clientUsername = nullptr; char* clientDomain = nullptr; const rdpSettings* settings = rdstls->context->settings; WINPR_ASSERT(settings); if (!rdstls_read_data(rdstls->log, s, &clientRedirectionGuidLength, &clientRedirectionGuid)) goto fail; if (!rdstls_read_unicode_string(rdstls->log, s, &clientUsername)) goto fail; if (!rdstls_read_unicode_string(rdstls->log, s, &clientDomain)) goto fail; if (!rdstls_read_unicode_string(rdstls->log, s, &clientPassword)) goto fail; { const BYTE* serverRedirectionGuid = freerdp_settings_get_pointer(settings, FreeRDP_RedirectionGuid); const UINT32 serverRedirectionGuidLength = freerdp_settings_get_uint32(settings, FreeRDP_RedirectionGuidLength); const char* serverUsername = freerdp_settings_get_string(settings, FreeRDP_Username); const char* serverDomain = freerdp_settings_get_string(settings, FreeRDP_Domain); const char* serverPassword = freerdp_settings_get_string(settings, FreeRDP_Password); rdstls->resultCode = RDSTLS_RESULT_SUCCESS; if (!rdstls_cmp_data(rdstls->log, "RedirectionGuid", serverRedirectionGuid, serverRedirectionGuidLength, clientRedirectionGuid, clientRedirectionGuidLength)) rdstls->resultCode = RDSTLS_RESULT_ACCESS_DENIED; if (!rdstls_cmp_str(rdstls->log, "UserName", serverUsername, clientUsername)) rdstls->resultCode = RDSTLS_RESULT_LOGON_FAILURE; if (!rdstls_cmp_str(rdstls->log, "Domain", serverDomain, clientDomain)) rdstls->resultCode = RDSTLS_RESULT_LOGON_FAILURE; if (!rdstls_cmp_str(rdstls->log, "Password", serverPassword, clientPassword)) rdstls->resultCode = RDSTLS_RESULT_LOGON_FAILURE; } rc = TRUE; fail: return rc; } static BOOL rdstls_process_authentication_request_with_cookie(WINPR_ATTR_UNUSED rdpRdstls* rdstls, WINPR_ATTR_UNUSED wStream* s) { // TODO WLog_Print(rdstls->log, WLOG_ERROR, "TODO: RDSTLS Cookie authentication not implemented"); return FALSE; } static BOOL rdstls_process_authentication_request(rdpRdstls* rdstls, wStream* s) { if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 2)) return FALSE; const UINT16 dataType = Stream_Get_UINT16(s); switch (dataType) { case RDSTLS_DATA_PASSWORD_CREDS: if (!rdstls_process_authentication_request_with_password(rdstls, s)) return FALSE; break; case RDSTLS_DATA_AUTORECONNECT_COOKIE: if (!rdstls_process_authentication_request_with_cookie(rdstls, s)) return FALSE; break; default: WLog_Print(rdstls->log, WLOG_ERROR, "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX32 " or 0x%04" PRIX32, dataType, WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_DATA_PASSWORD_CREDS), WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_DATA_AUTORECONNECT_COOKIE)); return FALSE; } return TRUE; } static BOOL rdstls_process_authentication_response(rdpRdstls* rdstls, wStream* s) { if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 6)) return FALSE; const UINT16 dataType = Stream_Get_UINT16(s); if (dataType != RDSTLS_DATA_RESULT_CODE) { WLog_Print(rdstls->log, WLOG_ERROR, "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX32, dataType, WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_DATA_RESULT_CODE)); return FALSE; } const UINT32 resultCode = Stream_Get_UINT32(s); if (resultCode != RDSTLS_RESULT_SUCCESS) { WLog_Print(rdstls->log, WLOG_ERROR, "resultCode: %s [0x%08" PRIX32 "]", rdstls_result_code_str(resultCode), resultCode); UINT32 error = FREERDP_ERROR_CONNECT_UNDEFINED; switch (resultCode) { case RDSTLS_RESULT_ACCESS_DENIED: error = FREERDP_ERROR_CONNECT_ACCESS_DENIED; break; case RDSTLS_RESULT_ACCOUNT_DISABLED: error = FREERDP_ERROR_CONNECT_ACCOUNT_DISABLED; break; case RDSTLS_RESULT_ACCOUNT_LOCKED_OUT: error = FREERDP_ERROR_CONNECT_ACCOUNT_LOCKED_OUT; break; case RDSTLS_RESULT_LOGON_FAILURE: error = FREERDP_ERROR_CONNECT_LOGON_FAILURE; break; case RDSTLS_RESULT_INVALID_LOGON_HOURS: error = FREERDP_ERROR_CONNECT_ACCOUNT_RESTRICTION; break; case RDSTLS_RESULT_PASSWORD_EXPIRED: error = FREERDP_ERROR_CONNECT_PASSWORD_EXPIRED; break; case RDSTLS_RESULT_PASSWORD_MUST_CHANGE: error = FREERDP_ERROR_CONNECT_PASSWORD_MUST_CHANGE; break; default: WLog_Print(rdstls->log, WLOG_ERROR, "Unexpected resultCode: [0x%08" PRIX32 "], NTSTATUS=%s, Win32Error=%s", resultCode, GetSecurityStatusString((SECURITY_STATUS)resultCode), Win32ErrorCode2Tag(resultCode & 0xFFFF)); error = FREERDP_ERROR_CONNECT_UNDEFINED; break; } freerdp_set_last_error_if_not(rdstls->context, error); return FALSE; } return TRUE; } static BOOL rdstls_send(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, void* extra) { rdpRdstls* rdstls = (rdpRdstls*)extra; rdpSettings* settings = nullptr; WINPR_ASSERT(transport); WINPR_ASSERT(s); WINPR_ASSERT(rdstls); settings = rdstls->context->settings; WINPR_ASSERT(settings); if (!Stream_EnsureRemainingCapacity(s, 2)) return FALSE; Stream_Write_UINT16(s, RDSTLS_VERSION_1); const RDSTLS_STATE state = rdstls_get_state(rdstls); switch (state) { case RDSTLS_STATE_CAPABILITIES: if (!rdstls_write_capabilities(rdstls, s)) return FALSE; break; case RDSTLS_STATE_AUTH_REQ: if (settings->RedirectionFlags & LB_PASSWORD_IS_PK_ENCRYPTED) { if (!rdstls_write_authentication_request_with_password(rdstls, s)) return FALSE; } else if (settings->ServerAutoReconnectCookie != nullptr) { if (!rdstls_write_authentication_request_with_cookie(rdstls, s)) return FALSE; } else { WLog_Print(rdstls->log, WLOG_ERROR, "cannot authenticate with password or auto-reconnect cookie"); return FALSE; } break; case RDSTLS_STATE_AUTH_RSP: if (!rdstls_write_authentication_response(rdstls, s)) return FALSE; break; default: WLog_Print(rdstls->log, WLOG_ERROR, "Invalid rdstls state %s [%" PRIu32 "]", rdstls_get_state_str(state), state); return FALSE; } return (transport_write(rdstls->transport, s) >= 0); } static int rdstls_recv(WINPR_ATTR_UNUSED rdpTransport* transport, wStream* s, void* extra) { rdpRdstls* rdstls = (rdpRdstls*)extra; WINPR_ASSERT(transport); WINPR_ASSERT(s); WINPR_ASSERT(rdstls); if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4)) return -1; const UINT16 version = Stream_Get_UINT16(s); if (version != RDSTLS_VERSION_1) { WLog_Print(rdstls->log, WLOG_ERROR, "received invalid RDSTLS Version=0x%04" PRIX16 ", expected 0x%04" PRIX16, version, WINPR_CXX_COMPAT_CAST(UINT32, RDSTLS_VERSION_1)); return -1; } const UINT16 pduType = Stream_Get_UINT16(s); switch (pduType) { case RDSTLS_TYPE_CAPABILITIES: if (!rdstls_process_capabilities(rdstls, s)) return -1; break; case RDSTLS_TYPE_AUTHREQ: if (!rdstls_process_authentication_request(rdstls, s)) return -1; break; case RDSTLS_TYPE_AUTHRSP: if (!rdstls_process_authentication_response(rdstls, s)) return -1; break; default: WLog_Print(rdstls->log, WLOG_ERROR, "unknown RDSTLS PDU type [0x%04" PRIx16 "]", pduType); return -1; } return 1; } #define rdstls_check_state_requirements(rdstls, expected) \ rdstls_check_state_requirements_((rdstls), (expected), __FILE__, __func__, __LINE__) static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE expected, const char* file, const char* fkt, size_t line) { const RDSTLS_STATE current = rdstls_get_state(rdstls); if (current == expected) return TRUE; WINPR_ASSERT(rdstls); const DWORD log_level = WLOG_ERROR; if (WLog_IsLevelActive(rdstls->log, log_level)) WLog_PrintTextMessage(rdstls->log, log_level, line, file, fkt, "Unexpected rdstls state %s [%u], expected %s [%u]", rdstls_get_state_str(current), current, rdstls_get_state_str(expected), expected); return FALSE; } static BOOL rdstls_send_capabilities(rdpRdstls* rdstls) { BOOL rc = FALSE; if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES)) return FALSE; wStream* s = Stream_New(nullptr, 512); if (!s) goto fail; WINPR_ASSERT(rdstls); if (!rdstls_send(rdstls->transport, s, rdstls)) goto fail; rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ); fail: Stream_Free(s, TRUE); return rc; } static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls) { BOOL rc = FALSE; if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ)) return FALSE; wStream* s = Stream_New(nullptr, 4096); if (!s) goto fail; WINPR_ASSERT(rdstls); { const int res = transport_read_pdu(rdstls->transport, s); if (res < 0) goto fail; } { const int status = rdstls_recv(rdstls->transport, s, rdstls); if (status < 0) goto fail; } rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP); fail: Stream_Free(s, TRUE); return rc; } static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls) { BOOL rc = FALSE; if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP)) return FALSE; wStream* s = Stream_New(nullptr, 512); if (!s) goto fail; WINPR_ASSERT(rdstls); if (!rdstls_send(rdstls->transport, s, rdstls)) goto fail; rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL); fail: Stream_Free(s, TRUE); return rc; } static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls) { BOOL rc = FALSE; if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES)) return FALSE; wStream* s = Stream_New(nullptr, 512); if (!s) goto fail; WINPR_ASSERT(rdstls); { const int res = transport_read_pdu(rdstls->transport, s); if (res < 0) goto fail; } { const int status = rdstls_recv(rdstls->transport, s, rdstls); if (status < 0) goto fail; } rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ); fail: Stream_Free(s, TRUE); return rc; } static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls) { BOOL rc = FALSE; if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ)) return FALSE; wStream* s = Stream_New(nullptr, 4096); if (!s) goto fail; WINPR_ASSERT(rdstls); if (!rdstls_send(rdstls->transport, s, rdstls)) goto fail; rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP); fail: Stream_Free(s, TRUE); return rc; } static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls) { BOOL rc = FALSE; WINPR_ASSERT(rdstls); if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP)) return FALSE; wStream* s = Stream_New(nullptr, 512); if (!s) goto fail; { const int res = transport_read_pdu(rdstls->transport, s); if (res < 0) goto fail; } { const int status = rdstls_recv(rdstls->transport, s, rdstls); if (status < 0) goto fail; } rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL); fail: Stream_Free(s, TRUE); return rc; } static int rdstls_server_authenticate(rdpRdstls* rdstls) { WINPR_ASSERT(rdstls); if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES)) return -1; if (!rdstls_send_capabilities(rdstls)) return -1; if (!rdstls_recv_authentication_request(rdstls)) return -1; if (!rdstls_send_authentication_response(rdstls)) return -1; if (rdstls->resultCode != RDSTLS_RESULT_SUCCESS) return -1; return 1; } static int rdstls_client_authenticate(rdpRdstls* rdstls) { if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES)) return -1; if (!rdstls_recv_capabilities(rdstls)) return -1; if (!rdstls_send_authentication_request(rdstls)) return -1; if (!rdstls_recv_authentication_response(rdstls)) return -1; return 1; } /** * Authenticate using RDSTLS. * @param rdstls The RDSTLS instance to use * * @return 1 if authentication is successful */ int rdstls_authenticate(rdpRdstls* rdstls) { WINPR_ASSERT(rdstls); if (rdstls->server) return rdstls_server_authenticate(rdstls); else return rdstls_client_authenticate(rdstls); } static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s) { size_t pduLength = 0; switch (dataType) { case RDSTLS_DATA_PASSWORD_CREDS: { if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 redirGuidLength = Stream_Get_UINT16(s); if (Stream_GetRemainingLength(s) < redirGuidLength) return 0; Stream_Seek(s, redirGuidLength); if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 usernameLength = Stream_Get_UINT16(s); if (Stream_GetRemainingLength(s) < usernameLength) return 0; Stream_Seek(s, usernameLength); if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 domainLength = Stream_Get_UINT16(s); if (Stream_GetRemainingLength(s) < domainLength) return 0; Stream_Seek(s, domainLength); if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 passwordLength = Stream_Get_UINT16(s); pduLength = Stream_GetPosition(s) + passwordLength; } break; case RDSTLS_DATA_AUTORECONNECT_COOKIE: { if (Stream_GetRemainingLength(s) < 4) return 0; Stream_Seek(s, 4); if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 cookieLength = Stream_Get_UINT16(s); pduLength = Stream_GetPosition(s) + cookieLength; } break; default: WLog_Print(log, WLOG_ERROR, "invalid RDSLTS dataType"); return -1; } if (pduLength > SSIZE_MAX) return 0; return (SSIZE_T)pduLength; } SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream) { SSIZE_T pduLength = -1; wStream sbuffer = WINPR_C_ARRAY_INIT; wStream* s = Stream_StaticConstInit(&sbuffer, Stream_Buffer(stream), Stream_Length(stream)); if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 version = Stream_Get_UINT16(s); if (version != RDSTLS_VERSION_1) { WLog_Print(log, WLOG_ERROR, "invalid RDSTLS version"); return -1; } if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 pduType = Stream_Get_UINT16(s); switch (pduType) { case RDSTLS_TYPE_CAPABILITIES: pduLength = 8; break; case RDSTLS_TYPE_AUTHREQ: { if (Stream_GetRemainingLength(s) < 2) return 0; const UINT16 dataType = Stream_Get_UINT16(s); pduLength = rdstls_parse_pdu_data_type(log, dataType, s); } break; case RDSTLS_TYPE_AUTHRSP: pduLength = 10; break; default: WLog_Print(log, WLOG_ERROR, "invalid RDSTLS PDU type"); return -1; } return pduLength; }