From 1f08cb9a7d2d3d9c5c966b5d308c36ebe126ed35 Mon Sep 17 00:00:00 2001 From: David Fort Date: Tue, 26 Jul 2022 12:53:41 +0200 Subject: [PATCH] Drdynvc needs love (#8059) * winpr: add lock operation on HashTables * drdynvc: change the listeners array for a hashtable and other micro cleanups * logonInfo: drop warning that is shown at every connection Let's avoid this log, we can't do anything if at Microsoft they don't respect their own specs. * rdpei: fix terminate of rdpei * drdynvc: implement the channel list with a hashtable by channelId --- channels/client/generic_dynvc.c | 3 + channels/drdynvc/client/drdynvc_main.c | 232 ++++++++++--------- channels/drdynvc/client/drdynvc_main.h | 4 +- libfreerdp/core/info.c | 7 - winpr/include/winpr/collections.h | 2 + winpr/libwinpr/utils/collections/HashTable.c | 12 + 6 files changed, 145 insertions(+), 115 deletions(-) diff --git a/channels/client/generic_dynvc.c b/channels/client/generic_dynvc.c index b2e941e48..0503c168d 100644 --- a/channels/client/generic_dynvc.c +++ b/channels/client/generic_dynvc.c @@ -106,6 +106,9 @@ static UINT generic_plugin_terminated(IWTSPlugin* pPlugin) WLog_Print(plugin->log, WLOG_TRACE, "..."); + /* some channels (namely rdpei), look at initialized to see if they should continue to run */ + plugin->initialized = FALSE; + if (plugin->terminatePluginFn) plugin->terminatePluginFn(plugin); diff --git a/channels/drdynvc/client/drdynvc_main.c b/channels/drdynvc/client/drdynvc_main.c index bfb80e8d3..2ee6255cb 100644 --- a/channels/drdynvc/client/drdynvc_main.c +++ b/channels/drdynvc/client/drdynvc_main.c @@ -70,7 +70,7 @@ static UINT dvcman_create_listener(IWTSVirtualChannelManager* pChannelMgr, DVCMAN* dvcman = (DVCMAN*)pChannelMgr; DVCMAN_LISTENER* listener; - WLog_DBG(TAG, "create_listener: %d.%s.", ArrayList_Count(dvcman->listeners) + 1, + WLog_DBG(TAG, "create_listener: %d.%s.", HashTable_Count(dvcman->listeners) + 1, pszChannelName); listener = (DVCMAN_LISTENER*)calloc(1, sizeof(DVCMAN_LISTENER)); @@ -98,7 +98,7 @@ static UINT dvcman_create_listener(IWTSVirtualChannelManager* pChannelMgr, if (ppListener) *ppListener = (IWTSListener*)listener; - if (!ArrayList_Append(dvcman->listeners, listener)) + if (!HashTable_Insert(dvcman->listeners, listener->channel_name, listener)) return ERROR_INTERNAL_ERROR; return CHANNEL_RC_OK; } @@ -113,7 +113,7 @@ static UINT dvcman_destroy_listener(IWTSVirtualChannelManager* pChannelMgr, IWTS { DVCMAN* dvcman = listener->dvcman; if (dvcman) - ArrayList_Remove(dvcman->listeners, listener); + HashTable_Remove(dvcman->listeners, listener->channel_name); } return CHANNEL_RC_OK; @@ -202,21 +202,16 @@ static const char* dvcman_get_channel_name(IWTSVirtualChannel* channel) static IWTSVirtualChannel* dvcman_find_channel_by_id(IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId) { - size_t index; IWTSVirtualChannel* channel = NULL; DVCMAN* dvcman = (DVCMAN*)pChannelMgr; - ArrayList_Lock(dvcman->channels); - for (index = 0; index < ArrayList_Count(dvcman->channels); index++) - { - DVCMAN_CHANNEL* cur = (DVCMAN_CHANNEL*)ArrayList_GetItem(dvcman->channels, index); - if (cur->channel_id == ChannelId) - { - channel = &cur->iface; - break; - } - } + DVCMAN_CHANNEL* dvcChannel; - ArrayList_Unlock(dvcman->channels); + HashTable_Lock(dvcman->channelsById); + dvcChannel = HashTable_GetItemValue(dvcman->channelsById, &ChannelId); + if (dvcChannel) + channel = &dvcChannel->iface; + + HashTable_Unlock(dvcman->channelsById); return channel; } @@ -234,6 +229,17 @@ static void wts_listener_free(void* arg) DVCMAN_LISTENER* listener = (DVCMAN_LISTENER*)arg; dvcman_wtslistener_free(listener); } + +static BOOL channelIdMatch(const void* k1, const void* k2) +{ + return *((UINT32*)k1) == *((UINT32*)k2); +} + +static UINT32 channelIdHash(const void* id) +{ + return *((UINT32*)id); +} + static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin) { wObject* obj; @@ -249,22 +255,31 @@ static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin) dvcman->iface.GetChannelId = dvcman_get_channel_id; dvcman->iface.GetChannelName = dvcman_get_channel_name; dvcman->drdynvc = plugin; - dvcman->channels = ArrayList_New(TRUE); + dvcman->channelsById = HashTable_New(TRUE); - if (!dvcman->channels) + if (!dvcman->channelsById) goto fail; - obj = ArrayList_Object(dvcman->channels); + HashTable_SetHashFunction(dvcman->channelsById, channelIdHash); + obj = HashTable_KeyObject(dvcman->channelsById); + obj->fnObjectEquals = channelIdMatch; + + obj = HashTable_ValueObject(dvcman->channelsById); obj->fnObjectFree = dvcman_channel_free; dvcman->pool = StreamPool_New(TRUE, 10); if (!dvcman->pool) goto fail; - dvcman->listeners = ArrayList_New(TRUE); + dvcman->listeners = HashTable_New(TRUE); if (!dvcman->listeners) goto fail; - obj = ArrayList_Object(dvcman->listeners); + HashTable_SetHashFunction(dvcman->listeners, HashTable_StringHash); + + obj = HashTable_KeyObject(dvcman->listeners); + obj->fnObjectEquals = HashTable_StringCompare; + + obj = HashTable_ValueObject(dvcman->listeners); obj->fnObjectFree = wts_listener_free; dvcman->plugin_names = ArrayList_New(TRUE); @@ -406,10 +421,10 @@ static void dvcman_clear(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pCha WINPR_UNUSED(drdynvc); - ArrayList_Clear(dvcman->channels); + HashTable_Clear(dvcman->channelsById); ArrayList_Clear(dvcman->plugins); ArrayList_Clear(dvcman->plugin_names); - ArrayList_Clear(dvcman->listeners); + HashTable_Clear(dvcman->listeners); } static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr) { @@ -418,9 +433,9 @@ static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChan WINPR_UNUSED(drdynvc); ArrayList_Free(dvcman->plugins); - ArrayList_Free(dvcman->channels); + HashTable_Free(dvcman->channelsById); ArrayList_Free(dvcman->plugin_names); - ArrayList_Free(dvcman->listeners); + HashTable_Free(dvcman->listeners); StreamPool_Free(dvcman->pool); free(dvcman); @@ -506,76 +521,75 @@ static UINT dvcman_close_channel_iface(IWTSVirtualChannel* pChannel) static UINT dvcman_create_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId, const char* ChannelName) { - size_t i; BOOL bAccept; DVCMAN_CHANNEL* channel; DrdynvcClientContext* context; DVCMAN* dvcman = (DVCMAN*)pChannelMgr; + DVCMAN_LISTENER* listener; + IWTSVirtualChannelCallback* pCallback = NULL; UINT error; + HashTable_Lock(dvcman->listeners); + listener = (DVCMAN_LISTENER*)HashTable_GetItemValue(dvcman->listeners, ChannelName); + if (!listener) + { + error = ERROR_NOT_FOUND; + goto out; + } + if (!(channel = dvcman_channel_new(drdynvc, pChannelMgr, ChannelId, ChannelName))) { WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_channel_new failed!"); - return CHANNEL_RC_NO_MEMORY; + error = CHANNEL_RC_NO_MEMORY; + goto out; } channel->status = ERROR_NOT_CONNECTED; - if (!ArrayList_Append(dvcman->channels, channel)) - return ERROR_INTERNAL_ERROR; - - ArrayList_Lock(dvcman->listeners); - for (i = 0; i < ArrayList_Count(dvcman->listeners); i++) + if (!HashTable_Insert(dvcman->channelsById, &channel->channel_id, channel)) { - DVCMAN_LISTENER* listener = (DVCMAN_LISTENER*)ArrayList_GetItem(dvcman->listeners, i); - - if (strcmp(listener->channel_name, ChannelName) == 0) - { - IWTSVirtualChannelCallback* pCallback = NULL; - channel->iface.Write = dvcman_write_channel; - channel->iface.Close = dvcman_close_channel_iface; - bAccept = TRUE; - - if ((error = listener->listener_callback->OnNewChannelConnection( - listener->listener_callback, &channel->iface, NULL, &bAccept, &pCallback)) == - CHANNEL_RC_OK && - bAccept) - { - WLog_Print(drdynvc->log, WLOG_DEBUG, "listener %s created new channel %" PRIu32 "", - listener->channel_name, channel->channel_id); - channel->status = CHANNEL_RC_OK; - channel->channel_callback = pCallback; - channel->pInterface = listener->iface.pInterface; - context = dvcman->drdynvc->context; - IFCALLRET(context->OnChannelConnected, error, context, ChannelName, - listener->iface.pInterface); - - if (error) - WLog_Print(drdynvc->log, WLOG_ERROR, - "context.OnChannelConnected failed with error %" PRIu32 "", error); - - goto fail; - } - else - { - if (error) - { - WLog_Print(drdynvc->log, WLOG_ERROR, - "OnNewChannelConnection failed with error %" PRIu32 "!", error); - goto fail; - } - else - { - WLog_Print(drdynvc->log, WLOG_ERROR, - "OnNewChannelConnection returned with bAccept FALSE!"); - error = ERROR_INTERNAL_ERROR; - goto fail; - } - } - } + WLog_Print(drdynvc->log, WLOG_ERROR, "unable to register channel in our channel list"); + error = ERROR_INTERNAL_ERROR; + goto out; } - error = ERROR_INTERNAL_ERROR; -fail: - ArrayList_Unlock(dvcman->listeners); + + channel->iface.Write = dvcman_write_channel; + channel->iface.Close = dvcman_close_channel_iface; + bAccept = TRUE; + + error = listener->listener_callback->OnNewChannelConnection( + listener->listener_callback, &channel->iface, NULL, &bAccept, &pCallback); + + if (error != CHANNEL_RC_OK) + { + WLog_Print(drdynvc->log, WLOG_ERROR, + "OnNewChannelConnection failed with error %" PRIu32 "!", error); + error = ERROR_INTERNAL_ERROR; + goto out; + } + + if (!bAccept) + { + WLog_Print(drdynvc->log, WLOG_ERROR, "OnNewChannelConnection returned with bAccept FALSE!"); + error = ERROR_INTERNAL_ERROR; + goto out; + } + + WLog_Print(drdynvc->log, WLOG_DEBUG, "listener %s created new channel %" PRIu32 "", + listener->channel_name, channel->channel_id); + channel->status = CHANNEL_RC_OK; + channel->channel_callback = pCallback; + channel->pInterface = listener->iface.pInterface; + context = dvcman->drdynvc->context; + + IFCALLRET(context->OnChannelConnected, error, context, ChannelName, listener->iface.pInterface); + if (error != CHANNEL_RC_OK) + { + WLog_Print(drdynvc->log, WLOG_ERROR, + "context.OnChannelConnected failed with error %" PRIu32 "", error); + } + +out: + HashTable_Unlock(dvcman->listeners); return error; } @@ -632,8 +646,8 @@ UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 Channel UINT error = CHANNEL_RC_OK; DVCMAN* dvcman = (DVCMAN*)pChannelMgr; drdynvcPlugin* drdynvc = dvcman->drdynvc; - channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); + channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); if (!channel) { // WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %"PRIu32" not found!", ChannelId); @@ -660,7 +674,7 @@ UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 Channel } } - ArrayList_Remove(dvcman->channels, channel); + HashTable_Remove(dvcman->channelsById, &ChannelId); return error; } @@ -1048,6 +1062,7 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c char* name; size_t length; DVCMAN* dvcman; + UINT32 retStatus; WINPR_UNUSED(Sp); if (!drdynvc) @@ -1086,8 +1101,8 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c WLog_Print(drdynvc->log, WLOG_DEBUG, "process_create_request: ChannelId=%" PRIu32 " ChannelName=%s", ChannelId, name); channel_status = dvcman_create_channel(drdynvc, drdynvc->channel_mgr, ChannelId, name); - data_out = StreamPool_Take(dvcman->pool, pos + 4); + data_out = StreamPool_Take(dvcman->pool, pos + 4); if (!data_out) { WLog_Print(drdynvc->log, WLOG_ERROR, "StreamPool_Take failed!"); @@ -1098,16 +1113,26 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c Stream_SetPosition(s, 1); Stream_Copy(s, data_out, pos - 1); - if (channel_status == CHANNEL_RC_OK) + switch (channel_status) { - WLog_Print(drdynvc->log, WLOG_DEBUG, "channel created"); - Stream_Write_UINT32(data_out, 0); - } - else - { - WLog_Print(drdynvc->log, WLOG_DEBUG, "no listener"); - Stream_Write_UINT32(data_out, (UINT32)0xC0000001); /* same code used by mstsc */ + case CHANNEL_RC_OK: + WLog_Print(drdynvc->log, WLOG_DEBUG, "channel created"); + retStatus = 0; + break; + case CHANNEL_RC_NO_MEMORY: + WLog_Print(drdynvc->log, WLOG_DEBUG, "not enough memory for channel creation"); + retStatus = STATUS_NO_MEMORY; + break; + case ERROR_NOT_FOUND: + WLog_Print(drdynvc->log, WLOG_DEBUG, "no listener for '%s'", name); + retStatus = (UINT32)0xC0000001; /* same code used by mstsc, STATUS_UNSUCCESSFUL */ + break; + default: + WLog_Print(drdynvc->log, WLOG_DEBUG, "channel creation error"); + retStatus = (UINT32)0xC0000001; /* same code used by mstsc, STATUS_UNSUCCESSFUL */ + break; } + Stream_Write_UINT32(data_out, retStatus); status = drdynvc_send(drdynvc, data_out); @@ -1386,6 +1411,15 @@ static void VCAPITYPE drdynvc_virtual_channel_open_event_ex(LPVOID lpUserParam, "drdynvc_virtual_channel_open_event reported an error"); } +static BOOL channelByIdCleanerFn(const void* key, void* value, void* arg) +{ + drdynvcPlugin* drdynvc = (drdynvcPlugin*)arg; + DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)value; + + dvcman_close_channel(drdynvc->channel_mgr, channel->channel_id, FALSE); + return TRUE; +} + static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg) { /* TODO: rewrite this */ @@ -1438,23 +1472,9 @@ static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg) /* Disconnect remaining dynamic channels that the server did not. * This is required to properly shut down channels by calling the appropriate * event handlers. */ - size_t count = 0; DVCMAN* drdynvcMgr = (DVCMAN*)drdynvc->channel_mgr; - do - { - ArrayList_Lock(drdynvcMgr->channels); - count = ArrayList_Count(drdynvcMgr->channels); - if (count > 0) - { - IWTSVirtualChannel* channel = - (IWTSVirtualChannel*)ArrayList_GetItem(drdynvcMgr->channels, 0); - const UINT32 ChannelId = drdynvc->channel_mgr->GetChannelId(channel); - dvcman_close_channel(drdynvc->channel_mgr, ChannelId, FALSE); - count--; - } - ArrayList_Unlock(drdynvcMgr->channels); - } while (count > 0); + HashTable_Foreach(drdynvcMgr->channelsById, channelByIdCleanerFn, drdynvc); } if (error && drdynvc->rdpcontext) diff --git a/channels/drdynvc/client/drdynvc_main.h b/channels/drdynvc/client/drdynvc_main.h index 5ded6132f..b0785b969 100644 --- a/channels/drdynvc/client/drdynvc_main.h +++ b/channels/drdynvc/client/drdynvc_main.h @@ -46,8 +46,8 @@ typedef struct wArrayList* plugin_names; wArrayList* plugins; - wArrayList* listeners; - wArrayList* channels; + wHashTable* listeners; + wHashTable* channelsById; wStreamPool* pool; } DVCMAN; diff --git a/libfreerdp/core/info.c b/libfreerdp/core/info.c index 773b87aa6..a2ffbd3a3 100644 --- a/libfreerdp/core/info.c +++ b/libfreerdp/core/info.c @@ -1098,13 +1098,6 @@ static BOOL rdp_recv_logon_info_v2(rdpRdp* rdp, wStream* s, logon_info* info) logonInfoV2TotalSize, Size); return FALSE; } - else - { - WLog_WARN(TAG, - "[SERVER-BUG] 2.2.10.1.1.2 Logon Info Version 2 (TS_LOGON_INFO_VERSION_2) " - "Size expected %" PRIu32 " bytes, got %" PRIu32 ", ignoring", - logonInfoV2TotalSize, Size); - } } Stream_Read_UINT32(s, info->sessionId); /* SessionId (4 bytes) */ diff --git a/winpr/include/winpr/collections.h b/winpr/include/winpr/collections.h index 9d156606b..40c8abc99 100644 --- a/winpr/include/winpr/collections.h +++ b/winpr/include/winpr/collections.h @@ -356,6 +356,8 @@ extern "C" WINPR_API wHashTable* HashTable_New(BOOL synchronized); WINPR_API void HashTable_Free(wHashTable* table); + WINPR_API void HashTable_Lock(wHashTable* table); + WINPR_API void HashTable_Unlock(wHashTable* table); WINPR_API wObject* HashTable_KeyObject(wHashTable* table); WINPR_API wObject* HashTable_ValueObject(wHashTable* table); diff --git a/winpr/libwinpr/utils/collections/HashTable.c b/winpr/libwinpr/utils/collections/HashTable.c index 5335e1894..a6b78427a 100644 --- a/winpr/libwinpr/utils/collections/HashTable.c +++ b/winpr/libwinpr/utils/collections/HashTable.c @@ -829,6 +829,18 @@ void HashTable_Free(wHashTable* table) free(table); } +void HashTable_Lock(wHashTable* table) +{ + WINPR_ASSERT(table); + EnterCriticalSection(&table->lock); +} + +void HashTable_Unlock(wHashTable* table) +{ + WINPR_ASSERT(table); + LeaveCriticalSection(&table->lock); +} + wObject* HashTable_KeyObject(wHashTable* table) { WINPR_ASSERT(table);