diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index be53d8a40..b44818974 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -96,8 +96,8 @@ struct s_http_response size_t BodyLength; BYTE* BodyContent; - wListDictionary* Authenticates; - wListDictionary* SetCookie; + wHashTable* Authenticates; + wHashTable* SetCookie; wStream* data; }; @@ -910,10 +910,10 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* if (!authScheme) return FALSE; - authValue = NULL; + authValue = ""; } - status = ListDictionary_Add(response->Authenticates, authScheme, authValue); + status = HashTable_Insert(response->Authenticates, authScheme, authValue); } else if (_stricmp(name, "Set-Cookie") == 0) { @@ -961,7 +961,7 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* return FALSE; } - status = ListDictionary_Add(response->SetCookie, CookieName, CookieValue); + status = HashTable_Insert(response->SetCookie, CookieName, CookieValue); } return status; @@ -1539,16 +1539,20 @@ const BYTE* http_response_get_body(const HttpResponse* response) return response->BodyContent; } -static BOOL set_compare(wListDictionary* dict) +static wHashTable* HashTable_New_String(void) { - WINPR_ASSERT(dict); - wObject* key = ListDictionary_KeyObject(dict); - wObject* value = ListDictionary_KeyObject(dict); - if (!key || !value) - return FALSE; - key->fnObjectEquals = strings_equals_nocase; - value->fnObjectEquals = strings_equals_nocase; - return TRUE; + wHashTable* table = HashTable_New(FALSE); + if (!table) + return NULL; + + if (!HashTable_SetupForStringData(table, TRUE)) + { + HashTable_Free(table); + return NULL; + } + HashTable_KeyObject(table)->fnObjectEquals = strings_equals_nocase; + HashTable_ValueObject(table)->fnObjectEquals = strings_equals_nocase; + return table; } HttpResponse* http_response_new(void) @@ -1558,22 +1562,16 @@ HttpResponse* http_response_new(void) if (!response) return NULL; - response->Authenticates = ListDictionary_New(FALSE); + response->Authenticates = HashTable_New_String(); if (!response->Authenticates) goto fail; - if (!set_compare(response->Authenticates)) - goto fail; - - response->SetCookie = ListDictionary_New(FALSE); + response->SetCookie = HashTable_New_String(); if (!response->SetCookie) goto fail; - if (!set_compare(response->SetCookie)) - goto fail; - response->data = Stream_New(NULL, 2048); if (!response->data) @@ -1595,8 +1593,8 @@ void http_response_free(HttpResponse* response) return; free((void*)response->lines); - ListDictionary_Free(response->Authenticates); - ListDictionary_Free(response->SetCookie); + HashTable_Free(response->Authenticates); + HashTable_Free(response->SetCookie); Stream_Free(response->data, TRUE); free(response); } @@ -1645,10 +1643,7 @@ const char* http_response_get_auth_token(const HttpResponse* response, const cha if (!response || !method) return NULL; - if (!ListDictionary_Contains(response->Authenticates, method)) - return NULL; - - return ListDictionary_GetItemValue(response->Authenticates, method); + return HashTable_GetItemValue(response->Authenticates, method); } const char* http_response_get_setcookie(const HttpResponse* response, const char* cookie) @@ -1656,10 +1651,7 @@ const char* http_response_get_setcookie(const HttpResponse* response, const char if (!response || !cookie) return NULL; - if (!ListDictionary_Contains(response->SetCookie, cookie)) - return NULL; - - return ListDictionary_GetItemValue(response->SetCookie, cookie); + return HashTable_GetItemValue(response->SetCookie, cookie); } TRANSFER_ENCODING http_response_get_transfer_encoding(const HttpResponse* response) @@ -1752,3 +1744,24 @@ BOOL http_request_append_header(wStream* stream, const char* param, free(str); return rc; } + +static BOOL extract_cookie(const void* pkey, void* pvalue, void* arg) +{ + const char* key = pkey; + const char* value = pvalue; + HttpContext* context = arg; + + WINPR_ASSERT(arg); + WINPR_ASSERT(key); + WINPR_ASSERT(value); + + return http_context_set_cookie(context, key, value); +} + +BOOL http_response_extract_cookies(const HttpResponse* response, HttpContext* context) +{ + WINPR_ASSERT(response); + WINPR_ASSERT(context); + + return HashTable_Foreach(response->SetCookie, extract_cookie, context); +} diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index 954e669f7..07d4556ee 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -116,6 +116,7 @@ FREERDP_LOCAL void http_response_free(HttpResponse* response); WINPR_ATTR_MALLOC(http_response_free, 1) FREERDP_LOCAL HttpResponse* http_response_new(void); +WINPR_ATTR_MALLOC(http_response_free, 1) FREERDP_LOCAL HttpResponse* http_response_recv(rdpTls* tls, BOOL readContentLength); FREERDP_LOCAL UINT16 http_response_get_status_code(const HttpResponse* response); @@ -125,6 +126,8 @@ FREERDP_LOCAL const char* http_response_get_auth_token(const HttpResponse* respo const char* method); FREERDP_LOCAL const char* http_response_get_setcookie(const HttpResponse* response, const char* cookie); +FREERDP_LOCAL BOOL http_response_extract_cookies(const HttpResponse* response, + HttpContext* context); FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(const HttpResponse* response); FREERDP_LOCAL BOOL http_response_is_websocket(const HttpContext* http, const HttpResponse* response); diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index f4c8cff9e..701846064 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -1361,6 +1361,8 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* return FALSE; } + (void)http_response_extract_cookies(response, rdg->http); + const UINT16 StatusCode = http_response_get_status_code(response); switch (StatusCode) { @@ -1409,6 +1411,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* *rpcFallback = TRUE; return FALSE; } + (void)http_response_extract_cookies(response, rdg->http); } } credssp_auth_free(rdg->auth); @@ -1430,6 +1433,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char* *rpcFallback = TRUE; return FALSE; } + (void)http_response_extract_cookies(response, rdg->http); } const UINT16 statusCode = http_response_get_status_code(response); diff --git a/libfreerdp/core/gateway/wst.c b/libfreerdp/core/gateway/wst.c index 6019cb7a5..0a32d7db9 100644 --- a/libfreerdp/core/gateway/wst.c +++ b/libfreerdp/core/gateway/wst.c @@ -406,6 +406,7 @@ static BOOL wst_handle_ok_or_forbidden(rdpWst* wst, HttpResponse** ppresponse, D if (!*ppresponse) return FALSE; + (void)http_response_extract_cookies(*ppresponse, wst->http); *pStatusCode = http_response_get_status_code(*ppresponse); } @@ -432,6 +433,8 @@ static BOOL wst_handle_denied(rdpWst* wst, HttpResponse** ppresponse, UINT16* pS if (!*ppresponse) return FALSE; + (void)http_response_extract_cookies(*ppresponse, wst->http); + while (!credssp_auth_is_complete(wst->auth)) { if (!wst_recv_auth_token(wst->auth, *ppresponse)) @@ -446,6 +449,7 @@ static BOOL wst_handle_denied(rdpWst* wst, HttpResponse** ppresponse, UINT16* pS *ppresponse = http_response_recv(wst->tls, TRUE); if (!*ppresponse) return FALSE; + (void)http_response_extract_cookies(*ppresponse, wst->http); } } *pStatusCode = http_response_get_status_code(*ppresponse); @@ -521,6 +525,7 @@ BOOL wst_connect(rdpWst* wst, DWORD timeout) freerdp_set_last_error_if_not(wst->context, FREERDP_ERROR_CONNECT_FAILED); return FALSE; } + (void)http_response_extract_cookies(response, wst->http); UINT16 StatusCode = http_response_get_status_code(response); BOOL success = TRUE;