Skip to content

Commit 27316c7

Browse files
committed
Filter mechanism is sasl2 based on user stored passwords
1 parent 9b028c1 commit 27316c7

1 file changed

Lines changed: 61 additions & 15 deletions

File tree

src/xmpp_stream_in.erl

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
-deprecated([{stop, 1}]).
3737

38+
-include("scram.hrl").
39+
3840
%%-define(DBGFSM, true).
3941
-ifdef(DBGFSM).
4042
-define(FSMOPTS, [{debug, [trace]}]).
@@ -963,7 +965,7 @@ init_channel_bindings(#{socket := Socket} = State) ->
963965
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
964966
#{lserver := LServer} = State) ->
965967
State1 = State#{sasl_mech => Mech},
966-
Mechs = get_sasl_mechanisms(State1),
968+
Mechs = get_sasl_mechanisms(State1, sasl),
967969
case lists:member(Mech, Mechs) of
968970
true when Mech == <<"EXTERNAL">> ->
969971
Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
@@ -1076,7 +1078,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
10761078
FastMechs = try callback(fast_mechanisms, State)
10771079
catch _:{?MODULE, undef} -> []
10781080
end,
1079-
Mechs = get_sasl_mechanisms(State1),
1081+
Mechs = get_sasl_mechanisms(State1, sasl2),
10801082
MechsAll = Mechs ++ FastMechs,
10811083
UAId = case UA of
10821084
#sasl2_user_agent{id = ID} when ID /= <<>> ->
@@ -1099,7 +1101,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
10991101
process_sasl2_result(Res, State1#{sasl2_inline_els => SaslInline,
11001102
sasl2_ua_id => UAId});
11011103
true ->
1102-
GetPW = get_password_fun(Mech, State1),
1104+
GetPW = get_password_fun_sasl2(Mech, State1),
11031105
CheckPW = check_password_fun(Mech, State1),
11041106
CheckPWDigest = check_password_digest_fun(Mech, State1),
11051107
GetFastTokens = get_fast_tokens_fun(Mech, State1),
@@ -1342,9 +1344,11 @@ send_features(#{stream_version := {1,0},
13421344
{Features, State3} =
13431345
case {Encrypted, Sasl2, AllowUnencryptedSasl2, TLSAvailable} of
13441346
{false, true, true, true} ->
1345-
{get_tls_feature(State2) ++ get_sasl2_feature(State2), State2};
1347+
St = prepare_password_fun_sasl2(State2),
1348+
{get_tls_feature(St) ++ get_sasl2_feature(St), St};
13461349
{_, true, _, _} when Encrypted; AllowUnencryptedSasl2 ->
1347-
{get_sasl2_feature(State2), State2};
1350+
St = prepare_password_fun_sasl2(State2),
1351+
{get_sasl2_feature(St), St};
13481352
{false, true, false, false} ->
13491353
{[], disable_sasl2(State2)};
13501354
{false, _, _, true} ->
@@ -1377,6 +1381,21 @@ get_password_fun(Mech, State) ->
13771381
catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
13781382
end.
13791383

1384+
1385+
-spec prepare_password_fun_sasl2(state()) -> state().
1386+
prepare_password_fun_sasl2(#{sasl2_stream_from := #jid{luser = User}} = State) ->
1387+
Fun = try callback(get_password_fun, <<>>, State) of
1388+
F ->
1389+
Res = F(User),
1390+
fun(_) -> Res end
1391+
catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
1392+
end,
1393+
State#{sasl2_password_fun => Fun}.
1394+
1395+
-spec get_password_fun_sasl2(xmpp_sasl:mechanism(), state()) -> fun().
1396+
get_password_fun_sasl2(_Mech, #{sasl2_password_fun := Fun}) ->
1397+
Fun.
1398+
13801399
-spec check_password_fun(xmpp_sasl:mechanism(), state()) -> fun().
13811400
check_password_fun(Mech, State) ->
13821401
try callback(check_password_fun, Mech, State)
@@ -1395,26 +1414,53 @@ get_fast_tokens_fun(Mech, State) ->
13951414
catch _:{?MODULE, undef} -> fun(_, _) -> [] end
13961415
end.
13971416

1398-
-spec get_sasl_mechanisms(state()) -> [xmpp_sasl:mechanism()].
1417+
-spec get_sasl_mechanisms(state(), sasl | sasl2) -> {[xmpp_sasl:mechanism()], state()}.
13991418
get_sasl_mechanisms(#{stream_encrypted := Encrypted,
1400-
xmlns := NS} = State) ->
1419+
xmlns := NS} = State, Type) ->
14011420
Mechs = if NS == ?NS_CLIENT -> xmpp_sasl:listmech();
1402-
true -> []
1421+
true -> []
14031422
end,
1404-
Mechs1 = if Encrypted -> [<<"EXTERNAL">>|Mechs];
1405-
true -> Mechs
1423+
Mechs1 = if Encrypted -> [<<"EXTERNAL">> | Mechs];
1424+
true -> Mechs
1425+
end,
1426+
Mechs2 = try callback(sasl_mechanisms, Mechs1, State)
1427+
catch _:{?MODULE, undef} -> Mechs1
14061428
end,
1407-
try callback(sasl_mechanisms, Mechs1, State)
1408-
catch _:{?MODULE, undef} -> Mechs1
1429+
if Type == sasl2 ->
1430+
filter_sasl2_user_mechs(Mechs2, State);
1431+
true -> Mechs2
14091432
end.
14101433

