@@ -1578,7 +1578,7 @@ def test_take_along_axis_validation():
15781578 def_dtypes = info_ .default_dtypes (device = x_dev )
15791579 ind_dt = def_dtypes ["indexing" ]
15801580 ind = dpt .zeros (1 , dtype = ind_dt )
1581- # axis valudation
1581+ # axis validation
15821582 with pytest .raises (ValueError ):
15831583 dpt .take_along_axis (x , ind , axis = 1 )
15841584 # mode validation
@@ -1594,6 +1594,116 @@ def test_take_along_axis_validation():
15941594 dpt .take_along_axis (x , ind2 )
15951595
15961596
1597+ def test_put_along_axis ():
1598+ get_queue_or_skip ()
1599+
1600+ n0 , n1 , n2 = 3 , 5 , 7
1601+ x = dpt .reshape (dpt .arange (n0 * n1 * n2 ), (n0 , n1 , n2 ))
1602+ ind_dt = dpt .__array_namespace_info__ ().default_dtypes (
1603+ device = x .sycl_device
1604+ )["indexing" ]
1605+ ind0 = dpt .ones ((1 , n1 , n2 ), dtype = ind_dt )
1606+ ind1 = dpt .ones ((n0 , 1 , n2 ), dtype = ind_dt )
1607+ ind2 = dpt .ones ((n0 , n1 , 1 ), dtype = ind_dt )
1608+
1609+ xc = dpt .copy (x )
1610+ vals = dpt .ones (ind0 .shape , dtype = x .dtype )
1611+ dpt .put_along_axis (xc , ind0 , vals , axis = 0 )
1612+ assert dpt .all (dpt .take_along_axis (xc , ind0 , axis = 0 ) == vals )
1613+
1614+ xc = dpt .copy (x )
1615+ vals = dpt .ones (ind1 .shape , dtype = x .dtype )
1616+ dpt .put_along_axis (xc , ind1 , vals , axis = 1 )
1617+ assert dpt .all (dpt .take_along_axis (xc , ind1 , axis = 1 ) == vals )
1618+
1619+ xc = dpt .copy (x )
1620+ vals = dpt .ones (ind2 .shape , dtype = x .dtype )
1621+ dpt .put_along_axis (xc , ind2 , vals , axis = 2 )
1622+ assert dpt .all (dpt .take_along_axis (xc , ind2 , axis = 2 ) == vals )
1623+
1624+ xc = dpt .copy (x )
1625+ vals = dpt .ones (ind2 .shape , dtype = x .dtype )
1626+ dpt .put_along_axis (xc , ind2 , dpt .asnumpy (vals ), axis = 2 )
1627+ assert dpt .all (dpt .take_along_axis (xc , ind2 , axis = 2 ) == vals )
1628+
1629+
1630+ def test_put_along_axis_validation ():
1631+ # type check on the first argument
1632+ with pytest .raises (TypeError ):
1633+ dpt .put_along_axis (tuple (), list (), list ())
1634+ get_queue_or_skip ()
1635+ n1 , n2 = 2 , 5
1636+ x = dpt .ones (n1 * n2 )
1637+ # type check on the second argument
1638+ with pytest .raises (TypeError ):
1639+ dpt .put_along_axis (x , list (), list ())
1640+ x_dev = x .sycl_device
1641+ info_ = dpt .__array_namespace_info__ ()
1642+ def_dtypes = info_ .default_dtypes (device = x_dev )
1643+ ind_dt = def_dtypes ["indexing" ]
1644+ ind = dpt .zeros (1 , dtype = ind_dt )
1645+ vals = dpt .zeros (1 , dtype = x .dtype )
1646+ # axis validation
1647+ with pytest .raises (ValueError ):
1648+ dpt .put_along_axis (x , ind , vals , axis = 1 )
1649+ # mode validation
1650+ with pytest .raises (ValueError ):
1651+ dpt .put_along_axis (x , ind , vals , axis = 0 , mode = "invalid" )
1652+ # same array-ranks validation
1653+ with pytest .raises (ValueError ):
1654+ dpt .put_along_axis (dpt .reshape (x , (n1 , n2 )), ind , vals )
1655+ # check compute-follows-data
1656+ q2 = dpctl .SyclQueue (x_dev , property = "enable_profiling" )
1657+ ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1658+ with pytest .raises (ExecutionPlacementError ):
1659+ dpt .put_along_axis (x , ind2 , vals )
1660+
1661+
1662+ def test_put_along_axis_application ():
1663+ get_queue_or_skip ()
1664+ info_ = dpt .__array_namespace_info__ ()
1665+ def_dtypes = info_ .default_dtypes (device = None )
1666+ ind_dt = def_dtypes ["indexing" ]
1667+ all_perms = dpt .asarray (
1668+ [
1669+ [0 , 1 , 2 , 3 ],
1670+ [0 , 2 , 1 , 3 ],
1671+ [2 , 0 , 1 , 3 ],
1672+ [2 , 1 , 0 , 3 ],
1673+ [1 , 0 , 2 , 3 ],
1674+ [1 , 2 , 0 , 3 ],
1675+ [0 , 1 , 3 , 2 ],
1676+ [0 , 2 , 3 , 1 ],
1677+ [2 , 0 , 3 , 1 ],
1678+ [2 , 1 , 3 , 0 ],
1679+ [1 , 0 , 3 , 2 ],
1680+ [1 , 2 , 3 , 0 ],
1681+ [0 , 3 , 1 , 2 ],
1682+ [0 , 3 , 2 , 1 ],
1683+ [2 , 3 , 0 , 1 ],
1684+ [2 , 3 , 1 , 0 ],
1685+ [1 , 3 , 0 , 2 ],
1686+ [1 , 3 , 2 , 0 ],
1687+ [3 , 0 , 1 , 2 ],
1688+ [3 , 0 , 2 , 1 ],
1689+ [3 , 2 , 0 , 1 ],
1690+ [3 , 2 , 1 , 0 ],
1691+ [3 , 1 , 0 , 2 ],
1692+ [3 , 1 , 2 , 0 ],
1693+ ],
1694+ dtype = ind_dt ,
1695+ )
1696+ p_mats = dpt .zeros ((24 , 4 , 4 ), dtype = dpt .int64 )
1697+ vals = dpt .ones ((24 , 4 , 1 ), dtype = p_mats .dtype )
1698+ # form 24 permutation matrices
1699+ dpt .put_along_axis (p_mats , all_perms [..., dpt .newaxis ], vals , axis = 2 )
1700+ p2 = p_mats @ p_mats
1701+ p4 = p2 @ p2
1702+ p8 = p4 @ p4
1703+ expected = dpt .eye (4 , dtype = p_mats .dtype )[dpt .newaxis , ...]
1704+ assert dpt .all (p8 @ p4 == expected )
1705+
1706+
15971707def check__extract_impl_validation (fn ):
15981708 x = dpt .ones (10 )
15991709 ind = dpt .ones (10 , dtype = "?" )
@@ -1670,7 +1780,11 @@ def check__put_multi_index_validation(fn):
16701780 with pytest .raises (ValueError ):
16711781 fn (x2 , (ind1 , ind2 ), 0 , x2 )
16721782 with pytest .raises (TypeError ):
1783+ # invalid index type
16731784 fn (x2 , (ind1 , list ()), 0 , x2 )
1785+ with pytest .raises (ValueError ):
1786+ # invalid mode keyword value
1787+ fn (x , inds , 0 , vals , mode = 100 )
16741788
16751789
16761790def test__copy_utils ():
0 commit comments