@@ -1578,7 +1578,7 @@ def test_take_along_axis_validation():
15781578 def_dtypes = info_ .default_dtypes (device = x_dev )
15791579 ind_dt = def_dtypes ["indexing" ]
15801580 ind = dpt .zeros (1 , dtype = ind_dt )
1581- # axis valudation
1581+ # axis validation
15821582 with pytest .raises (ValueError ):
15831583 dpt .take_along_axis (x , ind , axis = 1 )
15841584 # mode validation
@@ -1594,6 +1594,71 @@ def test_take_along_axis_validation():
15941594 dpt .take_along_axis (x , ind2 )
15951595
15961596
1597+ def test_put_along_axis ():
1598+ get_queue_or_skip ()
1599+
1600+ n0 , n1 , n2 = 3 , 5 , 7
1601+ x = dpt .reshape (dpt .arange (n0 * n1 * n2 ), (n0 , n1 , n2 ))
1602+ ind_dt = dpt .__array_namespace_info__ ().default_dtypes (
1603+ device = x .sycl_device
1604+ )["indexing" ]
1605+ ind0 = dpt .ones ((1 , n1 , n2 ), dtype = ind_dt )
1606+ ind1 = dpt .ones ((n0 , 1 , n2 ), dtype = ind_dt )
1607+ ind2 = dpt .ones ((n0 , n1 , 1 ), dtype = ind_dt )
1608+
1609+ xc = dpt .copy (x )
1610+ vals = dpt .ones (ind0 .shape , dtype = x .dtype )
1611+ dpt .put_along_axis (xc , ind0 , vals , axis = 0 )
1612+ assert dpt .all (dpt .take_along_axis (xc , ind0 , axis = 0 ) == vals )
1613+
1614+ xc = dpt .copy (x )
1615+ vals = dpt .ones (ind1 .shape , dtype = x .dtype )
1616+ dpt .put_along_axis (xc , ind1 , vals , axis = 1 )
1617+ assert dpt .all (dpt .take_along_axis (xc , ind1 , axis = 1 ) == vals )
1618+
1619+ xc = dpt .copy (x )
1620+ vals = dpt .ones (ind2 .shape , dtype = x .dtype )
1621+ dpt .put_along_axis (xc , ind2 , vals , axis = 2 )
1622+ assert dpt .all (dpt .take_along_axis (xc , ind2 , axis = 2 ) == vals )
1623+
1624+ xc = dpt .copy (x )
1625+ vals = dpt .ones (ind2 .shape , dtype = x .dtype )
1626+ dpt .put_along_axis (xc , ind2 , dpt .asnumpy (vals ), axis = 2 )
1627+ assert dpt .all (dpt .take_along_axis (xc , ind2 , axis = 2 ) == vals )
1628+
1629+
1630+ def test_put_along_axis_validation ():
1631+ # type check on the first argument
1632+ with pytest .raises (TypeError ):
1633+ dpt .put_along_axis (tuple (), list (), list ())
1634+ get_queue_or_skip ()
1635+ n1 , n2 = 2 , 5
1636+ x = dpt .ones (n1 * n2 )
1637+ # type check on the second argument
1638+ with pytest .raises (TypeError ):
1639+ dpt .put_along_axis (x , list (), list ())
1640+ x_dev = x .sycl_device
1641+ info_ = dpt .__array_namespace_info__ ()
1642+ def_dtypes = info_ .default_dtypes (device = x_dev )
1643+ ind_dt = def_dtypes ["indexing" ]
1644+ ind = dpt .zeros (1 , dtype = ind_dt )
1645+ vals = dpt .zeros (1 , dtype = x .dtype )
1646+ # axis validation
1647+ with pytest .raises (ValueError ):
1648+ dpt .put_along_axis (x , ind , vals , axis = 1 )
1649+ # mode validation
1650+ with pytest .raises (ValueError ):
1651+ dpt .put_along_axis (x , ind , vals , axis = 0 , mode = "invalid" )
1652+ # same array-ranks validation
1653+ with pytest .raises (ValueError ):
1654+ dpt .put_along_axis (dpt .reshape (x , (n1 , n2 )), ind , vals )
1655+ # check compute-follows-data
1656+ q2 = dpctl .SyclQueue (x_dev , property = "enable_profiling" )
1657+ ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1658+ with pytest .raises (ExecutionPlacementError ):
1659+ dpt .put_along_axis (x , ind2 , vals )
1660+
1661+
15971662def check__extract_impl_validation (fn ):
15981663 x = dpt .ones (10 )
15991664 ind = dpt .ones (10 , dtype = "?" )
@@ -1670,7 +1735,11 @@ def check__put_multi_index_validation(fn):
16701735 with pytest .raises (ValueError ):
16711736 fn (x2 , (ind1 , ind2 ), 0 , x2 )
16721737 with pytest .raises (TypeError ):
1738+ # invalid index type
16731739 fn (x2 , (ind1 , list ()), 0 , x2 )
1740+ with pytest .raises (ValueError ):
1741+ # invalid mode keyword value
1742+ fn (x , inds , 0 , vals , mode = 100 )
16741743
16751744
16761745def test__copy_utils ():
0 commit comments