|
1 | 1 | import dataclasses |
2 | 2 | import graphlib |
3 | 3 | import pathlib |
| 4 | +import threading |
| 5 | +import time |
4 | 6 | import typing |
5 | 7 |
|
6 | 8 | import pytest |
@@ -391,3 +393,78 @@ def test_e2e_parallel_graph( |
391 | 393 | "pyyaml==6.0.2", |
392 | 394 | }, |
393 | 395 | ] |
| 396 | + |
| 397 | + |
| 398 | +def test_tracking_topology_sorter_concurrent_access() -> None: |
| 399 | + """Test thread safety with concurrent get_available() and done() calls. |
| 400 | + EXPECTED: Should work correctly with multiple threads |
| 401 | + """ |
| 402 | + nodes = [mknode(f"node_{i}") for i in range(20)] |
| 403 | + |
| 404 | + graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]] |
| 405 | + graph_dict = {} |
| 406 | + for i in range(1, 20): |
| 407 | + graph_dict[nodes[i]] = [nodes[i - 1]] |
| 408 | + graph_dict[nodes[0]] = [] |
| 409 | + graph = graph_dict |
| 410 | + |
| 411 | + topo = TrackingTopologicalSorter(graph) |
| 412 | + topo.prepare() |
| 413 | + |
| 414 | + errors: list[Exception] = [] |
| 415 | + processed: list[DependencyNode] = [] |
| 416 | + process_lock = threading.Lock() |
| 417 | + |
| 418 | + def worker() -> None: |
| 419 | + try: |
| 420 | + while True: |
| 421 | + if not topo.is_active(): |
| 422 | + break |
| 423 | + |
| 424 | + try: |
| 425 | + available = topo.get_available() |
| 426 | + except ValueError as e: |
| 427 | + if "topology is not active" in str(e): |
| 428 | + break |
| 429 | + raise |
| 430 | + |
| 431 | + if not available: |
| 432 | + time.sleep(0.0001) |
| 433 | + continue |
| 434 | + |
| 435 | + node = sorted(available)[0] |
| 436 | + time.sleep(0.0001) |
| 437 | + |
| 438 | + with process_lock: |
| 439 | + if node not in processed: |
| 440 | + processed.append(node) |
| 441 | + topo.done(node) |
| 442 | + |
| 443 | + except Exception as e: |
| 444 | + errors.append(e) |
| 445 | + |
| 446 | + threads = [threading.Thread(target=worker) for _ in range(4)] |
| 447 | + |
| 448 | + for t in threads: |
| 449 | + t.start() |
| 450 | + |
| 451 | + for t in threads: |
| 452 | + t.join(timeout=5.0) |
| 453 | + if t.is_alive(): |
| 454 | + errors.append(TimeoutError("Thread did not complete in time")) |
| 455 | + |
| 456 | + assert not errors, f"Thread safety violated with {len(errors)} errors: {errors}" |
| 457 | + assert len(processed) == 20, f"Expected 20 nodes processed, got {len(processed)}" |
| 458 | + assert not topo.is_active() |
| 459 | + |
| 460 | + |
| 461 | +def test_tracking_topology_sorter_empty_graph() -> None: |
| 462 | + """Test with empty graph.""" |
| 463 | + topo = TrackingTopologicalSorter() |
| 464 | + topo.prepare() |
| 465 | + |
| 466 | + assert not topo.is_active() |
| 467 | + |
| 468 | + with pytest.raises(ValueError) as excinfo: |
| 469 | + topo.get_available() |
| 470 | + assert "topology is not active" in str(excinfo.value) |
0 commit comments