diff --git a/libfreerdp/core/rdstls.c b/libfreerdp/core/rdstls.c index 6bc88f2df..0e21fa38e 100644 --- a/libfreerdp/core/rdstls.c +++ b/libfreerdp/core/rdstls.c @@ -68,6 +68,8 @@ rdpRdstls* rdstls_new(rdpContext* context, rdpTransport* transport) rdstls->transport = transport; rdstls->server = settings->ServerMode; + rdstls->state = RDSTLS_STATE_INITIAL; + return rdstls; } @@ -88,17 +90,26 @@ static const char* rdstls_get_state_str(RDSTLS_STATE state) { switch (state) { + case RDSTLS_STATE_INITIAL: + return "RDSTLS_STATE_INITIAL"; case RDSTLS_STATE_CAPABILITIES: return "RDSTLS_STATE_CAPABILITIES"; case RDSTLS_STATE_AUTH_REQ: return "RDSTLS_STATE_AUTH_REQ"; case RDSTLS_STATE_AUTH_RSP: return "RDSTLS_STATE_AUTH_RSP"; + case RDSTLS_STATE_FINAL: + return "RDSTLS_STATE_FINAL"; default: return "UNKNOWN"; } } +static BOOL rdstls_valid_transition(RDSTLS_STATE originState, RDSTLS_STATE nextState) +{ + return originState + 1 == nextState; +} + static RDSTLS_STATE rdstls_get_state(rdpRdstls* rdstls) { WINPR_ASSERT(rdstls); @@ -111,6 +122,13 @@ static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state) WLog_DBG(TAG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state), rdstls_get_state_str(state)); + + if (!rdstls_valid_transition(rdstls->state, state)) + { + WLog_ERR(TAG, "invalid state transition"); + return FALSE; + } + rdstls->state = state; return TRUE; } @@ -519,6 +537,10 @@ static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra) if (!rdstls_write_authentication_response(rdstls, s)) return FALSE; break; + default: + WLog_ERR(TAG, "RDSTLS in invalid receive state %s", + rdstls_get_state_str(rdstls_get_state(rdstls))); + return -1; } if (transport_write(rdstls->transport, s) < 0) @@ -643,7 +665,7 @@ static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls) if (!rdstls_send(rdstls->transport, s, rdstls)) goto fail; - rc = TRUE; + rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL); fail: Stream_Free(s, TRUE); return rc; @@ -728,7 +750,7 @@ static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls) if (status < 0) goto fail; - rc = TRUE; + rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL); fail: Stream_Free(s, TRUE); return rc; @@ -736,7 +758,8 @@ fail: static int rdstls_server_authenticate(rdpRdstls* rdstls) { - rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES); + if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES)) + return -1; if (!rdstls_send_capabilities(rdstls)) return -1; @@ -755,7 +778,8 @@ static int rdstls_server_authenticate(rdpRdstls* rdstls) static int rdstls_client_authenticate(rdpRdstls* rdstls) { - rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES); + if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES)) + return -1; if (!rdstls_recv_capabilities(rdstls)) return -1; diff --git a/libfreerdp/core/rdstls.h b/libfreerdp/core/rdstls.h index b47d88fa7..f5686e76f 100644 --- a/libfreerdp/core/rdstls.h +++ b/libfreerdp/core/rdstls.h @@ -37,9 +37,11 @@ typedef struct rdp_rdstls rdpRdstls; typedef enum { + RDSTLS_STATE_INITIAL, RDSTLS_STATE_CAPABILITIES, RDSTLS_STATE_AUTH_REQ, RDSTLS_STATE_AUTH_RSP, + RDSTLS_STATE_FINAL, } RDSTLS_STATE; FREERDP_LOCAL int rdstls_authenticate(rdpRdstls* rdstls);