Updated reconnect to handle cases where PostConnect was not called

freerdp_reconnect might be called after a freerdp_connect failed due
to a TCP timeout waiting for user input.
In such cases we need to know if PostConect was already called and
do that if not.
This commit is contained in:
Armin Novak
2018-07-10 12:04:27 +02:00
parent c9cebf6ed6
commit 7a39dcd7e2
4 changed files with 63 additions and 53 deletions

View File

@@ -351,6 +351,35 @@ BOOL rdp_client_disconnect_and_clear(rdpRdp* rdp)
return TRUE;
}
static BOOL rdp_client_reconnect_channels(rdpRdp* rdp)
{
BOOL status;
rdpContext* context;
rdpChannels* channels;
if (!rdp || !rdp->context || !rdp->context->channels)
return FALSE;
context = rdp->context;
channels = context->channels;
if (context->instance->ConnectionCallbackState == CLIENT_STATE_INITIAL)
return FALSE;
if (context->instance->ConnectionCallbackState == CLIENT_STATE_PRECONNECT_PASSED)
{
if (!IFCALLRESULT(FALSE, context->instance->PostConnect, context->instance))
return FALSE;
context->instance->ConnectionCallbackState = CLIENT_STATE_POSTCONNECT_PASSED;
}
if (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED)
status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK);
return status;
}
BOOL rdp_client_redirect(rdpRdp* rdp)
{
BOOL status;
@@ -424,8 +453,8 @@ BOOL rdp_client_redirect(rdpRdp* rdp)
status = rdp_client_connect(rdp);
if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED))
status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK);
if (status)
status = rdp_client_reconnect_channels(rdp);
return status;
}
@@ -447,8 +476,8 @@ BOOL rdp_client_reconnect(rdpRdp* rdp)
status = rdp_client_connect(rdp);
if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED))
status = (freerdp_channels_post_connect(channels, context->instance) == CHANNEL_RC_OK);
if (status)
status = rdp_client_reconnect_channels(rdp);
return status;
}

View File

@@ -53,8 +53,7 @@ enum CLIENT_CONNECTION_STATE
{
CLIENT_STATE_INITIAL,
CLIENT_STATE_PRECONNECT_PASSED,
CLIENT_STATE_POSTCONNECT_PASSED,
CLIENT_STATE_POSTDISCONNECT_PASSED
CLIENT_STATE_POSTCONNECT_PASSED
};
FREERDP_LOCAL BOOL rdp_client_connect(rdpRdp* rdp);

View File

