Skip to content

Commit 8ac1876

Browse files
committed
Finish locking, add operationactions
1 parent 4107182 commit 8ac1876

File tree

8 files changed

+120
-51
lines changed

8 files changed

+120
-51
lines changed

durabletask/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True):
229229
self._logger.info(f"Purging instance '{instance_id}'.")
230230
self._stub.PurgeInstances(req)
231231

232-
def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None, signal_entity_options=None, cancellation=None):
233-
scheduled_time = signal_entity_options.scheduled_time if signal_entity_options and signal_entity_options.scheduled_time else None
232+
def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None):
234233
req = pb.SignalEntityRequest(
235234
instanceId=str(entity_instance_id),
236-
requestId=str(uuid.uuid4()),
237235
name=operation_name,
238236
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
239-
scheduledTime=scheduled_time,
237+
requestId=str(uuid.uuid4()),
238+
scheduledTime=None,
239+
parentTraceContext=None,
240240
requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
241241
)
242242
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
243-
self._stub.SignalEntity(req, timeout=cancellation.timeout if cancellation else None)
243+
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Optional, Type, TypeVar, overload
22

3+
from durabletask.entities.entity_instance_id import EntityInstanceId
4+
35
TState = TypeVar("TState")
46

57

@@ -9,12 +11,18 @@ def _initialize_entity_context(self, context):
911

1012
@overload
1113
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
12-
14+
1315
@overload
1416
def get_state(self, intended_type: None = None) -> Any: ...
15-
17+
1618
def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TState] | Any:
1719
return self.entity_context.get_state(intended_type)
1820

1921
def set_state(self, state: Any):
20-
self.entity_context.set_state(state)
22+
self.entity_context.set_state(state)
23+
24+
def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None:
25+
self.entity_context.signal_entity(entity_instance_id, operation, input)
26+
27+
def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None:
28+
self.entity_context.schedule_new_orchestration(orchestration_name, input, instance_id=instance_id)
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import durabletask.internal.helpers as ph
2-
3-
from durabletask.entities.entity_instance_id import EntityInstanceId
41
import durabletask.internal.orchestrator_service_pb2 as pb
52

63

@@ -11,9 +8,9 @@ def __init__(self, context):
118
def __enter__(self):
129
return self
1310

14-
def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions?
11+
def __exit__(self, exc_type, exc_val, exc_tb): # TODO: Handle exceptions?
1512
print(f"Unlocking entities: {self._context._entity_context.critical_section_locks}")
1613
for entity_unlock_message in self._context._entity_context.emit_lock_release_messages():
1714
task_id = self._context.next_sequence_number()
18-
action = pb.OrchestratorAction(task_id, sendEntityMessage=entity_unlock_message)
15+
action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1916
self._context._pending_actions[task_id] = action

durabletask/internal/entity_state_shim.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from ctypes import Union
2-
from typing import Any, TypeVar, runtime_checkable
1+
from typing import Any, TypeVar
32
from typing import Optional, Type, overload
4-
from typing_extensions import Protocol
53

4+
import durabletask.internal.orchestrator_service_pb2 as pb
65

76
TState = TypeVar("TState")
87

@@ -11,10 +10,12 @@ class StateShim:
1110
def __init__(self, start_state):
1211
self._current_state: Any = start_state
1312
self._checkpoint_state: Any = start_state
13+
self._operation_actions: list[pb.OperationAction] = []
14+
self._actions_checkpoint_state: int = 0
1415

1516
@overload
1617
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
17-
18+
1819
@overload
1920
def get_state(self, intended_type: None = None) -> Any: ...
2021

@@ -26,7 +27,7 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS
2627
return self._current_state
2728

2829
try:
29-
return intended_type(self._current_state) # type: ignore[call-arg]
30+
return intended_type(self._current_state) # type: ignore[call-arg]
3031
except Exception as ex:
3132
raise TypeError(
3233
f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'"
@@ -35,11 +36,22 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS
3536
def set_state(self, state):
3637
self._current_state = state
3738

39+
def add_operation_action(self, action: pb.OperationAction):
40+
self._operation_actions.append(action)
41+
42+
def get_operation_actions(self) -> list[pb.OperationAction]:
43+
return self._operation_actions[:self._actions_checkpoint_state]
44+
3845
def commit(self):
3946
self._checkpoint_state = self._current_state
47+
self._actions_checkpoint_state = len(self._operation_actions)
4048

4149
def rollback(self):
4250
self._current_state = self._checkpoint_state
51+
self._operation_actions = self._operation_actions[:self._actions_checkpoint_state]
4352

4453
def reset(self):
4554
self._current_state = None
55+
self._checkpoint_state = None
56+
self._operation_actions = []
57+
self._actions_checkpoint_state = 0

durabletask/internal/helpers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,22 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str],
199199
def new_call_entity_action(id: int, parent_instance_id: str, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]):
200200
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent(
201201
requestId=f"{parent_instance_id}:{id}",
202+
operation=operation,
203+
scheduledTime=None,
204+
input=get_string_value(encoded_input),
202205
parentInstanceId=get_string_value(parent_instance_id),
206+
parentExecutionId=None,
203207
targetInstanceId=get_string_value(str(entity_id)),
204-
input=get_string_value(encoded_input),
205-
operation=operation
206208
)))
207209

