Skip to content

Commit 570c1fc

Browse files
Merge pull request #8824 from JeremiahM37/tlsCurveFix
tls fix for set_groups
2 parents bfc55d9 + edfc536 commit 570c1fc

4 files changed

Lines changed: 189 additions & 37 deletions

File tree

src/tls.c

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5048,7 +5048,8 @@ int TLSX_SupportedCurve_Parse(const WOLFSSL* ssl, const byte* input,
50485048
{
50495049
word16 offset;
50505050
word16 name;
5051-
int ret;
5051+
int ret = 0;
5052+
TLSX* extension;
50525053

50535054
if(!isRequest && !IsAtLeastTLSv1_3(ssl->version)) {
50545055
#ifdef WOLFSSL_ALLOW_SERVER_SC_EXT
@@ -5057,57 +5058,66 @@ int TLSX_SupportedCurve_Parse(const WOLFSSL* ssl, const byte* input,
50575058
return BUFFER_ERROR; /* servers doesn't send this extension. */
50585059
#endif
50595060
}
5060-
50615061
if (OPAQUE16_LEN > length || length % OPAQUE16_LEN)
50625062
return BUFFER_ERROR;
5063-
50645063
ato16(input, &offset);
5065-
50665064
/* validating curve list length */
50675065
if (length != OPAQUE16_LEN + offset)
50685066
return BUFFER_ERROR;
5069-
50705067
offset = OPAQUE16_LEN;
50715068
if (offset == length)
50725069
return 0;
50735070

5074-
#if defined(WOLFSSL_TLS13) && !defined(WOLFSSL_NO_SERVER_GROUPS_EXT)
5075-
if (!isRequest) {
5076-
TLSX* extension;
5077-
SupportedCurve* curve;
5078-
5079-
extension = TLSX_Find(*extensions, TLSX_SUPPORTED_GROUPS);
5080-
if (extension != NULL) {
5081-
/* Replace client list with server list of supported groups. */
5082-
curve = (SupportedCurve*)extension->data;
5083-
extension->data = NULL;
5084-
TLSX_SupportedCurve_FreeAll(curve, ssl->heap);
5085-
5071+
extension = TLSX_Find(*extensions, TLSX_SUPPORTED_GROUPS);
5072+
if (extension == NULL) {
5073+
/* Just accept what the peer wants to use */
5074+
for (; offset < length; offset += OPAQUE16_LEN) {
50865075
ato16(input + offset, &name);
5087-
offset += OPAQUE16_LEN;
50885076

5089-
ret = TLSX_SupportedCurve_New(&curve, name, ssl->heap);
5090-
if (ret != 0)
5091-
return ret; /* throw error */
5092-
extension->data = (void*)curve;
5077+
ret = TLSX_UseSupportedCurve(extensions, name, ssl->heap);
5078+
/* If it is BAD_FUNC_ARG then it is a group we do not support, but
5079+
* that is fine. */
5080+
if (ret != WOLFSSL_SUCCESS &&
5081+
ret != WC_NO_ERR_TRACE(BAD_FUNC_ARG))
5082+
break;
5083+
ret = 0;
50935084
}
50945085
}
5095-
#endif
5086+
else {
5087+
/* Find the intersection with what the user has set */
5088+
SupportedCurve* commonCurves = NULL;
5089+
for (; offset < length; offset += OPAQUE16_LEN) {
5090+
SupportedCurve* foundCurve = (SupportedCurve*)extension->data;
5091+
ato16(input + offset, &name);
50965092

5097-
for (; offset < length; offset += OPAQUE16_LEN) {
5098-
ato16(input + offset, &name);
5093+
while (foundCurve != NULL && foundCurve->name != name)
5094+
foundCurve = foundCurve->next;
50995095

5100-
ret = TLSX_UseSupportedCurve(extensions, name, ssl->heap);
5101-
/* If it is BAD_FUNC_ARG then it is a group we do not support, but
5102-
* that is fine. */
5103-
if (ret != WOLFSSL_SUCCESS && ret != WC_NO_ERR_TRACE(BAD_FUNC_ARG)) {
5104-
return ret;
5096+
if (foundCurve != NULL) {
5097+
ret = commonCurves == NULL ?
5098+
TLSX_SupportedCurve_New(&commonCurves, name, ssl->heap) :
5099+
TLSX_SupportedCurve_Append(commonCurves, name, ssl->heap);
5100+
if (ret != 0)
5101+
break;
5102+
}
51055103
}
5104+
/* If no common curves return error. In TLS 1.3 we can still try to save
5105+
* this by using HRR. */
5106+
if (ret == 0 && commonCurves == NULL &&
5107+
!IsAtLeastTLSv1_3(ssl->version))
5108+
ret = ECC_CURVE_ERROR;
5109+
if (ret == 0) {
5110+
/* Now swap out the curves in the extension */
5111+
TLSX_SupportedCurve_FreeAll((SupportedCurve*)extension->data,
5112+
ssl->heap);
5113+
extension->data = commonCurves;
5114+
commonCurves = NULL;
5115+
}
5116+
TLSX_SupportedCurve_FreeAll(commonCurves, ssl->heap);
51065117
}
51075118

5108-
return 0;
5119+
return ret;
51095120
}
5110-
51115121
#endif
51125122

51135123
#if !defined(NO_WOLFSSL_SERVER)
@@ -10798,15 +10808,17 @@ int TLSX_KeyShare_SetSupported(const WOLFSSL* ssl, TLSX** extensions)
1079810808
TLSX* extension;
1079910809
SupportedCurve* curve = NULL;
1080010810
SupportedCurve* preferredCurve = NULL;
10811+
word16 name = WOLFSSL_NAMED_GROUP_INVALID;
1080110812
KeyShareEntry* kse = NULL;
1080210813
int preferredRank = WOLFSSL_MAX_GROUP_COUNT;
1080310814
int rank;
1080410815

1080510816
extension = TLSX_Find(*extensions, TLSX_SUPPORTED_GROUPS);
1080610817
if (extension != NULL)
1080710818
curve = (SupportedCurve*)extension->data;
10808-
/* Use server's preference order. */
1080910819
for (; curve != NULL; curve = curve->next) {
10820+
/* Use server's preference order. Common group was found but key share
10821+
* was missing */
1081010822
if (!TLSX_IsGroupSupported(curve->name))
1081110823
continue;
1081210824
if (wolfSSL_curve_is_disabled(ssl, curve->name))
@@ -10823,8 +10835,26 @@ int TLSX_KeyShare_SetSupported(const WOLFSSL* ssl, TLSX** extensions)
1082310835
curve = preferredCurve;
1082410836

1082510837
if (curve == NULL) {
10826-
WOLFSSL_ERROR_VERBOSE(BAD_KEY_SHARE_DATA);
10827-
return BAD_KEY_SHARE_DATA;
10838+
byte i;
10839+
/* Fallback to user selected group */
10840+
preferredRank = WOLFSSL_MAX_GROUP_COUNT;
10841+
for (i = 0; i < ssl->numGroups; i++) {
10842+
rank = TLSX_KeyShare_GroupRank(ssl, ssl->group[i]);
10843+
if (rank == -1)
10844+
continue;
10845+
if (rank < preferredRank) {
10846+
name = ssl->group[i];
10847+
preferredRank = rank;
10848+
}
10849+
}
10850+
if (name == WOLFSSL_NAMED_GROUP_INVALID) {
10851+
/* No group selected or specified by the server */
10852+
WOLFSSL_ERROR_VERBOSE(BAD_KEY_SHARE_DATA);
10853+
return BAD_KEY_SHARE_DATA;
10854+
}
10855+
}
10856+
else {
10857+
name = curve->name;
1082810858
}
1082910859

1083010860
#ifdef WOLFSSL_ASYNC_CRYPT
@@ -10848,7 +10878,7 @@ int TLSX_KeyShare_SetSupported(const WOLFSSL* ssl, TLSX** extensions)
1084810878
/* Extension got pushed to head */
1084910879
extension = *extensions;
1085010880
/* Push the selected curve */
10851-
ret = TLSX_KeyShare_New((KeyShareEntry**)&extension->data, curve->name,
10881+
ret = TLSX_KeyShare_New((KeyShareEntry**)&extension->data, name,
1085210882
ssl->heap, &kse);
1085310883
if (ret != 0)
1085410884
return ret;

tests/api.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30635,11 +30635,16 @@ static int test_wolfSSL_curves_mismatch(void)
3063530635
} test_params[] = {
3063630636
#ifdef WOLFSSL_TLS13
3063730637
{wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, "TLS 1.3",
30638-
WC_NO_ERR_TRACE(FATAL_ERROR), WC_NO_ERR_TRACE(BAD_KEY_SHARE_DATA)},
30638+
/* Client gets error because server will attempt HRR */
30639+
WC_NO_ERR_TRACE(BAD_KEY_SHARE_DATA),
30640+
WC_NO_ERR_TRACE(FATAL_ERROR)
30641+
},
3063930642
#endif
3064030643
#ifndef WOLFSSL_NO_TLS12
3064130644
{wolfTLSv1_2_client_method, wolfTLSv1_2_server_method, "TLS 1.2",
3064230645
WC_NO_ERR_TRACE(FATAL_ERROR),
30646+
/* Server gets error because <=1.2 doesn't have a mechanism
30647+
* to negotiate curves. */
3064330648
#ifdef OPENSSL_EXTRA
3064430649
WC_NO_ERR_TRACE(WOLFSSL_ERROR_SYSCALL)
3064530650
#else
@@ -68270,6 +68275,8 @@ TEST_CASE testCases[] = {
6827068275
TEST_DECL(test_ocsp_certid_enc_dec),
6827168276
TEST_DECL(test_tls12_unexpected_ccs),
6827268277
TEST_DECL(test_tls13_unexpected_ccs),
68278+
TEST_DECL(test_tls12_curve_intersection),
68279+
TEST_DECL(test_tls13_curve_intersection),
6827368280
TEST_DECL(test_wc_DhSetNamedKey),
6827468281
/* This test needs to stay at the end to clean up any caches allocated. */
6827568282
TEST_DECL(test_wolfSSL_Cleanup)

tests/api/test_tls.c

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,116 @@ int test_tls13_unexpected_ccs(void)
141141
#endif
142142
return EXPECT_RESULT();
143143
}
144+
int test_tls12_curve_intersection(void) {
145+
EXPECT_DECLS;
146+
#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \
147+
!defined(WOLFSSL_NO_TLS12) && defined(HAVE_ECC) && \
148+
defined(HAVE_CURVE25519)
149+
WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL;
150+
WOLFSSL *ssl_c = NULL, *ssl_s = NULL;
151+
struct test_memio_ctx test_ctx;
152+
int ret;
153+
const char* curve_name;
154+
int test1[] = {WOLFSSL_ECC_SECP256R1};
155+
int test2[] = {WOLFSSL_ECC_SECP384R1};
156+
int test3[] = {WOLFSSL_ECC_SECP256R1, WOLFSSL_ECC_SECP384R1};
157+
int test4[] = {WOLFSSL_ECC_SECP384R1, WOLFSSL_ECC_SECP256R1};
158+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
159+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
160+
wolfTLSv1_2_client_method, wolfTLSv1_2_server_method), 0);
161+
ExpectIntEQ(wolfSSL_set_groups(ssl_c,
162+
test1, 1), WOLFSSL_SUCCESS);
163+
ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
164+
165+
/* Fix: Get curve name and compare with string comparison or use curve
166+
* ID function */
167+
curve_name = wolfSSL_get_curve_name(ssl_s);
168+
/* or use appropriate string comparison */
169+
ExpectStrEQ(curve_name, "SECP256R1");
170+
curve_name = wolfSSL_get_curve_name(ssl_c);
171+
ExpectStrEQ(curve_name, "SECP256R1");
172+
173+
wolfSSL_free(ssl_c);
174+
wolfSSL_free(ssl_s);
175+
wolfSSL_CTX_free(ctx_c);
176+
wolfSSL_CTX_free(ctx_s);
177+
ssl_c = NULL;
178+
ssl_s = NULL;
179+
ctx_c = NULL;
180+
ctx_s = NULL;
181+
182+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
183+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
184+
wolfTLSv1_2_client_method, wolfTLSv1_2_server_method), 0);
185+
ExpectIntEQ(wolfSSL_set_groups(ssl_c,
186+
test2, 1), WOLFSSL_SUCCESS);
187+
ExpectIntEQ(wolfSSL_set_groups(ssl_s,
188+
test1, 1), WOLFSSL_SUCCESS);
189+
ExpectIntNE(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
190+
ret = wolfSSL_get_error(ssl_s, WOLFSSL_FATAL_ERROR);
191+
192+
/* Fix: Use proper constant or define HANDSHAKE_FAILURE */
193+
ExpectTrue(ret == WC_NO_ERR_TRACE(ECC_CURVE_ERROR));
194+
195+
wolfSSL_free(ssl_c);
196+
wolfSSL_free(ssl_s);
197+
wolfSSL_CTX_free(ctx_c);
198+
wolfSSL_CTX_free(ctx_s);
199+
ssl_c = NULL;
200+
ssl_s = NULL;
201+
ctx_c = NULL;
202+
ctx_s = NULL;
203+
204+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
205+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
206+
wolfTLSv1_2_client_method, wolfTLSv1_2_server_method), 0);
207+
ExpectIntEQ(wolfSSL_set_groups(ssl_c,
208+
test3, 2),
209+
WOLFSSL_SUCCESS);
210+
ExpectIntEQ(wolfSSL_set_groups(ssl_s,
211+
test4, 2),
212+
WOLFSSL_SUCCESS);
213+
ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
214+
215+
curve_name = wolfSSL_get_curve_name(ssl_s);
216+
ExpectStrEQ(curve_name, "SECP256R1");
217+
curve_name = wolfSSL_get_curve_name(ssl_c);
218+
ExpectStrEQ(curve_name, "SECP256R1");
219+
220+
wolfSSL_free(ssl_c);
221+
wolfSSL_free(ssl_s);
222+
wolfSSL_CTX_free(ctx_c);
223+
wolfSSL_CTX_free(ctx_s);
224+
#endif
225+
return EXPECT_RESULT();
226+
}
227+
228+
int test_tls13_curve_intersection(void) {
229+
EXPECT_DECLS;
230+
#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \
231+
defined(WOLFSSL_TLS13) && defined(HAVE_ECC) && defined(HAVE_CURVE25519)
232+
WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL;
233+
WOLFSSL *ssl_c = NULL, *ssl_s = NULL;
234+
struct test_memio_ctx test_ctx;
235+
const char* curve_name;
236+
int test1[] ={WOLFSSL_ECC_SECP256R1};
237+
238+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
239+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
240+
wolfTLSv1_3_client_method, wolfTLSv1_3_server_method), 0);
241+
ExpectIntEQ(wolfSSL_set_groups(ssl_c,
242+
test1, 1), WOLFSSL_SUCCESS);
243+
ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
244+
245+
curve_name = wolfSSL_get_curve_name(ssl_s);
246+
ExpectStrEQ(curve_name, "SECP256R1");
247+
curve_name = wolfSSL_get_curve_name(ssl_c);
248+
ExpectStrEQ(curve_name, "SECP256R1");
249+
250+
wolfSSL_free(ssl_c);
251+
wolfSSL_free(ssl_s);
252+
wolfSSL_CTX_free(ctx_c);
253+
wolfSSL_CTX_free(ctx_s);
254+
#endif
255+
return EXPECT_RESULT();
256+
}

tests/api/test_tls.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,7 @@
2424

2525
int test_tls12_unexpected_ccs(void);
2626
int test_tls13_unexpected_ccs(void);
27+
int test_tls12_curve_intersection(void);
28+
int test_tls13_curve_intersection(void);
2729

2830
#endif /* TESTS_API_TEST_TLS_EMS_H */

0 commit comments

Comments
 (0)