3535
3636-deprecated ([{stop , 1 }]).
3737
38+ -include (" scram.hrl" ).
39+
3840% %-define(DBGFSM, true).
3941-ifdef (DBGFSM ).
4042-define (FSMOPTS , [{debug , [trace ]}]).
@@ -963,7 +965,7 @@ init_channel_bindings(#{socket := Socket} = State) ->
963965process_sasl_request (# sasl_auth {mechanism = Mech , text = ClientIn },
964966 #{lserver := LServer } = State ) ->
965967 State1 = State #{sasl_mech => Mech },
966- Mechs = get_sasl_mechanisms (State1 ),
968+ Mechs = get_sasl_mechanisms (State1 , sasl ),
967969 case lists :member (Mech , Mechs ) of
968970 true when Mech == <<" EXTERNAL" >> ->
969971 Res = case xmpp_stream_pkix :authenticate (State1 , ClientIn ) of
@@ -1076,7 +1078,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
10761078 FastMechs = try callback (fast_mechanisms , State )
10771079 catch _ :{? MODULE , undef } -> []
10781080 end ,
1079- Mechs = get_sasl_mechanisms (State1 ),
1081+ Mechs = get_sasl_mechanisms (State1 , sasl2 ),
10801082 MechsAll = Mechs ++ FastMechs ,
10811083 UAId = case UA of
10821084 # sasl2_user_agent {id = ID } when ID /= <<>> ->
@@ -1099,7 +1101,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
10991101 process_sasl2_result (Res , State1 #{sasl2_inline_els => SaslInline ,
11001102 sasl2_ua_id => UAId });
11011103 true ->
1102- GetPW = get_password_fun (Mech , State1 ),
1104+ GetPW = get_password_fun_sasl2 (Mech , State1 ),
11031105 CheckPW = check_password_fun (Mech , State1 ),
11041106 CheckPWDigest = check_password_digest_fun (Mech , State1 ),
11051107 GetFastTokens = get_fast_tokens_fun (Mech , State1 ),
@@ -1342,9 +1344,11 @@ send_features(#{stream_version := {1,0},
13421344 {Features , State3 } =
13431345 case {Encrypted , Sasl2 , AllowUnencryptedSasl2 , TLSAvailable } of
13441346 {false , true , true , true } ->
1345- {get_tls_feature (State2 ) ++ get_sasl2_feature (State2 ), State2 };
1347+ St = prepare_password_fun_sasl2 (State2 ),
1348+ {get_tls_feature (St ) ++ get_sasl2_feature (St ), St };
13461349 {_ , true , _ , _ } when Encrypted ; AllowUnencryptedSasl2 ->
1347- {get_sasl2_feature (State2 ), State2 };
1350+ St = prepare_password_fun_sasl2 (State2 ),
1351+ {get_sasl2_feature (St ), St };
13481352 {false , true , false , false } ->
13491353 {[], disable_sasl2 (State2 )};
13501354 {false , _ , _ , true } ->
@@ -1377,6 +1381,21 @@ get_password_fun(Mech, State) ->
13771381 catch _ :{? MODULE , undef } -> fun (_ ) -> {false , undefined } end
13781382 end .
13791383
1384+
1385+ -spec prepare_password_fun_sasl2 (state ()) -> state ().
1386+ prepare_password_fun_sasl2 (#{sasl2_stream_from := # jid {luser = User }} = State ) ->
1387+ Fun = try callback (get_password_fun , <<>>, State ) of
1388+ F ->
1389+ Res = F (User ),
1390+ fun (_ ) -> Res end
1391+ catch _ :{? MODULE , undef } -> fun (_ ) -> {false , undefined } end
1392+ end ,
1393+ State #{sasl2_password_fun => Fun }.
1394+
1395+ -spec get_password_fun_sasl2 (xmpp_sasl :mechanism (), state ()) -> fun ().
1396+ get_password_fun_sasl2 (_Mech , #{sasl2_password_fun := Fun }) ->
1397+ Fun .
1398+
13801399-spec check_password_fun (xmpp_sasl :mechanism (), state ()) -> fun ().
13811400check_password_fun (Mech , State ) ->
13821401 try callback (check_password_fun , Mech , State )
@@ -1395,26 +1414,53 @@ get_fast_tokens_fun(Mech, State) ->
13951414 catch _ :{? MODULE , undef } -> fun (_ , _ ) -> [] end
13961415 end .
13971416
1398- -spec get_sasl_mechanisms (state ()) -> [xmpp_sasl :mechanism ()].
1417+ -spec get_sasl_mechanisms (state (), sasl | sasl2 ) -> { [xmpp_sasl :mechanism ()], state ()} .
13991418get_sasl_mechanisms (#{stream_encrypted := Encrypted ,
1400- xmlns := NS } = State ) ->
1419+ xmlns := NS } = State , Type ) ->
14011420 Mechs = if NS == ? NS_CLIENT -> xmpp_sasl :listmech ();
1402- true -> []
1421+ true -> []
14031422 end ,
1404- Mechs1 = if Encrypted -> [<<" EXTERNAL" >>|Mechs ];
1405- true -> Mechs
1423+ Mechs1 = if Encrypted -> [<<" EXTERNAL" >> | Mechs ];
1424+ true -> Mechs
1425+ end ,
1426+ Mechs2 = try callback (sasl_mechanisms , Mechs1 , State )
1427+ catch _ :{? MODULE , undef } -> Mechs1
14061428 end ,
1407- try callback (sasl_mechanisms , Mechs1 , State )
1408- catch _ :{? MODULE , undef } -> Mechs1
1429+ if Type == sasl2 ->
1430+ filter_sasl2_user_mechs (Mechs2 , State );
1431+ true -> Mechs2
14091432 end .
14101433
1434+ pass_to_mech (_ , all ) -> all ;
1435+ pass_to_mech (false , _ ) -> all ;
1436+ pass_to_mech ({false , _ , _ }, _ ) -> all ;
1437+ pass_to_mech (Bin , _ ) when is_binary (Bin ) -> all ;
1438+ pass_to_mech (# scram {hash = sha }, Acc ) -> [<<" SCRAM-SHA-1" >>, <<" SCRAM-SHA-1-PLUS" >> | Acc ];
1439+ pass_to_mech (# scram {hash = sha256 }, Acc ) -> [<<" SCRAM-SHA-256" >>, <<" SCRAM-SHA-256-PLUS" >> | Acc ];
1440+ pass_to_mech (# scram {hash = sha512 }, Acc ) -> [<<" SCRAM-SHA-512" >>, <<" SCRAM-SHA-512-PLUS" >> | Acc ];
1441+ pass_to_mech (List , Acc ) when is_list (List ) ->
1442+ lists :foldl (fun pass_to_mech /2 , Acc , List ).
1443+
1444+ filter_sasl2_user_mechs (Mechs , State ) ->
1445+ {Pass , _ } = (get_password_fun_sasl2 (<<>>, State ))(<<>>),
1446+ case pass_to_mech (Pass , []) of
1447+ all -> Mechs ;
1448+ M ->
1449+ OtherMechs = Mechs -- [<<" SCRAM-SHA-1" >>, <<" SCRAM-SHA-1-PLUS" >>,
1450+ <<" SCRAM-SHA-256" >>, <<" SCRAM-SHA-256-PLUS" >>,
1451+ <<" SCRAM-SHA-512" >>, <<" SCRAM-SHA-512-PLUS" >>],
1452+ ScramMechs = Mechs -- OtherMechs ,
1453+ OtherMechs ++ (ScramMechs -- (ScramMechs -- M ))
1454+ end .
1455+
1456+
14111457-spec get_sasl_feature (state ()) -> [sasl_mechanisms () | sasl_channel_binding ()].
14121458get_sasl_feature (#{stream_authenticated := false ,
1413- stream_encrypted := Encrypted } = State ) ->
1459+ stream_encrypted := Encrypted } = State ) ->
14141460 TLSRequired = is_starttls_required (State ),
14151461 if
14161462 Encrypted or not TLSRequired ->
1417- Mechs = get_sasl_mechanisms (State ),
1463+ Mechs = get_sasl_mechanisms (State , sasl ),
14181464 [# sasl_mechanisms {list = Mechs }] ++
14191465 case maps :get (sasl_channel_bindings , State , none ) of
14201466 none -> [];
@@ -1429,7 +1475,7 @@ get_sasl_feature(_) ->
14291475
14301476-spec get_sasl2_feature (state ()) -> [sasl2_authenticaton () | sasl_channel_binding ()].
14311477get_sasl2_feature (#{stream_authenticated := false } = State ) ->
1432- Mechs = get_sasl_mechanisms (State ),
1478+ Mechs = get_sasl_mechanisms (State , sasl2 ),
14331479
14341480 {SASL2Features , Bind2Features , ExtraFeatures } =
14351481 try callback (inline_stream_features , State )
0 commit comments