[core,gateway] refactor websocket support

* Unify read/write functions
* Unify websocket_context setup/teardown/reset
This commit is contained in:
Armin Novak
2024-11-20 13:33:57 +01:00
committed by akallabeth
parent 2c461d0ea9
commit 4cbdd2c176
4 changed files with 358 additions and 504 deletions

View File

@@ -112,7 +112,7 @@ typedef struct
union context
{
http_encoding_chunked_context chunked;
websocket_context websocket;
websocket_context* websocket;
} context;
} rdg_http_encoding_context;
@@ -327,11 +327,8 @@ static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket)
static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
{
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent)
return FALSE;
return websocket_write_wstream(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
}
return websocket_context_write_wstream(rdg->transferEncoding.context.websocket,
rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
return rdg_write_chunked(rdg->tlsIn->bio, sPacket);
}
@@ -344,9 +341,7 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
return -1;
if (encodingContext->isWebsocketTransport)
{
return websocket_read(bio, pBuffer, size, &encodingContext->context.websocket);
}
return websocket_context_read(encodingContext->context.websocket, bio, pBuffer, size);
switch (encodingContext->httpTransferEncoding)
{
@@ -1470,9 +1465,11 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
}
return FALSE;
}
rdg->transferEncoding.isWebsocketTransport = TRUE;
rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin;
rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL;
if (!websocket_context_reset(rdg->transferEncoding.context.websocket))
return FALSE;
if (rdg->extAuth == HTTP_EXTENDED_AUTH_SSPI_NTLM)
{
/* create a new auth context for SSPI_NTLM. This must be done after the last
@@ -1607,86 +1604,36 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
size_t fullLen = 0;
int status = 0;
wStream* sWS = NULL;
uint32_t maskingKey = 0;
BYTE* maskingKeyByte1 = (BYTE*)&maskingKey;
BYTE* maskingKeyByte2 = maskingKeyByte1 + 1;
BYTE* maskingKeyByte3 = maskingKeyByte1 + 2;
BYTE* maskingKeyByte4 = maskingKeyByte1 + 3;
winpr_RAND(&maskingKey, 4);
WINPR_ASSERT(rdg);
if (isize < 0)
return -1;
const size_t payloadSize = (size_t)isize + 10;
union
{
UINT32 u32;
UINT8 u8[4];
} maskingKey;
if (payloadSize < 1)
return 0;
if (payloadSize < 126)
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (payloadSize < 0x10000)
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
wStream* sWS =
websocket_context_packet_new(payloadSize, WebsocketBinaryOpcode, &maskingKey.u32);
if (!sWS)
return FALSE;
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode);
if (payloadSize < 126)
Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT);
else if (payloadSize < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
Stream_Write_UINT32_BE(sWS, 0);
Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize);
}
Stream_Write_UINT32(sWS, maskingKey);
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */
Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */
Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Type */
Stream_Write_UINT16(sWS, 0 ^ (maskingKey.u8[2] | maskingKey.u8[3] << 8)); /* Reserved */
Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey.u32); /* Packet length */
Stream_Write_UINT16(sWS,
(UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */
(UINT16)isize ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Data size */
/* masking key is now off by 2 bytes. fix that */
maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16);
maskingKey.u32 = (maskingKey.u32 & 0xffff) << 16 | (maskingKey.u32 >> 16);
/* mask as much as possible with 32bit access */
size_t streamPos = 0;
for (; streamPos + 4 <= (size_t)isize; streamPos += 4)
{
uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey;
Stream_Write_UINT32(sWS, masked);
}
/* mask the rest byte by byte */
for (; streamPos < (size_t)isize; streamPos++)
{
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
BYTE masked = *((buf + streamPos)) ^ *partialMask;
Stream_Write_UINT8(sWS, masked);
}
Stream_SealLength(sWS);
status = freerdp_tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status < 0)
return status;
WINPR_ASSERT(rdg->tlsOut);
wStream sPacket = { 0 };
Stream_StaticConstInit(&sPacket, buf, (size_t)isize);
if (!websocket_context_mask_and_send(rdg->tlsOut->bio, sWS, &sPacket, maskingKey.u32))
return -1;
return isize;
}
@@ -1733,12 +1680,9 @@ static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize
static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
{
WINPR_ASSERT(rdg);
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.closeSent == TRUE)
return -1;
return rdg_write_websocket_data_packet(rdg, buf, isize);
}
else
return rdg_write_chunked_data_packet(rdg, buf, isize);
}
@@ -2198,92 +2142,93 @@ static BIO_METHOD* BIO_s_rdg(void)
rdpRdg* rdg_new(rdpContext* context)
{
rdpRdg* rdg = NULL;
if (!context)
return NULL;
rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg));
rdpRdg* rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg));
if (!rdg)
return NULL;
if (rdg)
rdg->log = WLog_Get(TAG);
rdg->state = RDG_CLIENT_STATE_INITIAL;
rdg->context = context;
rdg->extAuth =
(rdg->context->settings->GatewayHttpExtAuthSspiNtlm ? HTTP_EXTENDED_AUTH_SSPI_NTLM
: HTTP_EXTENDED_AUTH_NONE);
if (rdg->context->settings->GatewayAccessToken)
rdg->extAuth = HTTP_EXTENDED_AUTH_PAA;
UuidCreate(&rdg->guid);
rdg->tlsOut = freerdp_tls_new(rdg->context);
if (!rdg->tlsOut)
goto rdg_alloc_error;
rdg->tlsIn = freerdp_tls_new(rdg->context);
if (!rdg->tlsIn)
goto rdg_alloc_error;
rdg->http = http_context_new();
if (!rdg->http)
goto rdg_alloc_error;
if (!http_context_set_uri(rdg->http, "/remoteDesktopGateway/") ||
!http_context_set_accept(rdg->http, "*/*") ||
!http_context_set_cache_control(rdg->http, "no-cache") ||
!http_context_set_pragma(rdg->http, "no-cache") ||
!http_context_set_connection(rdg->http, "Keep-Alive") ||
!http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") ||
!http_context_set_host(rdg->http, rdg->context->settings->GatewayHostname) ||
!http_context_set_rdg_connection_id(rdg->http, &rdg->guid) ||
!http_context_set_rdg_correlation_id(rdg->http, &rdg->guid) ||
!http_context_enable_websocket_upgrade(
rdg->http,
freerdp_settings_get_bool(rdg->context->settings, FreeRDP_GatewayHttpUseWebsockets)))
{
rdg->log = WLog_Get(TAG);
rdg->state = RDG_CLIENT_STATE_INITIAL;
rdg->context = context;
rdg->extAuth =
(rdg->context->settings->GatewayHttpExtAuthSspiNtlm ? HTTP_EXTENDED_AUTH_SSPI_NTLM
: HTTP_EXTENDED_AUTH_NONE);
if (rdg->context->settings->GatewayAccessToken)
rdg->extAuth = HTTP_EXTENDED_AUTH_PAA;
UuidCreate(&rdg->guid);
rdg->tlsOut = freerdp_tls_new(rdg->context);
if (!rdg->tlsOut)
goto rdg_alloc_error;
rdg->tlsIn = freerdp_tls_new(rdg->context);
if (!rdg->tlsIn)
goto rdg_alloc_error;
rdg->http = http_context_new();
if (!rdg->http)
goto rdg_alloc_error;
if (!http_context_set_uri(rdg->http, "/remoteDesktopGateway/") ||
!http_context_set_accept(rdg->http, "*/*") ||
!http_context_set_cache_control(rdg->http, "no-cache") ||
!http_context_set_pragma(rdg->http, "no-cache") ||
!http_context_set_connection(rdg->http, "Keep-Alive") ||
!http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") ||
!http_context_set_host(rdg->http, rdg->context->settings->GatewayHostname) ||
!http_context_set_rdg_connection_id(rdg->http, &rdg->guid) ||
!http_context_set_rdg_correlation_id(rdg->http, &rdg->guid) ||
!http_context_enable_websocket_upgrade(
rdg->http, freerdp_settings_get_bool(rdg->context->settings,
FreeRDP_GatewayHttpUseWebsockets)))
{
goto rdg_alloc_error;
}
if (rdg->extAuth != HTTP_EXTENDED_AUTH_NONE)
{
switch (rdg->extAuth)
{
case HTTP_EXTENDED_AUTH_PAA:
if (!http_context_set_rdg_auth_scheme(rdg->http, "PAA"))
goto rdg_alloc_error;
break;
case HTTP_EXTENDED_AUTH_SSPI_NTLM:
if (!http_context_set_rdg_auth_scheme(rdg->http, "SSPI_NTLM"))
goto rdg_alloc_error;
break;
default:
WLog_Print(rdg->log, WLOG_DEBUG,
"RDG extended authentication method %d not supported", rdg->extAuth);
}
}
rdg->frontBio = BIO_new(BIO_s_rdg());
if (!rdg->frontBio)
goto rdg_alloc_error;
BIO_set_data(rdg->frontBio, rdg);
InitializeCriticalSection(&rdg->writeSection);
rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
rdg->transferEncoding.isWebsocketTransport = FALSE;
goto rdg_alloc_error;
}
if (rdg->extAuth != HTTP_EXTENDED_AUTH_NONE)
{
switch (rdg->extAuth)
{
case HTTP_EXTENDED_AUTH_PAA:
if (!http_context_set_rdg_auth_scheme(rdg->http, "PAA"))
goto rdg_alloc_error;
break;
case HTTP_EXTENDED_AUTH_SSPI_NTLM:
if (!http_context_set_rdg_auth_scheme(rdg->http, "SSPI_NTLM"))
goto rdg_alloc_error;
break;
default:
WLog_Print(rdg->log, WLOG_DEBUG,
"RDG extended authentication method %d not supported", rdg->extAuth);
}
}
rdg->frontBio = BIO_new(BIO_s_rdg());
if (!rdg->frontBio)
goto rdg_alloc_error;
BIO_set_data(rdg->frontBio, rdg);
InitializeCriticalSection(&rdg->writeSection);
rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
rdg->transferEncoding.isWebsocketTransport = FALSE;
rdg->transferEncoding.context.websocket = websocket_context_new();
if (!rdg->transferEncoding.context.websocket)
goto rdg_alloc_error;
return rdg;
rdg_alloc_error:
WINPR_PRAGMA_DIAG_PUSH
@@ -2308,14 +2253,10 @@ void rdg_free(rdpRdg* rdg)
DeleteCriticalSection(&rdg->writeSection);
if (rdg->transferEncoding.isWebsocketTransport)
{
if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL)
Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE);
}
smartcardCertInfo_Free(rdg->smartcard);
websocket_context_free(rdg->transferEncoding.context.websocket);
free(rdg);
}

