diff --git a/distributed/client.py b/distributed/client.py index 0d41254a044..9b00b6a8196 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -6300,10 +6300,16 @@ def __init__(self, client=None, plot=False, filename="task-stream.html"): self._filename = filename self.figure = None self.client = client or default_client() - self.client.get_task_stream(start=0, stop=0) # ensure plugin + self._init = False def __enter__(self): - self.start = time() + if not self._init: + self.client.get_task_stream(start=0, stop=0) # ensure plugin + self._init = True + + # Smooth over time differences of client vs. workers + # FIXME this is very crude. We should query TaskStreamPlugin.index instead. + self.start = time() - 0.1 return self def __exit__(self, exc_type, exc_value, traceback): @@ -6315,6 +6321,13 @@ def __exit__(self, exc_type, exc_value, traceback): self.data.extend(L) async def __aenter__(self): + if not self._init: + await self.client.get_task_stream(start=0, stop=0) # ensure plugin + self._init = True + + # Smooth over time differences of client vs. workers + # FIXME this is very crude. We should query TaskStreamPlugin.index instead. + self.start = time() - 0.1 return self async def __aexit__(self, exc_type, exc_value, traceback): diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index c632549f389..7fa21636c1a 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -31,31 +31,24 @@ def __init__(self, scheduler, maxlen=None): self.index = 0 def transition(self, key, start, finish, *args, **kwargs): - if start == "processing": - if key not in self.scheduler.tasks: - return - if not kwargs.get("startstops"): - # Other methods require `kwargs` to have a non-empty list of `startstops` - return + if start == "processing" and finish in ("memory", "erred"): + assert kwargs["startstops"] kwargs["key"] = key - if finish == "memory" or finish == "erred": - self.buffer.append(kwargs) - self.index += 1 + self.buffer.append(kwargs) + self.index += 1 def collect(self, start=None, stop=None, count=None): def bisect(target, left, right): - if left == right: - return left - - mid = (left + right) // 2 - value = max( - startstop["stop"] for startstop in self.buffer[mid]["startstops"] - ) - - if value < target: - return bisect(target, mid + 1, right) - else: - return bisect(target, left, mid) + while left != right: + mid = (left + right) // 2 + stop = max( + startstop["stop"] for startstop in self.buffer[mid]["startstops"] + ) + if stop < target: + left = mid + 1 + else: + right = mid + return left if isinstance(start, str): start = time() - parse_timedelta(start) diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 0e0206952e3..e7daab7d3f2 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -1,8 +1,5 @@ from __future__ import annotations -import os -from time import sleep - import pytest from tlz import frequencies @@ -85,50 +82,36 @@ async def test_collect(c, s, a, b): assert tasks.collect(start=start, count=3) == list(tasks.buffer)[:3] -@gen_cluster(client=True) -async def test_no_startstops(c, s, a, b): - tasks = TaskStreamPlugin(s) - s.add_plugin(tasks) - # just to create the key on the scheduler - future = c.submit(inc, 1) - await wait(future) - assert len(tasks.buffer) == 1 - - tasks.transition(future.key, "processing", "erred", stimulus_id="s1") - # Transition was not recorded because it didn't contain `startstops` - assert len(tasks.buffer) == 1 - - tasks.transition(future.key, "processing", "erred", stimulus_id="s2", startstops=[]) - # Transition was not recorded because `startstops` was empty - assert len(tasks.buffer) == 1 - - tasks.transition( - future.key, - "processing", - "erred", - stimulus_id="s3", - startstops=[dict(start=time(), stop=time())], - ) - assert len(tasks.buffer) == 2 - - @gen_cluster(client=True) async def test_client(c, s, a, b): - L = await c.get_task_stream() - assert L == () + await c.get_task_stream() - futures = c.map(slowinc, range(10), delay=0.1) + futures = c.map(inc, range(10)) await wait(futures) - - tasks = s.plugins[TaskStreamPlugin.name] - L = await c.get_task_stream() - assert L == tuple(tasks.buffer) + data = await c.get_task_stream() + assert len(data) == 10 def test_client_sync(client): - with get_task_stream(client=client) as ts: - sleep(0.1) # to smooth over time differences on the scheduler - # to smooth over time differences on the scheduler + client.get_task_stream() + + futures = client.map(inc, range(10)) + wait(futures) + data = client.get_task_stream() + assert len(data) == 10 + + +@gen_cluster(client=True) +async def test_client_ctx(c, s, a, b): + async with get_task_stream() as ts: + futures = c.map(inc, range(10)) + await wait(futures) + + assert len(ts.data) == 10 + + +def test_client_ctx_sync(client): + with get_task_stream() as ts: futures = client.map(inc, range(10)) wait(futures) @@ -140,23 +123,29 @@ async def test_get_task_stream_plot(c, s, a, b): bkm = pytest.importorskip("bokeh.models") await c.get_task_stream() - futures = c.map(slowinc, range(10), delay=0.1) + futures = c.map(inc, range(10)) await wait(futures) data, figure = await c.get_task_stream(plot=True) + assert len(data) == 10 assert isinstance(figure, bkm.Plot) -def test_get_task_stream_save(client, tmp_path): +@gen_cluster(client=True) +async def test_get_task_stream_save(c, s, a, b, tmp_path): bkm = pytest.importorskip("bokeh.models") - tmpdir = str(tmp_path) - fn = os.path.join(tmpdir, "foo.html") + await c.get_task_stream() + + futures = c.map(inc, range(10)) + await wait(futures) + + fn = str(tmp_path / "foo.html") + data, figure = await c.get_task_stream(plot="save", filename=fn) + assert len(data) == 10 - with get_task_stream(plot="save", filename=fn) as ts: - wait(client.map(inc, range(10))) with open(fn) as f: data = f.read() assert "inc" in data assert "bokeh" in data - assert isinstance(ts.figure, bkm.Plot) + assert isinstance(figure, bkm.Plot)