libfreerdp-core: transport refactoring

This commit is contained in:
Marc-André Moreau
2013-11-07 17:37:58 -05:00
parent 61f95fbe16
commit 5536033a8a
6 changed files with 90 additions and 58 deletions

21
docs/valgrind.supp Normal file
View File

@@ -0,0 +1,21 @@
{
glibc_getaddrinfo
Memcheck:Param
sendmsg(mmsg[0].msg_hdr)
fun:sendmmsg
fun:__libc_res_nsend
fun:__libc_res_nquery
fun:__libc_res_nsearch
fun:_nss_dns_gethostbyname4_r
fun:gaih_inet
fun:getaddrinfo
fun:freerdp_tcp_connect
fun:tcp_connect
fun:transport_connect
fun:nego_tcp_connect
fun:nego_transport_connect
}

View File

@@ -491,6 +491,7 @@ BOOL nego_recv_response(rdpNego* nego)
status = nego_recv(nego->transport, s, nego);
Stream_Free(s, TRUE);
if (status < 0)
return FALSE;

View File

@@ -57,13 +57,14 @@
#include "tcp.h"
void tcp_get_ip_address(rdpTcp * tcp)
void tcp_get_ip_address(rdpTcp* tcp)
{
BYTE* ip;
socklen_t length;
struct sockaddr_in sockaddr;
length = sizeof(sockaddr);
ZeroMemory(&sockaddr, length);
if (getsockname(tcp->sockfd, (struct sockaddr*) &sockaddr, &length) == 0)
{
@@ -73,19 +74,12 @@ void tcp_get_ip_address(rdpTcp * tcp)
}
else
{
strncpy(tcp->ip_address, "127.0.0.1", sizeof(tcp->ip_address));
strcpy(tcp->ip_address, "127.0.0.1");
}
tcp->ip_address[sizeof(tcp->ip_address) - 1] = 0;
tcp->settings->IPv6Enabled = 0;
if (tcp->settings->ClientAddress)
{
free(tcp->settings->ClientAddress);
tcp->settings->ClientAddress = NULL;
}
free(tcp->settings->ClientAddress);
tcp->settings->ClientAddress = _strdup(tcp->ip_address);
}
@@ -122,7 +116,7 @@ void tcp_get_mac_address(rdpTcp* tcp)
mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]); */
}
BOOL tcp_connect(rdpTcp* tcp, const char* hostname, UINT16 port)
BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port)
{
UINT32 option_value;
socklen_t option_len;

View File

@@ -41,14 +41,14 @@ struct rdp_tcp
int sockfd;
char ip_address[32];
BYTE mac_address[6];
struct rdp_settings* settings;
rdpSettings* settings;
#ifdef _WIN32
WSAEVENT wsa_event;
#endif
HANDLE event;
};
BOOL tcp_connect(rdpTcp* tcp, const char* hostname, UINT16 port);
BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port);
BOOL tcp_disconnect(rdpTcp* tcp);
int tcp_read(rdpTcp* tcp, BYTE* data, int length);
int tcp_write(rdpTcp* tcp, BYTE* data, int length);

View File

