diff --git a/libfreerdp/core/capabilities.c b/libfreerdp/core/capabilities.c index 837974565..6b839c02d 100644 --- a/libfreerdp/core/capabilities.c +++ b/libfreerdp/core/capabilities.c @@ -4258,56 +4258,16 @@ BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UIN return TRUE; } -BOOL rdp_recv_demand_active(rdpRdp* rdp, wStream* s) +BOOL rdp_recv_demand_active(rdpRdp* rdp, wStream* s, UINT16 pduSource, UINT16 length) { - UINT16 channelId; - UINT16 pduType; - UINT16 pduSource; - UINT16 length; - UINT16 lengthSourceDescriptor; - UINT16 lengthCombinedCapabilities; + UINT16 lengthSourceDescriptor = 0; + UINT16 lengthCombinedCapabilities = 0; WINPR_ASSERT(rdp); + WINPR_ASSERT(rdp->settings); WINPR_ASSERT(rdp->context); WINPR_ASSERT(s); - if (!rdp_recv_get_active_header(rdp, s, &channelId, &length)) - return FALSE; - - if (freerdp_shall_disconnect_context(rdp->context)) - return TRUE; - - if (!rdp_read_share_control_header(rdp, s, NULL, NULL, &pduType, &pduSource)) - return FALSE; - - if (pduType == PDU_TYPE_DATA) - { - /* - * We can receive a Save Session Info Data PDU containing a LogonErrorInfo - * structure at this point from the server to indicate a connection error. - */ - state_run_t rc = rdp_recv_data_pdu(rdp, s); - if (state_run_failed(rc)) - return FALSE; - - return FALSE; - } - - if (pduType != PDU_TYPE_DEMAND_ACTIVE) - { - if (pduType != PDU_TYPE_SERVER_REDIRECTION) - { - char buffer1[256] = { 0 }; - char buffer2[256] = { 0 }; - - WLog_ERR(TAG, "expected %s, got %s", - pdu_type_to_str(PDU_TYPE_DEMAND_ACTIVE, buffer1, sizeof(buffer1)), - pdu_type_to_str(pduType, buffer2, sizeof(buffer2))); - } - - return FALSE; - } - rdp->settings->PduSource = pduSource; if (!Stream_CheckAndLogRequiredLength(TAG, s, 8)) diff --git a/libfreerdp/core/capabilities.h b/libfreerdp/core/capabilities.h index 738ea37e2..b416a1640 100644 --- a/libfreerdp/core/capabilities.h +++ b/libfreerdp/core/capabilities.h @@ -159,7 +159,7 @@ FREERDP_LOCAL BOOL rdp_recv_get_active_header(rdpRdp* rdp, wStream* s, UINT16* pChannelId, UINT16* length); -FREERDP_LOCAL BOOL rdp_recv_demand_active(rdpRdp* rdp, wStream* s); +FREERDP_LOCAL BOOL rdp_recv_demand_active(rdpRdp* rdp, wStream* s, UINT16 pduSource, UINT16 length); FREERDP_LOCAL BOOL rdp_send_demand_active(rdpRdp* rdp); FREERDP_LOCAL BOOL rdp_recv_confirm_active(rdpRdp* rdp, wStream* s, UINT16 pduLength); FREERDP_LOCAL BOOL rdp_send_confirm_active(rdpRdp* rdp); diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 508e6da2a..3e7b32add 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -1245,33 +1245,32 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s) state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s) { - size_t pos; - UINT16 length; + UINT16 length = 0; + UINT16 channelId = 0; + UINT16 pduType = 0; + UINT16 pduSource = 0; WINPR_ASSERT(rdp); WINPR_ASSERT(s); WINPR_ASSERT(rdp->settings); - pos = Stream_GetPosition(s); + if (!rdp_recv_get_active_header(rdp, s, &channelId, &length)) + return STATE_RUN_FAILED; - if (!rdp_recv_demand_active(rdp, s)) + if (freerdp_shall_disconnect_context(rdp->context)) + return STATE_RUN_QUIT_SESSION; + + if (!rdp_read_share_control_header(rdp, s, NULL, NULL, &pduType, &pduSource)) + return STATE_RUN_FAILED; + + switch (pduType) { - state_run_t rc; - UINT16 channelId; - - Stream_SetPosition(s, pos); - if (!rdp_recv_get_active_header(rdp, s, &channelId, &length)) - return STATE_RUN_FAILED; - /* Was Stream_Seek(s, RDP_PACKET_HEADER_MAX_LENGTH); - * but the headers aren't always that length, - * so that could result in a bad offset. - */ - rc = rdp_recv_out_of_sequence_pdu(rdp, s); - if (state_run_failed(rc)) - return rc; - if (!tpkt_ensure_stream_consumed(s, length)) - return STATE_RUN_FAILED; - return rc; + case PDU_TYPE_DEMAND_ACTIVE: + if (!rdp_recv_demand_active(rdp, s, pduSource, length)) + return STATE_RUN_FAILED; + break; + default: + return rdp_recv_out_of_sequence_pdu(rdp, s, pduType, length); } return STATE_RUN_SUCCESS; diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index 58d771e72..a1fa16fa1 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1340,34 +1340,40 @@ state_run_t rdp_recv_message_channel_pdu(rdpRdp* rdp, wStream* s, UINT16 securit return STATE_RUN_SUCCESS; } -state_run_t rdp_recv_out_of_sequence_pdu(rdpRdp* rdp, wStream* s) +state_run_t rdp_recv_out_of_sequence_pdu(rdpRdp* rdp, wStream* s, UINT16 pduType, UINT16 length) { - UINT16 type; - UINT16 length; - UINT16 channelId; - + state_run_t rc; WINPR_ASSERT(rdp); - if (!rdp_read_share_control_header(rdp, s, &length, NULL, &type, &channelId)) - return STATE_RUN_FAILED; + switch (pduType) + { + case PDU_TYPE_DATA: + rc = rdp_recv_data_pdu(rdp, s); + break; + case PDU_TYPE_SERVER_REDIRECTION: + rc = rdp_recv_enhanced_security_redirection_packet(rdp, s); + break; + case PDU_TYPE_FLOW_RESPONSE: + case PDU_TYPE_FLOW_STOP: + case PDU_TYPE_FLOW_TEST: + rc = STATE_RUN_SUCCESS; + break; + default: + { + char buffer1[256] = { 0 }; + char buffer2[256] = { 0 }; - if (type == PDU_TYPE_DATA) - { - return rdp_recv_data_pdu(rdp, s); + WLog_Print(rdp->log, WLOG_ERROR, "expected %s, got %s", + pdu_type_to_str(PDU_TYPE_DEMAND_ACTIVE, buffer1, sizeof(buffer1)), + pdu_type_to_str(pduType, buffer2, sizeof(buffer2))); + rc = STATE_RUN_FAILED; + } + break; } - else if (type == PDU_TYPE_SERVER_REDIRECTION) - { - return rdp_recv_enhanced_security_redirection_packet(rdp, s); - } - else if (type == PDU_TYPE_FLOW_RESPONSE || type == PDU_TYPE_FLOW_STOP || - type == PDU_TYPE_FLOW_TEST) - { - return STATE_RUN_SUCCESS; - } - else - { + + if (!tpkt_ensure_stream_consumed(s, length)) return STATE_RUN_FAILED; - } + return rc; } BOOL rdp_read_flow_control_pdu(rdpRdp* rdp, wStream* s, UINT16* type, UINT16* channel_id) diff --git a/libfreerdp/core/rdp.h b/libfreerdp/core/rdp.h index 67e7ed7db..d4f1e44b8 100644 --- a/libfreerdp/core/rdp.h +++ b/libfreerdp/core/rdp.h @@ -243,7 +243,8 @@ FREERDP_LOCAL BOOL rdp_send_message_channel_pdu(rdpRdp* rdp, wStream* s, UINT16 FREERDP_LOCAL state_run_t rdp_recv_message_channel_pdu(rdpRdp* rdp, wStream* s, UINT16 securityFlags); -FREERDP_LOCAL state_run_t rdp_recv_out_of_sequence_pdu(rdpRdp* rdp, wStream* s); +FREERDP_LOCAL state_run_t rdp_recv_out_of_sequence_pdu(rdpRdp* rdp, wStream* s, UINT16 pduType, + UINT16 length); FREERDP_LOCAL state_run_t rdp_recv_callback(rdpTransport* transport, wStream* s, void* extra);