Skip to content

Commit 391c849

Browse files
authored
Merge pull request #1 from plamorg/xforward
Xforward
2 parents 5d75b02 + c255d61 commit 391c849

8 files changed

Lines changed: 155 additions & 53 deletions

File tree

integration/examples/examples_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ func TestExamples(t *testing.T) {
1212
examples := []string{
1313
"./middlewares/auth-forward.yml",
1414
"./middlewares/ip-allow.yml",
15+
"./middlewares/x-forward.yml",
1516
"./additional-configuration.yml",
1617
"./basic.yml",
1718
"./health-check.yml",

integration/examples/middlewares/auth-forward.yml

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,3 @@ services:
1919

2020
# Forward X-Forwarded-* headers to the authentication server.
2121
xForwarded: true # Default: false.
22-
23-
# === Integrating with Authelia ===
24-
# The authForward middleware allows the use of external authenticators like Authelia to work with voltproxy proxied services.
25-
# Here is an example:
26-
# 1. Firstly, proxy the Authelia service:
27-
authelia:
28-
host: authelia.example.com
29-
tls: true
30-
# Redirect to local Authelia instance. Alternatively, specify container service if it's running in a Docker container.
31-
redirect: "http://172.22.0.2:9091"
32-
33-
# 2. Proxy the service you want to protect with Authelia.
34-
protected:
35-
host: protected.example.com
36-
tls: true
37-
redirect: "http://localhost:3000"
38-
middlewares:
39-
# Specify authForward middleware on the to-be-protected service.
40-
authForward:
41-
address: "https://authelia.example.com/api/verify?rd=https://authelia.example.com"
42-
responseHeaders:
43-
["Remote-User", "Remote-Groups", "Remote-Name", "Remote-Email"]
44-
xForwarded: true
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# The xForward middleware is used to forward X-Forwarded-* headers to an upstream service.
2+
3+
services:
4+
secureService:
5+
host: private.example.com
6+
redirect: "http://172.3.2.1:4321"
7+
middlewares:
8+
xForward:
9+
enable: true

integration/integration_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ services:
228228
defer res.Body.Close()
229229

230230
if res.StatusCode != http.StatusForbidden {
231-
t.Fatalf("expected status code %d, got %d", http.StatusOK, res.StatusCode)
231+
t.Fatalf("expected status code %d, got %d", http.StatusForbidden, res.StatusCode)
232232
}
233233

234234
if !authServerRan {
@@ -446,3 +446,29 @@ services:
446446
t.Fatalf("expected status codes %v, got %v", expected, received)
447447
}
448448
}
449+
450+
func TestXForward(t *testing.T) {
451+
server := NewMockServer(t, func(w http.ResponseWriter, r *http.Request) {
452+
if r.Header.Get("X-Forwarded-Host") == "" {
453+
t.Errorf("expected X-Forwarded-Host header to be set")
454+
}
455+
w.WriteHeader(http.StatusOK)
456+
})
457+
458+
conf := fmt.Sprintf(`
459+
services:
460+
server:
461+
host: server.example.com
462+
redirect: "%s"
463+
middlewares:
464+
xForward:
465+
enable: true`, server.URL())
466+
i := NewInstance(t, []byte(conf), nil)
467+
468+
res := i.RequestHost("server.example.com")
469+
defer res.Body.Close()
470+
471+
if res.StatusCode != http.StatusOK {
472+
t.Fatalf("expected status code %d, got %d", http.StatusOK, res.StatusCode)
473+
}
474+
}

middlewares/authforward.go

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,9 @@ package middlewares
22

33
import (
44
"log/slog"
5-
"net"
65
"net/http"
76
)
87

9-
const (
10-
xForwardedFor = "X-Forwarded-For"
11-
xForwardedMethod = "X-Forwarded-Method"
12-
xForwardedProto = "X-Forwarded-Proto"
13-
xForwardedHost = "X-Forwarded-Host"
14-
xForwardedURI = "X-Forwarded-Uri"
15-
)
16-
17-
var xForwardedHeaders = []string{
18-
xForwardedFor,
19-
xForwardedMethod,
20-
xForwardedProto,
21-
xForwardedHost,
22-
xForwardedURI,
23-
}
24-
258
// AuthForward is a middleware that forwards the request to an authentication server and
269
// proxies to the service if the authentication is successful.
2710
type AuthForward struct {
@@ -68,18 +51,7 @@ func (a *AuthForward) Handle(next http.Handler) http.Handler {
6851
}
6952

7053
if a.XForwarded {
71-
host, _, err := net.SplitHostPort(r.RemoteAddr)
72-
if err == nil {
73-
authReq.Header.Set(xForwardedFor, host)
74-
}
75-
authReq.Header.Set(xForwardedMethod, r.Method)
76-
if r.TLS != nil {
77-
authReq.Header.Set(xForwardedProto, "https")
78-
} else {
79-
authReq.Header.Set(xForwardedProto, "http")
80-
}
81-
authReq.Header.Set(xForwardedHost, r.Host)
82-
authReq.Header.Set(xForwardedURI, r.RequestURI)
54+
xForward(authReq, *r)
8355
} else {
8456
for _, header := range xForwardedHeaders {
8557
authReq.Header.Del(header)

middlewares/middleware.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
type Middlewares struct {
1111
IPAllow *IPAllow `yaml:"ipAllow"`
1212
AuthForward *AuthForward `yaml:"authForward"`
13+
XForward *XForward `yaml:"xForward"`
1314
}
1415

1516
// List returns a list of middlewares that are not nil.

middlewares/xforward.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package middlewares
2+
3+
import (
4+
"log/slog"
5+
"net"
6+
"net/http"
7+
)
8+
9+
const (
10+
xForwardedFor = "X-Forwarded-For"
11+
xForwardedMethod = "X-Forwarded-Method"
12+
xForwardedProto = "X-Forwarded-Proto"
13+
xForwardedHost = "X-Forwarded-Host"
14+
xForwardedURI = "X-Forwarded-Uri"
15+
)
16+
17+
var xForwardedHeaders = []string{
18+
xForwardedFor,
19+
xForwardedMethod,
20+
xForwardedProto,
21+
xForwardedHost,
22+
xForwardedURI,
23+
}
24+
25+
// XForward is a middleware that adds X-Forwarded headers to the request.
26+
type XForward struct {
27+
Enable bool `yaml:"enable"`
28+
}
29+
30+
func xForward(newReq *http.Request, r http.Request) {
31+
host, _, err := net.SplitHostPort(r.RemoteAddr)
32+
if err == nil {
33+
newReq.Header.Set(xForwardedFor, host)
34+
}
35+
newReq.Header.Set(xForwardedMethod, r.Method)
36+
if r.TLS != nil {
37+
newReq.Header.Set(xForwardedProto, "https")
38+
} else {
39+
newReq.Header.Set(xForwardedProto, "http")
40+
}
41+
newReq.Header.Set(xForwardedHost, r.Host)
42+
newReq.Header.Set(xForwardedURI, r.RequestURI)
43+
}
44+
45+
// Handle adds X-Forwarded headers to the request.
46+
func (x *XForward) Handle(next http.Handler) http.Handler {
47+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48+
logger := slog.Default().With(
49+
slog.String("host", r.Host),
50+
slog.Any("xForward", x))
51+
52+
if !x.Enable {
53+
next.ServeHTTP(w, r)
54+
return
55+
}
56+
57+
xForward(r, *r)
58+
59+
logger.Debug("Added X-Forwarded headers")
60+
61+
next.ServeHTTP(w, r)
62+
})
63+
}

middlewares/xforward_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package middlewares
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
)
8+
9+
func TestXForwardDisable(t *testing.T) {
10+
x := XForward{
11+
Enable: false,
12+
}
13+
14+
server := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
for _, header := range xForwardedHeaders {
16+
if r.Header.Get(header) != "" {
17+
t.Errorf("expected empty header %s, got %s", header, r.Header.Get(header))
18+
}
19+
}
20+
})
21+
22+
handler := x.Handle(server)
23+
24+
w, r := httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)
25+
handler.ServeHTTP(w, r)
26+
27+
if w.Code != http.StatusOK {
28+
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
29+
}
30+
}
31+
32+
func TestXForwardEnable(t *testing.T) {
33+
x := XForward{
34+
Enable: true,
35+
}
36+
37+
server := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38+
for _, header := range xForwardedHeaders {
39+
if r.Header.Get(header) == "" {
40+
t.Errorf("expected non-empty header %s", header)
41+
}
42+
}
43+
})
44+
45+
handler := x.Handle(server)
46+
47+
w, r := httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)
48+
handler.ServeHTTP(w, r)
49+
50+
if w.Code != http.StatusOK {
51+
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
52+
}
53+
}

0 commit comments

Comments
 (0)