@@ -96,35 +96,57 @@ static int linuxkm_test_aesxts(void);
9696#include <wolfssl/wolfcrypt/aes.h>
9797
9898struct km_AesCtx {
99- Aes * aes ; /* must be pointer to control alignment, needed for AESNI. */
100- u8 key [AES_MAX_KEY_SIZE / 8 ];
101- unsigned int keylen ;
99+ Aes * aes_encrypt ; /* must be pointer to control alignment, needed for AESNI. */
100+ Aes * aes_decrypt ; /* same. */
102101};
103102
104- static inline void km_ForceZero (struct km_AesCtx * ctx )
105- {
106- memzero_explicit (ctx -> key , sizeof (ctx -> key ));
107- ctx -> keylen = 0 ;
108- }
109-
110103#if defined(LINUXKM_LKCAPI_REGISTER_ALL ) || \
111104 defined(LINUXKM_LKCAPI_REGISTER_AESCBC ) || \
112105 defined(LINUXKM_LKCAPI_REGISTER_AESCFB ) || \
113106 defined(LINUXKM_LKCAPI_REGISTER_AESGCM )
114107
115- static int km_AesInitCommon (struct km_AesCtx * ctx , const char * name )
108+ static void km_AesExitCommon (struct km_AesCtx * ctx );
109+
110+ static int km_AesInitCommon (struct km_AesCtx * ctx , const char * name , int need_decryption )
116111{
117112 int err ;
118113
119- ctx -> aes = (Aes * )malloc (sizeof (* ctx -> aes ));
114+ ctx -> aes_encrypt = (Aes * )malloc (sizeof (* ctx -> aes_encrypt ));
120115
121- if (! ctx -> aes )
116+ if (! ctx -> aes_encrypt ) {
117+ pr_err ("error: km_AesInitCommon %s failed: %d\n" , name , MEMORY_E );
122118 return MEMORY_E ;
119+ }
123120
124- err = wc_AesInit (ctx -> aes , NULL , INVALID_DEVID );
121+ err = wc_AesInit (ctx -> aes_encrypt , NULL , INVALID_DEVID );
125122
126123 if (unlikely (err )) {
127124 pr_err ("error: km_AesInitCommon %s failed: %d\n" , name , err );
125+ free (ctx -> aes_encrypt );
126+ ctx -> aes_encrypt = NULL ;
127+ return err ;
128+ }
129+
130+ if (! need_decryption ) {
131+ ctx -> aes_decrypt = NULL ;
132+ return 0 ;
133+ }
134+
135+ ctx -> aes_decrypt = (Aes * )malloc (sizeof (* ctx -> aes_decrypt ));
136+
137+ if (! ctx -> aes_encrypt ) {
138+ pr_err ("error: km_AesInitCommon %s failed: %d\n" , name , MEMORY_E );
139+ km_AesExitCommon (ctx );
140+ return MEMORY_E ;
141+ }
142+
143+ err = wc_AesInit (ctx -> aes_decrypt , NULL , INVALID_DEVID );
144+
145+ if (unlikely (err )) {
146+ pr_err ("error: km_AesInitCommon %s failed: %d\n" , name , err );
147+ free (ctx -> aes_decrypt );
148+ ctx -> aes_decrypt = NULL ;
149+ km_AesExitCommon (ctx );
128150 return err ;
129151 }
130152
@@ -133,26 +155,38 @@ static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
133155
134156static void km_AesExitCommon (struct km_AesCtx * ctx )
135157{
136- wc_AesFree (ctx -> aes );
137- free (ctx -> aes );
138- ctx -> aes = NULL ;
139- km_ForceZero (ctx );
158+ if (ctx -> aes_encrypt ) {
159+ wc_AesFree (ctx -> aes_encrypt );
160+ free (ctx -> aes_encrypt );
161+ ctx -> aes_encrypt = NULL ;
162+ }
163+ if (ctx -> aes_decrypt ) {
164+ wc_AesFree (ctx -> aes_decrypt );
165+ free (ctx -> aes_decrypt );
166+ ctx -> aes_decrypt = NULL ;
167+ }
140168}
141169
142170static int km_AesSetKeyCommon (struct km_AesCtx * ctx , const u8 * in_key ,
143171 unsigned int key_len , const char * name )
144172{
145173 int err ;
146174
147- err = wc_AesSetKey (ctx -> aes , in_key , key_len , NULL , 0 );
175+ err = wc_AesSetKey (ctx -> aes_encrypt , in_key , key_len , NULL , AES_ENCRYPTION );
148176
149177 if (unlikely (err )) {
150178 pr_err ("error: km_AesSetKeyCommon %s failed: %d\n" , name , err );
151179 return err ;
152180 }
153181
154- XMEMCPY (ctx -> key , in_key , key_len );
155- ctx -> keylen = key_len ;
182+ if (ctx -> aes_decrypt ) {
183+ err = wc_AesSetKey (ctx -> aes_decrypt , in_key , key_len , NULL , AES_DECRYPTION );
184+
185+ if (unlikely (err )) {
186+ pr_err ("error: km_AesSetKeyCommon %s failed: %d\n" , name , err );
187+ return err ;
188+ }
189+ }
156190
157191 return 0 ;
158192}
@@ -161,25 +195,12 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
161195 defined(LINUXKM_LKCAPI_REGISTER_AESCBC ) || \
162196 defined(LINUXKM_LKCAPI_REGISTER_AESCFB )
163197
164- static int km_AesInit (struct crypto_skcipher * tfm )
165- {
166- struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
167- return km_AesInitCommon (ctx , WOLFKM_AESCBC_DRIVER );
168- }
169-
170198static void km_AesExit (struct crypto_skcipher * tfm )
171199{
172200 struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
173201 km_AesExitCommon (ctx );
174202}
175203
176- static int km_AesSetKey (struct crypto_skcipher * tfm , const u8 * in_key ,
177- unsigned int key_len )
178- {
179- struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
180- return km_AesSetKeyCommon (ctx , in_key , key_len , WOLFKM_AESCBC_DRIVER );
181- }
182-
183204#endif /* LINUXKM_LKCAPI_REGISTER_ALL ||
184205 * LINUXKM_LKCAPI_REGISTER_AESCBC ||
185206 * LINUXKM_LKCAPI_REGISTER_AESCFB
@@ -192,6 +213,19 @@ static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
192213#if defined(HAVE_AES_CBC ) && \
193214 (defined(LINUXKM_LKCAPI_REGISTER_ALL ) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC ))
194215
216+ static int km_AesCbcInit (struct crypto_skcipher * tfm )
217+ {
218+ struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
219+ return km_AesInitCommon (ctx , WOLFKM_AESCBC_DRIVER , 1 );
220+ }
221+
222+ static int km_AesCbcSetKey (struct crypto_skcipher * tfm , const u8 * in_key ,
223+ unsigned int key_len )
224+ {
225+ struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
226+ return km_AesSetKeyCommon (ctx , in_key , key_len , WOLFKM_AESCBC_DRIVER );
227+ }
228+
195229static int km_AesCbcEncrypt (struct skcipher_request * req )
196230{
197231 struct crypto_skcipher * tfm = NULL ;
@@ -206,15 +240,14 @@ static int km_AesCbcEncrypt(struct skcipher_request *req)
206240 err = skcipher_walk_virt (& walk , req , false);
207241
208242 while ((nbytes = walk .nbytes )) {
209- err = wc_AesSetKey (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
210- AES_ENCRYPTION );
243+ err = wc_AesSetIV (ctx -> aes_encrypt , walk .iv );
211244
212245 if (unlikely (err )) {
213- pr_err ("wc_AesSetKey failed: %d\n" , err );
246+ pr_err ("wc_AesSetIV failed: %d\n" , err );
214247 return err ;
215248 }
216249
217- err = wc_AesCbcEncrypt (ctx -> aes , walk .dst .virt .addr ,
250+ err = wc_AesCbcEncrypt (ctx -> aes_encrypt , walk .dst .virt .addr ,
218251 walk .src .virt .addr , nbytes );
219252
220253 if (unlikely (err )) {
@@ -242,15 +275,14 @@ static int km_AesCbcDecrypt(struct skcipher_request *req)
242275 err = skcipher_walk_virt (& walk , req , false);
243276
244277 while ((nbytes = walk .nbytes )) {
245- err = wc_AesSetKey (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
246- AES_DECRYPTION );
278+ err = wc_AesSetIV (ctx -> aes_decrypt , walk .iv );
247279
248280 if (unlikely (err )) {
249281 pr_err ("wc_AesSetKey failed" );
250282 return err ;
251283 }
252284
253- err = wc_AesCbcDecrypt (ctx -> aes , walk .dst .virt .addr ,
285+ err = wc_AesCbcDecrypt (ctx -> aes_decrypt , walk .dst .virt .addr ,
254286 walk .src .virt .addr , nbytes );
255287
256288 if (unlikely (err )) {
@@ -271,12 +303,12 @@ static struct skcipher_alg cbcAesAlg = {
271303 .base .cra_blocksize = AES_BLOCK_SIZE ,
272304 .base .cra_ctxsize = sizeof (struct km_AesCtx ),
273305 .base .cra_module = THIS_MODULE ,
274- .init = km_AesInit ,
306+ .init = km_AesCbcInit ,
275307 .exit = km_AesExit ,
276308 .min_keysize = AES_128_KEY_SIZE ,
277309 .max_keysize = AES_256_KEY_SIZE ,
278310 .ivsize = AES_BLOCK_SIZE ,
279- .setkey = km_AesSetKey ,
311+ .setkey = km_AesCbcSetKey ,
280312 .encrypt = km_AesCbcEncrypt ,
281313 .decrypt = km_AesCbcDecrypt ,
282314};
@@ -289,6 +321,19 @@ static int cbcAesAlg_loaded = 0;
289321#if defined(WOLFSSL_AES_CFB ) && \
290322 (defined(LINUXKM_LKCAPI_REGISTER_ALL ) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB ))
291323
324+ static int km_AesCfbInit (struct crypto_skcipher * tfm )
325+ {
326+ struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
327+ return km_AesInitCommon (ctx , WOLFKM_AESCFB_DRIVER , 0 );
328+ }
329+
330+ static int km_AesCfbSetKey (struct crypto_skcipher * tfm , const u8 * in_key ,
331+ unsigned int key_len )
332+ {
333+ struct km_AesCtx * ctx = crypto_skcipher_ctx (tfm );
334+ return km_AesSetKeyCommon (ctx , in_key , key_len , WOLFKM_AESCFB_DRIVER );
335+ }
336+
292337static int km_AesCfbEncrypt (struct skcipher_request * req )
293338{
294339 struct crypto_skcipher * tfm = NULL ;
@@ -303,15 +348,14 @@ static int km_AesCfbEncrypt(struct skcipher_request *req)
303348 err = skcipher_walk_virt (& walk , req , false);
304349
305350 while ((nbytes = walk .nbytes )) {
306- err = wc_AesSetKey (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
307- AES_ENCRYPTION );
351+ err = wc_AesSetIV (ctx -> aes_encrypt , walk .iv );
308352
309353 if (unlikely (err )) {
310354 pr_err ("wc_AesSetKey failed: %d\n" , err );
311355 return err ;
312356 }
313357
314- err = wc_AesCfbEncrypt (ctx -> aes , walk .dst .virt .addr ,
358+ err = wc_AesCfbEncrypt (ctx -> aes_encrypt , walk .dst .virt .addr ,
315359 walk .src .virt .addr , nbytes );
316360
317361 if (unlikely (err )) {
@@ -339,15 +383,14 @@ static int km_AesCfbDecrypt(struct skcipher_request *req)
339383 err = skcipher_walk_virt (& walk , req , false);
340384
341385 while ((nbytes = walk .nbytes )) {
342- err = wc_AesSetKey (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
343- AES_ENCRYPTION );
386+ err = wc_AesSetIV (ctx -> aes_encrypt , walk .iv );
344387
345388 if (unlikely (err )) {
346389 pr_err ("wc_AesSetKey failed" );
347390 return err ;
348391 }
349392
350- err = wc_AesCfbDecrypt (ctx -> aes , walk .dst .virt .addr ,
393+ err = wc_AesCfbDecrypt (ctx -> aes_encrypt , walk .dst .virt .addr ,
351394 walk .src .virt .addr , nbytes );
352395
353396 if (unlikely (err )) {
@@ -368,12 +411,12 @@ static struct skcipher_alg cfbAesAlg = {
368411 .base .cra_blocksize = AES_BLOCK_SIZE ,
369412 .base .cra_ctxsize = sizeof (struct km_AesCtx ),
370413 .base .cra_module = THIS_MODULE ,
371- .init = km_AesInit ,
414+ .init = km_AesCfbInit ,
372415 .exit = km_AesExit ,
373416 .min_keysize = AES_128_KEY_SIZE ,
374417 .max_keysize = AES_256_KEY_SIZE ,
375418 .ivsize = AES_BLOCK_SIZE ,
376- .setkey = km_AesSetKey ,
419+ .setkey = km_AesCfbSetKey ,
377420 .encrypt = km_AesCfbEncrypt ,
378421 .decrypt = km_AesCfbDecrypt ,
379422};
@@ -390,8 +433,7 @@ static int cfbAesAlg_loaded = 0;
390433static int km_AesGcmInit (struct crypto_aead * tfm )
391434{
392435 struct km_AesCtx * ctx = crypto_aead_ctx (tfm );
393- km_ForceZero (ctx );
394- return km_AesInitCommon (ctx , WOLFKM_AESGCM_DRIVER );
436+ return km_AesInitCommon (ctx , WOLFKM_AESGCM_DRIVER , 0 );
395437}
396438
397439static void km_AesGcmExit (struct crypto_aead * tfm )
@@ -403,8 +445,16 @@ static void km_AesGcmExit(struct crypto_aead * tfm)
403445static int km_AesGcmSetKey (struct crypto_aead * tfm , const u8 * in_key ,
404446 unsigned int key_len )
405447{
448+ int err ;
406449 struct km_AesCtx * ctx = crypto_aead_ctx (tfm );
407- return km_AesSetKeyCommon (ctx , in_key , key_len , WOLFKM_AESGCM_DRIVER );
450+
451+ err = wc_AesGcmSetKey (ctx -> aes_encrypt , in_key , key_len );
452+
453+ if (err ) {
454+ pr_err ("error: km_AesGcmSetKey %s failed: %d\n" , WOLFKM_AESGCM_DRIVER , err );
455+ }
456+
457+ return err ;
408458}
409459
410460static int km_AesGcmSetAuthsize (struct crypto_aead * tfm , unsigned int authsize )
@@ -454,7 +504,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
454504 return -1 ;
455505 }
456506
457- err = wc_AesGcmInit (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
507+ err = wc_AesGcmInit (ctx -> aes_encrypt , NULL /* key */ , 0 /* keylen */ , walk .iv ,
458508 AES_BLOCK_SIZE );
459509 if (unlikely (err )) {
460510 pr_err ("error: wc_AesGcmInit failed with return code %d.\n" , err );
@@ -467,7 +517,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
467517 return err ;
468518 }
469519
470- err = wc_AesGcmEncryptUpdate (ctx -> aes , NULL , NULL , 0 , assoc , assocLeft );
520+ err = wc_AesGcmEncryptUpdate (ctx -> aes_encrypt , NULL , NULL , 0 , assoc , assocLeft );
471521 assocLeft -= assocLeft ;
472522 scatterwalk_unmap (assoc );
473523 assoc = NULL ;
@@ -483,7 +533,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
483533 if (likely (cryptLeft && nbytes )) {
484534 n = cryptLeft < nbytes ? cryptLeft : nbytes ;
485535
486- err = wc_AesGcmEncryptUpdate (ctx -> aes , walk .dst .virt .addr ,
536+ err = wc_AesGcmEncryptUpdate (ctx -> aes_encrypt , walk .dst .virt .addr ,
487537 walk .src .virt .addr , cryptLeft , NULL , 0 );
488538 nbytes -= n ;
489539 cryptLeft -= n ;
@@ -497,7 +547,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
497547 err = skcipher_walk_done (& walk , nbytes );
498548 }
499549
500- err = wc_AesGcmEncryptFinal (ctx -> aes , authTag , tfm -> authsize );
550+ err = wc_AesGcmEncryptFinal (ctx -> aes_encrypt , authTag , tfm -> authsize );
501551 if (unlikely (err )) {
502552 pr_err ("error: wc_AesGcmEncryptFinal failed with return code %d\n" , err );
503553 return err ;
@@ -542,7 +592,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
542592 return -1 ;
543593 }
544594
545- err = wc_AesGcmInit (ctx -> aes , ctx -> key , ctx -> keylen , walk .iv ,
595+ err = wc_AesGcmInit (ctx -> aes_encrypt , NULL /* key */ , 0 /* keylen */ , walk .iv ,
546596 AES_BLOCK_SIZE );
547597 if (unlikely (err )) {
548598 pr_err ("error: wc_AesGcmInit failed with return code %d.\n" , err );
@@ -555,7 +605,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
555605 return err ;
556606 }
557607
558- err = wc_AesGcmDecryptUpdate (ctx -> aes , NULL , NULL , 0 , assoc , assocLeft );
608+ err = wc_AesGcmDecryptUpdate (ctx -> aes_encrypt , NULL , NULL , 0 , assoc , assocLeft );
559609 assocLeft -= assocLeft ;
560610 scatterwalk_unmap (assoc );
561611 assoc = NULL ;
@@ -571,7 +621,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
571621 if (likely (cryptLeft && nbytes )) {
572622 n = cryptLeft < nbytes ? cryptLeft : nbytes ;
573623
574- err = wc_AesGcmDecryptUpdate (ctx -> aes , walk .dst .virt .addr ,
624+ err = wc_AesGcmDecryptUpdate (ctx -> aes_encrypt , walk .dst .virt .addr ,
575625 walk .src .virt .addr , cryptLeft , NULL , 0 );
576626 nbytes -= n ;
577627 cryptLeft -= n ;
@@ -585,7 +635,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
585635 err = skcipher_walk_done (& walk , nbytes );
586636 }
587637
588- err = wc_AesGcmDecryptFinal (ctx -> aes , origAuthTag , tfm -> authsize );
638+ err = wc_AesGcmDecryptFinal (ctx -> aes_encrypt , origAuthTag , tfm -> authsize );
589639 if (unlikely (err )) {
590640 pr_err ("error: wc_AesGcmDecryptFinal failed with return code %d\n" , err );
591641
0 commit comments