This commit is contained in:
Dan Bungert
2013-10-25 10:43:21 -06:00
parent db890d9bf2
commit f13c8a0be7
4 changed files with 91 additions and 9 deletions

View File

@@ -49,6 +49,7 @@ struct rdp_tls
rdpSettings* settings;
SecPkgContext_Bindings* Bindings;
rdpCertificateStore* certificate_store;
char desc[20];
};
FREERDP_API BOOL tls_connect(rdpTls* tls);

View File

@@ -49,6 +49,8 @@
#define BUFFER_SIZE 16384
#include "lwd.h"
static void* transport_client_thread(void* arg);
wStream* transport_send_stream_init(rdpTransport* transport, int size)
@@ -221,6 +223,7 @@ BOOL transport_connect_tls(rdpTransport* transport)
if (transport->layer == TRANSPORT_LAYER_TSG)
{
transport->TsgTls = tls_new(transport->settings);
sprintf(transport->TsgTls->desc, "TsgTls");
transport->TsgTls->methods = BIO_s_tsg();
transport->TsgTls->tsg = (void*) transport->tsg;
@@ -242,8 +245,10 @@ BOOL transport_connect_tls(rdpTransport* transport)
return TRUE;
}
if (transport->TlsIn == NULL)
if (transport->TlsIn == NULL) {
transport->TlsIn = tls_new(transport->settings);
sprintf(transport->TlsIn->desc, "TlsIn");
}
if (transport->TlsOut == NULL)
transport->TlsOut = transport->TlsIn;
@@ -317,13 +322,17 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
transport->tsg = tsg;
transport->SplitInputOutput = TRUE;
if (transport->TlsIn == NULL)
if (transport->TlsIn == NULL) {
transport->TlsIn = tls_new(transport->settings);
sprintf(transport->TlsIn->desc, "TlsIn");
}
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
if (transport->TlsOut == NULL)
if (transport->TlsOut == NULL) {
transport->TlsOut = tls_new(transport->settings);
sprintf(transport->TlsOut->desc, "TlsOut");
}
transport->TlsOut->sockfd = transport->TcpOut->sockfd;
@@ -387,8 +396,10 @@ BOOL transport_accept_rdp(rdpTransport* transport)
BOOL transport_accept_tls(rdpTransport* transport)
{
if (transport->TlsIn == NULL)
if (transport->TlsIn == NULL) {
transport->TlsIn = tls_new(transport->settings);
sprintf(transport->TlsIn->desc, "TlsIn");
}
if (transport->TlsOut == NULL)
transport->TlsOut = transport->TlsIn;
@@ -410,8 +421,10 @@ BOOL transport_accept_nla(rdpTransport* transport)
if (transport->TlsIn == NULL)
transport->TlsIn = tls_new(transport->settings);
if (transport->TlsOut == NULL)
if (transport->TlsOut == NULL) {
transport->TlsOut = transport->TlsIn;
sprintf(transport->TlsIn->desc, "TlsIn");
}
transport->layer = TRANSPORT_LAYER_TLS;
transport->TlsIn->sockfd = transport->TcpIn->sockfd;
@@ -509,9 +522,21 @@ int transport_read_layer(rdpTransport* transport, UINT8* data, int bytes)
{
int read = 0;
int status = -1;
char *layer = "UNKNOWN";
if (transport->layer == TRANSPORT_LAYER_TLS)
layer = "TLS";
else if (transport->layer == TRANSPORT_LAYER_TCP)
layer = "TCP";
else if (transport->layer == TRANSPORT_LAYER_TSG)
layer = "TSG";
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
layer = "TSG_TLS";
while (read < bytes)
{
LWD("layer %s bytes %d read %d", layer, bytes, read);
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_read(transport->TlsIn, data + read, bytes - read);
else if (transport->layer == TRANSPORT_LAYER_TCP)
@@ -523,11 +548,15 @@ int transport_read_layer(rdpTransport* transport, UINT8* data, int bytes)
/* blocking means that we can't continue until this is read */
if (!transport->blocking)
if (!transport->blocking) {
LWD("layer %s return %d not blocking", layer, status);
return status;
}
if (status < 0)
if (status < 0) {
LWD("layer %s return %d negative status", layer, status);
return status;
}
read += status;
@@ -541,6 +570,7 @@ int transport_read_layer(rdpTransport* transport, UINT8* data, int bytes)
}
}
LWD("layer %s return %d normal", layer, status);
return read;
}
@@ -653,6 +683,7 @@ int transport_write(rdpTransport* transport, wStream* s)
{
int length;
int status = -1;
char *layer = "UNKNOWN";
WaitForSingleObject(transport->WriteMutex, INFINITE);
@@ -667,8 +698,19 @@ int transport_write(rdpTransport* transport, wStream* s)
}
#endif
if (transport->layer == TRANSPORT_LAYER_TLS)
layer = "TLS";
else if (transport->layer == TRANSPORT_LAYER_TCP)
layer = "TCP";
else if (transport->layer == TRANSPORT_LAYER_TSG)
layer = "TSG";
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
layer = "TSG_TLS";
while (length > 0)
{
LWD("layer %s length %d", layer, length);
if (transport->layer == TRANSPORT_LAYER_TLS)
status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
else if (transport->layer == TRANSPORT_LAYER_TCP)
@@ -678,11 +720,15 @@ int transport_write(rdpTransport* transport, wStream* s)
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
if (status < 0)
if (status < 0) {
LWD("layer %s length %d break %d negative status",
layer, length, status);
break; /* error occurred */
}
if (status == 0)
{
LWD("layer %s status 0", layer);
/* when sending is blocked in nonblocking mode, the receiving buffer should be checked */
if (!transport->blocking)
{
@@ -695,6 +741,8 @@ int transport_write(rdpTransport* transport, wStream* s)
tls_wait_write(transport->TlsOut);
else if (transport->layer == TRANSPORT_LAYER_TCP)
tcp_wait_write(transport->TcpOut);
else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
tls_wait_write(transport->TsgTls);
else
USleep(transport->SleepInterval);
}
@@ -714,6 +762,7 @@ int transport_write(rdpTransport* transport, wStream* s)
ReleaseMutex(transport->WriteMutex);
LWD("layer %s return %d", layer, status);
return status;
}

View File

@@ -31,6 +31,8 @@
#include <freerdp/crypto/tls.h>
#include <lwd.h>
static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
{
CryptoCert cert;
@@ -99,7 +101,7 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
{
printf("tls_ssl_info_callback: type: %d val: %d\n");
/* printf("tls_ssl_info_callback: type: %d val: %d\n", type, val); */
if (type & SSL_CB_HANDSHAKE_START)
{
@@ -373,6 +375,8 @@ int tls_read(rdpTls* tls, BYTE* data, int length)
int error;
int status;
LWD("length %d", length);
status = SSL_read(tls->ssl, data, length);
if (status <= 0)
@@ -411,6 +415,8 @@ int tls_read(rdpTls* tls, BYTE* data, int length)
}
}
LWD("ret %d", status);
return status;
}
@@ -434,6 +440,8 @@ int tls_write(rdpTls* tls, BYTE* data, int length)
int error;
int status;
LWD("length %d", length);
status = SSL_write(tls->ssl, data, length);
if (status <= 0)
@@ -471,6 +479,8 @@ int tls_write(rdpTls* tls, BYTE* data, int length)
}
}
LWD("ret %d", status);
return status;
}

22
lwd.h Normal file
View File

@@ -0,0 +1,22 @@
#ifndef __LWD_H__
#define __LWD_H__
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#define LWD(fmt, ...) do { \
time_t tod = time(NULL); \
char buf[25]; \
struct tm* tm_info = localtime(&tod); \
strftime(buf, 25, "%Y:%m:%d %H:%M:%S", tm_info); \
fprintf(stderr, "%s [%s] ", __FUNCTION__, buf); \
fprintf(stderr, fmt, ## __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stderr); \
} while( 0 )
#endif