diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index 330555ce5..f6c3f985a 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -691,7 +691,10 @@ static SSIZE_T rpc_client_default_out_channel_recv(rdpRpc* rpc) Stream_SetPosition(fragment, 0); /* Ignore errors, the PDU might not be complete. */ - (void)rts_read_common_pdu_header(fragment, &header, TRUE); + const rts_pdu_status_t rc = rts_read_common_pdu_header(fragment, &header, TRUE); + if (rc == RTS_PDU_FAIL) + return -1; + Stream_SetPosition(fragment, pos); if (header.frag_length > rpc->max_recv_frag) @@ -978,19 +981,19 @@ static void rpc_array_client_call_free(void* call) int rpc_in_channel_send_pdu(RpcInChannel* inChannel, const BYTE* buffer, size_t length) { - SSIZE_T status = 0; RpcClientCall* clientCall = NULL; - wStream s; + wStream s = Stream_Init(); rpcconn_common_hdr_t header = WINPR_C_ARRAY_INIT; - status = rpc_channel_write(&inChannel->common, buffer, length); + SSIZE_T status = rpc_channel_write(&inChannel->common, buffer, length); if (status <= 0) return -1; Stream_StaticConstInit(&s, buffer, length); - if (!rts_read_common_pdu_header(&s, &header, FALSE)) - return -1; + const rts_pdu_status_t rc = rts_read_common_pdu_header(&s, &header, FALSE); + if (rc != RTS_PDU_VALID) + return FALSE; clientCall = rpc_client_call_find_by_id(inChannel->common.client, header.call_id); if (!clientCall) diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index 958842a9a..130793acc 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -221,7 +221,8 @@ static BOOL rts_write_common_pdu_header(wStream* s, const rpcconn_common_hdr_t* return TRUE; } -BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, BOOL ignoreErrors) +rts_pdu_status_t rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, + BOOL ignoreErrors) { WINPR_ASSERT(s); WINPR_ASSERT(header); @@ -229,13 +230,13 @@ BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, BOOL i if (!ignoreErrors) { if (!Stream_CheckAndLogRequiredLength(TAG, s, sizeof(rpcconn_common_hdr_t))) - return FALSE; + return RTS_PDU_INCOMPLETE; } else { const size_t sz = Stream_GetRemainingLength(s); if (sz < sizeof(rpcconn_common_hdr_t)) - return FALSE; + return RTS_PDU_INCOMPLETE; } Stream_Read_UINT8(s, header->rpc_vers); @@ -252,22 +253,22 @@ BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, BOOL i if (!ignoreErrors) WLog_WARN(TAG, "Invalid header->frag_length of %" PRIu16 ", expected %" PRIuz, header->frag_length, sizeof(rpcconn_common_hdr_t)); - return FALSE; + return RTS_PDU_FAIL; } if (!ignoreErrors) { if (!Stream_CheckAndLogRequiredLength(TAG, s, header->frag_length - sizeof(rpcconn_common_hdr_t))) - return FALSE; + return RTS_PDU_INCOMPLETE; } else { const size_t sz2 = Stream_GetRemainingLength(s); if (sz2 < header->frag_length - sizeof(rpcconn_common_hdr_t)) - return FALSE; + return RTS_PDU_INCOMPLETE; } - return TRUE; + return RTS_PDU_VALID; } static BOOL rts_read_auth_verifier_no_checks(wStream* s, auth_verifier_co_t* auth, @@ -1189,7 +1190,8 @@ BOOL rts_read_pdu_header_ex(wStream* s, rpcconn_hdr_t* header, BOOL silent) WINPR_ASSERT(s); WINPR_ASSERT(header); - if (!rts_read_common_pdu_header(s, &header->common, silent)) + const rts_pdu_status_t status = rts_read_common_pdu_header(s, &header->common, silent); + if (status != RTS_PDU_VALID) return FALSE; WLog_DBG(TAG, "Reading PDU type %s", rts_pdu_ptype_to_string(header->common.ptype)); diff --git a/libfreerdp/core/gateway/rts.h b/libfreerdp/core/gateway/rts.h index e73e7fd66..b99153a33 100644 --- a/libfreerdp/core/gateway/rts.h +++ b/libfreerdp/core/gateway/rts.h @@ -93,9 +93,16 @@ FREERDP_LOCAL BOOL rts_read_pdu_header_ex(wStream* s, rpcconn_hdr_t* header, BOO FREERDP_LOCAL void rts_free_pdu_header(rpcconn_hdr_t* header, BOOL allocated); +typedef enum +{ + RTS_PDU_FAIL = -1, + RTS_PDU_INCOMPLETE = 0, + RTS_PDU_VALID = 1 +} rts_pdu_status_t; + WINPR_ATTR_NODISCARD -FREERDP_LOCAL BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, - BOOL ignoreErrors); +FREERDP_LOCAL rts_pdu_status_t rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header, + BOOL ignoreErrors); WINPR_ATTR_NODISCARD FREERDP_LOCAL BOOL rts_command_length(UINT32 CommandType, wStream* s, size_t* length, BOOL silent);