View File

@@ -23,25 +23,67 @@
#define TAG FREERDP_TAG("core.gateway.websocket")
BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode)
struct s_websocket_context
{
size_t fullLen = 0;
int status = 0;
wStream* sWS = NULL;
size_t payloadLength;
uint32_t maskingKey;
BOOL masking;
BOOL closeSent;
BYTE opcode;
BYTE fragmentOriginalOpcode;
BYTE lengthAndMaskPosition;
WEBSOCKET_STATE state;
wStream* responseStreamBuffer;
};
uint32_t maskingKey = 0;
static int websocket_write_all(BIO* bio, const BYTE* data, size_t length);
size_t streamPos = 0;
BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
UINT32 maskingKey)
{
const size_t len = Stream_Length(sDataPacket);
Stream_SetPosition(sDataPacket, 0);
WINPR_ASSERT(bio);
WINPR_ASSERT(sPacket);
const size_t len = Stream_Length(sPacket);
Stream_SetPosition(sPacket, 0);
if (len > INT_MAX)
if (!Stream_EnsureRemainingCapacity(sPacket, len))
return FALSE;
/* mask as much as possible with 32bit access */
size_t streamPos = 0;
for (; streamPos + 4 <= len; streamPos += 4)
{
const uint32_t data = Stream_Get_UINT32(sDataPacket);
Stream_Write_UINT32(sPacket, data ^ maskingKey);
}
/* mask the rest byte by byte */
for (; streamPos < len; streamPos++)
{
BYTE data = 0;
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
Stream_Read_UINT8(sDataPacket, data);
Stream_Write_UINT8(sPacket, data ^ *partialMask);
}
Stream_SealLength(sPacket);
ERR_clear_error();
const size_t size = Stream_Length(sPacket);
const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
Stream_Free(sPacket, TRUE);
if ((status < 0) || ((size_t)status != size))
return FALSE;
return TRUE;
}
wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
{
WINPR_ASSERT(pMaskingKey);
if (len > INT_MAX)
return NULL;
size_t fullLen = 0;
if (len < 126)
fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (len < 0x10000)
@@ -49,13 +91,14 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode
else
fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
wStream* sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
return NULL;
UINT32 maskingKey = 0;
winpr_RAND(&maskingKey, sizeof(maskingKey));
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
if (len < 126)
Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
else if (len < 0x10000)
@@ -70,39 +113,34 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode
Stream_Write_UINT32_BE(sWS, (UINT32)len);
}
Stream_Write_UINT32(sWS, maskingKey);
/* mask as much as possible with 32bit access */
for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
{
uint32_t data = 0;
Stream_Read_UINT32(sPacket, data);
Stream_Write_UINT32(sWS, data ^ maskingKey);
}
/* mask the rest byte by byte */
for (; streamPos < len; streamPos++)
{
BYTE data = 0;
BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
Stream_Read_UINT8(sPacket, data);
Stream_Write_UINT8(sWS, data ^ *partialMask);
}
Stream_SealLength(sWS);
ERR_clear_error();
const size_t size = Stream_Length(sWS);
if (size <= INT32_MAX)
status = BIO_write(bio, Stream_Buffer(sWS), (int)size);
Stream_Free(sWS, TRUE);
if (status != (SSIZE_T)fullLen)
return FALSE;
return TRUE;
*pMaskingKey = maskingKey;
return sWS;
}
static int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket,
WEBSOCKET_OPCODE opcode)
{
WINPR_ASSERT(context);
if (context->closeSent)
return FALSE;
if (opcode == WebsocketCloseOpcode)
context->closeSent = TRUE;
WINPR_ASSERT(bio);
WINPR_ASSERT(sPacket);
const size_t len = Stream_Length(sPacket);
uint32_t maskingKey = 0;
wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
if (!sWS)
return FALSE;
return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
}
int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
{
WINPR_ASSERT(bio);
WINPR_ASSERT(data);
@@ -140,75 +178,19 @@ static int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
return (int)length;
}
int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode)
int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize,
WEBSOCKET_OPCODE opcode)
{
size_t fullLen = 0;
int status = 0;
wStream* sWS = NULL;
uint32_t maskingKey = 0;
WINPR_ASSERT(bio);
WINPR_ASSERT(buf);
winpr_RAND(&maskingKey, sizeof(maskingKey));
if (isize < 0)
return -1;
const size_t payloadSize = (size_t)isize;
if (payloadSize < 126)
fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
else if (payloadSize < 0x10000)
fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
else
fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
sWS = Stream_New(NULL, fullLen);
if (!sWS)
return FALSE;
Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
if (payloadSize < 126)
Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT);
else if (payloadSize < 0x10000)
{
Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize);
}
else
{
Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
/* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
Stream_Write_UINT32_BE(sWS, 0);
Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize);
}
Stream_Write_UINT32(sWS, maskingKey);
/* mask as much as possible with 32bit access */
size_t streamPos = 0;
for (; streamPos + 4 <= payloadSize; streamPos += 4)
{
uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey;
Stream_Write_UINT32(sWS, masked);
}
/* mask the rest byte by byte */
for (; streamPos < payloadSize; streamPos++)
{
BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
BYTE masked = *((buf + streamPos)) ^ *partialMask;
Stream_Write_UINT8(sWS, masked);
}
Stream_SealLength(sWS);
status = websocket_write_all(bio, Stream_Buffer(sWS), Stream_Length(sWS));
Stream_Free(sWS, TRUE);
if (status < 0)
return status;
wStream sbuffer = { 0 };
wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize);
if (!websocket_context_write_wstream(context, bio, s, opcode))
return -2;
return isize;
}
@@ -234,10 +216,10 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
ERR_clear_error();
status = BIO_read(bio, pBuffer, (int)rlen);
if (status <= 0)
if ((status <= 0) || ((size_t)status > encodingContext->payloadLength))
return status;
encodingContext->payloadLength -= status;
encodingContext->payloadLength -= (size_t)status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
@@ -245,47 +227,21 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
return status;
}
static int websocket_read_discard(BIO* bio, websocket_context* encodingContext)
static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
{
char _dummy[256] = { 0 };
int status = 0;
WINPR_ASSERT(bio);
WINPR_ASSERT(encodingContext);
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
ERR_clear_error();
status = BIO_read(bio, _dummy, sizeof(_dummy));
if (status <= 0)
return status;
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
encodingContext->state = WebsocketStateOpcodeAndFin;
return status;
}
static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encodingContext)
{
int status = 0;
WINPR_ASSERT(bio);
wStream* s = encodingContext->responseStreamBuffer;
WINPR_ASSERT(s);
WINPR_ASSERT(encodingContext);
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
return 0;
}
if (Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
{
WLog_WARN(TAG,
"wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]",
@@ -293,111 +249,33 @@ static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encod
return -1;
}
const size_t rlen = encodingContext->payloadLength;
if (rlen > INT32_MAX)
return -1;
ERR_clear_error();
status = BIO_read(bio, Stream_Pointer(s), (int)rlen);
if (status <= 0)
const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
encodingContext);
if (status < 0)
return status;
Stream_Seek(s, status);
encodingContext->payloadLength -= status;
if (encodingContext->payloadLength == 0)
{
encodingContext->state = WebsocketStateOpcodeAndFin;
Stream_SealLength(s);
Stream_SetPosition(s, 0);
}
if (!Stream_SafeSeek(s, (size_t)status))
return -1;
return status;
}
static BOOL websocket_reply_close(BIO* bio, wStream* s)
static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s)
{
/* write back close */
wStream* closeFrame = NULL;
uint16_t maskingKey1 = 0;
uint16_t maskingKey2 = 0;
size_t closeDataLen = 0;
WINPR_ASSERT(bio);
closeDataLen = 0;
if (s != NULL && Stream_Length(s) >= 2)
closeDataLen = 2;
closeFrame = Stream_New(NULL, 6 + closeDataLen);
if (!closeFrame)
return FALSE;
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketCloseOpcode);
Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND(&maskingKey1, sizeof(maskingKey1));
winpr_RAND(&maskingKey2, sizeof(maskingKey2));
Stream_Write_UINT16(closeFrame, maskingKey1);
Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */
if (closeDataLen == 2)
{
uint16_t data = 0;
Stream_Read_UINT16(s, data);
Stream_Write_UINT16(closeFrame, data ^ maskingKey1);
}
Stream_SealLength(closeFrame);
const size_t rlen = Stream_Length(closeFrame);
int status = -1;
if (rlen <= INT32_MAX)
{
ERR_clear_error();
status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen);
}
Stream_Free(closeFrame, TRUE);
/* server MUST close socket now. The server is not allowed anymore to
* send frames but if he does, nothing bad would happen */
if (status < 0)
return FALSE;
return TRUE;
return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
}
static BOOL websocket_reply_pong(BIO* bio, wStream* s)
static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s)
{
wStream* closeFrame = NULL;
uint32_t maskingKey = 0;
WINPR_ASSERT(bio);
WINPR_ASSERT(s);
if (s != NULL)
return websocket_write_wstream(bio, s, WebsocketPongOpcode);
if (Stream_GetPosition(s) != 0)
return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
closeFrame = Stream_New(NULL, 6);
if (!closeFrame)
return FALSE;
Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */
winpr_RAND(&maskingKey, sizeof(maskingKey));
Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */
Stream_SealLength(closeFrame);
const size_t rlen = Stream_Length(closeFrame);
int status = -1;
if (rlen <= INT32_MAX)
{
ERR_clear_error();
status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen);
}
Stream_Free(closeFrame, TRUE);
if (status < 0)
return FALSE;
return TRUE;
return websocket_reply_close(bio, context, NULL);
}
static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
@@ -409,9 +287,9 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
WINPR_ASSERT(pBuffer);
WINPR_ASSERT(encodingContext);
BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
? encodingContext->fragmentOriginalOpcode & 0xf
: encodingContext->opcode & 0xf);
const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
? encodingContext->fragmentOriginalOpcode & 0xf
: encodingContext->opcode & 0xf);
switch (effectiveOpcode)
{
@@ -425,60 +303,55 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
}
case WebsocketPingOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status =
websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
status = websocket_read_wstream(bio, encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
if (!encodingContext->closeSent)
websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
}
}
break;
case WebsocketPongOpcode:
{
status = websocket_read_wstream(bio, encodingContext);
if (status < 0)
return status;
/* We don´t care about pong response data, discard. */
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
}
break;
case WebsocketCloseOpcode:
{
if (encodingContext->responseStreamBuffer == NULL)
encodingContext->responseStreamBuffer =
Stream_New(NULL, encodingContext->payloadLength);
status =
websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
status = websocket_read_wstream(bio, encodingContext);
if (status < 0)
return status;
if (encodingContext->payloadLength == 0)
{
websocket_reply_close(bio, encodingContext->responseStreamBuffer);
websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
encodingContext->closeSent = TRUE;
if (encodingContext->responseStreamBuffer)
Stream_Free(encodingContext->responseStreamBuffer, TRUE);
encodingContext->responseStreamBuffer = NULL;
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
}
}
break;
default:
WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode);
status = websocket_read_discard(bio, encodingContext);
status = websocket_read_wstream(bio, encodingContext);
if (status < 0)
return status;
Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
break;
}
/* return how many bytes have been written to pBuffer.
* Only WebsocketBinaryOpcode writes into it and it returns directly */
return 0;
}
int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* encodingContext)
int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size)
{
int status = 0;
int effectiveDataLen = 0;
@@ -493,7 +366,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
{
case WebsocketStateOpcodeAndFin:
{
BYTE buffer[1];
BYTE buffer[1] = { 0 };
ERR_clear_error();
status = BIO_read(bio, (char*)buffer, sizeof(buffer));
if (status <= 0)
@@ -508,8 +382,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
break;
case WebsocketStateLengthAndMasking:
{
BYTE buffer[1];
BYTE len = 0;
BYTE buffer[1] = { 0 };
ERR_clear_error();
status = BIO_read(bio, (char*)buffer, sizeof(buffer));
if (status <= 0)
@@ -518,7 +392,7 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
encodingContext->lengthAndMaskPosition = 0;
encodingContext->payloadLength = 0;
len = buffer[0] & 0x7f;
const BYTE len = buffer[0] & 0x7f;
if (len < 126)
{
encodingContext->payloadLength = len;
@@ -534,8 +408,9 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
case WebsocketStateShortLength:
case WebsocketStateLongLength:
{
BYTE buffer[1];
BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
BYTE buffer[1] = { 0 };
const BYTE lenLength =
(encodingContext->state == WebsocketStateShortLength ? 2 : 8);
while (encodingContext->lengthAndMaskPosition < lenLength)
{
ERR_clear_error();
@@ -577,3 +452,39 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco
}
/* should be unreachable */
}
websocket_context* websocket_context_new(void)
{
websocket_context* context = calloc(1, sizeof(websocket_context));
if (!context)
goto fail;
context->responseStreamBuffer = Stream_New(NULL, 1024);
if (!context->responseStreamBuffer)
goto fail;
if (!websocket_context_reset(context))
goto fail;
return context;
fail:
websocket_context_free(context);
return NULL;
}
void websocket_context_free(websocket_context* context)
{
if (!context)
return;
Stream_Free(context->responseStreamBuffer, TRUE);
free(context);
}
BOOL websocket_context_reset(websocket_context* context)
{
WINPR_ASSERT(context);
context->state = WebsocketStateOpcodeAndFin;
return Stream_SetPosition(context->responseStreamBuffer, 0);
}

View File

@@ -50,22 +50,27 @@ typedef enum
WebSocketStatePayload,
} WEBSOCKET_STATE;
typedef struct
{
size_t payloadLength;
uint32_t maskingKey;
BOOL masking;
BOOL closeSent;
BYTE opcode;
BYTE fragmentOriginalOpcode;
BYTE lengthAndMaskPosition;
WEBSOCKET_STATE state;
wStream* responseStreamBuffer;
} websocket_context;
typedef struct s_websocket_context websocket_context;
FREERDP_LOCAL BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode);
FREERDP_LOCAL int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode);
FREERDP_LOCAL int websocket_read(BIO* bio, BYTE* pBuffer, size_t size,
websocket_context* encodingContext);
FREERDP_LOCAL void websocket_context_free(websocket_context* context);
WINPR_ATTR_MALLOC(websocket_context_free, 1)
FREERDP_LOCAL websocket_context* websocket_context_new(void);
FREERDP_LOCAL BOOL websocket_context_reset(websocket_context* context);
FREERDP_LOCAL BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio,
wStream* sPacket, WEBSOCKET_OPCODE opcode);
FREERDP_LOCAL int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf,
int isize, WEBSOCKET_OPCODE opcode);
FREERDP_LOCAL int websocket_context_read(websocket_context* encodingContext, BIO* bio,
BYTE* pBuffer, size_t size);
WINPR_ATTR_MALLOC(Stream_Free, 1)
FREERDP_LOCAL wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode,
UINT32* pMaskingKey);
FREERDP_LOCAL BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
UINT32 maskingKey);
#endif /* FREERDP_LIB_CORE_GATEWAY_WEBSOCKET_H */

