Skip to content

Commit abc6edf

Browse files
authored
Merge pull request #7796 from SparkiDev/dtls_read_write_threaded
SSL asynchronous read/write and encrypt
2 parents 8803f3d + e4a661f commit abc6edf

8 files changed

Lines changed: 664 additions & 158 deletions

File tree

src/dtls13.c

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,17 @@ static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs)
341341
if (ssl->options.dtlsStateful)
342342
ssl->keys.dtls_expected_peer_handshake_number++;
343343

344-
/* we need to send ACKs on the last message of a flight that needs explicit
345-
acknowledgment */
346-
ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs);
344+
#ifdef WOLFSSL_RW_THREADED
345+
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
346+
#endif
347+
{
348+
/* we need to send ACKs on the last message of a flight that needs
349+
* explicit acknowledgment */
350+
ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs);
351+
#ifdef WOLFSSL_RW_THREADED
352+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
353+
#endif
354+
}
347355
}
348356

349357
int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
@@ -654,8 +662,17 @@ static void Dtls13RtxRecordUnlink(WOLFSSL* ssl, Dtls13RtxRecord** prevNext,
654662
Dtls13RtxRecord* r)
655663
{
656664
/* if r was at the tail of the list, update the tail pointer */
657-
if (r->next == NULL)
658-
ssl->dtls13Rtx.rtxRecordTailPtr = prevNext;
665+
if (r->next == NULL) {
666+
#ifdef WOLFSSL_RW_THREADED
667+
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
668+
#endif
669+
{
670+
ssl->dtls13Rtx.rtxRecordTailPtr = prevNext;
671+
#ifdef WOLFSSL_RW_THREADED
672+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
673+
#endif
674+
}
675+
}
659676

660677
/* unlink */
661678
*prevNext = r->next;
@@ -712,12 +729,20 @@ static int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)
712729

713730
WOLFSSL_ENTER("Dtls13RtxAddAck");
714731

715-
rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap);
716-
if (rn == NULL)
717-
return MEMORY_E;
732+
#ifdef WOLFSSL_RW_THREADED
733+
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
734+
#endif
735+
{
736+
rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap);
737+
if (rn == NULL)
738+
return MEMORY_E;
718739

719-
rn->next = ssl->dtls13Rtx.seenRecords;
720-
ssl->dtls13Rtx.seenRecords = rn;
740+
rn->next = ssl->dtls13Rtx.seenRecords;
741+
ssl->dtls13Rtx.seenRecords = rn;
742+
#ifdef WOLFSSL_RW_THREADED
743+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
744+
#endif
745+
}
721746

722747
return 0;
723748
}
@@ -730,15 +755,23 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl)
730755

731756
WOLFSSL_ENTER("Dtls13RtxFlushAcks");
732757

733-
list = ssl->dtls13Rtx.seenRecords;
758+
#ifdef WOLFSSL_RW_THREADED
759+
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
760+
#endif
761+
{
762+
list = ssl->dtls13Rtx.seenRecords;
734763

735-
while (list != NULL) {
736-
rn = list;
737-
list = rn->next;
738-
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
739-
}
764+
while (list != NULL) {
765+
rn = list;
766+
list = rn->next;
767+
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
768+
}
740769

741-
ssl->dtls13Rtx.seenRecords = NULL;
770+
ssl->dtls13Rtx.seenRecords = NULL;
771+
#ifdef WOLFSSL_RW_THREADED
772+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
773+
#endif
774+
}
742775
}
743776

744777
static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset)
@@ -2519,13 +2552,25 @@ static void Dtls13RtxRemoveRecord(WOLFSSL* ssl, w64wrapper epoch,
25192552
int Dtls13DoScheduledWork(WOLFSSL* ssl)
25202553
{
25212554
int ret;
2555+
int sendAcks;
25222556

25232557
WOLFSSL_ENTER("Dtls13DoScheduledWork");
25242558

25252559
ssl->dtls13SendingAckOrRtx = 1;
25262560

2527-
if (ssl->dtls13Rtx.sendAcks) {
2561+
#ifdef WOLFSSL_RW_THREADED
2562+
ret = wc_LockMutex(&ssl->dtls13Rtx.mutex);
2563+
if (ret < 0)
2564+
return ret;
2565+
#endif
2566+
sendAcks = ssl->dtls13Rtx.sendAcks;
2567+
if (sendAcks) {
25282568
ssl->dtls13Rtx.sendAcks = 0;
2569+
}
2570+
#ifdef WOLFSSL_RW_THREADED
2571+
ret = wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
2572+
#endif
2573+
if (sendAcks) {
25292574
ret = SendDtls13Ack(ssl);
25302575
if (ret != 0)
25312576
return ret;
@@ -2601,13 +2646,28 @@ static int Dtls13RtxHasKeyUpdateBuffered(WOLFSSL* ssl)
26012646
return 0;
26022647
}
26032648

2649+
int DoDtls13KeyUpdateAck(WOLFSSL* ssl)
2650+
{
2651+
int ret = 0;
2652+
2653+
if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) {
2654+
/* we removed the KeyUpdate message because it was ACKed */
2655+
ssl->dtls13WaitKeyUpdateAck = 0;
2656+
ret = Dtls13KeyUpdateAckReceived(ssl);
2657+
}
2658+
2659+
return ret;
2660+
}
2661+
26042662
int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
26052663
word32* processedSize)
26062664
{
26072665
const byte* ackMessage;
26082666
w64wrapper epoch, seq;
26092667
word16 length;
2668+
#ifndef WOLFSSL_RW_THREADED
26102669
int ret;
2670+
#endif
26112671
int i;
26122672

26132673
if (inputSize < OPAQUE16_LEN)
@@ -2639,15 +2699,13 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
26392699
ssl->options.serverState = SERVER_FINISHED_ACKED;
26402700
}
26412701

2702+
#ifndef WOLFSSL_RW_THREADED
26422703
if (ssl->dtls13WaitKeyUpdateAck) {
2643-
if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) {
2644-
/* we removed the KeyUpdate message because it was ACKed */
2645-
ssl->dtls13WaitKeyUpdateAck = 0;
2646-
ret = Dtls13KeyUpdateAckReceived(ssl);
2647-
if (ret != 0)
2648-
return ret;
2649-
}
2704+
ret = DoDtls13KeyUpdateAck(ssl);
2705+
if (ret != 0)
2706+
return ret;
26502707
}
2708+
#endif
26512709

26522710
*processedSize = length + OPAQUE16_LEN;
26532711

@@ -2698,9 +2756,17 @@ int SendDtls13Ack(WOLFSSL* ssl)
26982756
if (ret != 0)
26992757
return ret;
27002758

2701-
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
2702-
if (ret != 0)
2759+
#ifdef WOLFSSL_RW_THREADED
2760+
ret = wc_LockMutex(&ssl->dtls13Rtx.mutex);
2761+
if (ret < 0)
27032762
return ret;
2763+
#endif
2764+
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
2765+
#ifdef WOLFSSL_RW_THREADED
2766+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
2767+
#endif
2768+
if (ret != 0)
2769+
return ret;
27042770

27052771
output = GetOutputBuffer(ssl);
27062772

0 commit comments

Comments
 (0)