Skip to content

Commit 957fc74

Browse files
committed
linuxkm/lkcapi_glue.c: refactor AES-CBC, AES-CFB, and AES-GCM glue around struct km_AesCtx with separate aes_encrypt and aes_decrypt Aes pointers, and no cached key, to avoid AesSetKey operations at encrypt/decrypt time.
1 parent 8ae031a commit 957fc74

1 file changed

Lines changed: 111 additions & 61 deletions

File tree

linuxkm/lkcapi_glue.c

Lines changed: 111 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -96,35 +96,57 @@ static int linuxkm_test_aesxts(void);
9696
#include <wolfssl/wolfcrypt/aes.h>
9797

9898
struct km_AesCtx {
99-
Aes *aes; /* must be pointer to control alignment, needed for AESNI. */
100-
u8 key[AES_MAX_KEY_SIZE / 8];
101-
unsigned int keylen;
99+
Aes *aes_encrypt; /* must be pointer to control alignment, needed for AESNI. */
100+
Aes *aes_decrypt; /* same. */
102101
};
103102

104-
static inline void km_ForceZero(struct km_AesCtx * ctx)
105-
{
106-
memzero_explicit(ctx->key, sizeof(ctx->key));
107-
ctx->keylen = 0;
108-
}
109-
110103
#if defined(LINUXKM_LKCAPI_REGISTER_ALL) || \
111104
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
112105
defined(LINUXKM_LKCAPI_REGISTER_AESCFB) || \
113106
defined(LINUXKM_LKCAPI_REGISTER_AESGCM)
114107

115-
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
108+
static void km_AesExitCommon(struct km_AesCtx * ctx);
109+
110+
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name, int need_decryption)
116111
{
117112
int err;
118113

119-
ctx->aes = (Aes *)malloc(sizeof(*ctx->aes));
114+
ctx->aes_encrypt = (Aes *)malloc(sizeof(*ctx->aes_encrypt));
120115

121-
if (! ctx->aes)
116+
if (! ctx->aes_encrypt) {
117+
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
122118
return MEMORY_E;
119+
}
123120

124-
err = wc_AesInit(ctx->aes, NULL, INVALID_DEVID);
121+
err = wc_AesInit(ctx->aes_encrypt, NULL, INVALID_DEVID);
125122

126123
if (unlikely(err)) {
127124
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
125+
free(ctx->aes_encrypt);
126+
ctx->aes_encrypt = NULL;
127+
return err;
128+
}
129+
130+
if (! need_decryption) {
131+
ctx->aes_decrypt = NULL;
132+
return 0;
133+
}
134+
135+
ctx->aes_decrypt = (Aes *)malloc(sizeof(*ctx->aes_decrypt));
136+
137+
if (! ctx->aes_encrypt) {
138+
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
139+
km_AesExitCommon(ctx);
140+
return MEMORY_E;
141+
}
142+
143+
err = wc_AesInit(ctx->aes_decrypt, NULL, INVALID_DEVID);
144+
145+
if (unlikely(err)) {
146+
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
147+
free(ctx->aes_decrypt);
148+
ctx->aes_decrypt = NULL;
149+
km_AesExitCommon(ctx);
128150
return err;
129151
}
130152

@@ -133,26 +155,38 @@ static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
133155

134156
static void km_AesExitCommon(struct km_AesCtx * ctx)
135157
{
136-
wc_AesFree(ctx->aes);
137-
free(ctx->aes);
138-
ctx->aes = NULL;
139-
km_ForceZero(ctx);
158+
if (ctx->aes_encrypt) {
159+
wc_AesFree(ctx->aes_encrypt);
160+
free(ctx->aes_encrypt);
161+
ctx->aes_encrypt = NULL;
162+
}
163+
if (ctx->aes_decrypt) {
164+
wc_AesFree(ctx->aes_decrypt);
165+
free(ctx->aes_decrypt);
166+
ctx->aes_decrypt = NULL;
167+
}
140168
}
141169

142170
static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
143171
unsigned int key_len, const char * name)
144172
{
145173
int err;
146174

147-
err = wc_AesSetKey(ctx->aes, in_key, key_len, NULL, 0);
175+
err = wc_AesSetKey(ctx->aes_encrypt, in_key, key_len, NULL, AES_ENCRYPTION);
148176

149177
if (unlikely(err)) {
150178
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
151179
return err;
152180
}
153181

154-
XMEMCPY(ctx->key, in_key, key_len);
155-
ctx->keylen = key_len;
182+
if (ctx->aes_decrypt) {
183+
err = wc_AesSetKey(ctx->aes_decrypt, in_key, key_len, NULL, AES_DECRYPTION);
184+
185+
if (unlikely(err)) {
186+
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
187+
return err;
188+
}
189+
}
156190

157191
return 0;
158192
}
@@ -161,25 +195,12 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
161195
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
162196
defined(LINUXKM_LKCAPI_REGISTER_AESCFB)
163197

164-
static int km_AesInit(struct crypto_skcipher *tfm)
165-
{
166-
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
167-
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER);
168-
}
169-
170198
static void km_AesExit(struct crypto_skcipher *tfm)
171199
{
172200
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
173201
km_AesExitCommon(ctx);
174202
}
175203

