Skip to content

Commit 61ebcec

Browse files
committed
feat: add pytest
1 parent 3a86f5e commit 61ebcec

7 files changed

Lines changed: 102 additions & 4 deletions

File tree

examples/basic_policy.csv

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ r = sub, obj, act
44
[policy_definition]
55
p = sub, obj, act
66

7+
[role_definition]
8+
g = _, _
9+
710
[policy_effect]
811
e = some(where (p.eft == allow))
912

1013
[matchers]
11-
m = r.sub == p.sub && r.obj == p.obj && r.act == p.act
14+
m = (p.sub == "*" || g(r.sub, p.sub)) && r.obj == p.obj && (p.act == "*" || r.act == p.act)

examples/rbac_policy.csv

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
p, alice, /dataset1/*, GET
2+
p, alice, /dataset1/resource1, POST
3+
p, bob, /dataset2/resource1, *
4+
p, bob, /dataset2/resource2, GET
5+
p, bob, /dataset2/folder1/*, POST
6+
p, dataset1_admin, /dataset1/*, *
7+
p, *, /login, *
8+
9+
p, anonymous, /, GET
10+
11+
g, cathy, dataset1_admin

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"authorization",
5555
"permission"
5656
],
57-
packages=find_packages(exclude=["docs", "tests*"]),
57+
packages=find_packages(exclude=["docs", "pytest*"]),
5858
data_files=[desc_file],
5959
include_package_data=True,
6060
install_requires=install_requires,

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytest

tests/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import base64
2+
import binascii
3+
4+
import casbin
5+
import pytest
6+
import os
7+
from fastapi import FastAPI
8+
from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, SimpleUser
9+
from starlette.middleware.authentication import AuthenticationMiddleware
10+
11+
from fastapi_authz import CasbinMiddleware
12+
13+
14+
def get_examples(path):
15+
examples_path = os.path.split(os.path.realpath(__file__))[0] + "/../examples/"
16+
return os.path.abspath(examples_path + path)
17+
18+
19+
class BasicAuth(AuthenticationBackend):
20+
async def authenticate(self, request):
21+
if "Authorization" not in request.headers:
22+
return None
23+
24+
auth = request.headers["Authorization"]
25+
try:
26+
scheme, credentials = auth.split()
27+
decoded = base64.b64decode(credentials).decode("ascii")
28+
except (ValueError, UnicodeDecodeError, binascii.Error):
29+
raise AuthenticationError("Invalid basic auth credentials")
30+
31+
username, _, password = decoded.partition(":")
32+
return AuthCredentials(["authenticated"]), SimpleUser(username)
33+
34+
35+
@pytest.fixture
36+
def app_fixture():
37+
enforcer = casbin.Enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv"))
38+
39+
app = FastAPI()
40+
41+
@app.on_event('startup')
42+
async def startup():
43+
app.add_middleware(CasbinMiddleware, enforcer=enforcer)
44+
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())
45+
46+
yield app

tests/test_middleware.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
from fastapi import Depends
3+
from starlette.authentication import requires
4+
from starlette.testclient import TestClient
5+
6+
7+
@pytest.mark.parametrize(
8+
"test_server_path, test_client_path, method, status_code, user, response_body", [
9+
('/dataset1/resource2', '/dataset1/resource2', 'GET', 200, 'alice', 'ok'),
10+
]
11+
)
12+
def test_middleware_authed(app_fixture, test_server_path, test_client_path, method, status_code, user, response_body):
13+
# if method == 'GET':
14+
# @app_fixture.get(test_server_path)
15+
# async def index():
16+
# return 'ok'
17+
# elif method == 'POST':
18+
# @app_fixture.post(test_server_path)
19+
# async def index():
20+
# return 'ok'
21+
# elif method == 'PUT':
22+
# @app_fixture.put(test_server_path)
23+
# async def index():
24+
# return 'ok'
25+
26+
@getattr(app_fixture, method.lower())(test_server_path)
27+
async def index():
28+
return 'ok'
29+
30+
test_client = TestClient(app_fixture)
31+
32+
test_response = test_client.get(test_client_path, auth=(user, 'password'))
33+
34+
assert test_response.status_code == status_code
35+
assert test_response.json() == response_body
36+
37+
38+
if __name__ == '__main__':
39+
pytest.main()

0 commit comments

Comments
 (0)