@@ -1535,5 +1535,151 @@ def test_advanced_integer_indexing_cast_indices():
15351535 inds1 = dpt .astype (inds0 , "u4" )
15361536 inds2 = dpt .astype (inds0 , "u8" )
15371537 x = dpt .ones ((3 , 4 , 5 , 6 ), dtype = "i4" )
1538+ # test getitem
15381539 with pytest .raises (ValueError ):
15391540 x [inds0 , inds1 , inds2 , ...]
1541+ # test setitem
1542+ with pytest .raises (ValueError ):
1543+ x [inds0 , inds1 , inds2 , ...] = 1
1544+
1545+
1546+ def test_take_along_axis ():
1547+ get_queue_or_skip ()
1548+
1549+ n0 , n1 , n2 = 3 , 5 , 7
1550+ x = dpt .reshape (dpt .arange (n0 * n1 * n2 ), (n0 , n1 , n2 ))
1551+ ind_dt = dpt .__array_namespace_info__ ().default_dtypes (
1552+ device = x .sycl_device
1553+ )["indexing" ]
1554+ ind0 = dpt .ones ((1 , n1 , n2 ), dtype = ind_dt )
1555+ ind1 = dpt .ones ((n0 , 1 , n2 ), dtype = ind_dt )
1556+ ind2 = dpt .ones ((n0 , n1 , 1 ), dtype = ind_dt )
1557+
1558+ y0 = dpt .take_along_axis (x , ind0 , axis = 0 )
1559+ assert y0 .shape == ind0 .shape
1560+ y1 = dpt .take_along_axis (x , ind1 , axis = 1 )
1561+ assert y1 .shape == ind1 .shape
1562+ y2 = dpt .take_along_axis (x , ind2 , axis = 2 )
1563+ assert y2 .shape == ind2 .shape
1564+
1565+
1566+ def test_take_along_axis_validation ():
1567+ # type check on the first argument
1568+ with pytest .raises (TypeError ):
1569+ dpt .take_along_axis (tuple (), list ())
1570+ get_queue_or_skip ()
1571+ n1 , n2 = 2 , 5
1572+ x = dpt .ones (n1 * n2 )
1573+ # type check on the second argument
1574+ with pytest .raises (TypeError ):
1575+ dpt .take_along_axis (x , list ())
1576+ x_dev = x .sycl_device
1577+ info_ = dpt .__array_namespace_info__ ()
1578+ def_dtypes = info_ .default_dtypes (device = x_dev )
1579+ ind_dt = def_dtypes ["indexing" ]
1580+ ind = dpt .zeros (1 , dtype = ind_dt )
1581+ # axis valudation
1582+ with pytest .raises (ValueError ):
1583+ dpt .take_along_axis (x , ind , axis = 1 )
1584+ # mode validation
1585+ with pytest .raises (ValueError ):
1586+ dpt .take_along_axis (x , ind , axis = 0 , mode = "invalid" )
1587+ # same array-ranks validation
1588+ with pytest .raises (ValueError ):
1589+ dpt .take_along_axis (dpt .reshape (x , (n1 , n2 )), ind )
1590+ # check compute-follows-data
1591+ q2 = dpctl .SyclQueue (x_dev , property = "enable_profiling" )
1592+ ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1593+ with pytest .raises (ExecutionPlacementError ):
1594+ dpt .take_along_axis (x , ind2 )
1595+
1596+
1597+ def check__extract_impl_validation (fn ):
1598+ x = dpt .ones (10 )
1599+ ind = dpt .ones (10 , dtype = "?" )
1600+ with pytest .raises (TypeError ):
1601+ fn (list (), ind )
1602+ with pytest .raises (TypeError ):
1603+ fn (x , list ())
1604+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1605+ ind2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1606+ with pytest .raises (ExecutionPlacementError ):
1607+ fn (x , ind2 )
1608+ with pytest .raises (ValueError ):
1609+ fn (x , ind , 1 )
1610+
1611+
1612+ def check__nonzero_impl_validation (fn ):
1613+ with pytest .raises (TypeError ):
1614+ fn (list ())
1615+
1616+
1617+ def check__take_multi_index (fn ):
1618+ x = dpt .ones (10 )
1619+ x_dev = x .sycl_device
1620+ info_ = dpt .__array_namespace_info__ ()
1621+ def_dtypes = info_ .default_dtypes (device = x_dev )
1622+ ind_dt = def_dtypes ["indexing" ]
1623+ ind = dpt .arange (10 , dtype = ind_dt )
1624+ with pytest .raises (TypeError ):
1625+ fn (list (), tuple (), 1 )
1626+ with pytest .raises (ValueError ):
1627+ fn (x , (ind ,), 0 , mode = 2 )
1628+ with pytest .raises (ValueError ):
1629+ fn (x , (None ,), 1 )
1630+ with pytest .raises (IndexError ):
1631+ fn (x , (x ,), 1 )
1632+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1633+ ind2 = dpt .arange (10 , dtype = ind_dt , sycl_queue = q2 )
1634+ with pytest .raises (ExecutionPlacementError ):
1635+ fn (x , (ind2 ,), 0 )
1636+ m = dpt .ones ((10 , 10 ))
1637+ ind_1 = dpt .arange (10 , dtype = "i8" )
1638+ ind_2 = dpt .arange (10 , dtype = "u8" )
1639+ with pytest .raises (ValueError ):
1640+ fn (m , (ind_1 , ind_2 ), 0 )
1641+
1642+
1643+ def check__place_impl_validation (fn ):
1644+ with pytest .raises (TypeError ):
1645+ fn (list (), list (), list ())
1646+ x = dpt .ones (10 )
1647+ with pytest .raises (TypeError ):
1648+ fn (x , list (), list ())
1649+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1650+ mask2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1651+ with pytest .raises (ExecutionPlacementError ):
1652+ fn (x , mask2 , 1 )
1653+ x2 = dpt .ones ((5 , 5 ))
1654+ mask2 = dpt .ones ((5 , 5 ), dtype = "?" )
1655+ with pytest .raises (ValueError ):
1656+ fn (x2 , mask2 , x2 , axis = 1 )
1657+
1658+
1659+ def check__put_multi_index_validation (fn ):
1660+ with pytest .raises (TypeError ):
1661+ fn (list (), list (), 0 , list ())
1662+ x = dpt .ones (10 )
1663+ inds = dpt .arange (10 , dtype = "i8" )
1664+ vals = dpt .zeros (10 )
1665+ # test inds which is not a tuple/list
1666+ fn (x , inds , 0 , vals )
1667+ x2 = dpt .ones ((5 , 5 ))
1668+ ind1 = dpt .arange (5 , dtype = "i8" )
1669+ ind2 = dpt .arange (5 , dtype = "u8" )
1670+ with pytest .raises (ValueError ):
1671+ fn (x2 , (ind1 , ind2 ), 0 , x2 )
1672+ with pytest .raises (TypeError ):
1673+ fn (x2 , (ind1 , list ()), 0 , x2 )
1674+
1675+
1676+ def test__copy_utils ():
1677+ import dpctl .tensor ._copy_utils as cu
1678+
1679+ get_queue_or_skip ()
1680+
1681+ check__extract_impl_validation (cu ._extract_impl )
1682+ check__nonzero_impl_validation (cu ._nonzero_impl )
1683+ check__take_multi_index (cu ._take_multi_index )
1684+ check__place_impl_validation (cu ._place_impl )
1685+ check__put_multi_index_validation (cu ._put_multi_index )
0 commit comments