diff --git a/server/proxy/pf_channel.c b/server/proxy/pf_channel.c index 50caef70d..858c4d7c9 100644 --- a/server/proxy/pf_channel.c +++ b/server/proxy/pf_channel.c @@ -40,6 +40,28 @@ struct _ChannelStateTracker proxyData* pdata; }; +static BOOL channelTracker_resetCurrentPacket(ChannelStateTracker* tracker) +{ + WINPR_ASSERT(tracker); + + BOOL create = TRUE; + if (tracker->currentPacket) + { + const size_t cap = Stream_Capacity(tracker->currentPacket); + if (cap < 1 * 1000 * 1000) + create = FALSE; + else + Stream_Free(tracker->currentPacket, TRUE); + } + + if (create) + tracker->currentPacket = Stream_New(NULL, 10 * 1024); + if (!tracker->currentPacket) + return FALSE; + Stream_SetPosition(tracker->currentPacket, 0); + return TRUE; +} + ChannelStateTracker* channelTracker_new(pServerStaticChannelContext* channel, ChannelTrackerPeekFn fn, void* data) { @@ -55,8 +77,7 @@ ChannelStateTracker* channelTracker_new(pServerStaticChannelContext* channel, if (!channelTracker_setCustomData(ret, data)) goto fail; - ret->currentPacket = Stream_New(NULL, 10 * 1024); - if (!ret->currentPacket) + if (!channelTracker_resetCurrentPacket(ret)) goto fail; return ret; @@ -79,45 +100,35 @@ PfChannelResult channelTracker_update(ChannelStateTracker* tracker, const BYTE* tracker->channel->channel_name, xsize, firstPacket, lastPacket); if (flags & CHANNEL_FLAG_FIRST) { - /* don't keep a too big currentPacket */ - if (Stream_Capacity(tracker->currentPacket) > 1 * 1000 * 1000) - { - Stream_Free(tracker->currentPacket, TRUE); - - tracker->currentPacket = Stream_New(NULL, 10 * 1024); - if (!tracker->currentPacket) - { - return PF_CHANNEL_RESULT_ERROR; - } - } - else - { - Stream_SetPosition(tracker->currentPacket, 0); - } - - tracker->currentPacketSize = totalSize; + if (!channelTracker_resetCurrentPacket(tracker)) + return FALSE; + channelTracker_setCurrentPacketSize(tracker, totalSize); tracker->currentPacketReceived = 0; tracker->currentPacketFragments = 0; } - if (tracker->currentPacketReceived + xsize > tracker->currentPacketSize) + const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker); + if (tracker->currentPacketReceived + xsize > currentPacketSize) WLog_INFO(TAG, "cumulated size is bigger (%" PRIuz ") than total size (%" PRIuz ")", - tracker->currentPacketReceived + xsize, tracker->currentPacketSize); + tracker->currentPacketReceived + xsize, currentPacketSize); tracker->currentPacketReceived += xsize; tracker->currentPacketFragments++; - switch (tracker->mode) + switch (channelTracker_getMode(tracker)) { case CHANNEL_TRACKER_PEEK: - if (!Stream_EnsureRemainingCapacity(tracker->currentPacket, xsize)) + { + wStream* currentPacket = channelTracker_getCurrentPacket(tracker); + if (!Stream_EnsureRemainingCapacity(currentPacket, xsize)) return PF_CHANNEL_RESULT_ERROR; - Stream_Write(tracker->currentPacket, xdata, xsize); + Stream_Write(currentPacket, xdata, xsize); WINPR_ASSERT(tracker->peekFn); result = tracker->peekFn(tracker, firstPacket, lastPacket); - break; + } + break; case CHANNEL_TRACKER_PASS: result = PF_CHANNEL_RESULT_PASS; break; @@ -128,10 +139,12 @@ PfChannelResult channelTracker_update(ChannelStateTracker* tracker, const BYTE* if (lastPacket) { - tracker->mode = CHANNEL_TRACKER_PEEK; - if (tracker->currentPacketReceived != tracker->currentPacketSize) + const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker); + channelTracker_setMode(tracker, CHANNEL_TRACKER_PEEK); + + if (tracker->currentPacketReceived != currentPacketSize) WLog_INFO(TAG, "cumulated size(%" PRIuz ") does not match total size (%" PRIuz ")", - tracker->currentPacketReceived, tracker->currentPacketSize); + tracker->currentPacketReceived, currentPacketSize); } return result; @@ -161,12 +174,13 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, UINT32 flags = CHANNEL_FLAG_FIRST; BOOL r; const char* direction = toBack ? "F->B" : "B->F"; + const size_t currentPacketSize = channelTracker_getCurrentPacketSize(t); + wStream* currentPacket = channelTracker_getCurrentPacket(t); WINPR_ASSERT(t); WLog_VRB(TAG, "channelTracker_flushCurrent(%s): %s sz=%" PRIuz " first=%d last=%d", - t->channel->channel_name, direction, Stream_GetPosition(t->currentPacket), first, - last); + t->channel->channel_name, direction, Stream_GetPosition(currentPacket), first, last); if (first) return PF_CHANNEL_RESULT_PASS; @@ -182,10 +196,10 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, ev.channel_id = channel->front_channel_id; ev.channel_name = channel->channel_name; - ev.data = Stream_Buffer(t->currentPacket); - ev.data_len = Stream_GetPosition(t->currentPacket); + ev.data = Stream_Buffer(currentPacket); + ev.data_len = Stream_GetPosition(currentPacket); ev.flags = flags; - ev.total_size = t->currentPacketSize; + ev.total_size = currentPacketSize; if (!pdata->pc->sendChannelData) return PF_CHANNEL_RESULT_ERROR; @@ -195,9 +209,9 @@ PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, } ps = pdata->ps; - r = ps->context.peer->SendChannelPacket( - ps->context.peer, channel->front_channel_id, t->currentPacketSize, flags, - Stream_Buffer(t->currentPacket), Stream_GetPosition(t->currentPacket)); + r = ps->context.peer->SendChannelPacket(ps->context.peer, channel->front_channel_id, + currentPacketSize, flags, Stream_Buffer(currentPacket), + Stream_GetPosition(currentPacket)); return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR; } @@ -321,3 +335,15 @@ void* channelTracker_getCustomData(ChannelStateTracker* tracker) WINPR_ASSERT(tracker); return tracker->trackerData; } + +size_t channelTracker_getCurrentPacketSize(ChannelStateTracker* tracker) +{ + WINPR_ASSERT(tracker); + return tracker->currentPacketSize; +} + +BOOL channelTracker_setCurrentPacketSize(ChannelStateTracker* tracker, size_t size) +{ + WINPR_ASSERT(tracker); + tracker->currentPacketSize = size; +} diff --git a/server/proxy/pf_channel.h b/server/proxy/pf_channel.h index 0168c9708..4498bd420 100644 --- a/server/proxy/pf_channel.h +++ b/server/proxy/pf_channel.h @@ -49,6 +49,9 @@ void* channelTracker_getCustomData(ChannelStateTracker* tracker); wStream* channelTracker_getCurrentPacket(ChannelStateTracker* tracker); +size_t channelTracker_getCurrentPacketSize(ChannelStateTracker* tracker); +BOOL channelTracker_setCurrentPacketSize(ChannelStateTracker* tracker, size_t size); + PfChannelResult channelTracker_update(ChannelStateTracker* tracker, const BYTE* xdata, size_t xsize, UINT32 flags, size_t totalSize);