diff --git a/libfreerdp/core/freerdp.c b/libfreerdp/core/freerdp.c index 5d57815f1..86444f22a 100644 --- a/libfreerdp/core/freerdp.c +++ b/libfreerdp/core/freerdp.c @@ -159,7 +159,8 @@ static int freerdp_connect_begin(freerdp* instance) WINPR_ASSERT(context); freerdp_set_last_error_if_not(context, FREERDP_ERROR_PRE_CONNECT_FAILED); - WLog_Print(context->log, WLOG_ERROR, "freerdp_pre_connect failed"); + WLog_Print(context->log, WLOG_ERROR, "freerdp_pre_connect failed: %s", + rdp_client_connection_state_string(instance->ConnectionCallbackState)); return 0; } diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index 47247f46c..9595c1c73 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -112,7 +112,7 @@ typedef struct union context { http_encoding_chunked_context chunked; - websocket_context websocket; + websocket_context* websocket; } context; } rdg_http_encoding_context; @@ -327,11 +327,8 @@ static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket) static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket) { if (rdg->transferEncoding.isWebsocketTransport) - { - if (rdg->transferEncoding.context.websocket.closeSent) - return FALSE; - return websocket_write_wstream(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode); - } + return websocket_context_write_wstream(rdg->transferEncoding.context.websocket, + rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode); return rdg_write_chunked(rdg->tlsIn->bio, sPacket); } @@ -344,9 +341,7 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size, return -1; if (encodingContext->isWebsocketTransport) - { - return websocket_read(bio, pBuffer, size, &encodingContext->context.websocket); - } + return websocket_context_read(encodingContext->context.websocket, bio, pBuffer, size); switch (encodingContext->httpTransferEncoding) { @@ -1470,9 +1465,11 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* } return FALSE; } + rdg->transferEncoding.isWebsocketTransport = TRUE; - rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin; - rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL; + if (!websocket_context_reset(rdg->transferEncoding.context.websocket)) + return FALSE; + if (rdg->extAuth == HTTP_EXTENDED_AUTH_SSPI_NTLM) { /* create a new auth context for SSPI_NTLM. This must be done after the last @@ -1607,86 +1604,36 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback) static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) { - size_t fullLen = 0; - int status = 0; - wStream* sWS = NULL; - - uint32_t maskingKey = 0; - BYTE* maskingKeyByte1 = (BYTE*)&maskingKey; - BYTE* maskingKeyByte2 = maskingKeyByte1 + 1; - BYTE* maskingKeyByte3 = maskingKeyByte1 + 2; - BYTE* maskingKeyByte4 = maskingKeyByte1 + 3; - - winpr_RAND(&maskingKey, 4); - + WINPR_ASSERT(rdg); if (isize < 0) return -1; const size_t payloadSize = (size_t)isize + 10; + union + { + UINT32 u32; + UINT8 u8[4]; + } maskingKey; - if (payloadSize < 1) - return 0; - - if (payloadSize < 126) - fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */ - else if (payloadSize < 0x10000) - fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */ - else - fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */ - - sWS = Stream_New(NULL, fullLen); + wStream* sWS = + websocket_context_packet_new(payloadSize, WebsocketBinaryOpcode, &maskingKey.u32); if (!sWS) return FALSE; - Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode); - if (payloadSize < 126) - Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT); - else if (payloadSize < 0x10000) - { - Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT); - Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize); - } - else - { - Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT); - /* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */ - Stream_Write_UINT32_BE(sWS, 0); - Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize); - } - Stream_Write_UINT32(sWS, maskingKey); - - Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */ - Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */ - Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey); /* Packet length */ + Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Type */ + Stream_Write_UINT16(sWS, 0 ^ (maskingKey.u8[2] | maskingKey.u8[3] << 8)); /* Reserved */ + Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey.u32); /* Packet length */ Stream_Write_UINT16(sWS, - (UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */ + (UINT16)isize ^ (maskingKey.u8[0] | maskingKey.u8[1] << 8)); /* Data size */ /* masking key is now off by 2 bytes. fix that */ - maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16); + maskingKey.u32 = (maskingKey.u32 & 0xffff) << 16 | (maskingKey.u32 >> 16); - /* mask as much as possible with 32bit access */ - size_t streamPos = 0; - for (; streamPos + 4 <= (size_t)isize; streamPos += 4) - { - uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey; - Stream_Write_UINT32(sWS, masked); - } - - /* mask the rest byte by byte */ - for (; streamPos < (size_t)isize; streamPos++) - { - BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4; - BYTE masked = *((buf + streamPos)) ^ *partialMask; - Stream_Write_UINT8(sWS, masked); - } - - Stream_SealLength(sWS); - - status = freerdp_tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS)); - Stream_Free(sWS, TRUE); - - if (status < 0) - return status; + WINPR_ASSERT(rdg->tlsOut); + wStream sPacket = { 0 }; + Stream_StaticConstInit(&sPacket, buf, (size_t)isize); + if (!websocket_context_mask_and_send(rdg->tlsOut->bio, sWS, &sPacket, maskingKey.u32)) + return -1; return isize; } @@ -1733,12 +1680,9 @@ static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize) { + WINPR_ASSERT(rdg); if (rdg->transferEncoding.isWebsocketTransport) - { - if (rdg->transferEncoding.context.websocket.closeSent == TRUE) - return -1; return rdg_write_websocket_data_packet(rdg, buf, isize); - } else return rdg_write_chunked_data_packet(rdg, buf, isize); } @@ -2198,92 +2142,93 @@ static BIO_METHOD* BIO_s_rdg(void) rdpRdg* rdg_new(rdpContext* context) { - rdpRdg* rdg = NULL; - if (!context) return NULL; - rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg)); + rdpRdg* rdg = (rdpRdg*)calloc(1, sizeof(rdpRdg)); + if (!rdg) + return NULL; - if (rdg) + rdg->log = WLog_Get(TAG); + rdg->state = RDG_CLIENT_STATE_INITIAL; + rdg->context = context; + rdg->extAuth = + (rdg->context->settings->GatewayHttpExtAuthSspiNtlm ? HTTP_EXTENDED_AUTH_SSPI_NTLM + : HTTP_EXTENDED_AUTH_NONE); + + if (rdg->context->settings->GatewayAccessToken) + rdg->extAuth = HTTP_EXTENDED_AUTH_PAA; + + UuidCreate(&rdg->guid); + + rdg->tlsOut = freerdp_tls_new(rdg->context); + + if (!rdg->tlsOut) + goto rdg_alloc_error; + + rdg->tlsIn = freerdp_tls_new(rdg->context); + + if (!rdg->tlsIn) + goto rdg_alloc_error; + + rdg->http = http_context_new(); + + if (!rdg->http) + goto rdg_alloc_error; + + if (!http_context_set_uri(rdg->http, "/remoteDesktopGateway/") || + !http_context_set_accept(rdg->http, "*/*") || + !http_context_set_cache_control(rdg->http, "no-cache") || + !http_context_set_pragma(rdg->http, "no-cache") || + !http_context_set_connection(rdg->http, "Keep-Alive") || + !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") || + !http_context_set_host(rdg->http, rdg->context->settings->GatewayHostname) || + !http_context_set_rdg_connection_id(rdg->http, &rdg->guid) || + !http_context_set_rdg_correlation_id(rdg->http, &rdg->guid) || + !http_context_enable_websocket_upgrade( + rdg->http, + freerdp_settings_get_bool(rdg->context->settings, FreeRDP_GatewayHttpUseWebsockets))) { - rdg->log = WLog_Get(TAG); - rdg->state = RDG_CLIENT_STATE_INITIAL; - rdg->context = context; - rdg->extAuth = - (rdg->context->settings->GatewayHttpExtAuthSspiNtlm ? HTTP_EXTENDED_AUTH_SSPI_NTLM - : HTTP_EXTENDED_AUTH_NONE); - - if (rdg->context->settings->GatewayAccessToken) - rdg->extAuth = HTTP_EXTENDED_AUTH_PAA; - - UuidCreate(&rdg->guid); - - rdg->tlsOut = freerdp_tls_new(rdg->context); - - if (!rdg->tlsOut) - goto rdg_alloc_error; - - rdg->tlsIn = freerdp_tls_new(rdg->context); - - if (!rdg->tlsIn) - goto rdg_alloc_error; - - rdg->http = http_context_new(); - - if (!rdg->http) - goto rdg_alloc_error; - - if (!http_context_set_uri(rdg->http, "/remoteDesktopGateway/") || - !http_context_set_accept(rdg->http, "*/*") || - !http_context_set_cache_control(rdg->http, "no-cache") || - !http_context_set_pragma(rdg->http, "no-cache") || - !http_context_set_connection(rdg->http, "Keep-Alive") || - !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") || - !http_context_set_host(rdg->http, rdg->context->settings->GatewayHostname) || - !http_context_set_rdg_connection_id(rdg->http, &rdg->guid) || - !http_context_set_rdg_correlation_id(rdg->http, &rdg->guid) || - !http_context_enable_websocket_upgrade( - rdg->http, freerdp_settings_get_bool(rdg->context->settings, - FreeRDP_GatewayHttpUseWebsockets))) - { - goto rdg_alloc_error; - } - - if (rdg->extAuth != HTTP_EXTENDED_AUTH_NONE) - { - switch (rdg->extAuth) - { - case HTTP_EXTENDED_AUTH_PAA: - if (!http_context_set_rdg_auth_scheme(rdg->http, "PAA")) - goto rdg_alloc_error; - - break; - - case HTTP_EXTENDED_AUTH_SSPI_NTLM: - if (!http_context_set_rdg_auth_scheme(rdg->http, "SSPI_NTLM")) - goto rdg_alloc_error; - - break; - - default: - WLog_Print(rdg->log, WLOG_DEBUG, - "RDG extended authentication method %d not supported", rdg->extAuth); - } - } - - rdg->frontBio = BIO_new(BIO_s_rdg()); - - if (!rdg->frontBio) - goto rdg_alloc_error; - - BIO_set_data(rdg->frontBio, rdg); - InitializeCriticalSection(&rdg->writeSection); - - rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity; - rdg->transferEncoding.isWebsocketTransport = FALSE; + goto rdg_alloc_error; } + if (rdg->extAuth != HTTP_EXTENDED_AUTH_NONE) + { + switch (rdg->extAuth) + { + case HTTP_EXTENDED_AUTH_PAA: + if (!http_context_set_rdg_auth_scheme(rdg->http, "PAA")) + goto rdg_alloc_error; + + break; + + case HTTP_EXTENDED_AUTH_SSPI_NTLM: + if (!http_context_set_rdg_auth_scheme(rdg->http, "SSPI_NTLM")) + goto rdg_alloc_error; + + break; + + default: + WLog_Print(rdg->log, WLOG_DEBUG, + "RDG extended authentication method %d not supported", rdg->extAuth); + } + } + + rdg->frontBio = BIO_new(BIO_s_rdg()); + + if (!rdg->frontBio) + goto rdg_alloc_error; + + BIO_set_data(rdg->frontBio, rdg); + InitializeCriticalSection(&rdg->writeSection); + + rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity; + rdg->transferEncoding.isWebsocketTransport = FALSE; + + rdg->transferEncoding.context.websocket = websocket_context_new(); + if (!rdg->transferEncoding.context.websocket) + goto rdg_alloc_error; + return rdg; rdg_alloc_error: WINPR_PRAGMA_DIAG_PUSH @@ -2308,14 +2253,10 @@ void rdg_free(rdpRdg* rdg) DeleteCriticalSection(&rdg->writeSection); - if (rdg->transferEncoding.isWebsocketTransport) - { - if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL) - Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE); - } - smartcardCertInfo_Free(rdg->smartcard); + websocket_context_free(rdg->transferEncoding.context.websocket); + free(rdg); } diff --git a/libfreerdp/core/gateway/websocket.c b/libfreerdp/core/gateway/websocket.c index cc9e9bcff..ccc36c3b4 100644 --- a/libfreerdp/core/gateway/websocket.c +++ b/libfreerdp/core/gateway/websocket.c @@ -23,25 +23,67 @@ #define TAG FREERDP_TAG("core.gateway.websocket") -BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode) +struct s_websocket_context { - size_t fullLen = 0; - int status = 0; - wStream* sWS = NULL; + size_t payloadLength; + uint32_t maskingKey; + BOOL masking; + BOOL closeSent; + BYTE opcode; + BYTE fragmentOriginalOpcode; + BYTE lengthAndMaskPosition; + WEBSOCKET_STATE state; + wStream* responseStreamBuffer; +}; - uint32_t maskingKey = 0; +static int websocket_write_all(BIO* bio, const BYTE* data, size_t length); - size_t streamPos = 0; +BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket, + UINT32 maskingKey) +{ + const size_t len = Stream_Length(sDataPacket); + Stream_SetPosition(sDataPacket, 0); - WINPR_ASSERT(bio); - WINPR_ASSERT(sPacket); - - const size_t len = Stream_Length(sPacket); - Stream_SetPosition(sPacket, 0); - - if (len > INT_MAX) + if (!Stream_EnsureRemainingCapacity(sPacket, len)) return FALSE; + /* mask as much as possible with 32bit access */ + size_t streamPos = 0; + for (; streamPos + 4 <= len; streamPos += 4) + { + const uint32_t data = Stream_Get_UINT32(sDataPacket); + Stream_Write_UINT32(sPacket, data ^ maskingKey); + } + + /* mask the rest byte by byte */ + for (; streamPos < len; streamPos++) + { + BYTE data = 0; + BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4); + Stream_Read_UINT8(sDataPacket, data); + Stream_Write_UINT8(sPacket, data ^ *partialMask); + } + + Stream_SealLength(sPacket); + + ERR_clear_error(); + const size_t size = Stream_Length(sPacket); + const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size); + Stream_Free(sPacket, TRUE); + + if ((status < 0) || ((size_t)status != size)) + return FALSE; + + return TRUE; +} + +wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey) +{ + WINPR_ASSERT(pMaskingKey); + if (len > INT_MAX) + return NULL; + + size_t fullLen = 0; if (len < 126) fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */ else if (len < 0x10000) @@ -49,13 +91,14 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode else fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */ - sWS = Stream_New(NULL, fullLen); + wStream* sWS = Stream_New(NULL, fullLen); if (!sWS) - return FALSE; + return NULL; + UINT32 maskingKey = 0; winpr_RAND(&maskingKey, sizeof(maskingKey)); - Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode); + Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode)); if (len < 126) Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT); else if (len < 0x10000) @@ -70,39 +113,34 @@ BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode Stream_Write_UINT32_BE(sWS, (UINT32)len); } Stream_Write_UINT32(sWS, maskingKey); - - /* mask as much as possible with 32bit access */ - for (streamPos = 0; streamPos + 4 <= len; streamPos += 4) - { - uint32_t data = 0; - Stream_Read_UINT32(sPacket, data); - Stream_Write_UINT32(sWS, data ^ maskingKey); - } - - /* mask the rest byte by byte */ - for (; streamPos < len; streamPos++) - { - BYTE data = 0; - BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4); - Stream_Read_UINT8(sPacket, data); - Stream_Write_UINT8(sWS, data ^ *partialMask); - } - - Stream_SealLength(sWS); - - ERR_clear_error(); - const size_t size = Stream_Length(sWS); - if (size <= INT32_MAX) - status = BIO_write(bio, Stream_Buffer(sWS), (int)size); - Stream_Free(sWS, TRUE); - - if (status != (SSIZE_T)fullLen) - return FALSE; - - return TRUE; + *pMaskingKey = maskingKey; + return sWS; } -static int websocket_write_all(BIO* bio, const BYTE* data, size_t length) +BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket, + WEBSOCKET_OPCODE opcode) +{ + WINPR_ASSERT(context); + + if (context->closeSent) + return FALSE; + + if (opcode == WebsocketCloseOpcode) + context->closeSent = TRUE; + + WINPR_ASSERT(bio); + WINPR_ASSERT(sPacket); + + const size_t len = Stream_Length(sPacket); + uint32_t maskingKey = 0; + wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey); + if (!sWS) + return FALSE; + + return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey); +} + +int websocket_write_all(BIO* bio, const BYTE* data, size_t length) { WINPR_ASSERT(bio); WINPR_ASSERT(data); @@ -140,75 +178,19 @@ static int websocket_write_all(BIO* bio, const BYTE* data, size_t length) return (int)length; } -int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode) +int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize, + WEBSOCKET_OPCODE opcode) { - size_t fullLen = 0; - int status = 0; - wStream* sWS = NULL; - - uint32_t maskingKey = 0; - WINPR_ASSERT(bio); WINPR_ASSERT(buf); - winpr_RAND(&maskingKey, sizeof(maskingKey)); - if (isize < 0) return -1; - const size_t payloadSize = (size_t)isize; - if (payloadSize < 126) - fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */ - else if (payloadSize < 0x10000) - fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */ - else - fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */ - - sWS = Stream_New(NULL, fullLen); - if (!sWS) - return FALSE; - - Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode); - if (payloadSize < 126) - Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT); - else if (payloadSize < 0x10000) - { - Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT); - Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize); - } - else - { - Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT); - /* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */ - Stream_Write_UINT32_BE(sWS, 0); - Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize); - } - Stream_Write_UINT32(sWS, maskingKey); - - /* mask as much as possible with 32bit access */ - size_t streamPos = 0; - for (; streamPos + 4 <= payloadSize; streamPos += 4) - { - uint32_t masked = *((const uint32_t*)(buf + streamPos)) ^ maskingKey; - Stream_Write_UINT32(sWS, masked); - } - - /* mask the rest byte by byte */ - for (; streamPos < payloadSize; streamPos++) - { - BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4; - BYTE masked = *((buf + streamPos)) ^ *partialMask; - Stream_Write_UINT8(sWS, masked); - } - - Stream_SealLength(sWS); - - status = websocket_write_all(bio, Stream_Buffer(sWS), Stream_Length(sWS)); - Stream_Free(sWS, TRUE); - - if (status < 0) - return status; - + wStream sbuffer = { 0 }; + wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize); + if (!websocket_context_write_wstream(context, bio, s, opcode)) + return -2; return isize; } @@ -234,10 +216,10 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size, ERR_clear_error(); status = BIO_read(bio, pBuffer, (int)rlen); - if (status <= 0) + if ((status <= 0) || ((size_t)status > encodingContext->payloadLength)) return status; - encodingContext->payloadLength -= status; + encodingContext->payloadLength -= (size_t)status; if (encodingContext->payloadLength == 0) encodingContext->state = WebsocketStateOpcodeAndFin; @@ -245,47 +227,21 @@ static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size, return status; } -static int websocket_read_discard(BIO* bio, websocket_context* encodingContext) +static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext) { - char _dummy[256] = { 0 }; - int status = 0; - WINPR_ASSERT(bio); WINPR_ASSERT(encodingContext); - if (encodingContext->payloadLength == 0) - { - encodingContext->state = WebsocketStateOpcodeAndFin; - return 0; - } - - ERR_clear_error(); - status = BIO_read(bio, _dummy, sizeof(_dummy)); - if (status <= 0) - return status; - - encodingContext->payloadLength -= status; - - if (encodingContext->payloadLength == 0) - encodingContext->state = WebsocketStateOpcodeAndFin; - - return status; -} - -static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encodingContext) -{ - int status = 0; - - WINPR_ASSERT(bio); + wStream* s = encodingContext->responseStreamBuffer; WINPR_ASSERT(s); - WINPR_ASSERT(encodingContext); if (encodingContext->payloadLength == 0) { encodingContext->state = WebsocketStateOpcodeAndFin; return 0; } - if (Stream_GetRemainingCapacity(s) != encodingContext->payloadLength) + + if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength)) { WLog_WARN(TAG, "wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]", @@ -293,111 +249,33 @@ static int websocket_read_wstream(BIO* bio, wStream* s, websocket_context* encod return -1; } - const size_t rlen = encodingContext->payloadLength; - if (rlen > INT32_MAX) - return -1; - - ERR_clear_error(); - status = BIO_read(bio, Stream_Pointer(s), (int)rlen); - if (status <= 0) + const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s), + encodingContext); + if (status < 0) return status; - Stream_Seek(s, status); - - encodingContext->payloadLength -= status; - - if (encodingContext->payloadLength == 0) - { - encodingContext->state = WebsocketStateOpcodeAndFin; - Stream_SealLength(s); - Stream_SetPosition(s, 0); - } + if (!Stream_SafeSeek(s, (size_t)status)) + return -1; return status; } -static BOOL websocket_reply_close(BIO* bio, wStream* s) +static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s) { - /* write back close */ - wStream* closeFrame = NULL; - uint16_t maskingKey1 = 0; - uint16_t maskingKey2 = 0; - size_t closeDataLen = 0; - WINPR_ASSERT(bio); - closeDataLen = 0; - if (s != NULL && Stream_Length(s) >= 2) - closeDataLen = 2; - - closeFrame = Stream_New(NULL, 6 + closeDataLen); - if (!closeFrame) - return FALSE; - - Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketCloseOpcode); - Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */ - winpr_RAND(&maskingKey1, sizeof(maskingKey1)); - winpr_RAND(&maskingKey2, sizeof(maskingKey2)); - Stream_Write_UINT16(closeFrame, maskingKey1); - Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */ - - if (closeDataLen == 2) - { - uint16_t data = 0; - Stream_Read_UINT16(s, data); - Stream_Write_UINT16(closeFrame, data ^ maskingKey1); - } - Stream_SealLength(closeFrame); - - const size_t rlen = Stream_Length(closeFrame); - - int status = -1; - if (rlen <= INT32_MAX) - { - ERR_clear_error(); - status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen); - } - Stream_Free(closeFrame, TRUE); - - /* server MUST close socket now. The server is not allowed anymore to - * send frames but if he does, nothing bad would happen */ - if (status < 0) - return FALSE; - return TRUE; + return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode); } -static BOOL websocket_reply_pong(BIO* bio, wStream* s) +static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s) { - wStream* closeFrame = NULL; - uint32_t maskingKey = 0; - WINPR_ASSERT(bio); + WINPR_ASSERT(s); - if (s != NULL) - return websocket_write_wstream(bio, s, WebsocketPongOpcode); + if (Stream_GetPosition(s) != 0) + return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode); - closeFrame = Stream_New(NULL, 6); - if (!closeFrame) - return FALSE; - - Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode); - Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */ - winpr_RAND(&maskingKey, sizeof(maskingKey)); - Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */ - Stream_SealLength(closeFrame); - - const size_t rlen = Stream_Length(closeFrame); - int status = -1; - if (rlen <= INT32_MAX) - { - ERR_clear_error(); - status = BIO_write(bio, Stream_Buffer(closeFrame), (int)rlen); - } - Stream_Free(closeFrame, TRUE); - - if (status < 0) - return FALSE; - return TRUE; + return websocket_reply_close(bio, context, NULL); } static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size, @@ -409,9 +287,9 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size, WINPR_ASSERT(pBuffer); WINPR_ASSERT(encodingContext); - BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode - ? encodingContext->fragmentOriginalOpcode & 0xf - : encodingContext->opcode & 0xf); + const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode + ? encodingContext->fragmentOriginalOpcode & 0xf + : encodingContext->opcode & 0xf); switch (effectiveOpcode) { @@ -425,60 +303,55 @@ static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size, } case WebsocketPingOpcode: { - if (encodingContext->responseStreamBuffer == NULL) - encodingContext->responseStreamBuffer = - Stream_New(NULL, encodingContext->payloadLength); - - status = - websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext); + status = websocket_read_wstream(bio, encodingContext); if (status < 0) return status; if (encodingContext->payloadLength == 0) { - if (!encodingContext->closeSent) - websocket_reply_pong(bio, encodingContext->responseStreamBuffer); - - Stream_Free(encodingContext->responseStreamBuffer, TRUE); - encodingContext->responseStreamBuffer = NULL; + websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer); + Stream_SetPosition(encodingContext->responseStreamBuffer, 0); } } break; + case WebsocketPongOpcode: + { + status = websocket_read_wstream(bio, encodingContext); + if (status < 0) + return status; + /* We don“t care about pong response data, discard. */ + Stream_SetPosition(encodingContext->responseStreamBuffer, 0); + } + break; case WebsocketCloseOpcode: { - if (encodingContext->responseStreamBuffer == NULL) - encodingContext->responseStreamBuffer = - Stream_New(NULL, encodingContext->payloadLength); - - status = - websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext); + status = websocket_read_wstream(bio, encodingContext); if (status < 0) return status; if (encodingContext->payloadLength == 0) { - websocket_reply_close(bio, encodingContext->responseStreamBuffer); + websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer); encodingContext->closeSent = TRUE; - - if (encodingContext->responseStreamBuffer) - Stream_Free(encodingContext->responseStreamBuffer, TRUE); - encodingContext->responseStreamBuffer = NULL; + Stream_SetPosition(encodingContext->responseStreamBuffer, 0); } } break; default: - WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf); + WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode); - status = websocket_read_discard(bio, encodingContext); + status = websocket_read_wstream(bio, encodingContext); if (status < 0) return status; + Stream_SetPosition(encodingContext->responseStreamBuffer, 0); + break; } /* return how many bytes have been written to pBuffer. * Only WebsocketBinaryOpcode writes into it and it returns directly */ return 0; } -int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* encodingContext) +int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size) { int status = 0; int effectiveDataLen = 0; @@ -493,7 +366,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco { case WebsocketStateOpcodeAndFin: { - BYTE buffer[1]; + BYTE buffer[1] = { 0 }; + ERR_clear_error(); status = BIO_read(bio, (char*)buffer, sizeof(buffer)); if (status <= 0) @@ -508,8 +382,8 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco break; case WebsocketStateLengthAndMasking: { - BYTE buffer[1]; - BYTE len = 0; + BYTE buffer[1] = { 0 }; + ERR_clear_error(); status = BIO_read(bio, (char*)buffer, sizeof(buffer)); if (status <= 0) @@ -518,7 +392,7 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT); encodingContext->lengthAndMaskPosition = 0; encodingContext->payloadLength = 0; - len = buffer[0] & 0x7f; + const BYTE len = buffer[0] & 0x7f; if (len < 126) { encodingContext->payloadLength = len; @@ -534,8 +408,9 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco case WebsocketStateShortLength: case WebsocketStateLongLength: { - BYTE buffer[1]; - BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8); + BYTE buffer[1] = { 0 }; + const BYTE lenLength = + (encodingContext->state == WebsocketStateShortLength ? 2 : 8); while (encodingContext->lengthAndMaskPosition < lenLength) { ERR_clear_error(); @@ -577,3 +452,39 @@ int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, websocket_context* enco } /* should be unreachable */ } + +websocket_context* websocket_context_new(void) +{ + websocket_context* context = calloc(1, sizeof(websocket_context)); + if (!context) + goto fail; + + context->responseStreamBuffer = Stream_New(NULL, 1024); + if (!context->responseStreamBuffer) + goto fail; + + if (!websocket_context_reset(context)) + goto fail; + + return context; +fail: + websocket_context_free(context); + return NULL; +} + +void websocket_context_free(websocket_context* context) +{ + if (!context) + return; + + Stream_Free(context->responseStreamBuffer, TRUE); + free(context); +} + +BOOL websocket_context_reset(websocket_context* context) +{ + WINPR_ASSERT(context); + + context->state = WebsocketStateOpcodeAndFin; + return Stream_SetPosition(context->responseStreamBuffer, 0); +} diff --git a/libfreerdp/core/gateway/websocket.h b/libfreerdp/core/gateway/websocket.h index de41e4908..edaf152cd 100644 --- a/libfreerdp/core/gateway/websocket.h +++ b/libfreerdp/core/gateway/websocket.h @@ -50,22 +50,27 @@ typedef enum WebSocketStatePayload, } WEBSOCKET_STATE; -typedef struct -{ - size_t payloadLength; - uint32_t maskingKey; - BOOL masking; - BOOL closeSent; - BYTE opcode; - BYTE fragmentOriginalOpcode; - BYTE lengthAndMaskPosition; - WEBSOCKET_STATE state; - wStream* responseStreamBuffer; -} websocket_context; +typedef struct s_websocket_context websocket_context; -FREERDP_LOCAL BOOL websocket_write_wstream(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode); -FREERDP_LOCAL int websocket_write(BIO* bio, const BYTE* buf, int isize, WEBSOCKET_OPCODE opcode); -FREERDP_LOCAL int websocket_read(BIO* bio, BYTE* pBuffer, size_t size, - websocket_context* encodingContext); +FREERDP_LOCAL void websocket_context_free(websocket_context* context); + +WINPR_ATTR_MALLOC(websocket_context_free, 1) +FREERDP_LOCAL websocket_context* websocket_context_new(void); + +FREERDP_LOCAL BOOL websocket_context_reset(websocket_context* context); + +FREERDP_LOCAL BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, + wStream* sPacket, WEBSOCKET_OPCODE opcode); +FREERDP_LOCAL int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, + int isize, WEBSOCKET_OPCODE opcode); +FREERDP_LOCAL int websocket_context_read(websocket_context* encodingContext, BIO* bio, + BYTE* pBuffer, size_t size); + +WINPR_ATTR_MALLOC(Stream_Free, 1) +FREERDP_LOCAL wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, + UINT32* pMaskingKey); + +FREERDP_LOCAL BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket, + UINT32 maskingKey); #endif /* FREERDP_LIB_CORE_GATEWAY_WEBSOCKET_H */ diff --git a/libfreerdp/core/gateway/wst.c b/libfreerdp/core/gateway/wst.c index 9a98ec5de..079c005f7 100644 --- a/libfreerdp/core/gateway/wst.c +++ b/libfreerdp/core/gateway/wst.c @@ -65,7 +65,7 @@ struct rdp_wst char* gwhostname; uint16_t gwport; char* gwpath; - websocket_context wscontext; + websocket_context* wscontext; }; static const char arm_query_param[] = "%s%cClmTk=Bearer%%20%s&X-MS-User-Agent=FreeRDP%%2F3.0"; @@ -494,11 +494,7 @@ BOOL wst_connect(rdpWst* wst, DWORD timeout) return FALSE; if (isWebsocket) - { - wst->wscontext.state = WebsocketStateOpcodeAndFin; - wst->wscontext.responseStreamBuffer = NULL; - return TRUE; - } + return websocket_context_reset(wst->wscontext); else { char buffer[64] = { 0 }; @@ -537,7 +533,8 @@ static int wst_bio_write(BIO* bio, const char* buf, int num) WINPR_ASSERT(wst); BIO_clear_flags(bio, BIO_FLAGS_WRITE); EnterCriticalSection(&wst->writeSection); - status = websocket_write(wst->tls->bio, (const BYTE*)buf, num, WebsocketBinaryOpcode); + status = websocket_context_write(wst->wscontext, wst->tls->bio, (const BYTE*)buf, num, + WebsocketBinaryOpcode); LeaveCriticalSection(&wst->writeSection); if (status < 0) @@ -570,7 +567,7 @@ static int wst_bio_read(BIO* bio, char* buf, int size) while (status <= 0) { - status = websocket_read(wst->tls->bio, (BYTE*)buf, (size_t)size, &wst->wscontext); + status = websocket_context_read(wst->wscontext, wst->tls->bio, (BYTE*)buf, (size_t)size); if (status <= 0) { if (!BIO_should_retry(wst->tls->bio)) @@ -781,58 +778,59 @@ static BOOL wst_parse_url(rdpWst* wst, const char* url) rdpWst* wst_new(rdpContext* context) { - rdpWst* wst = NULL; - if (!context) return NULL; - wst = (rdpWst*)calloc(1, sizeof(rdpWst)); + rdpWst* wst = (rdpWst*)calloc(1, sizeof(rdpWst)); + if (!wst) + return NULL; - if (wst) + wst->context = context; + + wst->gwhostname = NULL; + wst->gwport = 443; + wst->gwpath = NULL; + + if (!wst_parse_url(wst, context->settings->GatewayUrl)) + goto wst_alloc_error; + + wst->tls = freerdp_tls_new(wst->context); + if (!wst->tls) + goto wst_alloc_error; + + wst->http = http_context_new(); + + if (!wst->http) + goto wst_alloc_error; + + if (!http_context_set_uri(wst->http, wst->gwpath) || + !http_context_set_accept(wst->http, "*/*") || + !http_context_set_cache_control(wst->http, "no-cache") || + !http_context_set_pragma(wst->http, "no-cache") || + !http_context_set_connection(wst->http, "Keep-Alive") || + !http_context_set_user_agent(wst->http, FREERDP_USER_AGENT) || + !http_context_set_x_ms_user_agent(wst->http, FREERDP_USER_AGENT) || + !http_context_set_host(wst->http, wst->gwhostname) || + !http_context_enable_websocket_upgrade(wst->http, TRUE)) { - wst->context = context; - - wst->gwhostname = NULL; - wst->gwport = 443; - wst->gwpath = NULL; - - if (!wst_parse_url(wst, context->settings->GatewayUrl)) - goto wst_alloc_error; - - wst->tls = freerdp_tls_new(wst->context); - if (!wst->tls) - goto wst_alloc_error; - - wst->http = http_context_new(); - - if (!wst->http) - goto wst_alloc_error; - - if (!http_context_set_uri(wst->http, wst->gwpath) || - !http_context_set_accept(wst->http, "*/*") || - !http_context_set_cache_control(wst->http, "no-cache") || - !http_context_set_pragma(wst->http, "no-cache") || - !http_context_set_connection(wst->http, "Keep-Alive") || - !http_context_set_user_agent(wst->http, FREERDP_USER_AGENT) || - !http_context_set_x_ms_user_agent(wst->http, FREERDP_USER_AGENT) || - !http_context_set_host(wst->http, wst->gwhostname) || - !http_context_enable_websocket_upgrade(wst->http, TRUE)) - { - goto wst_alloc_error; - } - - wst->frontBio = BIO_new(BIO_s_wst()); - - if (!wst->frontBio) - goto wst_alloc_error; - - BIO_set_data(wst->frontBio, wst); - InitializeCriticalSection(&wst->writeSection); - wst->auth = credssp_auth_new(context); - if (!wst->auth) - goto wst_alloc_error; + goto wst_alloc_error; } + wst->frontBio = BIO_new(BIO_s_wst()); + + if (!wst->frontBio) + goto wst_alloc_error; + + BIO_set_data(wst->frontBio, wst); + InitializeCriticalSection(&wst->writeSection); + wst->auth = credssp_auth_new(context); + if (!wst->auth) + goto wst_alloc_error; + + wst->wscontext = websocket_context_new(); + if (!wst->wscontext) + goto wst_alloc_error; + return wst; wst_alloc_error: WINPR_PRAGMA_DIAG_PUSH @@ -858,8 +856,7 @@ void wst_free(rdpWst* wst) DeleteCriticalSection(&wst->writeSection); - if (wst->wscontext.responseStreamBuffer != NULL) - Stream_Free(wst->wscontext.responseStreamBuffer, TRUE); + websocket_context_free(wst->wscontext); free(wst); }