Skip to content

Commit 2c2388b

Browse files
committed
feat: add full test
1 parent 61ebcec commit 2c2388b

4 files changed

Lines changed: 30 additions & 26 deletions

File tree

examples/rbac_model.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ g = _, _
1111
e = some(where (p.eft == allow))
1212

1313
[matchers]
14-
m = (p.sub == "*" || g(r.sub, p.sub)) && r.obj == p.obj && (p.act == "*" || r.act == p.act)
14+
m = (p.sub == "*" || g(r.sub, p.sub)) && (r.obj == p.obj || keyMatch(r.obj, p.obj)) && (p.act == "*" || r.act == p.act)

fastapi_authz/middleware.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import typing
2-
31
from casbin.enforcer import Enforcer
4-
2+
from starlette.authentication import BaseUser
53
from starlette.requests import Request
64
from starlette.responses import JSONResponse
75
from starlette.status import HTTP_403_FORBIDDEN
86
from starlette.types import ASGIApp, Receive, Scope, Send
9-
from starlette.authentication import BaseUser
107

118

129
class CasbinMiddleware:
@@ -61,7 +58,10 @@ def _enforce(self, scope: Scope, receive: Receive) -> bool:
6158
if 'user' not in scope:
6259
raise RuntimeError("Casbin Middleware must work with an Authentication Middleware")
6360

64-
assert isinstance(request.user,BaseUser)
61+
assert isinstance(request.user, BaseUser)
6562

6663
user = request.user.display_name if request.user.is_authenticated else 'anonymous'
64+
65+
print(user, path, method)
66+
6767
return self.enforcer.enforce(user, path, method)

tests/conftest.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import base64
22
import binascii
3+
import os
34

45
import casbin
56
import pytest
6-
import os
77
from fastapi import FastAPI
88
from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, SimpleUser
99
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -38,9 +38,7 @@ def app_fixture():
3838

3939
app = FastAPI()
4040

41-
@app.on_event('startup')
42-
async def startup():
43-
app.add_middleware(CasbinMiddleware, enforcer=enforcer)
44-
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())
41+
app.add_middleware(CasbinMiddleware, enforcer=enforcer)
42+
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())
4543

4644
yield app

tests/test_middleware.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,41 @@
11
import pytest
2-
from fastapi import Depends
3-
from starlette.authentication import requires
42
from starlette.testclient import TestClient
53

64

75
@pytest.mark.parametrize(
86
"test_server_path, test_client_path, method, status_code, user, response_body", [
97
('/dataset1/resource2', '/dataset1/resource2', 'GET', 200, 'alice', 'ok'),
8+
('/dataset1/resource2', '/dataset1/resource2', 'GET', 403, 'notalice', 'Forbidden'),
9+
('/dataset1/resource1', '/dataset1/resource1', 'POST', 200, 'alice', 'ok'),
1010
]
1111
)
1212
def test_middleware_authed(app_fixture, test_server_path, test_client_path, method, status_code, user, response_body):
13-
# if method == 'GET':
14-
# @app_fixture.get(test_server_path)
15-
# async def index():
16-
# return 'ok'
17-
# elif method == 'POST':
18-
# @app_fixture.post(test_server_path)
19-
# async def index():
20-
# return 'ok'
21-
# elif method == 'PUT':
22-
# @app_fixture.put(test_server_path)
23-
# async def index():
24-
# return 'ok'
13+
@getattr(app_fixture, method.lower())(test_server_path)
14+
async def index():
15+
return 'ok'
2516

17+
test_client = TestClient(app_fixture)
18+
19+
test_response = getattr(test_client, method.lower())(test_client_path, auth=(user, 'password'))
20+
21+
assert test_response.status_code == status_code
22+
assert test_response.json() == response_body
23+
24+
25+
@pytest.mark.parametrize(
26+
"test_server_path, test_client_path, method, status_code, response_body", [
27+
('/login', '/login', 'GET', 200, 'ok'),
28+
('/', '/', 'GET', 200, 'ok')
29+
]
30+
)
31+
def test_middleware_not_authed(app_fixture, test_server_path, test_client_path, method, status_code, response_body):
2632
@getattr(app_fixture, method.lower())(test_server_path)
2733
async def index():
2834
return 'ok'
2935

3036
test_client = TestClient(app_fixture)
3137

32-
test_response = test_client.get(test_client_path, auth=(user, 'password'))
38+
test_response = getattr(test_client, method.lower())(test_client_path)
3339

3440
assert test_response.status_code == status_code
3541
assert test_response.json() == response_body

0 commit comments

Comments
 (0)