@@ -965,7 +965,7 @@ init_channel_bindings(#{socket := Socket} = State) ->
965965process_sasl_request (# sasl_auth {mechanism = Mech , text = ClientIn },
966966 #{lserver := LServer } = State ) ->
967967 State1 = State #{sasl_mech => Mech },
968- Mechs = get_sasl_mechanisms (State1 , sasl ),
968+ { Mechs , PlusDisabled } = get_sasl_mechanisms (State1 , sasl ),
969969 case lists :member (Mech , Mechs ) of
970970 true when Mech == <<" EXTERNAL" >> ->
971971 Res = case xmpp_stream_pkix :authenticate (State1 , ClientIn ) of
@@ -989,7 +989,7 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
989989 end ,
990990 SASLState = xmpp_sasl :server_new (LServer , GetPW , CheckPW , CheckPWDigest , undefined ),
991991 CB = maps :get (sasl_channel_bindings , State1 , none ),
992- Res = xmpp_sasl :server_start (SASLState , Mech , ClientIn , CB , Mechs2 , undefined ),
992+ Res = xmpp_sasl :server_start (SASLState , Mech , ClientIn , CB , PlusDisabled , Mechs2 , undefined ),
993993 process_sasl_result (Res , disable_sasl2 (State1 #{sasl_state => SASLState }));
994994 false ->
995995 process_sasl_result ({error , unsupported_mechanism , <<" " >>}, disable_sasl2 (State1 ))
@@ -1078,7 +1078,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
10781078 FastMechs = try callback (fast_mechanisms , State )
10791079 catch _ :{? MODULE , undef } -> []
10801080 end ,
1081- Mechs = get_sasl_mechanisms (State1 , sasl2 ),
1081+ { Mechs , PlusDisabled } = get_sasl_mechanisms (State1 , sasl2 ),
10821082 MechsAll = Mechs ++ FastMechs ,
10831083 UAId = case UA of
10841084 # sasl2_user_agent {id = ID } when ID /= <<>> ->
@@ -1116,7 +1116,7 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
11161116 SASLState = xmpp_sasl :server_new (LServer , GetPW , CheckPW ,
11171117 CheckPWDigest , GetFastTokens ),
11181118 CB = maps :get (sasl_channel_bindings , State1 , none ),
1119- Res = xmpp_sasl :server_start (SASLState , Mech , ClientIn , CB , Mechs2 , UAId ),
1119+ Res = xmpp_sasl :server_start (SASLState , Mech , ClientIn , CB , PlusDisabled , Mechs2 , UAId ),
11201120 process_sasl2_result (Res , State1 #{sasl_state => SASLState ,
11211121 sasl2_inline_els => SaslInline ,
11221122 sasl2_ua_id => UAId });
@@ -1414,7 +1414,7 @@ get_fast_tokens_fun(Mech, State) ->
14141414 catch _ :{? MODULE , undef } -> fun (_ , _ ) -> [] end
14151415 end .
14161416
1417- -spec get_sasl_mechanisms (state (), sasl | sasl2 ) -> [xmpp_sasl :mechanism ()].
1417+ -spec get_sasl_mechanisms (state (), sasl | sasl2 ) -> { [xmpp_sasl :mechanism ()], boolean ()} .
14181418get_sasl_mechanisms (#{stream_encrypted := Encrypted ,
14191419 xmlns := NS } = State , Type ) ->
14201420 Mechs = if NS == ? NS_CLIENT -> xmpp_sasl :listmech ();
@@ -1423,15 +1423,18 @@ get_sasl_mechanisms(#{stream_encrypted := Encrypted,
14231423 Mechs1 = if Encrypted -> [<<" EXTERNAL" >> | Mechs ];
14241424 true -> Mechs
14251425 end ,
1426- Mechs2 = try callback (sasl_mechanisms , Mechs1 , State ) of
1427- {Sasl1 , _ } when Type == sasl -> Sasl1 ;
1428- {_ , Sasl2 } when Type == sasl2 -> Sasl2 ;
1429- Common -> Common
1430- catch _ :{? MODULE , undef } -> Mechs1
1431- end ,
1426+ {Mechs2 , PlusDisabled } =
1427+ try callback (sasl_mechanisms , Mechs1 , State ) of
1428+ {Sasl1 , _ } when Type == sasl -> {Sasl1 , false };
1429+ {_ , Sasl2 } when Type == sasl2 -> {Sasl2 , false };
1430+ {Sasl1 , _ , Disabled } when Type == sasl -> {Sasl1 , Disabled };
1431+ {_ , Sasl2 , Disabled } when Type == sasl2 -> {Sasl2 , Disabled };
1432+ Common -> {Common , false }
1433+ catch _ :{? MODULE , undef } -> {Mechs1 , false }
1434+ end ,
14321435 if Type == sasl2 ->
1433- filter_sasl2_user_mechs (Mechs2 , State );
1434- true -> Mechs2
1436+ filter_sasl2_user_mechs (Mechs2 , PlusDisabled , State );
1437+ true -> { Mechs2 , PlusDisabled }
14351438 end .
14361439
14371440pass_to_mech (_ , all ) -> all ;
@@ -1444,16 +1447,24 @@ pass_to_mech(#scram{hash = sha512}, Acc) -> [<<"SCRAM-SHA-512">>, <<"SCRAM-SHA-5
14441447pass_to_mech (List , Acc ) when is_list (List ) ->
14451448 lists :foldl (fun pass_to_mech /2 , Acc , List ).
14461449
1447- filter_sasl2_user_mechs (Mechs , State ) ->
1450+ filter_sasl2_user_mechs (Mechs , PlusDisabled , State ) ->
14481451 {Pass , _ } = (get_password_fun_sasl2 (<<>>, State ))(<<>>),
14491452 case pass_to_mech (Pass , []) of
1450- all -> Mechs ;
1453+ all -> { Mechs , PlusDisabled } ;
14511454 M ->
14521455 OtherMechs = Mechs -- [<<" SCRAM-SHA-1" >>, <<" SCRAM-SHA-1-PLUS" >>,
14531456 <<" SCRAM-SHA-256" >>, <<" SCRAM-SHA-256-PLUS" >>,
14541457 <<" SCRAM-SHA-512" >>, <<" SCRAM-SHA-512-PLUS" >>],
14551458 ScramMechs = Mechs -- OtherMechs ,
1456- OtherMechs ++ (ScramMechs -- (ScramMechs -- M ))
1459+ Mechs2 = OtherMechs ++ (ScramMechs -- (ScramMechs -- M )),
1460+ case PlusDisabled of
1461+ true -> {Mechs2 , true };
1462+ _ ->
1463+ PlusMechs = [<<" SCRAM-SHA-1-PLUS" >>, <<" SCRAM-SHA-256-PLUS" >>, <<" SCRAM-SHA-512-PLUS" >>],
1464+ HadPlusMechs = ScramMechs -- PlusMechs /= ScramMechs ,
1465+ HavePlusMechs = PlusMechs -- Mechs2 == PlusMechs ,
1466+ {Mechs2 , HadPlusMechs andalso not HavePlusMechs }
1467+ end
14571468 end .
14581469
14591470
@@ -1463,7 +1474,7 @@ get_sasl_feature(#{stream_authenticated := false,
14631474 TLSRequired = is_starttls_required (State ),
14641475 if
14651476 Encrypted or not TLSRequired ->
1466- Mechs = get_sasl_mechanisms (State , sasl ),
1477+ { Mechs , _ } = get_sasl_mechanisms (State , sasl ),
14671478 [# sasl_mechanisms {list = Mechs }] ++
14681479 case maps :get (sasl_channel_bindings , State , none ) of
14691480 none -> [];
@@ -1478,7 +1489,7 @@ get_sasl_feature(_) ->
14781489
14791490-spec get_sasl2_feature (state ()) -> [sasl2_authenticaton () | sasl_channel_binding ()].
14801491get_sasl2_feature (#{stream_authenticated := false } = State ) ->
1481- Mechs = get_sasl_mechanisms (State , sasl2 ),
1492+ { Mechs , _ } = get_sasl_mechanisms (State , sasl2 ),
14821493
14831494 {SASL2Features , Bind2Features , ExtraFeatures } =
14841495 try callback (inline_stream_features , State )
0 commit comments