Skip to content

Commit a76544b

Browse files
yoshikipomJamie Tanna
andauthored
fix: avoid stack overflow errors when using heavily recursive types (oapi-codegen#1377)
As noted in oapi-codegen#1373, we have cases where a heavily recursive `allOf` could lead to a stack overflow error. To avoid this, we can track the references that we've seen, and not recurse further if we have already traversed that ref. Closes oapi-codegen#1373. Co-authored-by: Jamie Tanna <jamie.tanna@elastic.co>
1 parent af135a9 commit a76544b

6 files changed

Lines changed: 77 additions & 12 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package: issue1373
2+
generate:
3+
models: true
4+
output: issue.gen.go
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package issue1373
2+
3+
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen --config=config.yaml spec.yaml

internal/test/issues/issue-1373/issue.gen.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
openapi: 3.0.2
2+
info:
3+
version: '0.0.1'
4+
title: example
5+
description: |
6+
Make sure that recursive $ref in allOf are handled properly
7+
paths:
8+
/example:
9+
get:
10+
operationId: exampleGet
11+
responses:
12+
'200':
13+
description: "OK"
14+
content:
15+
'application/json':
16+
schema:
17+
$ref: '#/components/schemas/RecursiveObject'
18+
19+
components:
20+
schemas:
21+
RecursiveObject:
22+
allOf:
23+
- $ref: "#/components/schemas/NonRecursiveObject"
24+
- $ref: "#/components/schemas/RecursiveObject"
25+
- properties:
26+
FieldInRecursive::
27+
type: string
28+
29+
NonRecursiveObject:
30+
properties:
31+
FieldInNonRecursive:
32+
type: string

pkg/codegen/merge_schemas.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ func mergeSchemas(allOf []*openapi3.SchemaRef, path []string) (Schema, error) {
3838
if err != nil {
3939
return Schema{}, err
4040
}
41-
schema, err = mergeOpenapiSchemas(schema, oneOfSchema, true)
41+
42+
seenSchemaRef := make(map[string]bool)
43+
if allOf[i].Ref != "" {
44+
seenSchemaRef[allOf[i].Ref] = true
45+
}
46+
schema, err = mergeOpenapiSchemas(schema, oneOfSchema, true, seenSchemaRef)
4247
if err != nil {
4348
return Schema{}, fmt.Errorf("error merging schemas for AllOf: %w", err)
4449
}
@@ -71,11 +76,17 @@ func valueWithPropagatedRef(ref *openapi3.SchemaRef) (openapi3.Schema, error) {
7176
return schema, nil
7277
}
7378

74-
func mergeAllOf(allOf []*openapi3.SchemaRef) (openapi3.Schema, error) {
79+
func mergeAllOf(allOf []*openapi3.SchemaRef, seenSchemaRef map[string]bool) (openapi3.Schema, error) {
7580
var schema openapi3.Schema
7681
for _, schemaRef := range allOf {
7782
var err error
78-
schema, err = mergeOpenapiSchemas(schema, *schemaRef.Value, true)
83+
if schemaRef.Ref != "" && seenSchemaRef[schemaRef.Ref] {
84+
continue
85+
}
86+
if schemaRef.Ref != "" {
87+
seenSchemaRef[schemaRef.Ref] = true
88+
}
89+
schema, err = mergeOpenapiSchemas(schema, *schemaRef.Value, true, seenSchemaRef)
7990
if err != nil {
8091
return openapi3.Schema{}, fmt.Errorf("error merging schemas for AllOf: %w", err)
8192
}
@@ -85,7 +96,7 @@ func mergeAllOf(allOf []*openapi3.SchemaRef) (openapi3.Schema, error) {
8596

8697
// mergeOpenapiSchemas merges two openAPI schemas and returns the schema
8798
// all of whose fields are composed.
88-
func mergeOpenapiSchemas(s1, s2 openapi3.Schema, allOf bool) (openapi3.Schema, error) {
99+
func mergeOpenapiSchemas(s1, s2 openapi3.Schema, allOf bool, seenSchemaRef map[string]bool) (openapi3.Schema, error) {
89100
var result openapi3.Schema
90101

91102
result.Extensions = make(map[string]any, len(s1.Extensions)+len(s2.Extensions))
@@ -100,15 +111,15 @@ func mergeOpenapiSchemas(s1, s2 openapi3.Schema, allOf bool) (openapi3.Schema, e
100111
var err error
101112
if s1.AllOf != nil {
102113
var merged openapi3.Schema
103-
merged, err = mergeAllOf(s1.AllOf)
114+
merged, err = mergeAllOf(s1.AllOf, seenSchemaRef)
104115
if err != nil {
105116
return openapi3.Schema{}, fmt.Errorf("error transitive merging AllOf on schema 1")
106117
}
107118
s1 = merged
108119
}
109120
if s2.AllOf != nil {
110121
var merged openapi3.Schema
111-
merged, err = mergeAllOf(s2.AllOf)
122+
merged, err = mergeAllOf(s2.AllOf, seenSchemaRef)
112123
if err != nil {
113124
return openapi3.Schema{}, fmt.Errorf("error transitive merging AllOf on schema 2")
114125
}

pkg/codegen/merge_schemas_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestMergeOpenapiSchemas_DiscriminatorPropagation(t *testing.T) {
1717
s1 := openapi3.Schema{Discriminator: disc}
1818
s2 := openapi3.Schema{}
1919

20-
result, err := mergeOpenapiSchemas(s1, s2, true)
20+
result, err := mergeOpenapiSchemas(s1, s2, true, make(map[string]bool))
2121
require.NoError(t, err)
2222
assert.Equal(t, disc, result.Discriminator)
2323
})
@@ -26,7 +26,7 @@ func TestMergeOpenapiSchemas_DiscriminatorPropagation(t *testing.T) {
2626
s1 := openapi3.Schema{}
2727
s2 := openapi3.Schema{Discriminator: disc}
2828

29-
result, err := mergeOpenapiSchemas(s1, s2, true)
29+
result, err := mergeOpenapiSchemas(s1, s2, true, make(map[string]bool))
3030
require.NoError(t, err)
3131
assert.Equal(t, disc, result.Discriminator)
3232
})
@@ -36,7 +36,7 @@ func TestMergeOpenapiSchemas_DiscriminatorPropagation(t *testing.T) {
3636
s1 := openapi3.Schema{Discriminator: disc}
3737
s2 := openapi3.Schema{Discriminator: disc2}
3838

39-
_, err := mergeOpenapiSchemas(s1, s2, true)
39+
_, err := mergeOpenapiSchemas(s1, s2, true, make(map[string]bool))
4040
require.Error(t, err)
4141
assert.Contains(t, err.Error(), "discriminators")
4242
})
@@ -45,7 +45,7 @@ func TestMergeOpenapiSchemas_DiscriminatorPropagation(t *testing.T) {
4545
s1 := openapi3.Schema{}
4646
s2 := openapi3.Schema{}
4747

48-
result, err := mergeOpenapiSchemas(s1, s2, true)
48+
result, err := mergeOpenapiSchemas(s1, s2, true, make(map[string]bool))
4949
require.NoError(t, err)
5050
assert.Nil(t, result.Discriminator)
5151
})
@@ -54,15 +54,15 @@ func TestMergeOpenapiSchemas_DiscriminatorPropagation(t *testing.T) {
5454
s1 := openapi3.Schema{Discriminator: disc}
5555
s2 := openapi3.Schema{}
5656

57-
_, err := mergeOpenapiSchemas(s1, s2, false)
57+
_, err := mergeOpenapiSchemas(s1, s2, false, make(map[string]bool))
5858
require.Error(t, err)
5959
})
6060

6161
t.Run("non-allOf with discriminator on s2 errors", func(t *testing.T) {
6262
s1 := openapi3.Schema{}
6363
s2 := openapi3.Schema{Discriminator: disc}
6464

65-
_, err := mergeOpenapiSchemas(s1, s2, false)
65+
_, err := mergeOpenapiSchemas(s1, s2, false, make(map[string]bool))
6666
require.Error(t, err)
6767
})
6868
}

0 commit comments

Comments
 (0)