Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/workflows/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from collections.abc import Callable
from typing import Any

from opentelemetry import trace

from workflows.recipe.recipe import Recipe
from workflows.recipe.validate import validate_recipe
from workflows.recipe.wrapper import RecipeWrapper
Expand Down Expand Up @@ -69,6 +71,68 @@ def unwrap_recipe(header, message):
message = mangle_for_receiving(message)
if header.get("workflows-recipe") in {True, "True", "true", 1}:
rw = RecipeWrapper(message=message, transport=transport_layer)
logger.debug("RecipeWrapper created: %s", rw)

# Extract and set DCID and recipe_id on the current span
span = trace.get_current_span()
dcid = None
recipe_id = None

# Extract recipe ID from environment
if isinstance(message, dict):
environment = message.get("environment", {})
if isinstance(environment, dict):
recipe_id = environment.get("ID")

# Try multiple locations where DCID might be stored
top_level_params = {}
if isinstance(message, dict):
# Direct parameters (top-level or in recipe)
top_level_params = message.get("parameters", {})

# Payload parameters (most common location)
payload = message.get("payload", {})
payload_params = {}
if isinstance(payload, dict):
payload_params = payload.get("parameters", {})

# Try all common locations
dcid = (
top_level_params.get("ispyb_dcid")
or top_level_params.get("dcid")
or payload_params.get("ispyb_dcid")
or payload_params.get("dcid")
or payload.get("ispyb_dcid")
or payload.get("dcid")
)

if dcid:
span.set_attribute("dcid", dcid)
span.add_event("recipe.dcid_extracted", attributes={"dcid": dcid})

if recipe_id:
span.set_attribute("recipe_id", recipe_id)
span.add_event(
"recipe.id_extracted", attributes={"recipe_id": recipe_id}
)

# Extract span_id and trace_id for logging
span_context = span.get_span_context()
if span_context and span_context.is_valid:
span_id = format(span_context.span_id, "016x")
trace_id = format(span_context.trace_id, "032x")

log_extra = {
"span_id": span_id,
"trace_id": trace_id,
}
if dcid:
log_extra["dcid"] = dcid
if recipe_id:
log_extra["recipe_id"] = recipe_id

logger.info("Processing recipe message", extra=log_extra)

if log_extender and rw.environment and rw.environment.get("ID"):
with log_extender("recipe_ID", rw.environment["ID"]):
return callback(rw, header, message.get("payload"))
Expand Down
38 changes: 38 additions & 0 deletions src/workflows/services/common_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
import time
from typing import Any

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor

import workflows
import workflows.logging
from workflows.transport.middleware.otel_tracing import OTELTracingMiddleware


class Status(enum.Enum):
Expand Down Expand Up @@ -185,6 +192,37 @@ def start_transport(self):
self.transport.subscription_callback_set_intercept(
self._transport_interceptor
)

# Configure OTELTracing
resource = Resource.create(
{
SERVICE_NAME: self._service_name,
}
)

self.log.debug("Configuring OTELTracing")
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)

# Configure BatchProcessor and OTLPSpanExporter to point to OTELCollector
otlp_exporter = OTLPSpanExporter(
endpoint="https://otel.tracing.diamond.ac.uk:4318/v1/traces", timeout=10
)
span_processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(span_processor)

# Add OTELTracingMiddleware to the transport layer
tracer = trace.get_tracer(__name__)
otel_middleware = OTELTracingMiddleware(
tracer, service_name=self._service_name
)
self._transport.add_middleware(otel_middleware)

self.log.debug(
"OTELTracingMiddleware added to transport layer of %s",
self._service_name,
)

metrics = self._environment.get("metrics")
if metrics:
import prometheus_client
Expand Down
39 changes: 39 additions & 0 deletions src/workflows/transport/middleware/otel_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import functools
from collections.abc import Callable

from opentelemetry import trace
from opentelemetry.propagate import extract

from workflows.transport.middleware import BaseTransportMiddleware


class OTELTracingMiddleware(BaseTransportMiddleware):
def __init__(self, tracer: trace.Tracer, service_name: str):
"""
Initialize the OpenTelemetry Tracing Middleware.

:param tracer: An OpenTelemetry tracer instance used to create spans.
"""
self.tracer = tracer
self.service_name = service_name

def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int:
@functools.wraps(callback)
def wrapped_callback(header, message):
# Extract trace context from message headers
ctx = extract(header) if header else None

# Start a new span with the extracted context
with self.tracer.start_as_current_span(
"transport.subscribe", context=ctx
) as span:
span.set_attribute("service_name", self.service_name)
span.set_attribute("channel", channel)

# Call the original callback
return callback(header, message)

# Call the next middleware with the wrapped callback
return call_next(channel, wrapped_callback, **kwargs)
Loading