@@ -219,7 +219,7 @@ BOOL transport_connect_tls(rdpTransport* transport)
transport->layer = TRANSPORT_LAYER_TSG_TLS;
if (tls_connect(transport->TsgTls) != TRUE)
if (!tls_connect(transport->TsgTls))
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
@@ -233,16 +233,16 @@ BOOL transport_connect_tls(rdpTransport* transport)
return TRUE;
}
if (transport->TlsIn == NULL)
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
if (transport->TlsOut == NULL)
if (!transport->TlsOut)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_connect(transport->TlsIn) != TRUE)
if (!tls_connect(transport->TlsIn))
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
@@ -265,17 +265,17 @@ BOOL transport_connect_nla(rdpTransport* transport)
freerdp* instance;
rdpSettings* settings;
settings = transport->settings;
instance = (freerdp*) settings->instance;
if (!transport_connect_tls(transport))
return FALSE;
/* Network Level Authentication */
if (transport->settings->Authentication != TRUE)
if (!settings->Authentication)
return TRUE;
settings = transport->settings;
instance = (freerdp*) settings->instance;
if (!transport->credssp)
transport->credssp = credssp_new(instance, transport, settings);
@@ -293,6 +293,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
}
credssp_free(transport->credssp);
transport->credssp = NULL;
return TRUE;
}
@@ -315,10 +316,10 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
transport->TlsOut->sockfd = transport->TcpOut->sockfd;
if (tls_connect(transport->TlsIn) != TRUE)
if (!tls_connect(transport->TlsIn))
return FALSE;
if (tls_connect(transport->TlsOut) != TRUE)
if (!tls_connect(transport->TlsOut))
return FALSE;
if (!tsg_connect(tsg, hostname, port))
@@ -387,7 +388,7 @@ BOOL transport_accept_tls(rdpTransport* transport)
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE;
return TRUE;
@@ -398,6 +399,9 @@ BOOL transport_accept_nla(rdpTransport* transport)
freerdp* instance;
rdpSettings* settings;
settings = transport->settings;
instance = (freerdp*) settings->instance;
if (!transport->TlsIn)
transport->TlsIn = tls_new(transport->settings);
@@ -407,18 +411,15 @@ BOOL transport_accept_nla(rdpTransport* transport)
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE;
/* Network Level Authentication */
if (transport->settings->Authentication != TRUE)
if (!settings->Authentication)
return TRUE;
settings = transport->settings;
instance = (freerdp*) settings->instance;
if (transport->credssp == NULL)
if (!transport->credssp)
transport->credssp = credssp_new(instance, transport, settings);
if (credssp_authenticate(transport->credssp) < 0)
@@ -496,7 +497,7 @@ UINT32 nla_header_length(wStream* s)
return length;
}
int transport_read_layer(rdpTransport* transport, UINT8* data, int bytes)
int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
{
int read = 0;
int status = -1;
@@ -539,10 +540,12 @@ int transport_read_layer(rdpTransport* transport, UINT8* data, int bytes)
int transport_read(rdpTransport* transport, wStream* s)
{
int status;
int position;
int pduLength;
int streamPosition;
BYTE header[4];
int transport_status;
position = 0;
pduLength = 0;
transport_status = 0;
@@ -553,54 +556,57 @@ int transport_read(rdpTransport* transport, wStream* s)
return -1;
/* first check if we have header */
streamPosition = Stream_GetPosition(s);
position = Stream_GetPosition(s);
if (streamPosition < 4)
if (position < 4)
{
status = transport_read_layer(transport, Stream_Buffer(s) + streamPosition, 4 - streamPosition);
status = transport_read_layer(transport, Stream_Buffer(s) + position, 4 - position);
if (status < 0)
return status;
transport_status += status;
if ((status + streamPosition) < 4)
if ((status + position) < 4)
return transport_status;
streamPosition += status;
position += status;
}
Stream_Peek(s, header, 4); /* peek at first 4 bytes */
/* if header is present, read in exactly one PDU */
if (Stream_Buffer(s)[0] == 0x03)
if (header[0] == 0x03)
{
/* TPKT header */
pduLength = (Stream_Buffer(s)[2] << 8) | Stream_Buffer(s)[3];
pduLength = (header[2] << 8) | header[3];
}
else if (Stream_Buffer(s)[0] == 0x30)
else if (header[0] == 0x30)
{
/* TSRequest (NLA) */
if (Stream_Buffer(s)[1] & 0x80)
if (header[1] & 0x80)
{
if ((Stream_Buffer(s)[1] & ~(0x80)) == 1)
if ((header[1] & ~(0x80)) == 1)
{
pduLength = Stream_Buffer(s)[2];
pduLength = header[2];
pduLength += 3;
}
else if ((Stream_Buffer(s)[1] & ~(0x80)) == 2)
else if ((header[1] & ~(0x80)) == 2)
{
pduLength = (Stream_Buffer(s)[2] << 8) | Stream_Buffer(s)[3];
pduLength = (header[2] << 8) | header[3];
pduLength += 4;
}
else
{
fprintf(stderr, "Error reading TSRequest!\n");
return -1;
}
}
else
{
pduLength = Stream_Buffer(s)[1];
pduLength = header[1];
pduLength += 2;
}
}
@@ -608,13 +614,13 @@ int transport_read(rdpTransport* transport, wStream* s)
{
/* Fast-Path Header */
if (Stream_Buffer(s)[1] & 0x80)
pduLength = ((Stream_Buffer(s)[1] & 0x7F) << 8) | Stream_Buffer(s)[2];
if (header[1] & 0x80)
pduLength = ((header[1] & 0x7F) << 8) | header[2];
else
pduLength = Stream_Buffer(s)[1];
pduLength = header[1];
}
status = transport_read_layer(transport, Stream_Buffer(s) + streamPosition, pduLength - streamPosition);
status = transport_read_layer(transport, Stream_Buffer(s) + position, pduLength - position);
if (status < 0)
return status;
@@ -623,14 +629,14 @@ int transport_read(rdpTransport* transport, wStream* s)
#ifdef WITH_DEBUG_TRANSPORT
/* dump when whole PDU is read */
if (streamPosition + status >= pduLength)
if (position + status >= pduLength)
{
fprintf(stderr, "Local < Remote\n");
winpr_HexDump(Stream_Buffer(s), pduLength);
}
#endif
if (streamPosition + status >= pduLength)
if (position + status >= pduLength)
{
WLog_Packet(transport->log, WLOG_TRACE, Stream_Buffer(s), pduLength, WLOG_PACKET_INBOUND);
}
@@ -799,7 +805,7 @@ int transport_check_fds(rdpTransport* transport)
{
int pos;
int status;
UINT16 length;
int length;
int recv_status;
wStream* received;

View File

@@ -45,7 +45,11 @@ void StreamPool_ShiftUsed(wStreamPool* pool, int index, int count)
else if (count < 0)
{
if (pool->uSize - index + count > 0)
MoveMemory(&pool->uArray[index], &pool->uArray[index - count], (pool->uSize - index + count) * sizeof(wStream*));
{
MoveMemory(&pool->uArray[index], &pool->uArray[index - count],
(pool->uSize - index + count) * sizeof(wStream*));
}
pool->uSize += count;
}
}
@@ -103,7 +107,11 @@ void StreamPool_ShiftAvailable(wStreamPool* pool, int index, int count)
else if (count < 0)
{
if (pool->aSize - index + count > 0)
MoveMemory(&pool->aArray[index], &pool->aArray[index - count], (pool->aSize - index + count) * sizeof(wStream*));
{
MoveMemory(&pool->aArray[index], &pool->aArray[index - count],
(pool->aSize - index + count) * sizeof(wStream*));
}
pool->aSize += count;
}
}
@@ -117,7 +125,6 @@ wStream* StreamPool_Take(wStreamPool* pool, size_t size)
int index;
int foundIndex;
wStream* s = NULL;
BOOL found = FALSE;
if (pool->synchronized)
EnterCriticalSection(&pool->lock);
@@ -125,6 +132,8 @@ wStream* StreamPool_Take(wStreamPool* pool, size_t size)
if (size == 0)
size = pool->defaultSize;
foundIndex = -1;
for (index = 0; index < pool->aSize; index++)
{
s = pool->aArray[index];
@@ -132,12 +141,11 @@ wStream* StreamPool_Take(wStreamPool* pool, size_t size)
if (Stream_Capacity(s) >= size)
{
foundIndex = index;
found = TRUE;
break;
}
}
if (!found)
if (foundIndex < 0)
{
s = Stream_New(NULL, size);
}
@@ -330,10 +338,12 @@ wStreamPool* StreamPool_New(BOOL synchronized, size_t defaultSize)
pool->aSize = 0;
pool->aCapacity = 32;
pool->aArray = (wStream**) malloc(sizeof(wStream*) * pool->aCapacity);
ZeroMemory(pool->aArray, sizeof(wStream*) * pool->aCapacity);
pool->uSize = 0;
pool->uCapacity = 32;
pool->uArray = (wStream**) malloc(sizeof(wStream*) * pool->uCapacity);
ZeroMemory(pool->uArray, sizeof(wStream*) * pool->uCapacity);
}
return pool;