Skip to content

Commit 7adcf67

Browse files
Add tests for put_along_axis
1 parent e88382c commit 7adcf67

1 file changed

Lines changed: 70 additions & 1 deletion

File tree

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ def test_take_along_axis_validation():
15781578
def_dtypes = info_.default_dtypes(device=x_dev)
15791579
ind_dt = def_dtypes["indexing"]
15801580
ind = dpt.zeros(1, dtype=ind_dt)
1581-
# axis valudation
1581+
# axis validation
15821582
with pytest.raises(ValueError):
15831583
dpt.take_along_axis(x, ind, axis=1)
15841584
# mode validation
@@ -1594,6 +1594,71 @@ def test_take_along_axis_validation():
15941594
dpt.take_along_axis(x, ind2)
15951595

15961596

1597+
def test_put_along_axis():
1598+
get_queue_or_skip()
1599+
1600+
n0, n1, n2 = 3, 5, 7
1601+
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
1602+
ind_dt = dpt.__array_namespace_info__().default_dtypes(
1603+
device=x.sycl_device
1604+
)["indexing"]
1605+
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
1606+
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
1607+
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)
1608+
1609+
xc = dpt.copy(x)
1610+
vals = dpt.ones(ind0.shape, dtype=x.dtype)
1611+
dpt.put_along_axis(xc, ind0, vals, axis=0)
1612+
assert dpt.all(dpt.take_along_axis(xc, ind0, axis=0) == vals)
1613+
1614+
xc = dpt.copy(x)
1615+
vals = dpt.ones(ind1.shape, dtype=x.dtype)
1616+
dpt.put_along_axis(xc, ind1, vals, axis=1)
1617+
assert dpt.all(dpt.take_along_axis(xc, ind1, axis=1) == vals)
1618+
1619+
xc = dpt.copy(x)
1620+
vals = dpt.ones(ind2.shape, dtype=x.dtype)
1621+
dpt.put_along_axis(xc, ind2, vals, axis=2)
1622+
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)
1623+
1624+
xc = dpt.copy(x)
1625+
vals = dpt.ones(ind2.shape, dtype=x.dtype)
1626+
dpt.put_along_axis(xc, ind2, dpt.asnumpy(vals), axis=2)
1627+
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)
1628+
1629+
1630+
def test_put_along_axis_validation():
1631+
# type check on the first argument
1632+
with pytest.raises(TypeError):
1633+
dpt.put_along_axis(tuple(), list(), list())
1634+
get_queue_or_skip()
1635+
n1, n2 = 2, 5
1636+
x = dpt.ones(n1 * n2)
1637+
# type check on the second argument
1638+
with pytest.raises(TypeError):
1639+
dpt.put_along_axis(x, list(), list())
1640+
x_dev = x.sycl_device
1641+
info_ = dpt.__array_namespace_info__()
1642+
def_dtypes = info_.default_dtypes(device=x_dev)
1643+
ind_dt = def_dtypes["indexing"]
1644+
ind = dpt.zeros(1, dtype=ind_dt)
1645+
vals = dpt.zeros(1, dtype=x.dtype)
1646+
# axis validation
1647+
with pytest.raises(ValueError):
1648+
dpt.put_along_axis(x, ind, vals, axis=1)
1649+
# mode validation
1650+
with pytest.raises(ValueError):
1651+
dpt.put_along_axis(x, ind, vals, axis=0, mode="invalid")
1652+
# same array-ranks validation
1653+
with pytest.raises(ValueError):
1654+
dpt.put_along_axis(dpt.reshape(x, (n1, n2)), ind, vals)
1655+
# check compute-follows-data
1656+
q2 = dpctl.SyclQueue(x_dev, property="enable_profiling")
1657+
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
1658+
with pytest.raises(ExecutionPlacementError):
1659+
dpt.put_along_axis(x, ind2, vals)
1660+
1661+
15971662
def check__extract_impl_validation(fn):
15981663
x = dpt.ones(10)
15991664
ind = dpt.ones(10, dtype="?")
@@ -1670,7 +1735,11 @@ def check__put_multi_index_validation(fn):
16701735
with pytest.raises(ValueError):
16711736
fn(x2, (ind1, ind2), 0, x2)
16721737
with pytest.raises(TypeError):
1738+
# invalid index type
16731739
fn(x2, (ind1, list()), 0, x2)
1740+
with pytest.raises(ValueError):
1741+
# invalid mode keyword value
1742+
fn(x, inds, 0, vals, mode=100)
16741743

16751744

16761745
def test__copy_utils():

0 commit comments

Comments
 (0)