176-
static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
177-
unsigned int key_len)
178-
{
179-
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
180-
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
181-
}
182-
183204
#endif /* LINUXKM_LKCAPI_REGISTER_ALL ||
184205
* LINUXKM_LKCAPI_REGISTER_AESCBC ||
185206
* LINUXKM_LKCAPI_REGISTER_AESCFB
@@ -192,6 +213,19 @@ static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
192213
#if defined(HAVE_AES_CBC) && \
193214
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC))
194215

216+
static int km_AesCbcInit(struct crypto_skcipher *tfm)
217+
{
218+
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
219+
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER, 1);
220+
}
221+
222+
static int km_AesCbcSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
223+
unsigned int key_len)
224+
{
225+
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
226+
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
227+
}
228+
195229
static int km_AesCbcEncrypt(struct skcipher_request *req)
196230
{
197231
struct crypto_skcipher * tfm = NULL;
@@ -206,15 +240,14 @@ static int km_AesCbcEncrypt(struct skcipher_request *req)
206240
err = skcipher_walk_virt(&walk, req, false);
207241

208242
while ((nbytes = walk.nbytes)) {
209-
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
210-
AES_ENCRYPTION);
243+
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
211244

212245
if (unlikely(err)) {
213-
pr_err("wc_AesSetKey failed: %d\n", err);
246+
pr_err("wc_AesSetIV failed: %d\n", err);
214247
return err;
215248
}
216249

217-
err = wc_AesCbcEncrypt(ctx->aes, walk.dst.virt.addr,
250+
err = wc_AesCbcEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
218251
walk.src.virt.addr, nbytes);
219252

220253
if (unlikely(err)) {
@@ -242,15 +275,14 @@ static int km_AesCbcDecrypt(struct skcipher_request *req)
242275
err = skcipher_walk_virt(&walk, req, false);
243276

244277
while ((nbytes = walk.nbytes)) {
245-
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
246-
AES_DECRYPTION);
278+
err = wc_AesSetIV(ctx->aes_decrypt, walk.iv);
247279

248280
if (unlikely(err)) {
249281
pr_err("wc_AesSetKey failed");
250282
return err;
251283
}
252284

253-
err = wc_AesCbcDecrypt(ctx->aes, walk.dst.virt.addr,
285+
err = wc_AesCbcDecrypt(ctx->aes_decrypt, walk.dst.virt.addr,
254286
walk.src.virt.addr, nbytes);
255287

256288
if (unlikely(err)) {
@@ -271,12 +303,12 @@ static struct skcipher_alg cbcAesAlg = {
271303
.base.cra_blocksize = AES_BLOCK_SIZE,
272304
.base.cra_ctxsize = sizeof(struct km_AesCtx),
273305
.base.cra_module = THIS_MODULE,
274-
.init = km_AesInit,
306+
.init = km_AesCbcInit,
275307
.exit = km_AesExit,
276308
.min_keysize = AES_128_KEY_SIZE,
277309
.max_keysize = AES_256_KEY_SIZE,
278310
.ivsize = AES_BLOCK_SIZE,
279-
.setkey = km_AesSetKey,
311+
.setkey = km_AesCbcSetKey,
280312
.encrypt = km_AesCbcEncrypt,
281313
.decrypt = km_AesCbcDecrypt,
282314
};
@@ -289,6 +321,19 @@ static int cbcAesAlg_loaded = 0;
289321
#if defined(WOLFSSL_AES_CFB) && \
290322
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB))
291323

324+
static int km_AesCfbInit(struct crypto_skcipher *tfm)
325+
{
326+
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
327+
return km_AesInitCommon(ctx, WOLFKM_AESCFB_DRIVER, 0);
328+
}
329+
330+
static int km_AesCfbSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
331+
unsigned int key_len)
332+
{
333+
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
334+
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCFB_DRIVER);
335+
}
336+
292337
static int km_AesCfbEncrypt(struct skcipher_request *req)
293338
{
294339
struct crypto_skcipher * tfm = NULL;
@@ -303,15 +348,14 @@ static int km_AesCfbEncrypt(struct skcipher_request *req)
303348
err = skcipher_walk_virt(&walk, req, false);
304349

305350
while ((nbytes = walk.nbytes)) {
306-
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
307-
AES_ENCRYPTION);
351+
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
308352

309353
if (unlikely(err)) {
310354
pr_err("wc_AesSetKey failed: %d\n", err);
311355
return err;
312356
}
313357

314-
err = wc_AesCfbEncrypt(ctx->aes, walk.dst.virt.addr,
358+
err = wc_AesCfbEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
315359
walk.src.virt.addr, nbytes);
316360

317361
if (unlikely(err)) {
@@ -339,15 +383,14 @@ static int km_AesCfbDecrypt(struct skcipher_request *req)
339383
err = skcipher_walk_virt(&walk, req, false);
340384

341385
while ((nbytes = walk.nbytes)) {
342-
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv,
343-
AES_ENCRYPTION);
386+
err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
344387

345388
if (unlikely(err)) {
346389
pr_err("wc_AesSetKey failed");
347390
return err;
348391
}
349392

350-
err = wc_AesCfbDecrypt(ctx->aes, walk.dst.virt.addr,
393+
err = wc_AesCfbDecrypt(ctx->aes_encrypt, walk.dst.virt.addr,
351394
walk.src.virt.addr, nbytes);
352395

353396
if (unlikely(err)) {
@@ -368,12 +411,12 @@ static struct skcipher_alg cfbAesAlg = {
368411
.base.cra_blocksize = AES_BLOCK_SIZE,
369412
.base.cra_ctxsize = sizeof(struct km_AesCtx),
370413
.base.cra_module = THIS_MODULE,
371-
.init = km_AesInit,
414+
.init = km_AesCfbInit,
372415
.exit = km_AesExit,
373416
.min_keysize = AES_128_KEY_SIZE,
374417
.max_keysize = AES_256_KEY_SIZE,
375418
.ivsize = AES_BLOCK_SIZE,
376-
.setkey = km_AesSetKey,
419+
.setkey = km_AesCfbSetKey,
377420
.encrypt = km_AesCfbEncrypt,
378421
.decrypt = km_AesCfbDecrypt,
379422
};
@@ -390,8 +433,7 @@ static int cfbAesAlg_loaded = 0;
390433
static int km_AesGcmInit(struct crypto_aead * tfm)
391434
{
392435
struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
393-
km_ForceZero(ctx);
394-
return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER);
436+
return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER, 0);
395437
}
396438

397439
static void km_AesGcmExit(struct crypto_aead * tfm)
@@ -403,8 +445,16 @@ static void km_AesGcmExit(struct crypto_aead * tfm)
403445
static int km_AesGcmSetKey(struct crypto_aead *tfm, const u8 *in_key,
404446
unsigned int key_len)
405447
{
448+
int err;
406449
struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
407-
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESGCM_DRIVER);
450+
451+
err = wc_AesGcmSetKey(ctx->aes_encrypt, in_key, key_len);
452+
453+
if (err) {
454+
pr_err("error: km_AesGcmSetKey %s failed: %d\n", WOLFKM_AESGCM_DRIVER, err);
455+
}
456+
457+
return err;
408458
}
409459

410460
static int km_AesGcmSetAuthsize(struct crypto_aead *tfm, unsigned int authsize)
@@ -454,7 +504,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
454504
return -1;
455505
}
456506