208210

209211
def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: Optional[str]):
210212
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent(
211213
requestId=f"{entity_id}:{id}",
212-
targetInstanceId=get_string_value(str(entity_id)),
213214
operation=operation,
214-
input=get_string_value(encoded_input)
215+
scheduledTime=None,
216+
input=get_string_value(encoded_input),
217+
targetInstanceId=get_string_value(str(entity_id)),
215218
)))
216219

217220

durabletask/internal/orchestration_entity_context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def emit_lock_release_messages(self):
6262
for entity_id in self.critical_section_locks:
6363
unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent(
6464
criticalSectionId=self.critical_section_id,
65-
targetInstanceId=get_string_value(str(entity_id))
65+
targetInstanceId=get_string_value(str(entity_id)),
66+
parentInstanceId=get_string_value(self.instance_id)
6667
))
6768
yield unlock_event
6869

@@ -79,7 +80,7 @@ def emit_request_message(self, target, operation_name: str, one_way: bool, opera
7980
def emit_acquire_message(self, critical_section_id: str, entities: List[EntityInstanceId]) -> Union[Tuple[None, None], Tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]]:
8081
if not entities:
8182
return None, None
82-
83+
8384
# Acquire the locks in a globally fixed order to avoid deadlocks
8485
# Also remove duplicates - this can be optimized for perf if necessary
8586
entity_ids = sorted(entities)
@@ -102,8 +103,10 @@ def emit_acquire_message(self, critical_section_id: str, entities: List[EntityIn
102103

103104
return request, target
104105

105-
def complete_acquire(self, result, critical_section_id):
106+
def complete_acquire(self, critical_section_id):
106107
# TODO: HashSet or equivalent
108+
if self.critical_section_id != critical_section_id:
109+
raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')")
107110
self.available_locks = self.critical_section_locks
108111
self.lock_acquisition_pending = False
109112

durabletask/task.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from abc import ABC, abstractmethod
99
from datetime import datetime, timedelta
1010
from typing import Any, Callable, Generator, Generic, Optional, Type, TypeVar, Union, overload
11+
import uuid
1112

1213
from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock
14+
from durabletask.internal import shared
1315
from durabletask.internal.entity_state_shim import StateShim
1416
import durabletask.internal.helpers as pbh
1517
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -141,7 +143,7 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
141143
pass
142144

143145
@abstractmethod
144-
def call_entity(self, entity: EntityInstanceId,
146+
def call_entity(self, entity: EntityInstanceId,
145147
operation: str, *,
146148
input: Optional[TInput] = None):
147149
"""Schedule entity function for execution.
@@ -545,10 +547,10 @@ def operation(self) -> str:
545547
The operation associated with this entity invocation.
546548
"""
547549
return self._operation
548-
550+
549551
@overload
550552
def get_state(self, intended_type: Type[TState]) -> Optional[TState]: ...
551-
553+
552554
@overload
553555
def get_state(self, intended_type: None = None) -> Any: ...
554556

@@ -558,6 +560,37 @@ def get_state(self, intended_type: Optional[Type[TState]] = None) -> Optional[TS
558560
def set_state(self, new_state: Any):
559561
self._state.set_state(new_state)
560562

563+
def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None:
564+
encoded_input = shared.to_json(input) if input is not None else None
565+
self._state.add_operation_action(
566+
pb.OperationAction(
567+
sendSignal=pb.SendSignalAction(
568+
instanceId=str(entity_instance_id),
569+
name=operation,
570+
input=pbh.get_string_value(encoded_input),
571+
scheduledTime=None,
572+
requestTime=None,
573+
parentTraceContext=None,
574+
)
575+
)
576+
)
577+
578+
def schedule_new_orchestration(self, orchestration_name: str, input: Optional[Any] = None, instance_id: Optional[str] = None) -> None:
579+
encoded_input = shared.to_json(input) if input is not None else None
580+
self._state.add_operation_action(
581+
pb.OperationAction(
582+
startNewOrchestration=pb.StartNewOrchestrationAction(
583+
instanceId=instance_id if instance_id else uuid.uuid4().hex, # TODO: Should this be non-none?
584+
name=orchestration_name,
585+
input=pbh.get_string_value(encoded_input),
586+
version=None,
587+
scheduledTime=None,
588+
requestTime=None,
589+
parentTraceContext=None
590+
)
591+
)
592+
)
593+
561594
@property
562595
def entity_id(self) -> EntityInstanceId:
563596
"""Get the ID of the entity instance.

0 commit comments

Comments
 (0)