Skip to content

Commit 2ae492f

Browse files
committed
Type-hinting improvements
1 parent 1939eea commit 2ae492f

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

durabletask/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,8 @@ def task_id(self) -> int:
538538
return self._task_id
539539

540540

541-
# Orchestrators are generators that yield tasks and receive/return any type
542-
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]
541+
# Orchestrators are generators that yield tasks, recieve any type, and return TOutput
542+
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task[Any], Any, TOutput], TOutput]]
543543

544544
# Activities are simple functions that can be scheduled by orchestrators
545545
Activity = Callable[[ActivityContext, TInput], TOutput]

durabletask/worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,42 +150,42 @@ def __init__(self):
150150
self.entities = {}
151151
self.entity_instances = {}
152152

153-
def add_orchestrator(self, fn: task.Orchestrator) -> str:
153+
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
154154
if fn is None:
155155
raise ValueError("An orchestrator function argument is required.")
156156

157157
name = task.get_name(fn)
158158
self.add_named_orchestrator(name, fn)
159159
return name
160160

161-
def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
161+
def add_named_orchestrator(self, name: str, fn: task.Orchestrator[TInput, TOutput]) -> None:
162162
if not name:
163163
raise ValueError("A non-empty orchestrator name is required.")
164164
if name in self.orchestrators:
165165
raise ValueError(f"A '{name}' orchestrator already exists.")
166166

167167
self.orchestrators[name] = fn
168168

169-
def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
169+
def get_orchestrator(self, name: str) -> Optional[task.Orchestrator[Any, Any]]:
170170
return self.orchestrators.get(name)
171171

172-
def add_activity(self, fn: task.Activity) -> str:
172+
def add_activity(self, fn: task.Activity[TInput, TOutput]) -> str:
173173
if fn is None:
174174
raise ValueError("An activity function argument is required.")
175175

176176
name = task.get_name(fn)
177177
self.add_named_activity(name, fn)
178178
return name
179179

180-
def add_named_activity(self, name: str, fn: task.Activity) -> None:
180+
def add_named_activity(self, name: str, fn: task.Activity[TInput, TOutput]) -> None:
181181
if not name:
182182
raise ValueError("A non-empty activity name is required.")
183183
if name in self.activities:
184184
raise ValueError(f"A '{name}' activity already exists.")
185185

186186
self.activities[name] = fn
187187

188-
def get_activity(self, name: str) -> Optional[task.Activity]:
188+
def get_activity(self, name: str) -> Optional[task.Activity[Any, Any]]:
189189
return self.activities.get(name)
190190

191191
def add_entity(self, fn: task.Entity) -> str:
@@ -362,7 +362,7 @@ def __enter__(self):
362362
def __exit__(self, type, value, traceback):
363363
self.stop()
364364

365-
def add_orchestrator(self, fn: task.Orchestrator) -> str:
365+
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
366366
"""Registers an orchestrator function with the worker."""
367367
if self._is_running:
368368
raise RuntimeError(

0 commit comments

Comments
 (0)