libfreerdp-core: transport refactoring to split in/out channels

This commit is contained in:
Marc-André Moreau
2012-11-14 20:46:51 -05:00
parent ac319e72ae
commit 98dcdcfb8f
9 changed files with 93 additions and 96 deletions

View File

@@ -35,8 +35,9 @@ typedef struct rdp_credssp rdpCredssp;
struct rdp_credssp
{
rdpTls* tls;
BOOL server;
rdpTls* TlsIn;
rdpTls* TlsOut;
int send_seq_num;
int recv_seq_num;
freerdp* instance;
@@ -57,7 +58,7 @@ struct rdp_credssp
FREERDP_API int credssp_authenticate(rdpCredssp* credssp);
FREERDP_API rdpCredssp* credssp_new(freerdp* instance, rdpTls* tls, rdpSettings* settings);
FREERDP_API rdpCredssp* credssp_new(freerdp* instance, rdpTls* TlsIn, rdpTls* TlsOut, rdpSettings* settings);
FREERDP_API void credssp_free(rdpCredssp* credssp);
#endif /* FREERDP_SSPI_CREDSSP_H */

View File

@@ -153,7 +153,7 @@ BOOL rdp_client_connect(rdpRdp* rdp)
return FALSE;
}
rdp->transport->process_single_pdu = TRUE;
rdp->transport->ProcessSinglePdu = TRUE;
while (rdp->state != CONNECTION_STATE_ACTIVE)
{
@@ -161,7 +161,7 @@ BOOL rdp_client_connect(rdpRdp* rdp)
return FALSE;
}
rdp->transport->process_single_pdu = FALSE;
rdp->transport->ProcessSinglePdu = FALSE;
return TRUE;
}

View File

@@ -298,7 +298,7 @@ void license_generate_hwid(rdpLicense* license)
BYTE* mac_address;
memset(license->hwid, 0, HWID_LENGTH);
mac_address = license->rdp->transport->tcp->mac_address;
mac_address = license->rdp->transport->TcpIn->mac_address;
md5 = crypto_md5_init();
crypto_md5_update(md5, mac_address, 6);

View File

