diff --git a/include/freerdp/crypto/tls.h b/include/freerdp/crypto/tls.h index bf5521300..180007e5e 100644 --- a/include/freerdp/crypto/tls.h +++ b/include/freerdp/crypto/tls.h @@ -70,7 +70,6 @@ struct rdp_tls SSL* ssl; BIO* bio; void* tsg; - int sockfd; SSL_CTX* ctx; BYTE* PublicKey; BIO_METHOD* methods; @@ -84,17 +83,11 @@ struct rdp_tls int alertDescription; }; -FREERDP_API int tls_connect(rdpTls* tls); -FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file); +FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying); +FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file); FREERDP_API BOOL tls_disconnect(rdpTls* tls); -FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length); -FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_wait_read(rdpTls* tls); -FREERDP_API int tls_wait_write(rdpTls* tls); +FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length); FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description); diff --git a/include/freerdp/peer.h b/include/freerdp/peer.h index c89d37a07..4fbe75bfc 100644 --- a/include/freerdp/peer.h +++ b/include/freerdp/peer.h @@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context); typedef BOOL (*psPeerInitialize)(freerdp_peer* client); typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount); typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client); +typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client); typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client); +typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client); +typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client); typedef BOOL (*psPeerClose)(freerdp_peer* client); typedef void (*psPeerDisconnect)(freerdp_peer* client); typedef BOOL (*psPeerCapabilities)(freerdp_peer* client); @@ -62,6 +65,7 @@ struct rdp_freerdp_peer psPeerInitialize Initialize; psPeerGetFileDescriptor GetFileDescriptor; psPeerGetEventHandle GetEventHandle; + psPeerGetReceiveEventHandle GetReceiveEventHandle; psPeerCheckFileDescriptor CheckFileDescriptor; psPeerClose Close; psPeerDisconnect Disconnect; @@ -81,6 +85,9 @@ struct rdp_freerdp_peer BOOL activated; BOOL authenticated; SEC_WINNT_AUTH_IDENTITY identity; + + psPeerIsWriteBlocked IsWriteBlocked; + psPeerDrainOutputBuffer DrainOutputBuffer; }; #ifdef __cplusplus diff --git a/include/freerdp/settings.h b/include/freerdp/settings.h index cc609d1a0..1f0ebb278 100644 --- a/include/freerdp/settings.h +++ b/include/freerdp/settings.h @@ -801,7 +801,8 @@ struct rdp_settings ALIGN64 char* Password; /* 22 */ ALIGN64 char* Domain; /* 23 */ ALIGN64 char* PasswordHash; /* 24 */ - UINT64 padding0064[64 - 25]; /* 25 */ + ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */ + UINT64 padding0064[64 - 26]; /* 26 */ UINT64 padding0128[128 - 64]; /* 64 */ /** diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index c9f33f01a..610b23091 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -26,6 +26,10 @@ #include #include +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif + #include "http.h" HttpContext* http_context_new() @@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls) nbytes = 0; length = 10000; content = NULL; - buffer = malloc(length); + buffer = calloc(length, 1); if (!buffer) return NULL; @@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls) { while (nbytes < 5) { - status = tls_read(tls, p, length - nbytes); + status = BIO_read(tls->bio, p, length - nbytes); - if (status < 0) - goto out_error; + if (status <= 0) + { + if (!BIO_should_retry(tls->bio)) + goto out_error; - if (!status) + USleep(100); continue; + } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(p, status); +#endif nbytes += status; p = (BYTE*) &buffer[nbytes]; } @@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!header_end) { - fprintf(stderr, "http_response_recv: invalid response:\n"); + fprintf(stderr, "%s: invalid response:\n", __FUNCTION__); winpr_HexDump(buffer, status); goto out_error; } @@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls) header_end[0] = '\0'; header_end[1] = '\0'; - content = &header_end[2]; + content = header_end + 2; count = 0; line = (char*) buffer; @@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!http_response_parse_header(http_response)) goto out_error; - if (http_response->ContentLength > 0) + http_response->bodyLen = nbytes - (content - (char *)buffer); + if (http_response->bodyLen > 0) { - http_response->Content = _strdup(content); - if (!http_response->Content) + http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen); + if (!http_response->BodyContent) goto out_error; + + CopyMemory(http_response->BodyContent, content, http_response->bodyLen); } break; @@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response) ListDictionary_Free(http_response->Authenticates); if (http_response->ContentLength > 0) - free(http_response->Content); + free(http_response->BodyContent); free(http_response); } diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index 748b45a36..ded9ba214 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -84,7 +84,8 @@ struct _http_response wListDictionary *Authenticates; int ContentLength; - char* Content; + BYTE *BodyContent; + int bodyLen; }; void http_response_print(HttpResponse* http_response); diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index 270dafbcf..b5beff4b2 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm; http_response = http_response_recv(rpc->TlsIn); + if (!http_response) + return -1; if (ListDictionary_Contains(http_response->Authenticates, "NTLM")) { @@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) if (!token64) goto out; - ntlm_token_data = NULL; crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } +out: ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - -out: http_response_free(http_response); return 0; @@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel) rdpNtlm* ntlm = NULL; rdpSettings* settings = rpc->settings; freerdp* instance = (freerdp*) rpc->settings->instance; - BOOL promptPassword = FALSE; if (channel == TSG_CHANNEL_IN) ntlm = rpc->NtlmHttpIn->ntlm; else if (channel == TSG_CHANNEL_OUT) ntlm = rpc->NtlmHttpOut->ntlm; - if ((!settings->GatewayPassword) || (!settings->GatewayUsername) - || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) - { - promptPassword = TRUE; - } - - if (promptPassword) + if (!settings->GatewayPassword || !settings->GatewayUsername || + !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername)) { if (instance->GatewayAuthenticate) { - BOOL proceed = instance->GatewayAuthenticate(instance, - &settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain); + BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername, + &settings->GatewayPassword, &settings->GatewayDomain); if (!proceed) { @@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc) char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM"); crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } - ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - + http_response_free(http_response); - return 0; } @@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc) success = TRUE; /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); /* Receive OUT Channel Response */ - rpc_ncacn_http_recv_out_channel_response(rpc); /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); ntlm_client_uninit(ntlm); @@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL if (channel == TSG_CHANNEL_IN) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); } else if (channel == TSG_CHANNEL_OUT) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", " + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624"); } } diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index c91a71071..2432ab06c 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -33,6 +33,11 @@ #include #include +#include + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif #include "http.h" #include "ntlm.h" @@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l { UINT32 alloc_hint = 0; rpcconn_hdr_t* header; + UINT32 frag_length; + UINT32 auth_length; + UINT32 auth_pad_length; + UINT32 sec_trailer_offset; + rpc_sec_trailer* sec_trailer; *offset = RPC_COMMON_FIELDS_LENGTH; header = ((rpcconn_hdr_t*) buffer); - if (header->common.ptype == PTYPE_RESPONSE) + switch (header->common.ptype) { - *offset += 8; - rpc_offset_align(offset, 8); - alloc_hint = header->response.alloc_hint; - } - else if (header->common.ptype == PTYPE_REQUEST) - { - *offset += 4; - rpc_offset_align(offset, 8); - alloc_hint = header->request.alloc_hint; - } - else if (header->common.ptype == PTYPE_RTS) - { - *offset += 4; - } - else - { - return FALSE; + case PTYPE_RESPONSE: + *offset += 8; + rpc_offset_align(offset, 8); + alloc_hint = header->response.alloc_hint; + break; + case PTYPE_REQUEST: + *offset += 4; + rpc_offset_align(offset, 8); + alloc_hint = header->request.alloc_hint; + break; + case PTYPE_RTS: + *offset += 4; + break; + default: + fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype); + return FALSE; } - if (length) + if (!length) + return TRUE; + + if (header->common.ptype == PTYPE_REQUEST) { - if (header->common.ptype == PTYPE_REQUEST) - { - UINT32 sec_trailer_offset; + UINT32 sec_trailer_offset; - sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; - *length = sec_trailer_offset - *offset; - } - else - { - UINT32 frag_length; - UINT32 auth_length; - UINT32 auth_pad_length; - UINT32 sec_trailer_offset; - rpc_sec_trailer* sec_trailer; + sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; + *length = sec_trailer_offset - *offset; + return TRUE; + } - frag_length = header->common.frag_length; - auth_length = header->common.auth_length; - sec_trailer_offset = frag_length - auth_length - 8; - sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; - auth_pad_length = sec_trailer->auth_pad_length; + frag_length = header->common.frag_length; + auth_length = header->common.auth_length; + + sec_trailer_offset = frag_length - auth_length - 8; + sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; + auth_pad_length = sec_trailer->auth_pad_length; #if 0 - fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", - sec_trailer->auth_type, - sec_trailer->auth_level, - sec_trailer->auth_pad_length, - sec_trailer->auth_reserved, - sec_trailer->auth_context_id); + fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", + sec_trailer->auth_type, + sec_trailer->auth_level, + sec_trailer->auth_pad_length, + sec_trailer->auth_reserved, + sec_trailer->auth_context_id); #endif - /** - * According to [MS-RPCE], auth_pad_length is the number of padding - * octets used to 4-byte align the security trailer, but in practice - * we get values up to 15, which indicates 16-byte alignment. - */ + /** + * According to [MS-RPCE], auth_pad_length is the number of padding + * octets used to 4-byte align the security trailer, but in practice + * we get values up to 15, which indicates 16-byte alignment. + */ - if ((frag_length - (sec_trailer_offset + 8)) != auth_length) - { - fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, - (frag_length - (sec_trailer_offset + 8))); - } - - *length = frag_length - auth_length - 24 - 8 - auth_pad_length; - } + if ((frag_length - (sec_trailer_offset + 8)) != auth_length) + { + fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, + (frag_length - (sec_trailer_offset + 8))); } + *length = frag_length - auth_length - 24 - 8 - auth_pad_length; return TRUE; } @@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length) { int status; - status = tls_read(rpc->TlsOut, data, length); + status = BIO_read(rpc->TlsOut->bio, data, length); + /* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length, + * status, BIO_should_retry(rpc->TlsOut->bio)); */ + if (status > 0) { +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data, status); +#endif + return status; + } - return status; + if (BIO_should_retry(rpc->TlsOut->bio)) + return 0; + + return -1; } -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) return status; } -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) ntlm = rpc->ntlm; - if ((!ntlm) || (!ntlm->table)) + if (!ntlm || !ntlm->table) { - fprintf(stderr, "rpc_write: invalid ntlm context\n"); + fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__); return -1; } if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK) { - fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n"); + fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__); return -1; } - request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t)); - ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t)); + request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t)); + if (!request_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); @@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->opnum = opnum; clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + goto out_free_pdu; + + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + goto out_free_clientCall; if (request_pdu->opnum == TsProxySetupReceivePipeOpnum) rpc->PipeCallId = request_pdu->call_id; @@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->frag_length = offset; - buffer = (BYTE*) malloc(request_pdu->frag_length); - + buffer = (BYTE*) calloc(1, request_pdu->frag_length); + if (!buffer) + goto out_free_pdu; CopyMemory(buffer, request_pdu, 24); offset = 24; @@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) Buffers[0].cbBuffer = offset; Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature; - Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer); - ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer); + Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer); + if (!Buffers[1].pvBuffer) + return -1; Message.cBuffers = 2; Message.ulVersion = SECBUFFER_VERSION; Message.pBuffers = (PSecBuffer) &Buffers; encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++); - if (encrypt_status != SEC_E_OK) { fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status); @@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) offset += Buffers[1].cbBuffer; free(Buffers[1].pvBuffer); - if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0) + if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0) length = -1; free(request_pdu); return length; + +out_free_clientCall: + rpc_client_call_free(clientCall); +out_free_pdu: + free(request_pdu); + return -1; } BOOL rpc_connect(rdpRpc* rpc) @@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CallId = 2; - rpc_client_new(rpc); + if (rpc_client_new(rpc) < 0) + goto out_free_virtualConnectionCookieTable; rpc->client->SynchronousSend = TRUE; rpc->client->SynchronousReceive = TRUE; return rpc; +out_free_virtualConnectionCookieTable: + rpc_client_free(rpc); + ArrayList_Free(rpc->VirtualConnectionCookieTable); out_free_virtual_connection: rpc_client_virtual_connection_free(rpc->VirtualConnection); out_free_ntlm_http_out: diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index d10d665c7..c86a8618f 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad); int rpc_out_read(rdpRpc* rpc, BYTE* data, int length); -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length); -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length); +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length); +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length); BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length); diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c index cf02a802a..ceae95159 100644 --- a/libfreerdp/core/gateway/rpc_bind.c +++ b/libfreerdp/core/gateway/rpc_bind.c @@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) DEBUG_RPC("Sending bind PDU"); rpc->ntlm = ntlm_new(); + if (!rpc->ntlm) + return -1; if ((!settings->GatewayPassword) || (!settings->GatewayUsername) || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) @@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc) settings->Username = _strdup(settings->GatewayUsername); settings->Domain = _strdup(settings->GatewayDomain); settings->Password = _strdup(settings->GatewayPassword); + + if (!settings->Username || !settings->Domain || settings->Password) + return -1; } } } - ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL); - ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname); + if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) || + !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) || + !ntlm_authenticate(rpc->ntlm) + ) + return -1; - ntlm_authenticate(rpc->ntlm); - - bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t)); - ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t)); + bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t)); + if (!bind_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu); @@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->p_context_elem.reserved2 = 0; bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem); + if (!bind_pdu->p_context_elem.p_cont_elem) + return -1; p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0]; @@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->frag_length = offset; buffer = (BYTE*) malloc(bind_pdu->frag_length); + if (!buffer) + return -1; CopyMemory(buffer, bind_pdu, 24); CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4); @@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc) length = bind_pdu->frag_length; clientCall = rpc_client_call_new(bind_pdu->call_id, 0); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + return -1; + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + return -1; if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0) length = -1; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index dff88b3e5..c3613f6be 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -34,9 +34,7 @@ #include #include "rpc_fault.h" - #include "rpc_client.h" - #include "../rdp.h" #define SYNCHRONOUS_TIMEOUT 5000 @@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) if (!pdu) { - pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); + pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU)); + if (!pdu) + return NULL; pdu->s = Stream_New(NULL, rpc->max_recv_frag); + if (!pdu->s) + { + free(pdu); + return NULL; + } } pdu->CallId = 0; @@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu) { - Queue_Enqueue(rpc->client->ReceivePool, pdu); - return 0; + return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1; } int rpc_client_on_fragment_received_event(rdpRpc* rpc) @@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) rpcconn_hdr_t* header; freerdp* instance; - instance = (freerdp*) rpc->transport->settings->instance; + instance = (freerdp *)rpc->transport->settings->instance; if (!rpc->client->pdu) rpc->client->pdu = rpc_client_receive_pool_take(rpc); @@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) return 0; } - if (header->common.ptype == PTYPE_RTS) + switch (header->common.ptype) { - if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED) - { - //fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n"); + case PTYPE_RTS: + if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED) + { + fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__); + return 0; + } + fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__); rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length); - rpc_client_fragment_pool_return(rpc, fragment); - } - else - { - fprintf(stderr, "warning: unhandled RTS PDU\n"); - } + return 0; - return 0; - } - else if (header->common.ptype == PTYPE_FAULT) - { - rpc_recv_fault_pdu(header); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; - } - - if (header->common.ptype != PTYPE_RESPONSE) - { - fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; + case PTYPE_FAULT: + rpc_recv_fault_pdu(header); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; + case PTYPE_RESPONSE: + break; + default: + fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; } rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length; @@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) { - fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n"); + fprintf(stderr, "%s: expected stub\n", __FUNCTION__); Queue_Enqueue(rpc->client->ReceiveQueue, NULL); return -1; } @@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (rpc->StubCallId != header->common.call_id) { - fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n", + fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__, rpc->StubCallId, header->common.call_id, rpc->StubFragCount); } @@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc) int status = -1; rpcconn_common_hdr_t* header; - if (!rpc->client->RecvFrag) - rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - - position = Stream_GetPosition(rpc->client->RecvFrag); - - if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + while (1) { - status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), - RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + if (!rpc->client->RecvFrag) + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - if (status < 0) + position = Stream_GetPosition(rpc->client->RecvFrag); + + while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) { - fprintf(stderr, "rpc_client_frag_read: error reading header\n"); - return -1; + status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), + RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + + if (status < 0) + { + fprintf(stderr, "rpc_client_frag_read: error reading header\n"); + return -1; + } + + if (!status) + return 0; + + Stream_Seek(rpc->client->RecvFrag, status); } - Stream_Seek(rpc->client->RecvFrag, status); - } + if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + return status; + - if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH) - { header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag); if (header->frag_length > rpc->max_recv_frag) @@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc) return -1; } - if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) + while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) { status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), header->frag_length - Stream_GetPosition(rpc->client->RecvFrag)); if (status < 0) { - fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n"); + fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__); return -1; } + if (!status) + return 0; + Stream_Seek(rpc->client->RecvFrag, status); } - } - else - { - return status; - } - if (status < 0) - return -1; - - status = Stream_GetPosition(rpc->client->RecvFrag) - position; - - if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) - { - /* complete fragment received */ - - Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); - Stream_SetPosition(rpc->client->RecvFrag, 0); - - Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); - rpc->client->RecvFrag = NULL; - - if (rpc_client_on_fragment_received_event(rpc) < 0) + if (status < 0) return -1; + + status = Stream_GetPosition(rpc->client->RecvFrag) - position; + + if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) + { + /* complete fragment received */ + + Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); + Stream_SetPosition(rpc->client->RecvFrag, 0); + + Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); + rpc->client->RecvFrag = NULL; + + if (rpc_client_on_fragment_received_event(rpc) < 0) + return -1; + } } - return status; + return 0; } /** @@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum) RpcClientCall* clientCall; clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall)); + if (!clientCall) + return NULL; - if (clientCall) - { - clientCall->CallId = CallId; - clientCall->OpNum = OpNum; - clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; - } + clientCall->CallId = CallId; + clientCall->OpNum = OpNum; + clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; return clientCall; } @@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) int status; pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); - pdu->s = Stream_New(buffer, length); + if (!pdu) + return -1; - Queue_Enqueue(rpc->client->SendQueue, pdu); + pdu->s = Stream_New(buffer, length); + if (!pdu->s) + goto out_free; + + if (!Queue_Enqueue(rpc->client->SendQueue, pdu)) + goto out_free_stream; if (rpc->client->SynchronousSend) { status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT); if (status == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n"); + fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent); return -1; } @@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) } return 0; + +out_free_stream: + Stream_Free(pdu->s, TRUE); +out_free: + free(pdu); + return -1; } int rpc_send_dequeue_pdu(rdpRpc* rpc) @@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) RPC_PDU* pdu; RpcClientCall* clientCall; rpcconn_common_hdr_t* header; + RpcInChannel *inChannel; pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue); - if (!pdu) return 0; - WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE); + inChannel = rpc->VirtualConnection->DefaultInChannel; + WaitForSingleObject(inChannel->Mutex, INFINITE); status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); @@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) clientCall = rpc_client_call_find_by_id(rpc, header->call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; - ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex); + ReleaseMutex(inChannel->Mutex); /* * This protocol specifies that only RPC PDUs are subject to the flow control abstract @@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) if (header->ptype == PTYPE_REQUEST) { - rpc->VirtualConnection->DefaultInChannel->BytesSent += status; - rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status; + inChannel->BytesSent += status; + inChannel->SenderAvailableWindow -= status; } Stream_Free(pdu->s, TRUE); @@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc) DWORD dwMilliseconds; DWORD result; - pdu = NULL; - dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; + dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); if (result == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n"); + fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__); return NULL; } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue); + if (result != WAIT_OBJECT_0) + return NULL; + + pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue); #ifdef WITH_DEBUG_TSG - if (pdu) - { - fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); - winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); - fprintf(stderr, "\n"); - } -#endif - - return pdu; + if (pdu) + { + fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); + winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); + fprintf(stderr, "\n"); } + else + { + fprintf(stderr, "Receiving a NULL PDU\n"); + } +#endif return pdu; } RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc) { - RPC_PDU* pdu; DWORD dwMilliseconds; DWORD result; - pdu = NULL; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); - if (result == WAIT_TIMEOUT) - { + if (result != WAIT_OBJECT_0) return NULL; - } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue); - return pdu; - } - - return pdu; + return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue); } static void* rpc_client_thread(void* arg) @@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg) DWORD nCount; HANDLE events[3]; HANDLE ReadEvent; + int fd; rpc = (rdpRpc*) arg; + fd = BIO_get_fd(rpc->TlsOut->bio, NULL); - ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd); + ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd); nCount = 0; events[nCount++] = rpc->client->StopEvent; events[nCount++] = Queue_Event(rpc->client->SendQueue); events[nCount++] = ReadEvent; + /* Do a first free run in case some bytes were set from the HTTP headers. + * We also have to do it because most of the time the underlying socket has notified, + * and the ssl layer has eaten all bytes, so we won't be notified any more even if the + * bytes are buffered locally + */ + if (rpc_client_on_read_event(rpc) < 0) + { + fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__); + goto out; + } + while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED) { status = WaitForMultipleObjects(nCount, events, FALSE, 100); - if (status != WAIT_TIMEOUT) + if (status == WAIT_TIMEOUT) + continue; + + if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) + break; + + if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) { - if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) - { + if (rpc_client_on_read_event(rpc) < 0) break; - } + } - if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) - { - if (rpc_client_on_read_event(rpc) < 0) - break; - } - - if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) - { - rpc_send_dequeue_pdu(rpc); - } + if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) + { + rpc_send_dequeue_pdu(rpc); } } +out: CloseHandle(ReadEvent); return NULL; @@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg) static void rpc_pdu_free(RPC_PDU* pdu) { + if (!pdu) + return; + Stream_Free(pdu->s, TRUE); free(pdu); } @@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc) { RpcClient* client = NULL; - client = (RpcClient*) calloc(1, sizeof(RpcClient)); - - if (client) - { - client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - - client->SendQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->pdu = NULL; - client->ReceivePool = Queue_New(TRUE, -1, -1); - client->ReceiveQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->RecvFrag = NULL; - client->FragmentPool = Queue_New(TRUE, -1, -1); - client->FragmentQueue = Queue_New(TRUE, -1, -1); - - Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - - client->ClientCallList = ArrayList_New(TRUE); - ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; - } - + client = (RpcClient *)calloc(1, sizeof(RpcClient)); rpc->client = client; + if (!client) + return -1; + client->Thread = CreateThread(NULL, 0, + (LPTHREAD_START_ROUTINE) rpc_client_thread, + rpc, CREATE_SUSPENDED, NULL); + if (!client->Thread) + return -1; + + client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->StopEvent) + return -1; + client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->PduSentEvent) + return -1; + + client->SendQueue = Queue_New(TRUE, -1, -1); + if (!client->SendQueue) + return -1; + Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->pdu = NULL; + client->ReceivePool = Queue_New(TRUE, -1, -1); + if (!client->ReceivePool) + return -1; + Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->ReceiveQueue = Queue_New(TRUE, -1, -1); + if (!client->ReceiveQueue) + return -1; + Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->RecvFrag = NULL; + client->FragmentPool = Queue_New(TRUE, -1, -1); + if (!client->FragmentPool) + return -1; + Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->FragmentQueue = Queue_New(TRUE, -1, -1); + if (!client->FragmentQueue) + return -1; + Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->ClientCallList = ArrayList_New(TRUE); + if (!client->ClientCallList) + return -1; + ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; return 0; } @@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc) rpc->client->Thread = NULL; } - rpc_client_free(rpc); - - return 0; + return rpc_client_free(rpc); } int rpc_client_free(rdpRpc* rpc) @@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc) client = rpc->client; - if (client) - { + if (!client) + return 0; + + if (client->SendQueue) Queue_Free(client->SendQueue); - if (client->RecvFrag) - rpc_fragment_free(client->RecvFrag); + if (client->RecvFrag) + rpc_fragment_free(client->RecvFrag); + if (client->FragmentPool) Queue_Free(client->FragmentPool); + if (client->FragmentQueue) Queue_Free(client->FragmentQueue); - if (client->pdu) - rpc_pdu_free(client->pdu); + if (client->pdu) + rpc_pdu_free(client->pdu); + if (client->ReceivePool) Queue_Free(client->ReceivePool); + if (client->ReceiveQueue) Queue_Free(client->ReceiveQueue); + if (client->ClientCallList) ArrayList_Free(client->ClientCallList); + if (client->StopEvent) CloseHandle(client->StopEvent); + if (client->PduSentEvent) CloseHandle(client->PduSentEvent); + if (client->Thread) CloseHandle(client->Thread); - free(client); - } - + free(client); return 0; } diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index 42ce2ad4e..d57a4240d 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc) if (!rpc_ntlm_http_out_connect(rpc)) { - fprintf(stderr, "rpc_out_connect_http error!\n"); + fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__); return FALSE; } if (rts_send_CONN_A1_pdu(rpc) != 0) { - fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__); return FALSE; } if (!rpc_ntlm_http_in_connect(rpc)) { - fprintf(stderr, "rpc_in_connect_http error!\n"); + fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__); return FALSE; } - if (rts_send_CONN_B1_pdu(rpc) != 0) + if (rts_send_CONN_B1_pdu(rpc) < 0) { - fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__); return FALSE; } @@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc) */ http_response = http_response_recv(rpc->TlsOut); + if (!http_response) + { + fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__); + return FALSE; + } if (http_response->StatusCode != HTTP_STATUS_OK) { - fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode); + fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode); http_response_print(http_response); http_response_free(http_response); @@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc) return FALSE; } + if (http_response->bodyLen) + { + /* inject bytes we have read in the body as a received packet for the RPC client */ + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); + Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen); + CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen); + } + //http_response_print(http_response); http_response_free(http_response); @@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc) rpc_client_start(rpc); pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__); return FALSE; } @@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc) */ pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__); return FALSE; } @@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc) return TRUE; } -#if defined WITH_DEBUG_RTS && 0 +#ifdef WITH_DEBUG_RTS static const char* const RTS_CMD_STRINGS[] = { @@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] = void rts_pdu_header_init(rpcconn_rts_hdr_t* header) { + ZeroMemory(header, sizeof(*header)); header->rpc_vers = 5; header->rpc_vers_minor = 0; header->ptype = PTYPE_RTS; @@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc) ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) BYTE* INChannelCookie; BYTE* AssociationGroupId; BYTE* VirtualConnectionCookie; + int status; rts_pdu_header_init(&header); header.frag_length = 104; @@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + status = rpc_in_write(rpc, buffer, length); free(buffer); - return 0; + return status; } /* CONN/C Sequence */ @@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Keep-Alive RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */ @@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return 0; @@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Ping RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) rts_extract_pdu_signature(rpc, &signature, rts); SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL); - if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK) + switch (SignatureId) { - return rts_recv_flow_control_ack_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION) - { - return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_PING) - { - rts_send_ping_pdu(rpc); - } - else - { - fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId); - rts_print_pdu_signature(rpc, &signature); + case RTS_PDU_FLOW_CONTROL_ACK: + return rts_recv_flow_control_ack_pdu(rpc, buffer, length); + case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION: + return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); + case RTS_PDU_PING: + return rts_send_ping_pdu(rpc); + default: + fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId); + rts_print_pdu_signature(rpc, &signature); + break; } return 0; diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c index 34598fe71..47242ca63 100644 --- a/libfreerdp/core/gateway/rts_signature.c +++ b/libfreerdp/core/gateway/rts_signature.c @@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt return FALSE; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r signature->CommandTypes[i] = CommandType; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P { pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; - if (signature->Flags == pSignature->Flags) + if (signature->Flags != pSignature->Flags) + continue; + + if (signature->NumberOfCommands != pSignature->NumberOfCommands) + continue; + + for (j = 0; j < signature->NumberOfCommands; j++) { - if (signature->NumberOfCommands == pSignature->NumberOfCommands) - { - for (j = 0; j < signature->NumberOfCommands; j++) - { - if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) - continue; - } - - if (entry) - *entry = &RTS_PDU_SIGNATURE_TABLE[i]; - - return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; - } + if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) + continue; } + + if (entry) + *entry = &RTS_PDU_SIGNATURE_TABLE[i]; + + return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; } return 0; diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index f130f73ab..5dd68886d 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -33,9 +33,9 @@ #include #include "rpc_client.h" - #include "tsg.h" + /** * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/ * Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/ @@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count, } length = 28 + totalDataBytes; - buffer = (BYTE*) malloc(length); + buffer = (BYTE*) calloc(1, length); + if (!buffer) + return -1; s = Stream_New(buffer, length); @@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; // Skip Packet Pointer packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE)) { - packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE)); - ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE)); + packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE)); + if (!packetCapsResponse) // TODO: correct cleanup + return FALSE; packet->tsgPacket.packetCapsResponse = packetCapsResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) IsMessagePresent = *((UINT32*) &buffer[offset]); offset += 4; MessageSwitchValue = *((UINT32*) &buffer[offset]); - DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", - IsMessagePresent, MessageSwitchValue); + DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue); offset += 4; } @@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: correct cleanup + return FALSE; packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) /* 4-byte alignment */ rpc_offset_align(&offset, 4); - tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES)); - ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES)); + tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES)); + if (!tsgCaps) + return FALSE; + versionCaps->tsgCaps = tsgCaps; offset += 4; /* MaxCount (4 bytes) */ @@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) } else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE)) { - packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE)); - ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + if (!packetQuarEncResponse) // TODO: handle cleanup + return FALSE; packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: handle cleanup + return FALSE; packetQuarEncResponse->versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI length = 60 + (count * 2); buffer = (BYTE*) malloc(length); + if (!buffer) + return FALSE; /* TunnelContext */ handle = (CONTEXT_HANDLE*) tunnelContext; @@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) return CopyLength; } - else + + + tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->pdu) { - tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->rpc->client->SynchronousReceive) + return 0; - if (!tsg->pdu) - { - if (tsg->rpc->client->SynchronousReceive) - return tsg_read(tsg, data, length); - else - return 0; - } - - tsg->PendingPdu = TRUE; - tsg->BytesAvailable = Stream_Length(tsg->pdu->s); - tsg->BytesRead = 0; - - CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; - - CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); - tsg->BytesAvailable -= CopyLength; - tsg->BytesRead += CopyLength; - - if (tsg->BytesAvailable < 1) - { - tsg->PendingPdu = FALSE; - rpc_recv_dequeue_pdu(rpc); - rpc_client_receive_pool_return(rpc, tsg->pdu); - } - - return CopyLength; + // weird !!!! + return tsg_read(tsg, data, length); } + + tsg->PendingPdu = TRUE; + tsg->BytesAvailable = Stream_Length(tsg->pdu->s); + tsg->BytesRead = 0; + + CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; + + CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); + tsg->BytesAvailable -= CopyLength; + tsg->BytesRead += CopyLength; + + if (tsg->BytesAvailable < 1) + { + tsg->PendingPdu = FALSE; + rpc_recv_dequeue_pdu(rpc); + rpc_client_receive_pool_return(rpc, tsg->pdu); + } + + return CopyLength; + } int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length) { + int status; + if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED) { - fprintf(stderr, "tsg_write error: connection lost\n"); + fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__); return -1; } - return TsProxySendToServer((handle_t) tsg, data, 1, &length); + status = TsProxySendToServer((handle_t) tsg, data, 1, &length); + if (status < 0) + return -1; + return length; } BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking) @@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport) { rdpTsg* tsg; - tsg = (rdpTsg*) malloc(sizeof(rdpTsg)); - ZeroMemory(tsg, sizeof(rdpTsg)); - - if (tsg != NULL) - { - tsg->transport = transport; - tsg->settings = transport->settings; - tsg->rpc = rpc_new(tsg->transport); - tsg->PendingPdu = FALSE; - } + tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg)); + if (!tsg) + return NULL; + tsg->transport = transport; + tsg->settings = transport->settings; + tsg->rpc = rpc_new(tsg->transport); + if (!tsg->rpc) + goto out_free; + tsg->PendingPdu = FALSE; return tsg; + +out_free: + free(tsg); + return NULL; } void tsg_free(rdpTsg* tsg) diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index e1662d335..bc7431f47 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client) fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile); return FALSE; } + if (settings->RdpServerRsaKey->ModulusLength > 256) { fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__); fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile); exit(1); } - } return TRUE; @@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client) return client->context->rdp->transport->TcpIn->event; } -static BOOL freerdp_peer_check_fds(freerdp_peer* client) + +static BOOL freerdp_peer_check_fds(freerdp_peer* peer) { int status; rdpRdp* rdp; - rdp = client->context->rdp; + rdp = peer->context->rdp; status = rdp_check_fds(rdp); @@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId return rdp_send_channel_data(client->context->rdp, channelId, data, size); } +static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer) +{ + return tranport_is_write_blocked(peer->context->rdp->transport); +} + +static int freerdp_peer_drain_output_buffer(freerdp_peer* peer) +{ + + rdpTransport *transport = peer->context->rdp->transport; + + return tranport_drain_output_buffer(transport); +} + void freerdp_peer_context_new(freerdp_peer* client) { rdpRdp* rdp; @@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client) rdp->transport->ReceiveExtra = client; transport_set_blocking_mode(rdp->transport, FALSE); + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; + IFCALL(client->ContextNew, client, client->context); } @@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd) client->Close = freerdp_peer_close; client->Disconnect = freerdp_peer_disconnect; client->SendChannelData = freerdp_peer_send_channel_data; + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; } return client; @@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd) void freerdp_peer_free(freerdp_peer* client) { - if (client) - { - rdp_free(client->context->rdp); - free(client->context); - free(client); - } + if (!client) + return; + + rdp_free(client->context->rdp); + free(client->context); + free(client); } diff --git a/libfreerdp/core/proxy.c b/libfreerdp/core/proxy.c index d21af9f66..4fdf3e8fe 100644 --- a/libfreerdp/core/proxy.c +++ b/libfreerdp/core/proxy.c @@ -18,13 +18,13 @@ */ +#include "proxy.h" #include "freerdp/settings.h" #include "tcp.h" #include "winpr/environment.h" /* For GetEnvironmentVariableA */ -/* TODO move into core/tcp.c? */ void http_proxy_read_environment(rdpSettings *settings, char *envname) { char env[256]; @@ -60,7 +60,7 @@ void http_proxy_read_environment(rdpSettings *settings, char *envname) freerdp_set_param_string(settings, FreeRDP_HTTPProxyHostname, hostname); } -BOOL http_proxy_connect(rdpTcp* tcp, const char* hostname, UINT16 port) +BOOL http_proxy_connect(BIO* bio, const char* hostname, UINT16 port) { int status; wStream* s; @@ -84,7 +84,7 @@ BOOL http_proxy_connect(rdpTcp* tcp, const char* hostname, UINT16 port) send_length = Stream_GetPosition(s); Stream_SetPosition(s, 0); while (send_length > 0) { - status = tcp_write(tcp, Stream_Pointer(s), send_length); + status = BIO_write(bio, Stream_Pointer(s), send_length); if (status < 0) { fprintf(stderr, "HTTP Proxy connection: error while writing: %d\n", status); return FALSE; @@ -111,14 +111,14 @@ BOOL http_proxy_connect(rdpTcp* tcp, const char* hostname, UINT16 port) return FALSE; } - status = tcp_read(tcp, (BYTE*)str + resultsize, sizeof(str)-resultsize-1); + status = BIO_read(bio, (BYTE*)str + resultsize, sizeof(str)-resultsize-1); if (status < 0) { /* Error? */ return FALSE; } else if (status == 0) { /* Error? */ - fprintf(stderr, "tcp_read() returned zero\n"); + fprintf(stderr, "BIO_read() returned zero\n"); return FALSE; } fprintf(stderr, "HTTP Proxy: received %d bytes\n", status); diff --git a/libfreerdp/core/proxy.h b/libfreerdp/core/proxy.h index 51613d587..f868bbca5 100644 --- a/libfreerdp/core/proxy.h +++ b/libfreerdp/core/proxy.h @@ -20,7 +20,10 @@ #ifndef __HTTP_PROXY_H #define __HTTP_PROXY_H +#include "freerdp/settings.h" +#include + void http_proxy_read_environment(rdpSettings *settings, char *envname); -BOOL http_proxy_connect(rdpTcp* tcp, const char* hostname, UINT16 port); +BOOL http_proxy_connect(BIO *bio, const char* hostname, UINT16 port); #endif diff --git a/libfreerdp/core/settings.c b/libfreerdp/core/settings.c index 4dcf87457..f504cf076 100644 --- a/libfreerdp/core/settings.c +++ b/libfreerdp/core/settings.c @@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags) ZeroMemory(settings, sizeof(rdpSettings)); settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE; + settings->WaitForOutputBufferFlush = TRUE; settings->DesktopWidth = 1024; settings->DesktopHeight = 768; @@ -581,6 +582,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings) /* BOOL values */ _settings->ServerMode = settings->ServerMode; /* 16 */ + _settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */ _settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */ _settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */ _settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */ diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index a51f9d567..8ca127ff7 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -67,6 +67,165 @@ #include "tcp.h" #include "proxy.h" +long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret) +{ + return 1; +} + +static int transport_bio_buffered_write(BIO* bio, const char* buf, int num) +{ + int status, ret; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int nchunks, committedBytes, i; + DataChunk chunks[2]; + + ret = num; + BIO_clear_retry_flags(bio); + tcp->writeBlocked = FALSE; + + /* we directly append extra bytes in the xmit buffer, this could be prevented + * but for now it makes the code more simple. + */ + if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num)) + { + fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num); + return -1; + } + + committedBytes = 0; + nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)); + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size); + /*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks, + chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status, + BIO_should_retry(bio->next_bio) + );*/ + if (status <= 0) + { + if (BIO_should_retry(bio->next_bio)) + { + tcp->writeBlocked = TRUE; + goto out; /* EWOULDBLOCK */ + } + + /* any other is an error, but we still have to commit written bytes */ + ret = -1; + goto out; + } + + committedBytes += status; + chunks[i].size -= status; + chunks[i].data += status; + } + } + +out: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes); + return ret; +} + +static int transport_bio_buffered_read(BIO* bio, char* buf, int size) +{ + int status; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + tcp->readBlocked = FALSE; + BIO_clear_retry_flags(bio); + + status = BIO_read(bio->next_bio, buf, size); + /*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */ + + if (status <= 0 && BIO_should_retry(bio->next_bio)) + { + BIO_set_retry_read(bio); + tcp->readBlocked = TRUE; + } + + return status; +} + +static int transport_bio_buffered_puts(BIO* bio, const char* str) +{ + return 1; +} + +static int transport_bio_buffered_gets(BIO* bio, char* str, int size) +{ + return 1; +} + +static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + switch (cmd) + { + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_WPENDING: + return ringbuffer_used(&tcp->xmitBuffer); + case BIO_CTRL_PENDING: + return 0; + default: + /*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */ + return BIO_ctrl(bio->next_bio, cmd, arg1, arg2); + } + + return 0; +} + +static int transport_bio_buffered_new(BIO* bio) +{ + bio->init = 1; + bio->num = 0; + bio->ptr = NULL; + bio->flags = 0; + + return 1; +} + +static int transport_bio_buffered_free(BIO* bio) +{ + return 1; +} + + +static BIO_METHOD transport_bio_buffered_socket_methods = +{ + BIO_TYPE_BUFFERED, + "BufferedSocket", + transport_bio_buffered_write, + transport_bio_buffered_read, + transport_bio_buffered_puts, + transport_bio_buffered_gets, + transport_bio_buffered_ctrl, + transport_bio_buffered_new, + transport_bio_buffered_free, + NULL, +}; + +BIO_METHOD* BIO_s_buffered_socket(void) +{ + return &transport_bio_buffered_socket_methods; +} + +BOOL transport_bio_buffered_drain(BIO *bio) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int status; + + if (!ringbuffer_used(&tcp->xmitBuffer)) + return 1; + + status = transport_bio_buffered_write(bio, NULL, 0); + return status >= 0; +} + + + void tcp_get_ip_address(rdpTcp* tcp) { BYTE* ip; @@ -137,70 +296,79 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port) if (hostname[0] == '/') { tcp->sockfd = freerdp_uds_connect(hostname); - if (tcp->sockfd < 0) return FALSE; + + tcp->socketBio = BIO_new_fd(tcp->sockfd, 1); + if (!tcp->socketBio) + return FALSE; } else { + tcp->socketBio = BIO_new(BIO_s_connect()); + if (!tcp->socketBio) + return FALSE; + if (tcp->settings->HTTPProxyEnabled) { printf("HTTP Proxy enabled: %s:%d!\n", tcp->settings->HTTPProxyHostname, tcp->settings->HTTPProxyPort); - tcp->sockfd = freerdp_tcp_connect(tcp->settings->HTTPProxyHostname, tcp->settings->HTTPProxyPort); - if (!http_proxy_connect(tcp, hostname, port)) + if (BIO_set_conn_hostname(tcp->socketBio, tcp->settings->HTTPProxyHostname) < 0 || + BIO_set_conn_int_port(tcp->socketBio, &tcp->settings->HTTPProxyPort) < 0) + return FALSE; + + if (BIO_do_connect(tcp->socketBio) <= 0) + return FALSE; + + if (!http_proxy_connect(tcp->socketBio, hostname, port)) return FALSE; } else { printf("HTTP Proxy disabled\n"); - tcp->sockfd = freerdp_tcp_connect(hostname, port); + if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0) + return FALSE; + + if (BIO_do_connect(tcp->socketBio) <= 0) + return FALSE; } - if (tcp->sockfd < 0) - return FALSE; - - SetEventFileDescriptor(tcp->event, tcp->sockfd); - - tcp_get_ip_address(tcp); - tcp_get_mac_address(tcp); - - option_value = 1; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len); - - /* receive buffer must be a least 32 K */ - if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) - { - if (option_value < (1024 * 32)) - { - option_value = 1024 * 32; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len); - } - } - - tcp_set_keep_alive_mode(tcp); + tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL); } + SetEventFileDescriptor(tcp->event, tcp->sockfd); + + tcp_get_ip_address(tcp); + tcp_get_mac_address(tcp); + + option_value = 1; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0) + fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__); + + /* receive buffer must be a least 32 K */ + if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) + { + if (option_value < (1024 * 32)) + { + option_value = 1024 * 32; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0) + { + fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__); + return FALSE; + } + } + } + + if (!tcp_set_keep_alive_mode(tcp)) + return FALSE; + + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); return TRUE; } -int tcp_read(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_read(tcp->sockfd, data, length); -} - -int tcp_write(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_write(tcp->sockfd, data, length); -} - -int tcp_wait_read(rdpTcp* tcp) -{ - return freerdp_tcp_wait_read(tcp->sockfd); -} - -int tcp_wait_write(rdpTcp* tcp) -{ - return freerdp_tcp_wait_write(tcp->sockfd); -} BOOL tcp_disconnect(rdpTcp* tcp) { @@ -218,7 +386,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking) if (flags == -1) { - fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n"); + fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno)); return FALSE; } @@ -306,6 +474,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd) { tcp->sockfd = sockfd; SetEventFileDescriptor(tcp->event, tcp->sockfd); + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer)); + + if (tcp->socketBio) + { + if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0) + return -1; + } + else + { + tcp->socketBio = BIO_new_socket(sockfd, 1); + if (!tcp->socketBio) + return -1; + } + + if (!tcp->bufferedBio) + { + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); + } + return 0; } @@ -325,25 +518,34 @@ rdpTcp* tcp_new(rdpSettings* settings) { rdpTcp* tcp; - tcp = (rdpTcp*) malloc(sizeof(rdpTcp)); + tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp)); + if (!tcp) + return NULL; - if (tcp) - { - ZeroMemory(tcp, sizeof(rdpTcp)); + if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000)) + goto out_free; - tcp->sockfd = -1; - tcp->settings = settings; - tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); - } + tcp->sockfd = -1; + tcp->settings = settings; + + tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); + if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE) + goto out_ringbuffer; return tcp; +out_ringbuffer: + ringbuffer_destroy(&tcp->xmitBuffer); +out_free: + free(tcp); + return NULL; } void tcp_free(rdpTcp* tcp) { - if (tcp) - { - CloseHandle(tcp->event); - free(tcp); - } + if (!tcp) + return; + + ringbuffer_destroy(&tcp->xmitBuffer); + CloseHandle(tcp->event); + free(tcp); } diff --git a/libfreerdp/core/tcp.h b/libfreerdp/core/tcp.h index b43fbaf1c..a8b3153b9 100644 --- a/libfreerdp/core/tcp.h +++ b/libfreerdp/core/tcp.h @@ -31,10 +31,15 @@ #include #include +#include +#include + #ifndef MSG_NOSIGNAL #define MSG_NOSIGNAL 0 #endif +#define BIO_TYPE_BUFFERED 66 + typedef struct rdp_tcp rdpTcp; struct rdp_tcp @@ -46,6 +51,12 @@ struct rdp_tcp #ifdef _WIN32 WSAEVENT wsa_event; #endif + BIO *socketBio; + BIO *bufferedBio; + RingBuffer xmitBuffer; + BOOL writeBlocked; + BOOL readBlocked; + HANDLE event; }; diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index 4cd7995ea..e4b4fb6ee 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -33,7 +33,9 @@ #include #include +#include +#include #include #include #include @@ -41,6 +43,12 @@ #ifndef _WIN32 #include #include +#include +#include +#endif + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include #endif #include "tpkt.h" @@ -70,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd) tcp_attach(transport->TcpIn, sockfd); transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } void transport_stop(rdpTransport* transport) @@ -99,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport) transport_stop(transport); - if (transport->layer == TRANSPORT_LAYER_TLS) - status &= tls_disconnect(transport->TlsIn); - - if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS)) - { - status &= tsg_disconnect(transport->tsg); - } - else - { - status &= tcp_disconnect(transport->TcpIn); - } + BIO_free_all(transport->frontBio); + transport->frontBio = 0; return status; } @@ -132,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num) rdpTsg* tsg; tsg = (rdpTsg*) bio->ptr; - status = tsg_write(tsg, (BYTE*) buf, num); BIO_clear_retry_flags(bio); + status = tsg_write(tsg, (BYTE*) buf, num); + if (status > 0) + return status; if (status == 0) - { BIO_set_retry_write(bio); - } - return status < 0 ? 0 : num; + return -1; } static int transport_bio_tsg_read(BIO* bio, char* buf, int size) @@ -223,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void) return &transport_bio_tsg_methods; } + + BOOL transport_connect_tls(rdpTransport* transport) { + rdpSettings *settings = transport->settings; + rdpTls *targetTls; + BIO *targetBio; int tls_status; freerdp* instance; rdpContext* context; @@ -235,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport) if (transport->layer == TRANSPORT_LAYER_TSG) { transport->TsgTls = tls_new(transport->settings); - - transport->TsgTls->methods = BIO_s_tsg(); - transport->TsgTls->tsg = (void*) transport->tsg; - transport->layer = TRANSPORT_LAYER_TSG_TLS; - transport->TsgTls->hostname = transport->settings->ServerHostname; - transport->TsgTls->port = transport->settings->ServerPort; + targetTls = transport->TsgTls; + targetBio = transport->frontBio; + } + else + { + if (!transport->TlsIn) + transport->TlsIn = tls_new(settings); - if (transport->TsgTls->port == 0) - transport->TsgTls->port = 3389; + if (!transport->TlsOut) + transport->TlsOut = transport->TlsIn; - tls_status = tls_connect(transport->TsgTls); + targetTls = transport->TlsIn; + targetBio = transport->TcpIn->bufferedBio; - if (tls_status < 1) - { - if (tls_status < 0) - { - if (!connectErrorCode) - connectErrorCode = TLSCONNECTERROR; - - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED); - } - else - { - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); - } - - tls_free(transport->TsgTls); - transport->TsgTls = NULL; - - return FALSE; - } - - return TRUE; + transport->layer = TRANSPORT_LAYER_TLS; } - if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - if (!transport->TlsOut) - transport->TlsOut = transport->TlsIn; + targetTls->hostname = settings->ServerHostname; + targetTls->port = settings->ServerPort; - transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; + if (targetTls->port == 0) + targetTls->port = 3389; - transport->TlsIn->hostname = transport->settings->ServerHostname; - transport->TlsIn->port = transport->settings->ServerPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 3389; - - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(targetTls, targetBio); if (tls_status < 1) { @@ -307,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport) freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); } - tls_free(transport->TlsIn); - - if (transport->TlsIn == transport->TlsOut) - transport->TlsIn = transport->TlsOut = NULL; - else - transport->TlsIn = NULL; + return FALSE; + } + transport->frontBio = targetTls->bio; + if (!transport->frontBio) + { + fprintf(stderr, "%s: unable to prepend a filtering TLS bio"); return FALSE; } @@ -324,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport) { freerdp* instance; rdpSettings* settings; + rdpCredssp *credSsp; settings = transport->settings; instance = (freerdp*) settings->instance; @@ -339,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport) if (!transport->credssp) { transport->credssp = credssp_new(instance, transport, settings); + if (!transport->credssp) + return FALSE; + transport_set_nla_mode(transport, TRUE); if (settings->AuthenticationServiceClass) { transport->credssp->ServicePrincipalName = credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname); + if (!transport->credssp->ServicePrincipalName) + return FALSE; } } - if (credssp_authenticate(transport->credssp) < 0) + credSsp = transport->credssp; + if (credssp_authenticate(credSsp) < 0) { if (!connectErrorCode) connectErrorCode = AUTHENTICATIONERROR; @@ -362,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport) "If credentials are valid, the NTLMSSP implementation may be to blame.\n"); transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return FALSE; } transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return TRUE; @@ -381,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 int tls_status; freerdp* instance; rdpContext* context; + rdpSettings *settings = transport->settings; instance = (freerdp*) transport->settings->instance; context = instance->context; tsg = tsg_new(transport); + if (!tsg) + return FALSE; tsg->transport = transport; transport->tsg = tsg; transport->SplitInputOutput = TRUE; if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - transport->TlsIn->hostname = transport->settings->GatewayHostname; - transport->TlsIn->port = transport->settings->GatewayPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 443; - + { + transport->TlsIn = tls_new(settings); + if (!transport->TlsIn) + return FALSE; + } if (!transport->TlsOut) - transport->TlsOut = tls_new(transport->settings); + { + transport->TlsOut = tls_new(settings); + if (!transport->TlsOut) + return FALSE; + } - transport->TlsOut->sockfd = transport->TcpOut->sockfd; - transport->TlsOut->hostname = transport->settings->GatewayHostname; - transport->TlsOut->port = transport->settings->GatewayPort; + /* put a decent default value for gateway port */ + if (!settings->GatewayPort) + settings->GatewayPort = 443; - if (transport->TlsOut->port == 0) - transport->TlsOut->port = 443; + transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname; + transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort; - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -429,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 return FALSE; } - tls_status = tls_connect(transport->TlsOut); - + tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -450,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 if (!tsg_connect(tsg, hostname, port)) return FALSE; + transport->frontBio = BIO_new(BIO_s_tsg()); + transport->frontBio->ptr = tsg; return TRUE; } @@ -472,18 +460,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por if (transport->GatewayEnabled) { transport->layer = TRANSPORT_LAYER_TSG; + transport->SplitInputOutput = TRUE; transport->TcpOut = tcp_new(settings); - status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort); + if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpIn, FALSE)) + return FALSE; - if (status) - { - /* Connect second channel */ - status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort); - } + if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpOut, FALSE)) + return FALSE; - if (status) - status = transport_tsg_connect(transport, hostname, port); + if (!transport_tsg_connect(transport, hostname, port)) + return FALSE; + status = TRUE; } else { @@ -491,6 +481,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } if (status) @@ -523,11 +514,11 @@ BOOL transport_accept_tls(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; return TRUE; } @@ -546,10 +537,10 @@ BOOL transport_accept_nla(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; /* Network Level Authentication */ @@ -643,56 +634,131 @@ UINT32 nla_header_length(wStream* s) return length; } +static int transport_wait_for_read(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpIn; + + tcpIn = transport->TcpIn; + if (tcpIn->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpIn->sockfd, rsetPtr); + } + else if (tcpIn->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpIn->sockfd, wsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + +static int transport_wait_for_write(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpOut; + + tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn; + if (tcpOut->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpOut->sockfd, wsetPtr); + } + else if (tcpOut->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpOut->sockfd, rsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes) { int read = 0; int status = -1; + while (read < bytes) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_read(transport->TlsIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_read(transport->TcpIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_read(transport->tsg, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) { - status = tls_read(transport->TsgTls, data + read, bytes - read); + status = BIO_read(transport->frontBio, data + read, bytes - read); + + if (!status) + { + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; } - /* blocking means that we can't continue until this is read */ - - if (!transport->blocking) - return status; - if (status < 0) { - /* A read error indicates that the peer has dropped the connection */ - transport->layer = TRANSPORT_LAYER_CLOSED; - return status; + if (!BIO_should_retry(transport->frontBio)) + { + /* something unexpected happened, let's close */ + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; + } + + /* non blocking will survive a partial read */ + if (!transport->blocking) + return read; + + /* blocking means that we can't continue until we have read the number of + * requested bytes */ + if (transport_wait_for_read(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__); + return -1; + } + continue; } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read); +#endif + read += status; - - if (status == 0) - { - /* - * instead of sleeping, we should wait timeout on the - * socket but this only happens on initial connection - */ - USleep(transport->SleepInterval); - } } return read; } + + int transport_read(rdpTransport* transport, wStream* s) { int status; int position; int pduLength; - BYTE header[4]; + BYTE *header; int transport_status; position = 0; @@ -723,7 +789,7 @@ int transport_read(rdpTransport* transport, wStream* s) position += status; } - CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */ + header = Stream_Buffer(s); /* if header is present, read exactly one PDU */ @@ -815,6 +881,8 @@ static int transport_read_nonblocking(rdpTransport* transport) return status; } +BOOL transport_bio_buffered_drain(BIO *bio); + int transport_write(rdpTransport* transport, wStream* s) { int length; @@ -840,36 +908,48 @@ int transport_write(rdpTransport* transport, wStream* s) while (length > 0) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_write(transport->TlsOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_write(transport->TcpOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_write(transport->tsg, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - status = tls_write(transport->TsgTls, Stream_Pointer(s), length); + status = BIO_write(transport->frontBio, Stream_Pointer(s), length); - if (status < 0) - break; /* error occurred */ - - if (status == 0) + if (status <= 0) { - /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */ - if (!transport->blocking) - { - /* and in case we do have buffered some data, we set the event so next loop will get it */ - if (transport_read_nonblocking(transport) > 0) - SetEvent(transport->ReceiveEvent); - } + /* the buffered BIO that is at the end of the chain always says OK for writing, + * so a retry means that for any reason we need to read. The most probable + * is a SSL or TSG BIO in the chain. + */ + if (!BIO_should_retry(transport->frontBio)) + return status; - if (transport->layer == TRANSPORT_LAYER_TLS) - tls_wait_write(transport->TlsOut); - else if (transport->layer == TRANSPORT_LAYER_TCP) - tcp_wait_write(transport->TcpOut); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - tls_wait_write(transport->TsgTls); - else - USleep(transport->SleepInterval); + /* non-blocking can live with blocked IOs */ + if (!transport->blocking) + return status; + + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + continue; + } + + if (transport->blocking || transport->settings->WaitForOutputBufferFlush) + { + /* blocking transport, we must ensure the write buffer is really empty */ + rdpTcp *out = transport->TcpOut; + + while (out->writeBlocked) + { + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + + if (!transport_bio_buffered_drain(out->bufferedBio)) + { + fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__); + return -1; + } + } } length -= status; @@ -958,6 +1038,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* } } +BOOL tranport_is_write_blocked(rdpTransport* transport) +{ + if (transport->TcpIn->writeBlocked) + return TRUE; + + return transport->SplitInputOutput && + transport->TcpOut && + transport->TcpOut->writeBlocked; +} + +int tranport_drain_output_buffer(rdpTransport* transport) +{ + BOOL ret = FALSE; + + /* First try to send some accumulated bytes in the send buffer */ + if (transport->TcpIn->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio)) + return -1; + ret |= transport->TcpIn->writeBlocked; + } + + if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio)) + return -1; + ret |= transport->TcpOut->writeBlocked; + } + + return ret; +} + int transport_check_fds(rdpTransport* transport) { int pos; @@ -1092,15 +1204,14 @@ int transport_check_fds(rdpTransport* transport) recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra); - Stream_Release(received); - - if (recv_status < 0) - return -1; - if (recv_status == 1) { return 1; /* session redirection */ } + Stream_Release(received); + + if (recv_status < 0) + return -1; } return 0; @@ -1211,80 +1322,107 @@ rdpTransport* transport_new(rdpSettings* settings) { rdpTransport* transport; - transport = (rdpTransport*) malloc(sizeof(rdpTransport)); + transport = (rdpTransport *)calloc(1, sizeof(rdpTransport)); + if (!transport) + return NULL; - if (transport) - { - ZeroMemory(transport, sizeof(rdpTransport)); + WLog_Init(); + transport->log = WLog_Get("com.freerdp.core.transport"); + if (!transport->log) + goto out_free; - WLog_Init(); - transport->log = WLog_Get("com.freerdp.core.transport"); + transport->TcpIn = tcp_new(settings); + if (!transport->TcpIn) + goto out_free; - transport->TcpIn = tcp_new(settings); + transport->settings = settings; - transport->settings = settings; + /* a small 0.1ms delay when transport is blocking. */ + transport->SleepInterval = 100; - /* a small 0.1ms delay when transport is blocking. */ - transport->SleepInterval = 100; + transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + if (!transport->ReceivePool) + goto out_free_tcpin; - transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + /* receive buffer for non-blocking read. */ + transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); + if (!transport->ReceiveBuffer) + goto out_free_receivepool; - /* receive buffer for non-blocking read. */ - transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); - transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE) + goto out_free_receivebuffer; - transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE) + goto out_free_receiveEvent; - transport->blocking = TRUE; - transport->GatewayEnabled = FALSE; + transport->blocking = TRUE; + transport->GatewayEnabled = FALSE; + transport->layer = TRANSPORT_LAYER_TCP; - InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000); - InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000); - - transport->layer = TRANSPORT_LAYER_TCP; - } + if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000)) + goto out_free_connectedEvent; + if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000)) + goto out_free_readlock; return transport; + +out_free_readlock: + DeleteCriticalSection(&(transport->ReadLock)); +out_free_connectedEvent: + CloseHandle(transport->connectedEvent); +out_free_receiveEvent: + CloseHandle(transport->ReceiveEvent); +out_free_receivebuffer: + StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer); +out_free_receivepool: + StreamPool_Free(transport->ReceivePool); +out_free_tcpin: + tcp_free(transport->TcpIn); +out_free: + free(transport); + return NULL; } void transport_free(rdpTransport* transport) { - if (transport) - { - transport_stop(transport); + if (!transport) + return; - if (transport->ReceiveBuffer) - Stream_Release(transport->ReceiveBuffer); + transport_stop(transport); - StreamPool_Free(transport->ReceivePool); + if (transport->ReceiveBuffer) + Stream_Release(transport->ReceiveBuffer); - CloseHandle(transport->ReceiveEvent); - CloseHandle(transport->connectedEvent); + StreamPool_Free(transport->ReceivePool); - if (transport->TlsIn) - tls_free(transport->TlsIn); + CloseHandle(transport->ReceiveEvent); + CloseHandle(transport->connectedEvent); - if (transport->TlsOut != transport->TlsIn) - tls_free(transport->TlsOut); + if (transport->TlsIn) + tls_free(transport->TlsIn); - transport->TlsIn = NULL; - transport->TlsOut = NULL; + if (transport->TlsOut != transport->TlsIn) + tls_free(transport->TlsOut); - if (transport->TcpIn) - tcp_free(transport->TcpIn); + transport->TlsIn = NULL; + transport->TlsOut = NULL; - if (transport->TcpOut != transport->TcpIn) - tcp_free(transport->TcpOut); + if (transport->TcpIn) + tcp_free(transport->TcpIn); - transport->TcpIn = NULL; - transport->TcpOut = NULL; + if (transport->TcpOut != transport->TcpIn) + tcp_free(transport->TcpOut); - tsg_free(transport->tsg); - transport->tsg = NULL; + transport->TcpIn = NULL; + transport->TcpOut = NULL; - DeleteCriticalSection(&(transport->ReadLock)); - DeleteCriticalSection(&(transport->WriteLock)); + tsg_free(transport->tsg); + transport->tsg = NULL; - free(transport); - } + DeleteCriticalSection(&(transport->ReadLock)); + DeleteCriticalSection(&(transport->WriteLock)); + + free(transport); } diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index b8834ce7a..829807405 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport; #include #include + typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra); struct rdp_transport { TRANSPORT_LAYER layer; + BIO *frontBio; rdpTsg* tsg; rdpTcp* TcpIn; rdpTcp* TcpOut; @@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking); void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); +BOOL tranport_is_write_blocked(rdpTransport* transport); +BOOL tranport_drain_output_buffer(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport); int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); diff --git a/libfreerdp/crypto/tls.c b/libfreerdp/crypto/tls.c index 52c217782..016584fcc 100644 --- a/libfreerdp/crypto/tls.c +++ b/libfreerdp/crypto/tls.c @@ -28,34 +28,35 @@ #include #include +#include #include - -#ifdef HAVE_VALGRIND_MEMCHECK_H -#include -#endif +#include "../core/tcp.h" static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer) { CryptoCert cert; - X509* server_cert; + X509* remote_cert; if (peer) - server_cert = SSL_get_peer_certificate(tls->ssl); + remote_cert = SSL_get_peer_certificate(tls->ssl); else - server_cert = SSL_get_certificate(tls->ssl); + remote_cert = SSL_get_certificate(tls->ssl); - if (!server_cert) + if (!remote_cert) { - fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n"); - cert = NULL; - } - else - { - cert = malloc(sizeof(*cert)); - cert->px509 = server_cert; + fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__); + return NULL; } + cert = malloc(sizeof(*cert)); + if (!cert) + { + X509_free(remote_cert); + return NULL; + } + + cert->px509 = remote_cert; return cert; } @@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) PrefixLength = strlen(TLS_SERVER_END_POINT); ChannelBindingTokenLength = PrefixLength + CertificateHashLength; - ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings)); - ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings)); + ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings)); + if (!ContextBindings) + return NULL; ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength; - ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength); - ZeroMemory(ChannelBindings, ContextBindings->BindingsLength); + ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength); + if (!ChannelBindings) + goto out_free; ContextBindings->Bindings = ChannelBindings; ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength; @@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength); return ContextBindings; + +out_free: + free(ContextBindings); + return NULL; } -static void tls_ssl_info_callback(const SSL* ssl, int type, int val) + +BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode) { - if (type & SSL_CB_HANDSHAKE_START) - { - - } -} - -int tls_connect(rdpTls* tls) -{ - CryptoCert cert; - long options = 0; - int verify_status; - int connection_status; - - tls->ctx = SSL_CTX_new(TLSv1_client_method()); - + tls->ctx = SSL_CTX_new(method); if (!tls->ctx) { - fprintf(stderr, "SSL_CTX_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__); + return FALSE; + } + + SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + + SSL_CTX_set_options(tls->ctx, options); + SSL_CTX_set_read_ahead(tls->ctx, 1); + + tls->bio = BIO_new_ssl(tls->ctx, clientMode); + if (BIO_get_ssl(tls->bio, &tls->ssl) < 0) + { + fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__); + return FALSE; + } + + BIO_push(tls->bio, underlying); + return TRUE; +} + +int tls_do_handshake(rdpTls* tls, BOOL clientMode) +{ + CryptoCert cert; + int verify_status, status; + + do + { + struct timeval tv; + fd_set rset; + int fd; + + status = BIO_do_handshake(tls->bio); + if (status == 1) + break; + if (!BIO_should_retry(tls->bio)) + return -1; + + /* we select() only for read even if we should test both read and write + * depending of what have blocked */ + FD_ZERO(&rset); + + fd = BIO_get_fd(tls->bio, NULL); + if (fd < 0) + { + fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__); + return -1; + } + + FD_SET(fd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 10 * 1000; /* 10ms */ + + status = select(fd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + { + fprintf(stderr, "%s: error during select()\n", __FUNCTION__); + return -1; + } + } + while (TRUE); + + if (!clientMode) + return 1; + + cert = tls_get_certificate(tls, clientMode); + if (!cert) + { + fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__); return -1; } - //SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + tls->Bindings = tls_get_channel_bindings(cert->px509); + if (!tls->Bindings) + { + fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__); + return -1; + } + + if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) + { + fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__); + tls_free_certificate(cert); + return -1; + } + + verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); + + if (verify_status < 1) + { + fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__); + tls_disconnect(tls); + tls_free_certificate(cert); + return 0; + } + + tls_free_certificate(cert); + + return verify_status; +} + +int tls_connect(rdpTls* tls, BIO *underlying) +{ + int options = 0; /** * SSL_OP_NO_COMPRESSION: @@ -138,7 +230,7 @@ int tls_connect(rdpTls* tls) #ifdef SSL_OP_NO_COMPRESSION options |= SSL_OP_NO_COMPRESSION; #endif - + /** * SSL_OP_TLS_BLOCK_PADDING_BUG: * @@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); + if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE)) + return FALSE; - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) - { - fprintf(stderr, "SSL_new failed\n"); - return -1; - } - - if (tls->tsg) - { - tls->bio = BIO_new(tls->methods); - - if (!tls->bio) - { - fprintf(stderr, "BIO_new failed\n"); - return -1; - } - - tls->bio->ptr = tls->tsg; - - SSL_set_bio(tls->ssl, tls->bio, tls->bio); - - SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback); - } - else - { - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return -1; - } - } - - connection_status = SSL_connect(tls->ssl); - - if (connection_status <= 0) - { - if (tls_print_error("SSL_connect", tls->ssl, connection_status)) - { - return -1; - } - } - - cert = tls_get_certificate(tls, TRUE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return -1; - } - - tls->Bindings = tls_get_channel_bindings(cert->px509); - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return -1; - } - - verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); - - if (verify_status < 1) - { - fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n"); - tls_disconnect(tls); - } - - tls_free_certificate(cert); - - return verify_status; + return tls_do_handshake(tls, TRUE); } -BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) + + +BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file) { - CryptoCert cert; long options = 0; - int connection_status; - tls->ctx = SSL_CTX_new(SSLv23_server_method()); - - if (tls->ctx == NULL) - { - fprintf(stderr, "SSL_CTX_new failed\n"); - return FALSE; - } - - /* + /** * SSL_OP_NO_SSLv2: * * We only want SSLv3 and TLSv1, so disable SSLv2. @@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); - - if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0) - { - fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n"); - fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); + if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE)) return FALSE; - } - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) + if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__); + fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); return FALSE; } if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_use_certificate_file failed\n"); + fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__); return FALSE; } - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return FALSE; - } - - while (1) - { - connection_status = SSL_accept(tls->ssl); - - if (connection_status <= 0) - { - switch (SSL_get_error(tls->ssl, connection_status)) - { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - break; - - default: - if (tls_print_error("SSL_accept", tls->ssl, connection_status)) - return FALSE; - break; - - } - } - else - { - break; - } - } - - cert = tls_get_certificate(tls, FALSE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return FALSE; - } - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return FALSE; - } - - free(cert); - - fprintf(stderr, "TLS connection accepted\n"); - - return TRUE; + return tls_do_handshake(tls, FALSE) > 0; } BOOL tls_disconnect(rdpTls* tls) @@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls) if (!tls) return FALSE; - if (tls->ssl) + if (!tls->ssl) + return TRUE; + + if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) { - if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) - { - /** - * OpenSSL doesn't really expose an API for sending a TLS alert manually. - * - * The following code disables the sending of the default "close notify" - * and then proceeds to force sending a custom TLS alert before shutting down. - * - * Manually sending a TLS alert is necessary in certain cases, - * like when server-side NLA results in an authentication failure. - */ + /** + * OpenSSL doesn't really expose an API for sending a TLS alert manually. + * + * The following code disables the sending of the default "close notify" + * and then proceeds to force sending a custom TLS alert before shutting down. + * + * Manually sending a TLS alert is necessary in certain cases, + * like when server-side NLA results in an authentication failure. + */ - SSL_set_quiet_shutdown(tls->ssl, 1); + SSL_set_quiet_shutdown(tls->ssl, 1); - if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) - SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); + if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) + SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); - tls->ssl->s3->alert_dispatch = 1; - tls->ssl->s3->send_alert[0] = tls->alertLevel; - tls->ssl->s3->send_alert[1] = tls->alertDescription; + tls->ssl->s3->alert_dispatch = 1; + tls->ssl->s3->send_alert[0] = tls->alertLevel; + tls->ssl->s3->send_alert[1] = tls->alertDescription; - if (tls->ssl->s3->wbuf.left == 0) - tls->ssl->method->ssl_dispatch_alert(tls->ssl); + if (tls->ssl->s3->wbuf.left == 0) + tls->ssl->method->ssl_dispatch_alert(tls->ssl); - SSL_shutdown(tls->ssl); - } - else - { - SSL_shutdown(tls->ssl); - } + SSL_shutdown(tls->ssl); + } + else + { + SSL_shutdown(tls->ssl); } return TRUE; } -int tls_read(rdpTls* tls, BYTE* data, int length) + +BIO *findBufferedBio(BIO *front) { - int error; - int status; + BIO *ret = front; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_read(tls->ssl, data, length); - - if (status == 0) + while (ret) { - return -1; /* peer disconnected */ + if (BIO_method_type(ret) == BIO_TYPE_BUFFERED) + return ret; + ret = ret->next_bio; } - if (status <= 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n", - // length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - if (WSAGetLastError() == WSAEWOULDBLOCK) -#else - if ((errno == EAGAIN) || (errno == 0)) -#endif - { - status = 0; - } - else - { - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - } - break; - - default: - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - break; - } - } - -#ifdef HAVE_VALGRIND_MEMCHECK_H - VALGRIND_MAKE_MEM_DEFINED(data, status); -#endif - - return status; + return ret; } -int tls_write(rdpTls* tls, BYTE* data, int length) +int tls_write_all(rdpTls* tls, const BYTE* data, int length) { - int error; - int status; + int status, nchunks, commitedBytes; + rdpTcp *tcp; + fd_set rset, wset; + fd_set *rsetPtr, *wsetPtr; + struct timeval tv; + BIO *bio = tls->bio; + DataChunk chunks[2]; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_write(tls->ssl, data, length); - - if (status == 0) + BIO *bufferedBio = findBufferedBio(bio); + if (!bufferedBio) { - return -1; /* peer disconnected */ + fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__); + return -1; } - if (status < 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: - if (errno == EAGAIN) - { - status = 0; - } - else - { - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - } - break; - - default: - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - break; - } - } - - return status; -} - -int tls_write_all(rdpTls* tls, BYTE* data, int length) -{ - int status; - int sent = 0; + tcp = (rdpTcp *)bufferedBio->ptr; do { - status = tls_write(tls, &data[sent], length - sent); - + status = BIO_write(bio, data, length); + /*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/ if (status > 0) - sent += status; - else if (status == 0) - tls_wait_write(tls); - - if (sent >= length) break; + + if (!BIO_should_retry(bio)) + return -1; + + /* we try to handle SSL want_read and want_write nicely */ + rsetPtr = wsetPtr = 0; + if (tcp->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(&wset); + FD_SET(tcp->sockfd, &wset); + } + else if (tcp->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + } + else + { + fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__); + USleep(10); + continue; + } + + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; + + status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); + if (status < 0) + return -1; } - while (status >= 0); + while (TRUE); - if (status > 0) - return length; - else - return status; -} - -int tls_wait_read(rdpTls* tls) -{ - return freerdp_tcp_wait_read(tls->sockfd); -} - -int tls_wait_write(rdpTls* tls) -{ - return freerdp_tcp_wait_write(tls->sockfd); -} - -static void tls_errors(const char *prefix) -{ - unsigned long error; - - while ((error = ERR_get_error()) != 0) - fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL)); -} - -BOOL tls_print_error(char* func, SSL* connection, int value) -{ - switch (SSL_get_error(connection, value)) + /* make sure the output buffer is empty */ + commitedBytes = 0; + while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)))) { - case SSL_ERROR_ZERO_RETURN: - fprintf(stderr, "%s: Server closed TLS connection\n", func); - return TRUE; + int i; - case SSL_ERROR_WANT_READ: - fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func); - return FALSE; + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size); + if (status > 0) + { + chunks[i].size -= status; + chunks[i].data += status; + commitedBytes += status; + continue; + } - case SSL_ERROR_WANT_WRITE: - fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func); - return FALSE; + if (!BIO_should_retry(tcp->socketBio)) + goto out_fail; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError()); -#else - fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno); -#endif - tls_errors(func); - return TRUE; + status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + goto out_fail; + } - case SSL_ERROR_SSL: - fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func); - tls_errors(func); - return TRUE; - - default: - fprintf(stderr, "%s: Unknown error\n", func); - tls_errors(func); - return TRUE; + } } + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return length; + +out_fail: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return -1; } + + int tls_set_alert_code(rdpTls* tls, int level, int description) { tls->alertLevel = level; @@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (!bio) { - fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n"); + fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__); return -1; } @@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status); + fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status); return -1; } @@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0); } - fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n", - length, status, pemCert); + fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert); free(pemCert); BIO_free(bio); @@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings) { rdpTls* tls; - tls = (rdpTls*) malloc(sizeof(rdpTls)); + tls = (rdpTls *)calloc(1, sizeof(rdpTls)); + if (!tls) + return NULL; - if (tls) - { - ZeroMemory(tls, sizeof(rdpTls)); + SSL_load_error_strings(); + SSL_library_init(); - SSL_load_error_strings(); - SSL_library_init(); - - tls->settings = settings; - tls->certificate_store = certificate_store_new(settings); - - tls->alertLevel = TLS_ALERT_LEVEL_WARNING; - tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; - } + tls->settings = settings; + tls->certificate_store = certificate_store_new(settings); + if (!tls->certificate_store) + goto out_free; + tls->alertLevel = TLS_ALERT_LEVEL_WARNING; + tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; return tls; + +out_free: + free(tls); + return NULL; } void tls_free(rdpTls* tls) { - if (tls) + if (!tls) + return; + + if (tls->ctx) { - if (tls->ssl) - { - SSL_free(tls->ssl); - tls->ssl = NULL; - } - - if (tls->ctx) - { - SSL_CTX_free(tls->ctx); - tls->ctx = NULL; - } - - if (tls->PublicKey) - { - free(tls->PublicKey); - tls->PublicKey = NULL; - } - - if (tls->Bindings) - { - free(tls->Bindings->Bindings); - free(tls->Bindings); - tls->Bindings = NULL; - } - - certificate_store_free(tls->certificate_store); - tls->certificate_store = NULL; - - free(tls); + SSL_CTX_free(tls->ctx); + tls->ctx = NULL; } + + if (tls->PublicKey) + { + free(tls->PublicKey); + tls->PublicKey = NULL; + } + + if (tls->Bindings) + { + free(tls->Bindings->Bindings); + free(tls->Bindings); + tls->Bindings = NULL; + } + + certificate_store_free(tls->certificate_store); + tls->certificate_store = NULL; + + free(tls); }