Skip to content

Commit cd7cb80

Browse files
authored
🥅 Handle cancellation (#94)
* 💡 Update test comments on CI Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ⏪ Put back CI comment Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ♻️ Separate integration tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🥅 Handle cancellations Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ✅ Add tests for shutdown Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🔧 Configure probes Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ⚡ Cache envoy protos for integration tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
1 parent 288620d commit cd7cb80

6 files changed

Lines changed: 304 additions & 95 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,26 @@ jobs:
8585
- name: Install uv
8686
run: pip install uv
8787

88+
# Cache compiled protobuf files across CI runs
89+
- name: Extract proto commit hash
90+
id: proto-hash
91+
run: |
92+
echo "hash=$(grep 'ENVOY_DATA_PLANE_COMMIT=' proto-build.sh | cut -d'"' -f2)" >> "$GITHUB_OUTPUT"
93+
94+
- name: Cache protobuf files
95+
id: proto-cache
96+
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
97+
with:
98+
path: |
99+
src/envoy
100+
src/xds
101+
src/validate
102+
src/udpa
103+
key: protos-${{ steps.proto-hash.outputs.hash }}
104+
88105
# Build generated protos (gitignored, needed for real envoy imports)
89106
- name: Build protobuf files
107+
if: steps.proto-cache.outputs.cache-hit != 'true'
90108
run: |
91109
uv sync --group proto
92110
USE_HTTPS=true ./proto-build.sh

ext-proc.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ spec:
2727
labels:
2828
app: plugins-adapter
2929
spec:
30+
# Allow 35s for graceful shutdown: 5s preStop + 15s gRPC drain + margin
31+
terminationGracePeriodSeconds: 35
3032
securityContext:
3133
runAsNonRoot: true
3234
runAsUser: 1000
@@ -67,3 +69,23 @@ spec:
6769
value: "./"
6870
ports:
6971
- containerPort: 50052
72+
lifecycle:
73+
preStop:
74+
exec:
75+
# Delay SIGTERM so Envoy/Istio can remove this pod from
76+
# its upstream list before we start draining streams.
77+
command: ["/bin/sleep", "5"]
78+
# gRPC health probes rely on the grpc-health-checking service
79+
# registered in serve()
80+
readinessProbe:
81+
grpc:
82+
port: 50052
83+
initialDelaySeconds: 5
84+
periodSeconds: 10
85+
failureThreshold: 3
86+
livenessProbe:
87+
grpc:
88+
port: 50052
89+
initialDelaySeconds: 10
90+
periodSeconds: 30
91+
failureThreshold: 3

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ requires-python = ">=3.11"
66
dependencies = [
77
"grpcio>=1.78.0",
88
"grpcio-tools>=1.78.0",
9+
"grpcio-health-checking>=1.78.0",
910
"betterproto2==0.9.1",
1011
"cpex==0.1.0.dev10",
1112
]

src/server.py

Lines changed: 119 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import os
6+
import signal
67
from typing import AsyncIterator
78

89
import grpc
@@ -23,6 +24,8 @@
2324
from envoy.service.ext_proc.v3 import external_processor_pb2 as ep
2425
from envoy.service.ext_proc.v3 import external_processor_pb2_grpc as ep_grpc
2526
from envoy.type.v3 import http_status_pb2 as http_status_pb2
27+
from grpc_health.v1 import health as grpc_health
28+
from grpc_health.v1 import health_pb2, health_pb2_grpc
2629

2730
# ============================================================================
2831
# LOGGING CONFIGURATION
@@ -330,108 +333,111 @@ async def Process(
330333
req_body_buf = bytearray()
331334
resp_body_buf = bytearray()
332335

333-
async for request in request_iterator:
334-
# ----------------------------------------------------------------
335-
# Request Headers Processing
336-
# ----------------------------------------------------------------
337-
if request.HasField("request_headers"):
338-
_headers = request.request_headers.headers
339-
yield ep.ProcessingResponse(
340-
request_headers=ep.HeadersResponse(
341-
response=ep.CommonResponse(
342-
header_mutation=ep.HeaderMutation(
343-
set_headers=[
344-
core.HeaderValueOption(
345-
header=core.HeaderValue(
346-
key="x-ext-proc-header",
347-
raw_value="hello-from-ext-proc".encode("utf-8"),
348-
),
349-
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
350-
)
351-
]
336+
try:
337+
async for request in request_iterator:
338+
# ----------------------------------------------------------------
339+
# Request Headers Processing
340+
# ----------------------------------------------------------------
341+
if request.HasField("request_headers"):
342+
_headers = request.request_headers.headers
343+
yield ep.ProcessingResponse(
344+
request_headers=ep.HeadersResponse(
345+
response=ep.CommonResponse(
346+
header_mutation=ep.HeaderMutation(
347+
set_headers=[
348+
core.HeaderValueOption(
349+
header=core.HeaderValue(
350+
key="x-ext-proc-header",
351+
raw_value="hello-from-ext-proc".encode("utf-8"),
352+
),
353+
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
354+
)
355+
]
356+
)
352357
)
353358
)
354359
)
355-
)
356-
# ----------------------------------------------------------------
357-
# Response Headers Processing
358-
# ----------------------------------------------------------------
359-
elif request.HasField("response_headers"):
360-
_headers = request.response_headers.headers
361-
yield ep.ProcessingResponse(
362-
response_headers=ep.HeadersResponse(
363-
response=ep.CommonResponse(
364-
header_mutation=ep.HeaderMutation(
365-
set_headers=[
366-
core.HeaderValueOption(
367-
header=core.HeaderValue(
368-
key="x-ext-proc-response-header",
369-
raw_value="processed-by-ext-proc".encode("utf-8"),
370-
),
371-
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
372-
)
373-
]
360+
# ----------------------------------------------------------------
361+
# Response Headers Processing
362+
# ----------------------------------------------------------------
363+
elif request.HasField("response_headers"):
364+
_headers = request.response_headers.headers
365+
yield ep.ProcessingResponse(
366+
response_headers=ep.HeadersResponse(
367+
response=ep.CommonResponse(
368+
header_mutation=ep.HeaderMutation(
369+
set_headers=[
370+
core.HeaderValueOption(
371+
header=core.HeaderValue(
372+
key="x-ext-proc-response-header",
373+
raw_value="processed-by-ext-proc".encode("utf-8"),
374+
),
375+
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
376+
)
377+
]
378+
)
374379
)
375380
)
376381
)
377-
)
378-
379-
# ----------------------------------------------------------------
380-
# Request Body Processing (MCP Tool/Prompt Invocations)
381-
# ----------------------------------------------------------------
382-
elif request.HasField("request_body") and request.request_body.body:
383-
chunk = request.request_body.body
384-
req_body_buf.extend(chunk)
385-
386-
if getattr(request.request_body, "end_of_stream", False):
387-
try:
388-
text = req_body_buf.decode("utf-8")
389-
except UnicodeDecodeError:
390-
logger.debug("Request body not UTF-8; skipping")
391-
else:
392-
logger.info(json.loads(text))
393-
body = json.loads(text)
394-
if "method" in body and body["method"] == "tools/call":
395-
body_resp = await getToolPreInvokeResponse(body)
396-
elif "method" in body and body["method"] == "prompts/get":
397-
body_resp = await getPromptPreFetchResponse(body)
382+
383+
# ----------------------------------------------------------------
384+
# Request Body Processing (MCP Tool/Prompt Invocations)
385+
# ----------------------------------------------------------------
386+
elif request.HasField("request_body") and request.request_body.body:
387+
chunk = request.request_body.body
388+
req_body_buf.extend(chunk)
389+
390+
if getattr(request.request_body, "end_of_stream", False):
391+
try:
392+
text = req_body_buf.decode("utf-8")
393+
except UnicodeDecodeError:
394+
logger.debug("Request body not UTF-8; skipping")
398395
else:
399-
body_resp = ep.ProcessingResponse(
400-
request_body=ep.BodyResponse(response=ep.CommonResponse())
401-
)
396+
logger.info(json.loads(text))
397+
body = json.loads(text)
398+
if "method" in body and body["method"] == "tools/call":
399+
body_resp = await getToolPreInvokeResponse(body)
400+
elif "method" in body and body["method"] == "prompts/get":
401+
body_resp = await getPromptPreFetchResponse(body)
402+
else:
403+
body_resp = ep.ProcessingResponse(
404+
request_body=ep.BodyResponse(response=ep.CommonResponse())
405+
)
406+
yield body_resp
407+
408+
req_body_buf.clear()
409+
410+
# ----------------------------------------------------------------
411+
# Response Body Processing (MCP Tool Results)
412+
# ----------------------------------------------------------------
413+
elif request.HasField("response_body"):
414+
logger.debug(f"Processing response body: {request}")
415+
416+
# Buffer content if present in this chunk
417+
if request.response_body.body:
418+
chunk = request.response_body.body
419+
resp_body_buf.extend(chunk)
420+
logger.debug(f"Buffered chunk ({len(chunk)} bytes)")
421+
422+
# Check for end of stream (regardless of whether this chunk has content)
423+
if getattr(request.response_body, "end_of_stream", False):
424+
logger.debug("End of stream reached, processing complete buffered response")
425+
426+
# Process the buffered content
427+
body_resp = await process_response_body_buffer(resp_body_buf)
402428
yield body_resp
429+
resp_body_buf.clear()
430+
else:
431+
# Intermediate chunk - acknowledge but don't process yet
432+
logger.debug("Buffering intermediate chunk, waiting for end_of_stream")
433+
yield ep.ProcessingResponse(response_body=ep.BodyResponse(response=ep.CommonResponse()))
403434

404-
req_body_buf.clear()
405-
406-
# ----------------------------------------------------------------
407-
# Response Body Processing (MCP Tool Results)
408-
# ----------------------------------------------------------------
409-
elif request.HasField("response_body"):
410-
logger.debug(f"Processing response body: {request}")
411-
412-
# Buffer content if present in this chunk
413-
if request.response_body.body:
414-
chunk = request.response_body.body
415-
resp_body_buf.extend(chunk)
416-
logger.debug(f"Buffered chunk ({len(chunk)} bytes)")
417-
418-
# Check for end of stream (regardless of whether this chunk has content)
419-
if getattr(request.response_body, "end_of_stream", False):
420-
logger.debug("End of stream reached, processing complete buffered response")
421-
422-
# Process the buffered content
423-
body_resp = await process_response_body_buffer(resp_body_buf)
424-
yield body_resp
425-
resp_body_buf.clear()
426435
else:
427-
# Intermediate chunk - acknowledge but don't process yet
428-
logger.debug("Buffering intermediate chunk, waiting for end_of_stream")
429-
yield ep.ProcessingResponse(response_body=ep.BodyResponse(response=ep.CommonResponse()))
430-
431-
else:
432-
# Unhandled request types
433-
logger.warning("Not processed")
434-
logger.warning(request)
436+
# Unhandled request types
437+
logger.warning("Not processed")
438+
logger.warning(request)
439+
except asyncio.CancelledError:
440+
logger.info("Process stream cancelled (client disconnect or pod rollover)")
435441

436442

437443
# ============================================================================
@@ -452,13 +458,31 @@ async def serve(host: str = "0.0.0.0", port: int = 50052):
452458
logger.debug(f"Loaded {manager.plugin_count} plugins")
453459

454460
server = grpc.aio.server()
455-
# server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
456461
ep_grpc.add_ExternalProcessorServicer_to_server(ExtProcServicer(), server)
462+
463+
# Register gRPC health check service for Kubernetes readiness/liveness probes
464+
health_servicer = grpc_health.HealthServicer()
465+
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
466+
457467
listen_addr = f"{host}:{port}"
458468
server.add_insecure_port(listen_addr)
459-
logger.info("Starting ext_proc MY server on %s", listen_addr)
469+
logger.info("Starting ext_proc server on %s", listen_addr)
460470
await server.start()
461-
# wait forever
471+
472+
# Mark server as healthy after startup
473+
health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)
474+
475+
# Install SIGTERM handler for graceful drain on pod rollover
476+
loop = asyncio.get_running_loop()
477+
478+
async def _shutdown():
479+
logger.info("SIGTERM received — draining in-flight streams (grace=15s)")
480+
health_servicer.set("", health_pb2.HealthCheckResponse.NOT_SERVING)
481+
await server.stop(grace=15)
482+
483+
loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.ensure_future(_shutdown()))
484+
logger.info("SIGTERM handler registered; waiting for termination")
485+
462486
await server.wait_for_termination()
463487

464488

tests/integration/test_ext_proc_e2e.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
and exercise the full request/response flow.
55
"""
66

7+
import asyncio
78
import json
89

910
import pytest
@@ -215,3 +216,38 @@ async def test_response_body_tool_result_blocked(grpc_stub):
215216
assert "Blocked by test" in error_body["error"]["message"]
216217
finally:
217218
PassthroughPlugin.reset()
219+
220+
221+
# ---------------------------------------------------------------------------
222+
# Stream Cancellation (simulates pod rollover)
223+
# ---------------------------------------------------------------------------
224+
225+
226+
@pytest.mark.asyncio
227+
async def test_stream_cancel_does_not_crash_server(grpc_stub):
228+
"""Cancelling a bidi stream mid-flight should not crash the server.
229+
230+
After cancellation, a subsequent request should still succeed,
231+
confirming the server is still healthy.
232+
"""
233+
# Open a stream and cancel it without finishing
234+
call = grpc_stub.Process()
235+
request = ep.ProcessingRequest(
236+
request_headers=ep.HttpHeaders(
237+
headers=core.HeaderMap(headers=[]),
238+
)
239+
)
240+
await call.write(request)
241+
call.cancel()
242+
243+
# Small delay for the server to process the cancellation
244+
await asyncio.sleep(0.1)
245+
246+
# Verify the server is still operational with a normal request
247+
follow_up = ep.ProcessingRequest(
248+
request_headers=ep.HttpHeaders(
249+
headers=core.HeaderMap(headers=[]),
250+
)
251+
)
252+
response = await send_one(grpc_stub, follow_up)
253+
assert response.HasField("request_headers")

0 commit comments

Comments
 (0)