@@ -400,6 +400,7 @@ BOOL freerdp_check_event_handles(rdpContext* context)
{
if (freerdp_get_last_error(context) == FREERDP_ERROR_SUCCESS)
WLog_ERR(TAG, "checkChannelErrorEvent() failed - %"PRIi32"", status);
return FALSE;
}
@@ -407,6 +408,7 @@ BOOL freerdp_check_event_handles(rdpContext* context)
{
int rc = freerdp_message_queue_process_pending_messages(
context->instance, FREERDP_INPUT_MESSAGE_QUEUE);
if (rc < 0)
return FALSE;
else
@@ -512,7 +514,6 @@ BOOL freerdp_disconnect(freerdp* instance)
}
IFCALL(instance->PostDisconnect, instance);
instance->ConnectionCallbackState = CLIENT_STATE_POSTDISCONNECT_PASSED;
if (instance->update->pcap_rfx)
{

View File

@@ -103,7 +103,6 @@ BOOL nego_connect(rdpNego* nego)
{
WLog_DBG(TAG, "Security Layer Negotiation is disabled");
/* attempt only the highest enabled protocol (see nego_attempt_*) */
nego->EnabledProtocols[PROTOCOL_NLA] = FALSE;
nego->EnabledProtocols[PROTOCOL_TLS] = FALSE;
nego->EnabledProtocols[PROTOCOL_RDP] = FALSE;
@@ -152,13 +151,13 @@ BOOL nego_connect(rdpNego* nego)
do
{
WLog_DBG(TAG, "state: %s", NEGO_STATE_STRINGS[nego->state]);
nego_send(nego);
if (nego->state == NEGO_STATE_FAIL)
{
if (freerdp_get_last_error(nego->transport->context) == FREERDP_ERROR_SUCCESS)
WLog_ERR(TAG, "Protocol Security Negotiation Failure");
nego->state = NEGO_STATE_FINAL;
return FALSE;
}
@@ -167,7 +166,6 @@ BOOL nego_connect(rdpNego* nego)
}
WLog_DBG(TAG, "Negotiated %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
/* update settings with negotiated protocol security */
settings->RequestedProtocols = nego->RequestedProtocols;
settings->SelectedProtocol = nego->SelectedProtocol;
@@ -183,14 +181,16 @@ BOOL nego_connect(rdpNego* nego)
* Advertise all supported encryption methods if the client
* implementation did not set any security methods
*/
settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT | ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS;
settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT |
ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS;
}
}
/* finally connect security layer (if not already done) */
if (!nego_security_connect(nego))
{
WLog_DBG(TAG, "Failed to connect with %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
WLog_DBG(TAG, "Failed to connect with %s security",
PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
return FALSE;
}
@@ -242,7 +242,7 @@ BOOL nego_security_connect(rdpNego* nego)
* @return
*/
BOOL nego_tcp_connect(rdpNego* nego)
static BOOL nego_tcp_connect(rdpNego* nego)
{
if (!nego->TcpConnected)
{
@@ -280,7 +280,8 @@ BOOL nego_tcp_connect(rdpNego* nego)
BOOL nego_transport_connect(rdpNego* nego)
{
nego_tcp_connect(nego);
if (!nego_tcp_connect(nego))
return FALSE;
if (nego->TcpConnected && !nego->NegotiateSecurityLayer)
return nego_security_connect(nego);
@@ -301,7 +302,6 @@ BOOL nego_transport_disconnect(rdpNego* nego)
nego->TcpConnected = FALSE;
nego->SecurityConnected = FALSE;
return TRUE;
}
@@ -317,7 +317,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
UINT32 cbSize;
UINT16 cchPCB = 0;
WCHAR* wszPCB = NULL;
WLog_DBG(TAG, "Sending preconnection PDU");
if (!nego_tcp_connect(nego))
@@ -334,6 +333,7 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
}
s = Stream_New(NULL, cbSize);
if (!s)
{
WLog_ERR(TAG, "Stream_New failed!");
@@ -361,7 +361,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
}
Stream_Free(s, TRUE);
return TRUE;
}
@@ -373,7 +372,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
void nego_attempt_ext(rdpNego* nego)
{
nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS | PROTOCOL_EXT;
WLog_DBG(TAG, "Attempting NLA extended security");
if (!nego_transport_connect(nego))
@@ -419,7 +417,6 @@ void nego_attempt_ext(rdpNego* nego)
void nego_attempt_nla(rdpNego* nego)
{
nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS;
WLog_DBG(TAG, "Attempting NLA security");
if (!nego_transport_connect(nego))
@@ -463,7 +460,6 @@ void nego_attempt_nla(rdpNego* nego)
void nego_attempt_tls(rdpNego* nego)
{
nego->RequestedProtocols = PROTOCOL_TLS;
WLog_DBG(TAG, "Attempting TLS security");
if (!nego_transport_connect(nego))
@@ -503,7 +499,6 @@ void nego_attempt_tls(rdpNego* nego)
void nego_attempt_rdp(rdpNego* nego)
{
nego->RequestedProtocols = PROTOCOL_RDP;
WLog_DBG(TAG, "Attempting RDP security");
if (!nego_transport_connect(nego))
@@ -534,7 +529,6 @@ BOOL nego_recv_response(rdpNego* nego)
{
int status;
wStream* s;
s = Stream_New(NULL, 1024);
if (!s)
@@ -552,7 +546,6 @@ BOOL nego_recv_response(rdpNego* nego)
}
status = nego_recv(nego->transport, s, nego);
Stream_Free(s, TRUE);
if (status < 0)
@@ -588,14 +581,12 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
if (li > 6)
{
/* rdpNegData (optional) */
Stream_Read_UINT8(s, type); /* Type */
switch (type)
{
case TYPE_RDP_NEG_RSP:
nego_process_negotiation_response(nego, s);
WLog_DBG(TAG, "selected_protocol: %"PRIu32"", nego->SelectedProtocol);
/* enhanced security selected ? */
@@ -603,12 +594,13 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
if (nego->SelectedProtocol)
{
if ((nego->SelectedProtocol == PROTOCOL_NLA) &&
(!nego->EnabledProtocols[PROTOCOL_NLA]))
(!nego->EnabledProtocols[PROTOCOL_NLA]))
{
nego->state = NEGO_STATE_FAIL;
}
if ((nego->SelectedProtocol == PROTOCOL_TLS) &&
(!nego->EnabledProtocols[PROTOCOL_TLS]))
(!nego->EnabledProtocols[PROTOCOL_TLS]))
{
nego->state = NEGO_STATE_FAIL;
}
@@ -617,6 +609,7 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
{
nego->state = NEGO_STATE_FAIL;
}
break;
case TYPE_RDP_NEG_FAILURE:
@@ -663,13 +656,11 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
* string terminated by a 0x0D0A two-byte sequence:
* Cookie:[space]mstshash=[ANSISTRING][\x0D\x0A]
*/
BYTE *str = NULL;
BYTE* str = NULL;
UINT16 crlf = 0;
size_t pos, len;
BOOL result = FALSE;
BOOL isToken = FALSE;
str = Stream_Pointer(s);
pos = Stream_GetPosition(s);
@@ -693,8 +684,10 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
while (Stream_GetRemainingLength(s) >= 2)
{
Stream_Read_UINT16(s, crlf);
if (crlf == 0x0A0D)
break;
Stream_Rewind(s, 1);
}
@@ -703,6 +696,7 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
Stream_Rewind(s, 2);
len = Stream_GetPosition(s) - pos;
Stream_Write_UINT16(s, 0);
if (strlen((char*)str) == len)
{
if (isToken)
@@ -716,12 +710,12 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
{
Stream_SetPosition(s, pos);
WLog_ERR(TAG, "invalid %s received",
isToken ? "routing token" : "cookie");
isToken ? "routing token" : "cookie");
}
else
{
WLog_DBG(TAG, "received %s [%s]",
isToken ? "routing token" : "cookie", str);
isToken ? "routing token" : "cookie", str);
}
return result;
@@ -760,7 +754,6 @@ BOOL nego_read_request(rdpNego* nego, wStream* s)
if (Stream_GetRemainingLength(s) >= 8)
{
/* rdpNegData (optional) */
Stream_Read_UINT8(s, type); /* Type */
if (type != TYPE_RDP_NEG_REQ)
@@ -808,8 +801,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
size_t bm, em;
BYTE flags = 0;
int cookie_length;
s = Stream_New(NULL, 512);
if (!s)
{
WLog_ERR(TAG, "Stream_New failed!");
@@ -827,8 +820,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
/* Ensure Routing Token is correctly terminated - may already be present in string */
if ((nego->RoutingTokenLength > 2) &&
(nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) &&
(nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A))
(nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) &&
(nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A))
{
WLog_DBG(TAG, "Routing token looks correctly terminated - use verbatim");
length += nego->RoutingTokenLength;
@@ -860,7 +853,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
if ((nego->RequestedProtocols > PROTOCOL_RDP) || (nego->sendNegoData))
{
/* RDP_NEG_DATA must be present for TLS and NLA */
if (nego->RestrictedAdminModeRequired)
flags |= RESTRICTED_ADMIN_MODE_REQUIRED;
@@ -876,7 +868,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
tpkt_write_header(s, length);
tpdu_write_connection_request(s, length - 5);
Stream_SetPosition(s, em);
Stream_SealLength(s);
if (transport_write(nego->transport, s) < 0)
@@ -886,7 +877,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
}
Stream_Free(s, TRUE);
return TRUE;
}
@@ -900,13 +890,10 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s)
{
BYTE flags;
UINT16 length;
Stream_Read_UINT8(s, flags);
Stream_Read_UINT16(s, length);
Stream_Read_UINT32(s, nego->RequestedProtocols);
WLog_DBG(TAG, "RDP_NEG_REQ: RequestedProtocol: 0x%08"PRIX32"", nego->RequestedProtocols);
nego->state = NEGO_STATE_FINAL;
}
@@ -919,7 +906,6 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s)
void nego_process_negotiation_response(rdpNego* nego, wStream* s)
{
UINT16 length;
WLog_DBG(TAG, "RDP_NEG_RSP");
if (Stream_GetRemainingLength(s) < 7)
@@ -932,7 +918,6 @@ void nego_process_negotiation_response(rdpNego* nego, wStream* s)
Stream_Read_UINT8(s, nego->flags);
Stream_Read_UINT16(s, length);
Stream_Read_UINT32(s, nego->SelectedProtocol);
nego->state = NEGO_STATE_FINAL;
}
@@ -947,9 +932,7 @@ void nego_process_negotiation_failure(rdpNego* nego, wStream* s)
BYTE flags;
UINT16 length;
UINT32 failureCode;
WLog_DBG(TAG, "RDP_NEG_FAILURE");
Stream_Read_UINT8(s, flags);
Stream_Read_UINT16(s, length);
Stream_Read_UINT32(s, failureCode);
@@ -999,11 +982,10 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
wStream* s;
BYTE flags;
rdpSettings* settings;
status = TRUE;
settings = nego->transport->settings;
s = Stream_New(NULL, 512);
if (!s)
{
WLog_ERR(TAG, "Stream_New failed!");
@@ -1018,11 +1000,9 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
{
UINT32 errorCode = (nego->SelectedProtocol & ~PROTOCOL_FAILED_NEGO);
flags = 0;
Stream_Write_UINT8(s, TYPE_RDP_NEG_FAILURE);
Stream_Write_UINT8(s, flags); /* flags */
Stream_Write_UINT16(s, 8); /* RDP_NEG_DATA length (8) */
Stream_Write_UINT32(s, errorCode);
length += 8;
status = FALSE;
@@ -1047,7 +1027,6 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
tpkt_write_header(s, length);
tpdu_write_connection_confirm(s, length - 5);
Stream_SetPosition(s, em);
Stream_SealLength(s);
if (transport_write(nego->transport, s) < 0)
@@ -1148,9 +1127,7 @@ rdpNego* nego_new(rdpTransport* transport)
return NULL;
nego->transport = transport;
nego_init(nego);
return nego;
}
@@ -1276,8 +1253,10 @@ BOOL nego_set_routing_token(rdpNego* nego, BYTE* RoutingToken, DWORD RoutingToke
free(nego->RoutingToken);
nego->RoutingTokenLength = RoutingTokenLength;
nego->RoutingToken = (BYTE*) malloc(nego->RoutingTokenLength);
if (!nego->RoutingToken)
return FALSE;
CopyMemory(nego->RoutingToken, RoutingToken, nego->RoutingTokenLength);
return TRUE;
}
@@ -1300,8 +1279,10 @@ BOOL nego_set_cookie(rdpNego* nego, char* cookie)
return TRUE;
nego->cookie = _strdup(cookie);
if (!nego->cookie)
return FALSE;
return TRUE;
}