diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c old mode 100644 new mode 100755 index 03bbe28b4..ce288f337 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -75,57 +75,61 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) return TRUE; } -static wStream* rdg_receive_packet(rdpRdg* rdg) +static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, int size) { int status; + int readCount = 0; + + while (readCount < size) + { + status = BIO_read(tls->bio, buffer, size - readCount); + + if (status <= 0) + { + if (!BIO_should_retry(tls->bio)) + return FALSE; + + continue; + } + + readCount += status; + } + return TRUE; +} + +static wStream* rdg_receive_packet(rdpRdg* rdg) +{ wStream* s; - RdgPacketHeader* packet; - UINT32 readCount = 0; + size_t packetLength; + s = Stream_New(NULL, 1024); if (!s) return NULL; - packet = (RdgPacketHeader*) Stream_Buffer(s); - - while (readCount < sizeof(RdgPacketHeader)) + if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s), sizeof(RdgPacketHeader))) { - status = BIO_read(rdg->tlsOut->bio, Stream_Pointer(s), sizeof(RdgPacketHeader) - readCount); - - if (status < 0) - { - continue; - } - - readCount += status; - Stream_Seek(s, readCount); + Stream_Free(s, TRUE); + return NULL; } - if (Stream_Capacity(s) < packet->packetLength) - { - if (!Stream_EnsureCapacity(s, packet->packetLength)) - { - Stream_Free(s, TRUE); - return NULL; - } + packetLength = ((RdgPacketHeader*)Stream_Buffer(s))->packetLength; - packet = (RdgPacketHeader*) Stream_Buffer(s); + if (!Stream_EnsureCapacity(s, packetLength)) + { + Stream_Free(s, TRUE); + return NULL; } - while (readCount < packet->packetLength) + if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s) + sizeof(RdgPacketHeader), + packetLength - sizeof(RdgPacketHeader))) { - status = BIO_read(rdg->tlsOut->bio, Stream_Pointer(s), packet->packetLength - readCount); - - if (status < 0) - { - continue; - } - - readCount += status; - Stream_Seek(s, readCount); + Stream_Free(s, TRUE); + return NULL; } - Stream_SealLength(s); + Stream_SetLength(s, packetLength); + return s; }