View File

@@ -65,7 +65,7 @@ struct rdp_wst
char* gwhostname;
uint16_t gwport;
char* gwpath;
websocket_context wscontext;
websocket_context* wscontext;
};
static const char arm_query_param[] = "%s%cClmTk=Bearer%%20%s&X-MS-User-Agent=FreeRDP%%2F3.0";
@@ -494,11 +494,7 @@ BOOL wst_connect(rdpWst* wst, DWORD timeout)
return FALSE;
if (isWebsocket)
{
wst->wscontext.state = WebsocketStateOpcodeAndFin;
wst->wscontext.responseStreamBuffer = NULL;
return TRUE;
}
return websocket_context_reset(wst->wscontext);
else
{
char buffer[64] = { 0 };
@@ -537,7 +533,8 @@ static int wst_bio_write(BIO* bio, const char* buf, int num)
WINPR_ASSERT(wst);
BIO_clear_flags(bio, BIO_FLAGS_WRITE);
EnterCriticalSection(&wst->writeSection);
status = websocket_write(wst->tls->bio, (const BYTE*)buf, num, WebsocketBinaryOpcode);
status = websocket_context_write(wst->wscontext, wst->tls->bio, (const BYTE*)buf, num,
WebsocketBinaryOpcode);
LeaveCriticalSection(&wst->writeSection);
if (status < 0)
@@ -570,7 +567,7 @@ static int wst_bio_read(BIO* bio, char* buf, int size)
while (status <= 0)
{
status = websocket_read(wst->tls->bio, (BYTE*)buf, (size_t)size, &wst->wscontext);
status = websocket_context_read(wst->wscontext, wst->tls->bio, (BYTE*)buf, (size_t)size);
if (status <= 0)
{
if (!BIO_should_retry(wst->tls->bio))
@@ -781,58 +778,59 @@ static BOOL wst_parse_url(rdpWst* wst, const char* url)
rdpWst* wst_new(rdpContext* context)
{
rdpWst* wst = NULL;
if (!context)
return NULL;
wst = (rdpWst*)calloc(1, sizeof(rdpWst));
rdpWst* wst = (rdpWst*)calloc(1, sizeof(rdpWst));
if (!wst)
return NULL;
if (wst)
wst->context = context;
wst->gwhostname = NULL;
wst->gwport = 443;
wst->gwpath = NULL;
if (!wst_parse_url(wst, context->settings->GatewayUrl))
goto wst_alloc_error;
wst->tls = freerdp_tls_new(wst->context);
if (!wst->tls)
goto wst_alloc_error;
wst->http = http_context_new();
if (!wst->http)
goto wst_alloc_error;
if (!http_context_set_uri(wst->http, wst->gwpath) ||
!http_context_set_accept(wst->http, "*/*") ||
!http_context_set_cache_control(wst->http, "no-cache") ||
!http_context_set_pragma(wst->http, "no-cache") ||
!http_context_set_connection(wst->http, "Keep-Alive") ||
!http_context_set_user_agent(wst->http, FREERDP_USER_AGENT) ||
!http_context_set_x_ms_user_agent(wst->http, FREERDP_USER_AGENT) ||
!http_context_set_host(wst->http, wst->gwhostname) ||
!http_context_enable_websocket_upgrade(wst->http, TRUE))
{
wst->context = context;
wst->gwhostname = NULL;
wst->gwport = 443;
wst->gwpath = NULL;
if (!wst_parse_url(wst, context->settings->GatewayUrl))
goto wst_alloc_error;
wst->tls = freerdp_tls_new(wst->context);
if (!wst->tls)
goto wst_alloc_error;
wst->http = http_context_new();
if (!wst->http)
goto wst_alloc_error;
if (!http_context_set_uri(wst->http, wst->gwpath) ||
!http_context_set_accept(wst->http, "*/*") ||
!http_context_set_cache_control(wst->http, "no-cache") ||
!http_context_set_pragma(wst->http, "no-cache") ||
!http_context_set_connection(wst->http, "Keep-Alive") ||
!http_context_set_user_agent(wst->http, FREERDP_USER_AGENT) ||
!http_context_set_x_ms_user_agent(wst->http, FREERDP_USER_AGENT) ||
!http_context_set_host(wst->http, wst->gwhostname) ||
!http_context_enable_websocket_upgrade(wst->http, TRUE))
{
goto wst_alloc_error;
}
wst->frontBio = BIO_new(BIO_s_wst());
if (!wst->frontBio)
goto wst_alloc_error;
BIO_set_data(wst->frontBio, wst);
InitializeCriticalSection(&wst->writeSection);
wst->auth = credssp_auth_new(context);
if (!wst->auth)
goto wst_alloc_error;
goto wst_alloc_error;
}
wst->frontBio = BIO_new(BIO_s_wst());
if (!wst->frontBio)
goto wst_alloc_error;
BIO_set_data(wst->frontBio, wst);
InitializeCriticalSection(&wst->writeSection);
wst->auth = credssp_auth_new(context);
if (!wst->auth)
goto wst_alloc_error;
wst->wscontext = websocket_context_new();
if (!wst->wscontext)
goto wst_alloc_error;
return wst;
wst_alloc_error:
WINPR_PRAGMA_DIAG_PUSH
@@ -858,8 +856,7 @@ void wst_free(rdpWst* wst)
DeleteCriticalSection(&wst->writeSection);
if (wst->wscontext.responseStreamBuffer != NULL)
Stream_Free(wst->wscontext.responseStreamBuffer, TRUE);
websocket_context_free(wst->wscontext);
free(wst);
}