457-
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv,
507+
err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
458508
AES_BLOCK_SIZE);
459509
if (unlikely(err)) {
460510
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@@ -467,7 +517,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
467517
return err;
468518
}
469519

470-
err = wc_AesGcmEncryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft);
520+
err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
471521
assocLeft -= assocLeft;
472522
scatterwalk_unmap(assoc);
473523
assoc = NULL;
@@ -483,7 +533,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
483533
if (likely(cryptLeft && nbytes)) {
484534
n = cryptLeft < nbytes ? cryptLeft : nbytes;
485535

486-
err = wc_AesGcmEncryptUpdate(ctx->aes, walk.dst.virt.addr,
536+
err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
487537
walk.src.virt.addr, cryptLeft, NULL, 0);
488538
nbytes -= n;
489539
cryptLeft -= n;
@@ -497,7 +547,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
497547
err = skcipher_walk_done(&walk, nbytes);
498548
}
499549

500-
err = wc_AesGcmEncryptFinal(ctx->aes, authTag, tfm->authsize);
550+
err = wc_AesGcmEncryptFinal(ctx->aes_encrypt, authTag, tfm->authsize);
501551
if (unlikely(err)) {
502552
pr_err("error: wc_AesGcmEncryptFinal failed with return code %d\n", err);
503553
return err;
@@ -542,7 +592,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
542592
return -1;
543593
}
544594

545-
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv,
595+
err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
546596
AES_BLOCK_SIZE);
547597
if (unlikely(err)) {
548598
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@@ -555,7 +605,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
555605
return err;
556606
}
557607

558-
err = wc_AesGcmDecryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft);
608+
err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
559609
assocLeft -= assocLeft;
560610
scatterwalk_unmap(assoc);
561611
assoc = NULL;
@@ -571,7 +621,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
571621
if (likely(cryptLeft && nbytes)) {
572622
n = cryptLeft < nbytes ? cryptLeft : nbytes;
573623

574-
err = wc_AesGcmDecryptUpdate(ctx->aes, walk.dst.virt.addr,
624+
err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
575625
walk.src.virt.addr, cryptLeft, NULL, 0);
576626
nbytes -= n;
577627
cryptLeft -= n;
@@ -585,7 +635,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
585635
err = skcipher_walk_done(&walk, nbytes);
586636
}
587637

588-
err = wc_AesGcmDecryptFinal(ctx->aes, origAuthTag, tfm->authsize);
638+
err = wc_AesGcmDecryptFinal(ctx->aes_encrypt, origAuthTag, tfm->authsize);
589639
if (unlikely(err)) {
590640
pr_err("error: wc_AesGcmDecryptFinal failed with return code %d\n", err);
591641

0 commit comments

Comments
 (0)