1434+
pass_to_mech(_, all) -> all;
1435+
pass_to_mech(false, _) -> all;
1436+
pass_to_mech({false, _, _}, _) -> all;
1437+
pass_to_mech(Bin, _) when is_binary(Bin) -> all;
1438+
pass_to_mech(#scram{hash = sha}, Acc) -> [<<"SCRAM-SHA-1">>, <<"SCRAM-SHA-1-PLUS">> | Acc];
1439+
pass_to_mech(#scram{hash = sha256}, Acc) -> [<<"SCRAM-SHA-256">>, <<"SCRAM-SHA-256-PLUS">> | Acc];
1440+
pass_to_mech(#scram{hash = sha512}, Acc) -> [<<"SCRAM-SHA-512">>, <<"SCRAM-SHA-512-PLUS">> | Acc];
1441+
pass_to_mech(List, Acc) when is_list(List) ->
1442+
lists:foldl(fun pass_to_mech/2, Acc, List).
1443+
1444+
filter_sasl2_user_mechs(Mechs, State) ->
1445+
{Pass, _} = (get_password_fun_sasl2(<<>>, State))(<<>>),
1446+
case pass_to_mech(Pass, []) of
1447+
all -> Mechs;
1448+
M ->
1449+
OtherMechs = Mechs -- [<<"SCRAM-SHA-1">>, <<"SCRAM-SHA-1-PLUS">>,
1450+
<<"SCRAM-SHA-256">>, <<"SCRAM-SHA-256-PLUS">>,
1451+
<<"SCRAM-SHA-512">>, <<"SCRAM-SHA-512-PLUS">>],
1452+
ScramMechs = Mechs -- OtherMechs,
1453+
OtherMechs ++ (ScramMechs -- (ScramMechs -- M))
1454+
end.
1455+
1456+
14111457
-spec get_sasl_feature(state()) -> [sasl_mechanisms() | sasl_channel_binding()].
14121458
get_sasl_feature(#{stream_authenticated := false,
1413-
stream_encrypted := Encrypted} = State) ->
1459+
stream_encrypted := Encrypted} = State) ->
14141460
TLSRequired = is_starttls_required(State),
14151461
if
14161462
Encrypted or not TLSRequired ->
1417-
Mechs = get_sasl_mechanisms(State),
1463+
Mechs = get_sasl_mechanisms(State, sasl),
14181464
[#sasl_mechanisms{list = Mechs}] ++
14191465
case maps:get(sasl_channel_bindings, State, none) of
14201466
none -> [];
@@ -1429,7 +1475,7 @@ get_sasl_feature(_) ->
14291475

14301476
-spec get_sasl2_feature(state()) -> [sasl2_authenticaton() | sasl_channel_binding()].
14311477
get_sasl2_feature(#{stream_authenticated := false} = State) ->
1432-
Mechs = get_sasl_mechanisms(State),
1478+
Mechs = get_sasl_mechanisms(State, sasl2),
14331479

14341480
{SASL2Features, Bind2Features, ExtraFeatures} =
14351481
try callback(inline_stream_features, State)

0 commit comments

Comments
 (0)