diff --git a/winpr/libwinpr/crypto/cipher.c b/winpr/libwinpr/crypto/cipher.c index 7336bb9c3..ff27b3826 100644 --- a/winpr/libwinpr/crypto/cipher.c +++ b/winpr/libwinpr/crypto/cipher.c @@ -20,7 +20,6 @@ #include #include - #include #include "../log.h" @@ -45,6 +44,19 @@ * RC4 */ +struct winpr_rc4_ctx_private_st +{ + union + { +#if defined(WITH_OPENSSL) + EVP_CIPHER_CTX* ctx; +#endif +#if defined(WITH_MBEDTLS) && defined(MBEDTLS_ARC4_C) + mbedtls_arc4_context mctx; +#endif + } u; +}; + static WINPR_RC4_CTX* winpr_RC4_New_Internal(const BYTE* key, size_t keylen, BOOL override_fips) { WINPR_RC4_CTX* ctx = NULL; @@ -55,48 +67,51 @@ static WINPR_RC4_CTX* winpr_RC4_New_Internal(const BYTE* key, size_t keylen, BOO if (!key || (keylen == 0)) return NULL; + ctx = calloc(1, sizeof(WINPR_RC4_CTX)); + if (!ctx) + return NULL; + #if defined(WITH_OPENSSL) if (keylen > INT_MAX) - return NULL; + goto fail; - if (!(ctx = (WINPR_RC4_CTX*)EVP_CIPHER_CTX_new())) - return NULL; + ctx->u.ctx = EVP_CIPHER_CTX_new(); + if (!ctx->u.ctx) + goto fail; evp = EVP_rc4(); if (!evp) - return NULL; + goto fail; - EVP_CIPHER_CTX_init((EVP_CIPHER_CTX*)ctx); - if (EVP_EncryptInit_ex((EVP_CIPHER_CTX*)ctx, evp, NULL, NULL, NULL) != 1) - { - EVP_CIPHER_CTX_free((EVP_CIPHER_CTX*)ctx); - return NULL; - } + EVP_CIPHER_CTX_init(ctx->u.ctx); + if (EVP_EncryptInit_ex(ctx->u.ctx, evp, NULL, NULL, NULL) != 1) + goto fail; - /* EVP_CIPH_FLAG_NON_FIPS_ALLOW does not exist before openssl 1.0.1 */ + /* EVP_CIPH_FLAG_NON_FIPS_ALLOW does not exist before openssl 1.0.1 */ #if !(OPENSSL_VERSION_NUMBER < 0x10001000L) if (override_fips == TRUE) - EVP_CIPHER_CTX_set_flags((EVP_CIPHER_CTX*)ctx, EVP_CIPH_FLAG_NON_FIPS_ALLOW); + EVP_CIPHER_CTX_set_flags(ctx->u.ctx, EVP_CIPH_FLAG_NON_FIPS_ALLOW); #endif - EVP_CIPHER_CTX_set_key_length((EVP_CIPHER_CTX*)ctx, (int)keylen); - if (EVP_EncryptInit_ex((EVP_CIPHER_CTX*)ctx, NULL, NULL, key, NULL) != 1) - { - EVP_CIPHER_CTX_free((EVP_CIPHER_CTX*)ctx); - return NULL; - } + EVP_CIPHER_CTX_set_key_length(ctx->u.ctx, (int)keylen); + if (EVP_EncryptInit_ex(ctx->u.ctx, NULL, NULL, key, NULL) != 1) + goto fail; + #elif defined(WITH_MBEDTLS) && defined(MBEDTLS_ARC4_C) - if (!(ctx = (WINPR_RC4_CTX*)calloc(1, sizeof(mbedtls_arc4_context)))) - return NULL; + ctx->u.mctx = calloc(1, sizeof(mbedtls_arc4_context)) if (!ctx->u.mctx) goto fail; - mbedtls_arc4_init((mbedtls_arc4_context*)ctx); - mbedtls_arc4_setup((mbedtls_arc4_context*)ctx, key, (unsigned int)keylen); + mbedtls_arc4_init(ctx->u.mctx); + mbedtls_arc4_setup(ctx->u.mctx, key, (unsigned int)keylen); #endif return ctx; + +fail: + winpr_RC4_Free(ctx); + return NULL; } WINPR_RC4_CTX* winpr_RC4_New_Allow_FIPS(const BYTE* key, size_t keylen) @@ -111,18 +126,19 @@ WINPR_RC4_CTX* winpr_RC4_New(const BYTE* key, size_t keylen) BOOL winpr_RC4_Update(WINPR_RC4_CTX* ctx, size_t length, const BYTE* input, BYTE* output) { + WINPR_ASSERT(ctx); #if defined(WITH_OPENSSL) + WINPR_ASSERT(ctx->u.ctx); int outputLength; if (length > INT_MAX) return FALSE; - WINPR_ASSERT(ctx); - if (EVP_CipherUpdate((EVP_CIPHER_CTX*)ctx, output, &outputLength, input, (int)length) != 1) - return FALSE; + EVP_CipherUpdate(ctx->u.ctx, output, &outputLength, input, (int)length); return TRUE; #elif defined(WITH_MBEDTLS) && defined(MBEDTLS_ARC4_C) - if (mbedtls_arc4_crypt((mbedtls_arc4_context*)ctx, length, input, output) == 0) + WINPR_ASSERT(ctx->u.mctx); + if (mbedtls_arc4_crypt(ctx->u.mctx, length, input, output) == 0) return TRUE; #endif @@ -135,11 +151,11 @@ void winpr_RC4_Free(WINPR_RC4_CTX* ctx) return; #if defined(WITH_OPENSSL) - EVP_CIPHER_CTX_free((EVP_CIPHER_CTX*)ctx); + EVP_CIPHER_CTX_free(ctx->u.ctx); #elif defined(WITH_MBEDTLS) && defined(MBEDTLS_ARC4_C) - mbedtls_arc4_free((mbedtls_arc4_context*)ctx); - free(ctx); + mbedtls_arc4_free(ctx->u.mctx); #endif + free(ctx); } /**