@@ -152,7 +152,7 @@ BOOL nego_connect(rdpNego* nego)
nego->transport->settings->SelectedProtocol = nego->selected_protocol;
nego->transport->settings->NegotiationFlags = nego->flags;
if(nego->selected_protocol == PROTOCOL_RDP)
if (nego->selected_protocol == PROTOCOL_RDP)
{
nego->transport->settings->DisableEncryption = TRUE;
nego->transport->settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS;
@@ -160,7 +160,7 @@ BOOL nego_connect(rdpNego* nego)
}
/* finally connect security layer (if not already done) */
if(!nego_security_connect(nego))
if (!nego_security_connect(nego))
{
DEBUG_NEGO("Failed to connect with %s security", PROTOCOL_SECURITY_STRINGS[nego->selected_protocol]);
return FALSE;

View File

@@ -44,7 +44,7 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client)
static BOOL freerdp_peer_get_fds(freerdp_peer* client, void** rfds, int* rcount)
{
rfds[*rcount] = (void*)(long)(client->context->rdp->transport->tcp->sockfd);
rfds[*rcount] = (void*)(long)(client->context->rdp->transport->TcpIn->sockfd);
(*rcount)++;
return TRUE;

View File

@@ -1229,8 +1229,8 @@ int rpc_read(rdpRpc* rpc, BYTE* data, int length)
BOOL rpc_connect(rdpRpc* rpc)
{
rpc->tls_in = rpc->transport->tls_in;
rpc->tls_out = rpc->transport->tls_out;
rpc->tls_in = rpc->transport->TlsIn;
rpc->tls_out = rpc->transport->TlsOut;
if (!rts_connect(rpc))
{

View File

@@ -25,6 +25,8 @@
#include <stdlib.h>
#include <string.h>
#include <winpr/crt.h>
#include <freerdp/utils/tcp.h>
#include <freerdp/utils/sleep.h>
#include <freerdp/utils/stream.h>
@@ -67,15 +69,15 @@ STREAM* transport_send_stream_init(rdpTransport* transport, int size)
void transport_attach(rdpTransport* transport, int sockfd)
{
transport->tcp->sockfd = sockfd;
transport->TcpIn->sockfd = sockfd;
}
BOOL transport_disconnect(rdpTransport* transport)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
tls_disconnect(transport->tls);
tls_disconnect(transport->TlsIn);
return tcp_disconnect(transport->tcp);
return tcp_disconnect(transport->TcpIn);
}
BOOL transport_connect_rdp(rdpTransport* transport)
@@ -87,19 +89,23 @@ BOOL transport_connect_rdp(rdpTransport* transport)
BOOL transport_connect_tls(rdpTransport* transport)
{
if (transport->tls == NULL)
transport->tls = tls_new(transport->settings);
if (transport->TlsIn == NULL)
transport->TlsIn = tls_new(transport->settings);
if (transport->TlsOut == NULL)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
transport->tls->sockfd = transport->tcp->sockfd;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_connect(transport->tls) != TRUE)
if (tls_connect(transport->TlsIn) != TRUE)
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
tls_free(transport->TlsIn);
transport->TlsIn = NULL;
tls_free(transport->tls);
transport->tls = NULL;
return FALSE;
}
@@ -111,21 +117,8 @@ BOOL transport_connect_nla(rdpTransport* transport)
freerdp* instance;
rdpSettings* settings;
if (transport->tls == NULL)
transport->tls = tls_new(transport->settings);
transport->layer = TRANSPORT_LAYER_TLS;
transport->tls->sockfd = transport->tcp->sockfd;
if (tls_connect(transport->tls) != TRUE)
{
if (!connectErrorCode)
connectErrorCode = TLSCONNECTERROR;
tls_free(transport->tls);
transport->tls = NULL;
if (!transport_connect_tls(transport))
return FALSE;
}
/* Network Level Authentication */
@@ -136,7 +129,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
instance = (freerdp*) settings->instance;
if (transport->credssp == NULL)
transport->credssp = credssp_new(instance, transport->tls, settings);
transport->credssp = credssp_new(instance, transport->TlsIn, transport->TlsOut, settings);
if (credssp_authenticate(transport->credssp) < 0)
{
@@ -161,21 +154,22 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
tsg->transport = transport;
transport->tsg = tsg;
transport->SplitInputOutput = TRUE;
if (transport->tls_in == NULL)
transport->tls_in = tls_new(transport->settings);
if (transport->TlsIn == NULL)
transport->TlsIn = tls_new(transport->settings);
transport->tls_in->sockfd = transport->tcp_in->sockfd;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (transport->tls_out == NULL)
transport->tls_out = tls_new(transport->settings);
if (transport->TlsOut == NULL)
transport->TlsOut = tls_new(transport->settings);
transport->tls_out->sockfd = transport->tcp_out->sockfd;
transport->TlsOut->sockfd = transport->TcpOut->sockfd;
if (tls_connect(transport->tls_in) != TRUE)
if (tls_connect(transport->TlsIn) != TRUE)
return FALSE;
if (tls_connect(transport->tls_out) != TRUE)
if (tls_connect(transport->TlsOut) != TRUE)
return FALSE;
if (!tsg_connect(tsg, hostname, port))
@@ -192,19 +186,22 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
if (transport->settings->GatewayUsageMethod)
{
transport->layer = TRANSPORT_LAYER_TSG;
transport->tcp_out = tcp_new(settings);
transport->TcpOut = tcp_new(settings);
status = tcp_connect(transport->tcp_in, settings->GatewayHostname, 443);
status = tcp_connect(transport->TcpIn, settings->GatewayHostname, 443);
if (status)
status = tcp_connect(transport->tcp_out, settings->GatewayHostname, 443);
status = tcp_connect(transport->TcpOut, settings->GatewayHostname, 443);
if (status)
status = transport_tsg_connect(transport, hostname, port);
}
else
{
status = tcp_connect(transport->tcp, hostname, port);
status = tcp_connect(transport->TcpIn, hostname, port);
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
}
return status;
@@ -219,13 +216,13 @@ BOOL transport_accept_rdp(rdpTransport* transport)
BOOL transport_accept_tls(rdpTransport* transport)
{
if (transport->tls == NULL)
transport->tls = tls_new(transport->settings);
if (transport->TlsIn == NULL)
transport->TlsIn = tls_new(transport->settings);
transport->layer = TRANSPORT_LAYER_TLS;
transport->tls->sockfd = transport->tcp->sockfd;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_accept(transport->tls, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
if (tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
return FALSE;
return TRUE;
@@ -236,13 +233,13 @@ BOOL transport_accept_nla(rdpTransport* transport)
freerdp* instance;
rdpSettings* settings;
if (transport->tls == NULL)
transport->tls = tls_new(transport->settings);
if (transport->TlsIn == NULL)
transport->TlsIn = tls_new(transport->settings);
transport->layer = TRANSPORT_LAYER_TLS;
transport->tls->sockfd = transport->tcp->sockfd;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (tls_accept(transport->tls, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
if (tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile) != TRUE)
return FALSE;
/* Network Level Authentication */
@@ -254,7 +251,7 @@ BOOL transport_accept_nla(rdpTransport* transport)
instance = (freerdp*) settings->instance;
if (transport->credssp == NULL)
transport->credssp = credssp_new(instance, transport->tls, settings);
transport->credssp = credssp_new(instance, transport->TlsIn, transport->TlsOut, settings);
if (credssp_authenticate(transport->credssp) < 0)
{
@@ -275,9 +272,9 @@ int transport_read(rdpTransport* transport, STREAM* s)
while (TRUE)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_read(transport->tls, stream_get_tail(s), stream_get_left(s));
status = tls_read(transport->TlsIn, stream_get_tail(s), stream_get_left(s));
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_read(transport->tcp, stream_get_tail(s), stream_get_left(s));
status = tcp_read(transport->TcpIn, stream_get_tail(s), stream_get_left(s));
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_read(transport->tsg, stream_get_tail(s), stream_get_left(s));
@@ -335,9 +332,9 @@ int transport_write(rdpTransport* transport, STREAM* s)
while (length > 0)
{
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_write(transport->tls, stream_get_tail(s), length);
status = tls_write(transport->TlsOut, stream_get_tail(s), length);
else if (transport->layer == TRANSPORT_LAYER_TCP)
status = tcp_write(transport->tcp, stream_get_tail(s), length);
status = tcp_write(transport->TcpOut, stream_get_tail(s), length);
else if (transport->layer == TRANSPORT_LAYER_TSG)
status = tsg_write(transport->tsg, stream_get_tail(s), length);
@@ -376,7 +373,7 @@ void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount)
#ifdef _WIN32
rfds[*rcount] = transport->tcp->wsa_event;
#else
rfds[*rcount] = (void*)(long)(transport->tcp->sockfd);
rfds[*rcount] = (void*)(long)(transport->TcpIn->sockfd);
#endif
(*rcount)++;
wait_obj_get_fds(transport->recv_event, rfds, rcount);
@@ -403,6 +400,7 @@ int transport_check_fds(rdpTransport** ptransport)
while ((pos = stream_get_pos(transport->recv_buffer)) > 0)
{
stream_set_pos(transport->recv_buffer, 0);
if (tpkt_verify_header(transport->recv_buffer)) /* TPKT */
{
/* Ensure the TPKT header is available. */
@@ -411,6 +409,7 @@ int transport_check_fds(rdpTransport** ptransport)
stream_set_pos(transport->recv_buffer, pos);
return 0;
}
length = tpkt_read_header(transport->recv_buffer);
}
else /* Fast Path */
@@ -421,13 +420,16 @@ int transport_check_fds(rdpTransport** ptransport)
stream_set_pos(transport->recv_buffer, pos);
return 0;
}
/* Fastpath header can be two or three bytes long. */
length = fastpath_header_length(transport->recv_buffer);
if (pos < length)
{
stream_set_pos(transport->recv_buffer, pos);
return 0;
}
length = fastpath_read_header(NULL, transport->recv_buffer);
}
@@ -473,7 +475,7 @@ int transport_check_fds(rdpTransport** ptransport)
/* transport might now have been freed by rdp_client_redirect and a new rdp->transport created */
transport = *ptransport;
if (transport->process_single_pdu)
if (transport->ProcessSinglePdu)
{
/* one at a time but set event if data buffered
* so the main loop will call freerdp_check_fds asap */
@@ -490,19 +492,19 @@ int transport_check_fds(rdpTransport** ptransport)
BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking)
{
transport->blocking = blocking;
return tcp_set_blocking_mode(transport->tcp, blocking);
return tcp_set_blocking_mode(transport->TcpIn, blocking);
}
rdpTransport* transport_new(rdpSettings* settings)
{
rdpTransport* transport;
transport = (rdpTransport*) xzalloc(sizeof(rdpTransport));
transport = (rdpTransport*) malloc(sizeof(rdpTransport));
ZeroMemory(transport, sizeof(rdpTransport));
if (transport != NULL)
{
transport->tcp = tcp_new(settings);
transport->tcp_in = tcp_new(settings);
transport->TcpIn = tcp_new(settings);
transport->settings = settings;
@@ -534,11 +536,10 @@ void transport_free(rdpTransport* transport)
stream_free(transport->send_stream);
wait_obj_free(transport->recv_event);
if (transport->tls)
tls_free(transport->tls);
if (transport->TlsIn)
tls_free(transport->TlsIn);
tcp_free(transport->tcp);
tcp_free(transport->tcp_in);
tcp_free(transport->TcpIn);
tsg_free(transport->tsg);
free(transport);

View File

@@ -50,22 +50,21 @@ struct rdp_transport
STREAM* recv_stream;
STREAM* send_stream;
TRANSPORT_LAYER layer;
struct rdp_tcp* tcp;
struct rdp_tls* tls;
struct rdp_tsg* tsg;
struct rdp_tcp* tcp_in;
struct rdp_tcp* tcp_out;
struct rdp_tls* tls_in;
struct rdp_tls* tls_out;
struct rdp_credssp* credssp;
struct rdp_settings* settings;
rdpTsg* tsg;
rdpTcp* TcpIn;
rdpTcp* TcpOut;
rdpTls* TlsIn;
rdpTls* TlsOut;
rdpCredssp* credssp;
rdpSettings* settings;
UINT32 usleep_interval;
void* recv_extra;
STREAM* recv_buffer;
TransportRecv recv_callback;
struct wait_obj* recv_event;
BOOL blocking;
BOOL process_single_pdu; /* process single pdu in transport_check_fds */
BOOL ProcessSinglePdu;
BOOL SplitInputOutput;
};
STREAM* transport_recv_stream_init(rdpTransport* transport, int size);

View File

@@ -133,8 +133,8 @@ int credssp_ntlm_client_init(rdpCredssp* credssp)
(char*) credssp->identity.User, (char*) credssp->identity.Domain, (char*) credssp->identity.Password);
#endif
sspi_SecBufferAlloc(&credssp->PublicKey, credssp->tls->PublicKeyLength);
CopyMemory(credssp->PublicKey.pvBuffer, credssp->tls->PublicKey, credssp->tls->PublicKeyLength);
sspi_SecBufferAlloc(&credssp->PublicKey, credssp->TlsIn->PublicKeyLength);
CopyMemory(credssp->PublicKey.pvBuffer, credssp->TlsIn->PublicKey, credssp->TlsIn->PublicKeyLength);
length = sizeof(TERMSRV_SPN_PREFIX) + strlen(settings->ServerHostname);
@@ -164,8 +164,8 @@ int credssp_ntlm_server_init(rdpCredssp* credssp)
rdpSettings* settings = credssp->settings;
instance = (freerdp*) settings->instance;
sspi_SecBufferAlloc(&credssp->PublicKey, credssp->tls->PublicKeyLength);
CopyMemory(credssp->PublicKey.pvBuffer, credssp->tls->PublicKey, credssp->tls->PublicKeyLength);
sspi_SecBufferAlloc(&credssp->PublicKey, credssp->TlsIn->PublicKeyLength);
CopyMemory(credssp->PublicKey.pvBuffer, credssp->TlsIn->PublicKey, credssp->TlsIn->PublicKeyLength);
return 1;
}
@@ -1127,10 +1127,7 @@ void credssp_send(rdpCredssp* credssp)
ber_write_octet_string(s, credssp->pubKeyAuth.pvBuffer, length);
}
//printf("Sending TSRequest: (%d)\n", stream_get_length(s));
//freerdp_hexdump(s->data, stream_get_length(s));
tls_write(credssp->tls, s->data, stream_get_length(s));
tls_write(credssp->TlsOut, s->data, stream_get_length(s));
stream_free(s);
}
@@ -1149,7 +1146,7 @@ int credssp_recv(rdpCredssp* credssp)
s = stream_new(4096);
status = tls_read_all(credssp->tls, s->p, stream_get_left(s));
status = tls_read_all(credssp->TlsIn, s->p, stream_get_left(s));
s->size = status;
if (status < 0)
@@ -1159,9 +1156,6 @@ int credssp_recv(rdpCredssp* credssp)
return -1;
}
//printf("Receiving TSRequest: (%d)\n", s->size);
//freerdp_hexdump(s->data, s->size);
/* TSRequest */
ber_read_sequence_tag(s, &length);
ber_read_contextual_tag(s, 0, &length, TRUE);
@@ -1236,11 +1230,12 @@ void credssp_buffer_free(rdpCredssp* credssp)
* @return new CredSSP state machine.
*/
rdpCredssp* credssp_new(freerdp* instance, rdpTls* tls, rdpSettings* settings)
rdpCredssp* credssp_new(freerdp* instance, rdpTls* TlsIn, rdpTls* TlsOut, rdpSettings* settings)
{
rdpCredssp* credssp;
credssp = (rdpCredssp*) xzalloc(sizeof(rdpCredssp));
credssp = (rdpCredssp*) malloc(sizeof(rdpCredssp));
ZeroMemory(credssp, sizeof(rdpCredssp));
if (credssp != NULL)
{
@@ -1252,7 +1247,8 @@ rdpCredssp* credssp_new(freerdp* instance, rdpTls* tls, rdpSettings* settings)
credssp->instance = instance;
credssp->settings = settings;
credssp->server = settings->ServerMode;
credssp->tls = tls;
credssp->TlsIn = TlsIn;
credssp->TlsOut = TlsOut;
credssp->send_seq_num = 0;
credssp->recv_seq_num = 0;
ZeroMemory(&credssp->negoToken, sizeof(SecBuffer));