mirror of
https://github.com/morgan9e/FreeRDP
synced 2026-04-15 00:44:19 +09:00
Merge pull request #10877 from akallabeth/websocket-simplify
[core,gateway] refactor websocket support
This commit is contained in:
@@ -159,7 +159,8 @@ static int freerdp_connect_begin(freerdp* instance)
|
||||
WINPR_ASSERT(context);
|
||||
freerdp_set_last_error_if_not(context, FREERDP_ERROR_PRE_CONNECT_FAILED);
|
||||
|
||||
WLog_Print(context->log, WLOG_ERROR, "freerdp_pre_connect failed");
|
||||
WLog_Print(context->log, WLOG_ERROR, "freerdp_pre_connect failed: %s",
|
||||
rdp_client_connection_state_string(instance->ConnectionCallbackState));
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ typedef struct
|
||||
union context
|
||||
{
|
||||
http_encoding_chunked_context chunked;
|
||||
websocket_context websocket;
|
||||
websocket_context* websocket;
|
||||
} context;
|
||||
} rdg_http_encoding_context;
|
||||
|
||||
@@ -327,11 +327,8 @@ static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket)
|
||||
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
|
||||
{
|
||||
if (rdg->transferEncoding.isWebsocketTransport)
|
||||
{
|
||||
if (rdg->transferEncoding.context.websocket.closeSent)
|
||||
return FALSE;
|
||||
return websocket_write_wstream(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
|
||||
}
|
||||
return websocket_context_write_wstream(rdg->transferEncoding.context.websocket,
|
||||
rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
|
||||
|
||||
return rdg_write_chunked(rdg->tlsIn->bio, sPacket);
|
||||
}
|
||||
@@ -344,9 +341,7 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
return -1;
|
||||
|
||||
if (encodingContext->isWebsocketTransport)
|
||||
{
|
||||
return websocket_read(bio, pBuffer, size, &encodingContext->context.websocket);
|
||||
}
|
||||
return websocket_context_read(encodingContext->context.websocket, bio, pBuffer, size);
|
||||
|
||||
switch (encodingContext->httpTransferEncoding)
|
||||
{
|
||||
@@ -1470,9 +1465,11 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
|
||||
}
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
rdg->transferEncoding.isWebsocketTransport = TRUE;
|
||||
rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin;
|
||||
rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL;
|
||||
if (!websocket_context_reset(rdg->transferEncoding.context.websocket))
|
||||
return FALSE;
|
||||
|
||||
if (rdg->extAuth == HTTP_EXTENDED_AUTH_SSPI_NTLM)
|
||||
{
|
||||
/* create a new auth context for SSPI_NTLM. This must be done after the last
|
||||
@@ -1607,86 +1604,36 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
|
||||
|
||||
static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
|
||||
{
|
||||
size_t fullLen = 0;
|
||||
int status = 0;
|
||||
wStream* sWS = NULL;
|
||||
|
||||
uint32_t maskingKey = 0;
|
||||
BYTE* maskingKeyByte1 = (BYTE*)&maskingKey;
|
||||
BYTE* maskingKeyByte2 = maskingKeyByte1 + 1;
|
||||
BYTE* maskingKeyByte3 = maskingKeyByte1 + 2;
|
||||
BYTE* maskingKeyByte4 = maskingKeyByte1 + 3;
|
||||
|
||||
winpr_RAND(&maskingKey, 4);
|
||||
|
||||
WINPR_ASSERT(rdg);
|
||||
if (isize < 0)
|
||||
return -1;
|
||||
|
||||
const size_t payloadSize = (size_t)isize + 10;
|
||||
union
|
||||
{
|
||||
UINT32 u32;
|
||||
UINT8 u8[4];
|
||||
} maskingKey;
|
||||
|
||||
if (payloadSize < 1)
|
||||
return 0;
|
||||
|
||||
if (payloadSize < 126)
|
||||
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
|
||||
else if (payloadSize < 0x10000)
|
||||
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
|
||||
else
|
||||
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
|
||||
|
||||
sWS = Stream_New(NULL, fullLen);
|
||||
wStream* sWS =
|
||||
websocket_context_packet_new(payloadSize, WebsocketBinaryOpcode, &maskingKey.u32);
|
||||
if (!sWS)
|
||||
return FALSE;
|
||||
|
||||
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode);
|
||||
if (payloadSize < 126)
|
||||
Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT);
|
||||
else if (payloadSize < 0x10000)
|
||||
{
|
||||
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
|
||||
Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
|
||||
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
|
||||
Stream_Write_UINT32_BE(sWS, 0);
|
||||
Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize);
|
||||
}
|
||||
Stream_Write_UINT32(sWS, maskingKey);
|
||||
|
||||
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */
|
||||
Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */
|
||||
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */
|
||||
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Type */
|
||||
Stream_Write_UINT16(sWS, 0 ^ (maskingKey.u8[2] | maskingKey.u8[3] << 8)); /* Reserved */
|
||||
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey.u32); /* Packet length */
|
||||
Stream_Write_UINT16(sWS,
|
||||
(UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */
|
||||
(UINT16)isize ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Data size */
|
||||
|
||||
/* masking key is now off by 2 bytes. fix that */
|
||||
maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16);
|
||||
maskingKey.u32 = (maskingKey.u32 & 0xffff) << 16 | (maskingKey.u32 >> 16);
|
||||
|
||||
/* mask as much as possible with 32bit access */
|
||||
size_t streamPos = 0;
|
||||
for (; streamPos + 4 <= (size_t)isize; streamPos += 4)
|
||||
{
|
||||
uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey;
|
||||
Stream_Write_UINT32(sWS, masked);
|
||||
}
|
||||
|
||||
/* mask the rest byte by byte */
|
||||
for (; streamPos < (size_t)isize; streamPos++)
|
||||
{
|
||||
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
|
||||
BYTE masked = *((buf + streamPos)) ^ *partialMask;
|
||||
Stream_Write_UINT8(sWS, masked);
|
||||
}
|
||||
|
||||
Stream_SealLength(sWS);
|
||||
|
||||
status = freerdp_tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS));
|
||||
Stream_Free(sWS, TRUE);
|
||||
|
||||
if (status < 0)
|
||||
return status;
|
||||
WINPR_ASSERT(rdg->tlsOut);
|
||||
wStream sPacket = { 0 };
|
||||
Stream_StaticConstInit(&sPacket, buf, (size_t)isize);
|
||||
if (!websocket_context_mask_and_send(rdg->tlsOut->bio, sWS, &sPacket, maskingKey.u32))
|
||||
return -1;
|
||||
|
||||
return isize;
|
||||
}
|
||||
@@ -1733,12 +1680,9 @@ static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize
|
||||
|
||||
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
|
||||
{
|
||||
WINPR_ASSERT(rdg);
|
||||
if (rdg->transferEncoding.isWebsocketTransport)
|
||||
{
|
||||
if (rdg->transferEncoding.context.websocket.closeSent == TRUE)
|
||||
return -1;
|
||||
return rdg_write_websocket_data_packet(rdg, buf, isize);
|
||||
}
|
||||
else
|
||||
return rdg_write_chunked_data_packet(rdg, buf, isize);
|
||||
}
|
||||
@@ -2198,15 +2142,13 @@ static BIO_METHOD* BIO_s_rdg(void)
|
||||
|
||||
rdpRdg* rdg_new(rdpContext* context)
|
||||
{
|
||||
rdpRdg* rdg = NULL;
|
||||
|
||||
if (!context)
|
||||
return NULL;
|
||||
|
||||
rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg));
|
||||
rdpRdg* rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg));
|
||||
if (!rdg)
|
||||
return NULL;
|
||||
|
||||
if (rdg)
|
||||
{
|
||||
rdg->log = WLog_Get(TAG);
|
||||
rdg->state = RDG_CLIENT_STATE_INITIAL;
|
||||
rdg->context = context;
|
||||
@@ -2244,8 +2186,8 @@ rdpRdg* rdg_new(rdpContext* context)
|
||||
!http_context_set_rdg_connection_id(rdg->http, &rdg->guid) ||
|
||||
!http_context_set_rdg_correlation_id(rdg->http, &rdg->guid) ||
|
||||
!http_context_enable_websocket_upgrade(
|
||||
rdg->http, freerdp_settings_get_bool(rdg->context->settings,
|
||||
FreeRDP_GatewayHttpUseWebsockets)))
|
||||
rdg->http,
|
||||
freerdp_settings_get_bool(rdg->context->settings, FreeRDP_GatewayHttpUseWebsockets)))
|
||||
{
|
||||
goto rdg_alloc_error;
|
||||
}
|
||||
@@ -2282,7 +2224,10 @@ rdpRdg* rdg_new(rdpContext* context)
|
||||
|
||||
rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
|
||||
rdg->transferEncoding.isWebsocketTransport = FALSE;
|
||||
}
|
||||
|
||||
rdg->transferEncoding.context.websocket = websocket_context_new();
|
||||
if (!rdg->transferEncoding.context.websocket)
|
||||
goto rdg_alloc_error;
|
||||
|
||||
return rdg;
|
||||
rdg_alloc_error:
|
||||
@@ -2308,14 +2253,10 @@ void rdg_free(rdpRdg* rdg)
|
||||
|
||||
DeleteCriticalSection(&rdg->writeSection);
|
||||
|
||||
if (rdg->transferEncoding.isWebsocketTransport)
|
||||
{
|
||||
if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL)
|
||||
Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE);
|
||||
}
|
||||
|
||||
smartcardCertInfo_Free(rdg->smartcard);
|
||||
|
||||
websocket_context_free(rdg->transferEncoding.context.websocket);
|
||||
|
||||
free(rdg);
|
||||
}
|
||||
|
||||
|
||||
@@ -23,25 +23,67 @@
|
||||
|
||||
#define TAG FREERDP_TAG("core.gateway.websocket")
|
||||
|
||||
BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode)
|
||||
struct s_websocket_context
|
||||
{
|
||||
size_t fullLen = 0;
|
||||
int status = 0;
|
||||
wStream* sWS = NULL;
|
||||
size_t payloadLength;
|
||||
uint32_t maskingKey;
|
||||
BOOL masking;
|
||||
BOOL closeSent;
|
||||
BYTE opcode;
|
||||
BYTE fragmentOriginalOpcode;
|
||||
BYTE lengthAndMaskPosition;
|
||||
WEBSOCKET_STATE state;
|
||||
wStream* responseStreamBuffer;
|
||||
};
|
||||
|
||||
uint32_t maskingKey = 0;
|
||||
static int websocket_write_all(BIO* bio, const BYTE* data, size_t length);
|
||||
|
||||
size_t streamPos = 0;
|
||||
BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
|
||||
UINT32 maskingKey)
|
||||
{
|
||||
const size_t len = Stream_Length(sDataPacket);
|
||||
Stream_SetPosition(sDataPacket, 0);
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(sPacket);
|
||||
|
||||
const size_t len = Stream_Length(sPacket);
|
||||
Stream_SetPosition(sPacket, 0);
|
||||
|
||||
if (len > INT_MAX)
|
||||
if (!Stream_EnsureRemainingCapacity(sPacket, len))
|
||||
return FALSE;
|
||||
|
||||
/* mask as much as possible with 32bit access */
|
||||
size_t streamPos = 0;
|
||||
for (; streamPos + 4 <= len; streamPos += 4)
|
||||
{
|
||||
const uint32_t data = Stream_Get_UINT32(sDataPacket);
|
||||
Stream_Write_UINT32(sPacket, data ^ maskingKey);
|
||||
}
|
||||
|
||||
/* mask the rest byte by byte */
|
||||
for (; streamPos < len; streamPos++)
|
||||
{
|
||||
BYTE data = 0;
|
||||
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
|
||||
Stream_Read_UINT8(sDataPacket, data);
|
||||
Stream_Write_UINT8(sPacket, data ^ *partialMask);
|
||||
}
|
||||
|
||||
Stream_SealLength(sPacket);
|
||||
|
||||
ERR_clear_error();
|
||||
const size_t size = Stream_Length(sPacket);
|
||||
const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
|
||||
Stream_Free(sPacket, TRUE);
|
||||
|
||||
if ((status < 0) || ((size_t)status != size))
|
||||
return FALSE;
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
|
||||
{
|
||||
WINPR_ASSERT(pMaskingKey);
|
||||
if (len > INT_MAX)
|
||||
return NULL;
|
||||
|
||||
size_t fullLen = 0;
|
||||
if (len < 126)
|
||||
fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
|
||||
else if (len < 0x10000)
|
||||
@@ -49,13 +91,14 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode
|
||||
else
|
||||
fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
|
||||
|
||||
sWS = Stream_New(NULL, fullLen);
|
||||
wStream* sWS = Stream_New(NULL, fullLen);
|
||||
if (!sWS)
|
||||
return FALSE;
|
||||
return NULL;
|
||||
|
||||
UINT32 maskingKey = 0;
|
||||
winpr_RAND(&maskingKey, sizeof(maskingKey));
|
||||
|
||||
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
|
||||
Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
|
||||
if (len < 126)
|
||||
Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
|
||||
else if (len < 0x10000)
|
||||
@@ -70,39 +113,34 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode
|
||||
Stream_Write_UINT32_BE(sWS, (UINT32)len);
|
||||
}
|
||||
Stream_Write_UINT32(sWS, maskingKey);
|
||||
|
||||
/* mask as much as possible with 32bit access */
|
||||
for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
|
||||
{
|
||||
uint32_t data = 0;
|
||||
Stream_Read_UINT32(sPacket, data);
|
||||
Stream_Write_UINT32(sWS, data ^ maskingKey);
|
||||
*pMaskingKey = maskingKey;
|
||||
return sWS;
|
||||
}
|
||||
|
||||
/* mask the rest byte by byte */
|
||||
for (; streamPos < len; streamPos++)
|
||||
BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket,
|
||||
WEBSOCKET_OPCODE opcode)
|
||||
{
|
||||
BYTE data = 0;
|
||||
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
|
||||
Stream_Read_UINT8(sPacket, data);
|
||||
Stream_Write_UINT8(sWS, data ^ *partialMask);
|
||||
}
|
||||
WINPR_ASSERT(context);
|
||||
|
||||
Stream_SealLength(sWS);
|
||||
|
||||
ERR_clear_error();
|
||||
const size_t size = Stream_Length(sWS);
|
||||
if (size <= INT32_MAX)
|
||||
status = BIO_write(bio, Stream_Buffer(sWS), (int)size);
|
||||
Stream_Free(sWS, TRUE);
|
||||
|
||||
if (status != (SSIZE_T)fullLen)
|
||||
if (context->closeSent)
|
||||
return FALSE;
|
||||
|
||||
return TRUE;
|
||||
if (opcode == WebsocketCloseOpcode)
|
||||
context->closeSent = TRUE;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(sPacket);
|
||||
|
||||
const size_t len = Stream_Length(sPacket);
|
||||
uint32_t maskingKey = 0;
|
||||
wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
|
||||
if (!sWS)
|
||||
return FALSE;
|
||||
|
||||
return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
|
||||
}
|
||||
|
||||
static int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
|
||||
int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
|
||||
{
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(data);
|
||||
@@ -140,75 +178,19 @@ static int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
|
||||
return (int)length;
|
||||
}
|
||||
|
||||
int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode)
|
||||
int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize,
|
||||
WEBSOCKET_OPCODE opcode)
|
||||
{
|
||||
size_t fullLen = 0;
|
||||
int status = 0;
|
||||
wStream* sWS = NULL;
|
||||
|
||||
uint32_t maskingKey = 0;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(buf);
|
||||
|
||||
winpr_RAND(&maskingKey, sizeof(maskingKey));
|
||||
|
||||
if (isize < 0)
|
||||
return -1;
|
||||
|
||||
const size_t payloadSize = (size_t)isize;
|
||||
if (payloadSize < 126)
|
||||
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
|
||||
else if (payloadSize < 0x10000)
|
||||
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
|
||||
else
|
||||
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
|
||||
|
||||
sWS = Stream_New(NULL, fullLen);
|
||||
if (!sWS)
|
||||
return FALSE;
|
||||
|
||||
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
|
||||
if (payloadSize < 126)
|
||||
Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT);
|
||||
else if (payloadSize < 0x10000)
|
||||
{
|
||||
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
|
||||
Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
|
||||
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
|
||||
Stream_Write_UINT32_BE(sWS, 0);
|
||||
Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize);
|
||||
}
|
||||
Stream_Write_UINT32(sWS, maskingKey);
|
||||
|
||||
/* mask as much as possible with 32bit access */
|
||||
size_t streamPos = 0;
|
||||
for (; streamPos + 4 <= payloadSize; streamPos += 4)
|
||||
{
|
||||
uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey;
|
||||
Stream_Write_UINT32(sWS, masked);
|
||||
}
|
||||
|
||||
/* mask the rest byte by byte */
|
||||
for (; streamPos < payloadSize; streamPos++)
|
||||
{
|
||||
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
|
||||
BYTE masked = *((buf + streamPos)) ^ *partialMask;
|
||||
Stream_Write_UINT8(sWS, masked);
|
||||
}
|
||||
|
||||
Stream_SealLength(sWS);
|
||||
|
||||
status = websocket_write_all(bio, Stream_Buffer(sWS), Stream_Length(sWS));
|
||||
Stream_Free(sWS, TRUE);
|
||||
|
||||
if (status < 0)
|
||||
return status;
|
||||
|
||||
wStream sbuffer = { 0 };
|
||||
wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize);
|
||||
if (!websocket_context_write_wstream(context, bio, s, opcode))
|
||||
return -2;
|
||||
return isize;
|
||||
}
|
||||
|
||||
@@ -234,10 +216,10 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
|
||||
ERR_clear_error();
|
||||
status = BIO_read(bio, pBuffer, (int)rlen);
|
||||
if (status <= 0)
|
||||
if ((status <= 0) || ((size_t)status > encodingContext->payloadLength))
|
||||
return status;
|
||||
|
||||
encodingContext->payloadLength -= status;
|
||||
encodingContext->payloadLength -= (size_t)status;
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
encodingContext->state = WebsocketStateOpcodeAndFin;
|
||||
@@ -245,47 +227,21 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
return status;
|
||||
}
|
||||
|
||||
static int websocket_read_discard(BIO* bio, websocket_context* encodingContext)
|
||||
static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
|
||||
{
|
||||
char _dummy[256] = { 0 };
|
||||
int status = 0;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(encodingContext);
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
{
|
||||
encodingContext->state = WebsocketStateOpcodeAndFin;
|
||||
return 0;
|
||||
}
|
||||
|
||||
ERR_clear_error();
|
||||
status = BIO_read(bio, _dummy, sizeof(_dummy));
|
||||
if (status <= 0)
|
||||
return status;
|
||||
|
||||
encodingContext->payloadLength -= status;
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
encodingContext->state = WebsocketStateOpcodeAndFin;
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encodingContext)
|
||||
{
|
||||
int status = 0;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
wStream* s = encodingContext->responseStreamBuffer;
|
||||
WINPR_ASSERT(s);
|
||||
WINPR_ASSERT(encodingContext);
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
{
|
||||
encodingContext->state = WebsocketStateOpcodeAndFin;
|
||||
return 0;
|
||||
}
|
||||
if (Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
|
||||
|
||||
if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
|
||||
{
|
||||
WLog_WARN(TAG,
|
||||
"wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]",
|
||||
@@ -293,111 +249,33 @@ static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encod
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t rlen = encodingContext->payloadLength;
|
||||
if (rlen > INT32_MAX)
|
||||
const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
|
||||
encodingContext);
|
||||
if (status < 0)
|
||||
return status;
|
||||
|
||||
if (!Stream_SafeSeek(s, (size_t)status))
|
||||
return -1;
|
||||
|
||||
ERR_clear_error();
|
||||
status = BIO_read(bio, Stream_Pointer(s), (int)rlen);
|
||||
if (status <= 0)
|
||||
return status;
|
||||
|
||||
Stream_Seek(s, status);
|
||||
|
||||
encodingContext->payloadLength -= status;
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
{
|
||||
encodingContext->state = WebsocketStateOpcodeAndFin;
|
||||
Stream_SealLength(s);
|
||||
Stream_SetPosition(s, 0);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static BOOL websocket_reply_close(BIO* bio, wStream* s)
|
||||
static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s)
|
||||
{
|
||||
/* write back close */
|
||||
wStream* closeFrame = NULL;
|
||||
uint16_t maskingKey1 = 0;
|
||||
uint16_t maskingKey2 = 0;
|
||||
size_t closeDataLen = 0;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
|
||||
closeDataLen = 0;
|
||||
if (s != NULL && Stream_Length(s) >= 2)
|
||||
closeDataLen = 2;
|
||||
|
||||
closeFrame = Stream_New(NULL, 6 + closeDataLen);
|
||||
if (!closeFrame)
|
||||
return FALSE;
|
||||
|
||||
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketCloseOpcode);
|
||||
Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */
|
||||
winpr_RAND(&maskingKey1, sizeof(maskingKey1));
|
||||
winpr_RAND(&maskingKey2, sizeof(maskingKey2));
|
||||
Stream_Write_UINT16(closeFrame, maskingKey1);
|
||||
Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */
|
||||
|
||||
if (closeDataLen == 2)
|
||||
{
|
||||
uint16_t data = 0;
|
||||
Stream_Read_UINT16(s, data);
|
||||
Stream_Write_UINT16(closeFrame, data ^ maskingKey1);
|
||||
}
|
||||
Stream_SealLength(closeFrame);
|
||||
|
||||
const size_t rlen = Stream_Length(closeFrame);
|
||||
|
||||
int status = -1;
|
||||
if (rlen <= INT32_MAX)
|
||||
{
|
||||
ERR_clear_error();
|
||||
status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen);
|
||||
}
|
||||
Stream_Free(closeFrame, TRUE);
|
||||
|
||||
/* server MUST close socket now. The server is not allowed anymore to
|
||||
* send frames but if he does, nothing bad would happen */
|
||||
if (status < 0)
|
||||
return FALSE;
|
||||
return TRUE;
|
||||
return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
|
||||
}
|
||||
|
||||
static BOOL websocket_reply_pong(BIO* bio, wStream* s)
|
||||
static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s)
|
||||
{
|
||||
wStream* closeFrame = NULL;
|
||||
uint32_t maskingKey = 0;
|
||||
|
||||
WINPR_ASSERT(bio);
|
||||
WINPR_ASSERT(s);
|
||||
|
||||
if (s != NULL)
|
||||
return websocket_write_wstream(bio, s, WebsocketPongOpcode);
|
||||
if (Stream_GetPosition(s) != 0)
|
||||
return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
|
||||
|
||||
closeFrame = Stream_New(NULL, 6);
|
||||
if (!closeFrame)
|
||||
return FALSE;
|
||||
|
||||
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
|
||||
Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */
|
||||
winpr_RAND(&maskingKey, sizeof(maskingKey));
|
||||
Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */
|
||||
Stream_SealLength(closeFrame);
|
||||
|
||||
const size_t rlen = Stream_Length(closeFrame);
|
||||
int status = -1;
|
||||
if (rlen <= INT32_MAX)
|
||||
{
|
||||
ERR_clear_error();
|
||||
status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen);
|
||||
}
|
||||
Stream_Free(closeFrame, TRUE);
|
||||
|
||||
if (status < 0)
|
||||
return FALSE;
|
||||
return TRUE;
|
||||
return websocket_reply_close(bio, context, NULL);
|
||||
}
|
||||
|
||||
static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
@@ -409,7 +287,7 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
WINPR_ASSERT(pBuffer);
|
||||
WINPR_ASSERT(encodingContext);
|
||||
|
||||
BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
|
||||
const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
|
||||
? encodingContext->fragmentOriginalOpcode & 0xf
|
||||
: encodingContext->opcode & 0xf);
|
||||
|
||||
@@ -425,60 +303,55 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
}
|
||||
case WebsocketPingOpcode:
|
||||
{
|
||||
if (encodingContext->responseStreamBuffer == NULL)
|
||||
encodingContext->responseStreamBuffer =
|
||||
Stream_New(NULL, encodingContext->payloadLength);
|
||||
|
||||
status =
|
||||
websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
|
||||
status = websocket_read_wstream(bio, encodingContext);
|
||||
if (status < 0)
|
||||
return status;
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
{
|
||||
if (!encodingContext->closeSent)
|
||||
websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
|
||||
|
||||
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
|
||||
encodingContext->responseStreamBuffer = NULL;
|
||||
websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
|
||||
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case WebsocketPongOpcode:
|
||||
{
|
||||
status = websocket_read_wstream(bio, encodingContext);
|
||||
if (status < 0)
|
||||
return status;
|
||||
/* We don´t care about pong response data, discard. */
|
||||
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
|
||||
}
|
||||
break;
|
||||
case WebsocketCloseOpcode:
|
||||
{
|
||||
if (encodingContext->responseStreamBuffer == NULL)
|
||||
encodingContext->responseStreamBuffer =
|
||||
Stream_New(NULL, encodingContext->payloadLength);
|
||||
|
||||
status =
|
||||
websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
|
||||
status = websocket_read_wstream(bio, encodingContext);
|
||||
if (status < 0)
|
||||
return status;
|
||||
|
||||
if (encodingContext->payloadLength == 0)
|
||||
{
|
||||
websocket_reply_close(bio, encodingContext->responseStreamBuffer);
|
||||
websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
|
||||
encodingContext->closeSent = TRUE;
|
||||
|
||||
if (encodingContext->responseStreamBuffer)
|
||||
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
|
||||
encodingContext->responseStreamBuffer = NULL;
|
||||
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
|
||||
WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode);
|
||||
|
||||
status = websocket_read_discard(bio, encodingContext);
|
||||
status = websocket_read_wstream(bio, encodingContext);
|
||||
if (status < 0)
|
||||
return status;
|
||||
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
|
||||
break;
|
||||
}
|
||||
/* return how many bytes have been written to pBuffer.
|
||||
* Only WebsocketBinaryOpcode writes into it and it returns directly */
|
||||
return 0;
|
||||
}
|
||||
|
||||
int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* encodingContext)
|
||||
int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size)
|
||||
{
|
||||
int status = 0;
|
||||
int effectiveDataLen = 0;
|
||||
@@ -493,7 +366,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
|
||||
{
|
||||
case WebsocketStateOpcodeAndFin:
|
||||
{
|
||||
BYTE buffer[1];
|
||||
BYTE buffer[1] = { 0 };
|
||||
|
||||
ERR_clear_error();
|
||||
status = BIO_read(bio, (char*)buffer, sizeof(buffer));
|
||||
if (status <= 0)
|
||||
@@ -508,8 +382,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
|
||||
break;
|
||||
case WebsocketStateLengthAndMasking:
|
||||
{
|
||||
BYTE buffer[1];
|
||||
BYTE len = 0;
|
||||
BYTE buffer[1] = { 0 };
|
||||
|
||||
ERR_clear_error();
|
||||
status = BIO_read(bio, (char*)buffer, sizeof(buffer));
|
||||
if (status <= 0)
|
||||
@@ -518,7 +392,7 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
|
||||
encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
|
||||
encodingContext->lengthAndMaskPosition = 0;
|
||||
encodingContext->payloadLength = 0;
|
||||
len = buffer[0] & 0x7f;
|
||||
const BYTE len = buffer[0] & 0x7f;
|
||||
if (len < 126)
|
||||
{
|
||||
encodingContext->payloadLength = len;
|
||||
@@ -534,8 +408,9 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
|
||||
case WebsocketStateShortLength:
|
||||
case WebsocketStateLongLength:
|
||||
{
|
||||
BYTE buffer[1];
|
||||
BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
|
||||
BYTE buffer[1] = { 0 };
|
||||
const BYTE lenLength =
|
||||
(encodingContext->state == WebsocketStateShortLength ? 2 : 8);
|
||||
while (encodingContext->lengthAndMaskPosition < lenLength)
|
||||
{
|
||||
ERR_clear_error();
|
||||
@@ -577,3 +452,39 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
|
||||
}
|
||||
/* should be unreachable */
|
||||
}
|
||||
|
||||
websocket_context* websocket_context_new(void)
|
||||
{
|
||||
websocket_context* context = calloc(1, sizeof(websocket_context));
|
||||
if (!context)
|
||||
goto fail;
|
||||
|
||||
context->responseStreamBuffer = Stream_New(NULL, 1024);
|
||||
if (!context->responseStreamBuffer)
|
||||
goto fail;
|
||||
|
||||
if (!websocket_context_reset(context))
|
||||
goto fail;
|
||||
|
||||
return context;
|
||||
fail:
|
||||
websocket_context_free(context);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void websocket_context_free(websocket_context* context)
|
||||
{
|
||||
if (!context)
|
||||
return;
|
||||
|
||||
Stream_Free(context->responseStreamBuffer, TRUE);
|
||||
free(context);
|
||||
}
|
||||
|
||||
BOOL websocket_context_reset(websocket_context* context)
|
||||
{
|
||||
WINPR_ASSERT(context);
|
||||
|
||||
context->state = WebsocketStateOpcodeAndFin;
|
||||
return Stream_SetPosition(context->responseStreamBuffer, 0);
|
||||
}
|
||||
|
||||
@@ -50,22 +50,27 @@ typedef enum
|
||||
WebSocketStatePayload,
|
||||
} WEBSOCKET_STATE;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
size_t payloadLength;
|
||||
uint32_t maskingKey;
|
||||
BOOL masking;
|
||||
BOOL closeSent;
|
||||
BYTE opcode;
|
||||
BYTE fragmentOriginalOpcode;
|
||||
BYTE lengthAndMaskPosition;
|
||||
WEBSOCKET_STATE state;
|
||||
wStream* responseStreamBuffer;
|
||||
} websocket_context;
|
||||
typedef struct s_websocket_context websocket_context;
|
||||
|
||||
FREERDP_LOCAL BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode);
|
||||
FREERDP_LOCAL int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode);
|
||||
FREERDP_LOCAL int websocket_read(BIO* bio, BYTE* pBuffer, size_t size,
|
||||
websocket_context* encodingContext);
|
||||
FREERDP_LOCAL void websocket_context_free(websocket_context* context);
|
||||
|
||||
WINPR_ATTR_MALLOC(websocket_context_free, 1)
|
||||
FREERDP_LOCAL websocket_context* websocket_context_new(void);
|
||||
|
||||
FREERDP_LOCAL BOOL websocket_context_reset(websocket_context* context);
|
||||
|
||||
FREERDP_LOCAL BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio,
|
||||
wStream* sPacket, WEBSOCKET_OPCODE opcode);
|
||||
FREERDP_LOCAL int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf,
|
||||
int isize, WEBSOCKET_OPCODE opcode);
|
||||
FREERDP_LOCAL int websocket_context_read(websocket_context* encodingContext, BIO* bio,
|
||||
BYTE* pBuffer, size_t size);
|
||||
|
||||
WINPR_ATTR_MALLOC(Stream_Free, 1)
|
||||
FREERDP_LOCAL wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode,
|
||||
UINT32* pMaskingKey);
|
||||
|
||||
FREERDP_LOCAL BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
|
||||
UINT32 maskingKey);
|
||||
|
||||
#endif /* FREERDP_LIB_CORE_GATEWAY_WEBSOCKET_H */
|
||||
|
||||
@@ -65,7 +65,7 @@ struct rdp_wst
|
||||
char* gwhostname;
|
||||
uint16_t gwport;
|
||||
char* gwpath;
|
||||
websocket_context wscontext;
|
||||
websocket_context* wscontext;
|
||||
};
|
||||
|
||||
static const char arm_query_param[] = "%s%cClmTk=Bearer%%20%s&X-MS-User-Agent=FreeRDP%%2F3.0";
|
||||
@@ -494,11 +494,7 @@ BOOL wst_connect(rdpWst* wst, DWORD timeout)
|
||||
return FALSE;
|
||||
|
||||
if (isWebsocket)
|
||||
{
|
||||
wst->wscontext.state = WebsocketStateOpcodeAndFin;
|
||||
wst->wscontext.responseStreamBuffer = NULL;
|
||||
return TRUE;
|
||||
}
|
||||
return websocket_context_reset(wst->wscontext);
|
||||
else
|
||||
{
|
||||
char buffer[64] = { 0 };
|
||||
@@ -537,7 +533,8 @@ static int wst_bio_write(BIO* bio, const char* buf, int num)
|
||||
WINPR_ASSERT(wst);
|
||||
BIO_clear_flags(bio, BIO_FLAGS_WRITE);
|
||||
EnterCriticalSection(&wst->writeSection);
|
||||
status = websocket_write(wst->tls->bio, (const BYTE*)buf, num, WebsocketBinaryOpcode);
|
||||
status = websocket_context_write(wst->wscontext, wst->tls->bio, (const BYTE*)buf, num,
|
||||
WebsocketBinaryOpcode);
|
||||
LeaveCriticalSection(&wst->writeSection);
|
||||
|
||||
if (status < 0)
|
||||
@@ -570,7 +567,7 @@ static int wst_bio_read(BIO* bio, char* buf, int size)
|
||||
|
||||
while (status <= 0)
|
||||
{
|
||||
status = websocket_read(wst->tls->bio, (BYTE*)buf, (size_t)size, &wst->wscontext);
|
||||
status = websocket_context_read(wst->wscontext, wst->tls->bio, (BYTE*)buf, (size_t)size);
|
||||
if (status <= 0)
|
||||
{
|
||||
if (!BIO_should_retry(wst->tls->bio))
|
||||
@@ -781,15 +778,13 @@ static BOOL wst_parse_url(rdpWst* wst, const char* url)
|
||||
|
||||
rdpWst* wst_new(rdpContext* context)
|
||||
{
|
||||
rdpWst* wst = NULL;
|
||||
|
||||
if (!context)
|
||||
return NULL;
|
||||
|
||||
wst = (rdpWst*)calloc(1, sizeof(rdpWst));
|
||||
rdpWst* wst = (rdpWst*)calloc(1, sizeof(rdpWst));
|
||||
if (!wst)
|
||||
return NULL;
|
||||
|
||||
if (wst)
|
||||
{
|
||||
wst->context = context;
|
||||
|
||||
wst->gwhostname = NULL;
|
||||
@@ -831,7 +826,10 @@ rdpWst* wst_new(rdpContext* context)
|
||||
wst->auth = credssp_auth_new(context);
|
||||
if (!wst->auth)
|
||||
goto wst_alloc_error;
|
||||
}
|
||||
|
||||
wst->wscontext = websocket_context_new();
|
||||
if (!wst->wscontext)
|
||||
goto wst_alloc_error;
|
||||
|
||||
return wst;
|
||||
wst_alloc_error:
|
||||
@@ -858,8 +856,7 @@ void wst_free(rdpWst* wst)
|
||||
|
||||
DeleteCriticalSection(&wst->writeSection);
|
||||
|
||||
if (wst->wscontext.responseStreamBuffer != NULL)
|
||||
Stream_Free(wst->wscontext.responseStreamBuffer, TRUE);
|
||||
websocket_context_free(wst->wscontext);
|
||||
|
||||
free(wst);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user