Skip to content

Commit 95d0df9

Browse files
authored
Patching (#524)
Patching is a mechanism for safely upgrading workflow code. It is an alternative to [workflow versioning](https://docs.dbos.dev/python/tutorials/workflow-tutorial#workflow-versioning-and-recovery) (though they can be used together). The problem patching solves is "How do I make a breaking change to a workflow's code but continue execution of long-running workflows that started on the old code version?" A breaking change is any change in what steps run or the order in which they run. To use patching, first enable it in configuration: ```python config: DBOSConfig = { "name": "dbos-starter", "system_database_url": os.environ.get("DBOS_SYSTEM_DATABASE_URL"), "enable_patching": True, } DBOS(config=config) ``` Next, when making a breaking change, use an `if DBOS.patch():` conditional. `DBOS.patch()` returns `True` for new workflows (those started after the breaking change) and `False` for old workflows (those started before the breaking change). Therefore, if `DBOS.patch()` is true, call the new code, else call the old code. So let's say our workflow is: ```python @DBOS.workflow() def workflow(): foo() bar() ``` We want to replace the call to `foo()` with a call to `baz()`, which is a breaking change. We can do this safely using a patch: ```python @DBOS.workflow() def workflow(): if DBOS.patch("use-baz"): baz() else: foo() bar() ``` Now, new workflows will run `baz()`, while old workflows will safely recover through `foo()`. Once all workflows of the pre-patch code version are complete, we can remove patches from our code. First, we deprecate the patch. This will safely run workflows containing the patch marker, but will not insert the patch marker into new workflows: ```python @DBOS.workflow() def workflow(): DBOS.deprecate_patch("use-baz") baz() bar() ``` Then, when all workflows containing the patch marker are complete, we can remove the patch entirely and complete the workflow upgrade! ```python @DBOS.workflow() def workflow(): baz() bar() ``` If any mistakes happen during the process (a breaking change is not patched, or a patch is deprecated or removed prematurely), the workflow will throw a clean `DBOSUnexpectedStepError` pointing to the step where the problem occurred. Also, one advanced feature is that if you need to make consecutive breaking changes to the same code, you can stack patches: ```python @DBOS.workflow() def workflow(): if DBOS.patch("use-qux"): qux() elif DBOS.patch("use-baz"): baz() else: foo() bar() ```
1 parent 85b8073 commit 95d0df9

File tree

6 files changed

+414
-5
lines changed

6 files changed

+414
-5
lines changed

dbos/_dbos.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def __init__(
341341
self.conductor_key: Optional[str] = conductor_key
342342
if config.get("conductor_key"):
343343
self.conductor_key = config.get("conductor_key")
344+
self.enable_patching = config.get("enable_patching") == True
344345
self.conductor_websocket: Optional[ConductorWebsocket] = None
345346
self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop()
346347
self._active_workflows_set: set[str] = set()
@@ -350,6 +351,8 @@ def __init__(
350351
# Globally set the application version and executor ID.
351352
# In DBOS Cloud, instead use the values supplied through environment variables.
352353
if not os.environ.get("DBOS__CLOUD") == "true":
354+
if self.enable_patching:
355+
GlobalParams.app_version = "PATCHING_ENABLED"
353356
if (
354357
"application_version" in config
355358
and config["application_version"] is not None
@@ -1524,6 +1527,50 @@ async def read_stream_async(
15241527
await asyncio.sleep(1.0)
15251528
continue
15261529

1530+
@classmethod
1531+
def patch(cls, patch_name: str) -> bool:
1532+
if not _get_dbos_instance().enable_patching:
1533+
raise DBOSException("enable_patching must be True in DBOS configuration")
1534+
ctx = get_local_dbos_context()
1535+
if ctx is None or not ctx.is_workflow():
1536+
raise DBOSException("DBOS.patch must be called from a workflow")
1537+
workflow_id = ctx.workflow_id
1538+
function_id = ctx.function_id
1539+
patch_name = f"DBOS.patch-{patch_name}"
1540+
patched = _get_dbos_instance()._sys_db.patch(
1541+
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
1542+
)
1543+
# If the patch was applied, increment function ID
1544+
if patched:
1545+
ctx.function_id += 1
1546+
return patched
1547+
1548+
@classmethod
1549+
def patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
1550+
return asyncio.to_thread(cls.patch, patch_name)
1551+
1552+
@classmethod
1553+
def deprecate_patch(cls, patch_name: str) -> bool:
1554+
if not _get_dbos_instance().enable_patching:
1555+
raise DBOSException("enable_patching must be True in DBOS configuration")
1556+
ctx = get_local_dbos_context()
1557+
if ctx is None or not ctx.is_workflow():
1558+
raise DBOSException("DBOS.deprecate_patch must be called from a workflow")
1559+
workflow_id = ctx.workflow_id
1560+
function_id = ctx.function_id
1561+
patch_name = f"DBOS.patch-{patch_name}"
1562+
patch_exists = _get_dbos_instance()._sys_db.deprecate_patch(
1563+
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
1564+
)
1565+
# If the patch is already in history, increment function ID
1566+
if patch_exists:
1567+
ctx.function_id += 1
1568+
return True
1569+
1570+
@classmethod
1571+
def deprecate_patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
1572+
return asyncio.to_thread(cls.deprecate_patch, patch_name)
1573+
15271574
@classproperty
15281575
def tracer(self) -> DBOSTracer:
15291576
"""Return the DBOS OpenTelemetry tracer."""

dbos/_dbos_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class DBOSConfig(TypedDict, total=False):
6363
conductor_key: Optional[str]
6464
conductor_url: Optional[str]
6565
serializer: Optional[Serializer]
66+
enable_patching: Optional[bool]
6667

6768

6869
class RuntimeConfig(TypedDict, total=False):

dbos/_error.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(self, msg: str):
143143
self.status_code = 403
144144

145145
def __reduce__(self) -> Any:
146-
# Tell jsonpickle how to reconstruct this object
146+
# Tell pickle how to reconstruct this object
147147
return (self.__class__, (self.msg,))
148148

149149

@@ -162,7 +162,7 @@ def __init__(
162162
)
163163

164164
def __reduce__(self) -> Any:
165-
# Tell jsonpickle how to reconstruct this object
165+
# Tell pickle how to reconstruct this object
166166
return (self.__class__, (self.step_name, self.max_retries, self.errors))
167167

168168

@@ -182,11 +182,19 @@ class DBOSUnexpectedStepError(DBOSException):
182182
def __init__(
183183
self, workflow_id: str, step_id: int, expected_name: str, recorded_name: str
184184
) -> None:
185+
self.inputs = (workflow_id, step_id, expected_name, recorded_name)
185186
super().__init__(
186187
f"During execution of workflow {workflow_id} step {step_id}, function {recorded_name} was recorded when {expected_name} was expected. Check that your workflow is deterministic.",
187188
dbos_error_code=DBOSErrorCode.UnexpectedStep.value,
188189
)
189190

191+
def __reduce__(self) -> Any:
192+
# Tell pickle how to reconstruct this object
193+
return (
194+
self.__class__,
195+
self.inputs,
196+
)
197+
190198

191199
class DBOSQueueDeduplicatedError(DBOSException):
192200
"""Exception raised when a workflow is deduplicated in the queue."""
@@ -203,7 +211,7 @@ def __init__(
203211
)
204212

205213
def __reduce__(self) -> Any:
206-
# Tell jsonpickle how to reconstruct this object
214+
# Tell pickle how to reconstruct this object
207215
return (
208216
self.__class__,
209217
(self.workflow_id, self.queue_name, self.deduplication_id),
@@ -219,7 +227,7 @@ def __init__(self, workflow_id: str):
219227
)
220228

221229
def __reduce__(self) -> Any:
222-
# Tell jsonpickle how to reconstruct this object
230+
# Tell pickle how to reconstruct this object
223231
return (self.__class__, (self.workflow_id,))
224232

225233

dbos/_sys_db.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,3 +2361,43 @@ def get_metrics(self, start_time: str, end_time: str) -> List[MetricData]:
23612361
)
23622362

23632363
return metrics
2364+
2365+
@db_retry()
2366+
def patch(self, *, workflow_id: str, function_id: int, patch_name: str) -> bool:
2367+
"""If there is no checkpoint for this point in history,
2368+
insert a patch marker and return True.
2369+
Otherwise, return whether the checkpoint is this patch marker."""
2370+
with self.engine.begin() as c:
2371+
checkpoint_name: str | None = c.execute(
2372+
sa.select(SystemSchema.operation_outputs.c.function_name).where(
2373+
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
2374+
& (SystemSchema.operation_outputs.c.function_id == function_id)
2375+
)
2376+
).scalar()
2377+
if checkpoint_name is None:
2378+
result: OperationResultInternal = {
2379+
"workflow_uuid": workflow_id,
2380+
"function_id": function_id,
2381+
"function_name": patch_name,
2382+
"output": None,
2383+
"error": None,
2384+
"started_at_epoch_ms": int(time.time() * 1000),
2385+
}
2386+
self._record_operation_result_txn(result, c)
2387+
return True
2388+
else:
2389+
return checkpoint_name == patch_name
2390+
2391+
@db_retry()
2392+
def deprecate_patch(
2393+
self, *, workflow_id: str, function_id: int, patch_name: str
2394+
) -> bool:
2395+
"""Respect patch markers in history, but do not introduce new patch markers"""
2396+
with self.engine.begin() as c:
2397+
checkpoint_name: str | None = c.execute(
2398+
sa.select(SystemSchema.operation_outputs.c.function_name).where(
2399+
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
2400+
& (SystemSchema.operation_outputs.c.function_id == function_id)
2401+
)
2402+
).scalar()
2403+
return checkpoint_name == patch_name

tests/test_dbos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2078,7 +2078,7 @@ class JsonSerializer(Serializer):
20782078
def serialize(self, data: Any) -> str:
20792079
return json.dumps(data)
20802080

2081-
def deserialize(cls, serialized_data: str) -> Any:
2081+
def deserialize(self, serialized_data: str) -> Any:
20822082
return json.loads(serialized_data)
20832083

20842084
# Configure DBOS with a JSON-based custom serializer

0 commit comments

Comments
 (0)