Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b8d3226
fc
praateekmahajan Nov 25, 2025
36ff69b
Merge branch 'main' of github.com:NVIDIA/NeMo-Curator into praateek/a…
praateekmahajan Nov 25, 2025
f73ca9c
more changes
praateekmahajan Nov 25, 2025
6768fe0
restore extra files
praateekmahajan Nov 25, 2025
f6d0d53
more upstream
praateekmahajan Nov 25, 2025
8a95a94
revert doc
praateekmahajan Nov 25, 2025
efe5f55
add workflow.py
praateekmahajan Nov 25, 2025
8e61481
bug fixes
praateekmahajan Nov 25, 2025
b2d3013
more changes / tests
praateekmahajan Nov 25, 2025
3c998a8
update tests
praateekmahajan Nov 25, 2025
af0787c
Merge branch 'main' into praateek/add-workflow-results
praateekmahajan Nov 26, 2025
3c24f9a
Merge branch 'main' into praateek/add-workflow-results
sarahyurick Dec 1, 2025
ee06c5d
Apply suggestions from code review
sarahyurick Dec 1, 2025
57fd481
pr review
praateekmahajan Dec 19, 2025
b17b484
Merge branch 'main' of github.com:NVIDIA/NeMo-Curator into praateek/a…
praateekmahajan Dec 19, 2025
24daddb
Merge branch 'praateek/add-workflow-results' of github.com:praateekma…
praateekmahajan Dec 19, 2025
7646545
pr review 2
praateekmahajan Dec 19, 2025
e39c01a
..
praateekmahajan Dec 22, 2025
8a140fd
Merge branch 'main' of github.com:NVIDIA/NeMo-Curator into praateek/a…
praateekmahajan Jan 9, 2026
ef63a96
greptile
praateekmahajan Jan 9, 2026
1075bd3
..
praateekmahajan Jan 9, 2026
20e717f
..
praateekmahajan Jan 9, 2026
1b63a97
move id gen writing up near minhash for fuzzy
praateekmahajan Jan 10, 2026
2384628
Merge branch 'main' into praateek/add-workflow-results
ayushdg Jan 12, 2026
0459053
pr comments
praateekmahajan Jan 13, 2026
c94492a
Merge branch 'praateek/add-workflow-results' of github.com:praateekma…
praateekmahajan Jan 13, 2026
90c98cb
Merge branch 'main' into praateek/add-workflow-results
praateekmahajan Jan 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions nemo_curator/pipeline/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

from nemo_curator.tasks import Task


@dataclass
class WorkflowRunResult:
"""Container returned by high-level workflows to expose pipeline outputs.
Attributes:
workflow_name: Human readable workflow identifier (e.g., "fuzzy_dedup").
pipeline_tasks: Mapping of pipeline names to the ``Task`` objects they produced.
metadata: Free-form dictionary for workflow specific timing or counters.
"""

workflow_name: str
pipeline_tasks: dict[str, list[Task]] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)

def add_pipeline_tasks(self, pipeline_name: str, tasks: list[Task] | None) -> None:
"""Record the tasks emitted by a pipeline run (empty list if None)."""
self.pipeline_tasks[pipeline_name] = list(tasks or [])

def extend_metadata(self, updates: dict[str, Any] | None = None) -> None:
"""Update metadata dictionary in-place."""
if updates:
self.metadata.update(updates)

def add_metadata(self, key: str, value: Any) -> None: # noqa: ANN401
"""Add a metadata key-value pair."""
self.metadata[key] = value

def get_metadata(self, key: str) -> Any: # noqa: ANN401
"""Get a metadata value."""
return self.metadata.get(key)


class WorkflowBase(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the need for this class. Is it to add more stuff while expanding in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup basically

@abstractmethod
def run(self, *args, **kwargs) -> WorkflowRunResult: ...
60 changes: 44 additions & 16 deletions nemo_curator/stages/deduplication/exact/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo_curator.backends.experimental.ray_actor_pool import RayActorPoolExecutor
from nemo_curator.backends.utils import merge_executor_configs, warn_on_env_var_override
from nemo_curator.pipeline import Pipeline
from nemo_curator.pipeline.workflow import WorkflowBase, WorkflowRunResult
from nemo_curator.stages.deduplication.exact.identification import ExactDuplicateIdentification
from nemo_curator.stages.deduplication.id_generator import (
create_id_generator_actor,
Expand All @@ -33,7 +34,7 @@
ID_GENERATOR_OUTPUT_FILENAME = "exact_id_generator.json"


class ExactDeduplicationWorkflow:
class ExactDeduplicationWorkflow(WorkflowBase):
"""
A pipeline that performs exact deduplication of a dataset.
It consists of the following stages:
Expand Down Expand Up @@ -165,17 +166,27 @@ def _validate_initial_tasks(self, initial_tasks: list[FileGroupTask] | None) ->
msg = "input_path to the dataset must be provided if initial_tasks are not provided manually."
raise ValueError(msg)

def run(
def run( # noqa: PLR0915
self, initial_tasks: list[FileGroupTask] | None = None, executor: RayActorPoolExecutor | None = None
) -> None:
) -> WorkflowRunResult:
"""Run the deduplication pipeline.

Args:
initial_tasks:
Set of FileGroupTasks generated by a previous stage pointing to the dataset to be deduplicated.
If not provided, the pipeline will generate the input tasks based on the input_dir and input_file_extensions.
executor: RayActorPoolExecutor | None
Executor to use for the pipeline.
If not provided, the default RayActorPoolExecutor will be used.

Returns:
WorkflowRunResult object containing the results and timing information
"""
self._validate_initial_tasks(initial_tasks)
workflow_result = WorkflowRunResult(workflow_name="exact_deduplication")
input_filegroups_time = 0.0
identification_time = 0.0

if executor is None:
executor = RayActorPoolExecutor(config=self.executor_config)
else:
Expand All @@ -185,6 +196,8 @@ def run(
previous_config = executor.config
executor.config = merge_executor_configs(executor.config, self.executor_config)
warn_on_env_var_override(previous_config, executor.config)
total_start_time = time.time()

if self.assign_id:
try:
create_id_generator_actor()
Expand All @@ -196,26 +209,31 @@ def run(
"""
raise RuntimeError(err_msg) from None

id_generator_path = None
try:
start_time = time.time()
if initial_tasks is None:
input_filegroups_pipeline = self._create_input_filegroups()
initial_tasks = input_filegroups_pipeline.run(executor=executor, initial_tasks=initial_tasks)
initial_filegroups_end_time = time.time()
logger.info(
f"Created input tasks from {self.input_path} in {(initial_filegroups_end_time - start_time):.2f} seconds"
)
input_start_time = time.time()
initial_tasks = input_filegroups_pipeline.run(executor=executor, initial_tasks=None)
input_filegroups_time = time.time() - input_start_time
workflow_result.add_metadata("input_filegroups_time", input_filegroups_time)
workflow_result.add_pipeline_tasks("input_filegroups", initial_tasks)
logger.info(f"Created input tasks from {self.input_path} in {input_filegroups_time:.2f} seconds")

initial_tasks = initial_tasks or []
identification_pipeline = self._create_identification_pipeline(num_input_tasks=len(initial_tasks))
identification_start_time = time.time()
removal_id_tasks = identification_pipeline.run(executor=executor, initial_tasks=initial_tasks)
identification_end_time = time.time()
logger.info(
f"Exact duplicate identification pipeline completed in {(identification_end_time - identification_start_time):.2f} seconds"
)
identification_time = identification_end_time - identification_start_time
workflow_result.add_metadata("identification_time", identification_time)
workflow_result.add_pipeline_tasks("identification", removal_id_tasks)
logger.info(f"Exact duplicate identification pipeline completed in {identification_time:.2f} seconds")

num_duplicates = sum(task._metadata.get("num_removal_ids", 0) for task in removal_id_tasks)
if num_duplicates == 0:
num_duplicates_identified = sum(
task._metadata.get("num_removal_ids", 0) for task in removal_id_tasks or []
)
if num_duplicates_identified == 0:
Comment on lines +233 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable num_duplicates_identified is defined inside the try block (lines 233-235) but is referenced outside it in the workflow_summary at line 256. If an exception occurs before this variable is assigned (e.g., during input file group creation or id generator setup), the code will throw a NameError when trying to access num_duplicates_identified in line 256.

Fix: Initialize num_duplicates_identified = 0 before the try block at line 213, similar to how input_filegroups_time and identification_time are initialized at lines 187-188.

Suggested change
num_duplicates_identified = sum(
task._metadata.get("num_removal_ids", 0) for task in removal_id_tasks or []
)
if num_duplicates_identified == 0:
id_generator_path = None
num_duplicates_identified = 0
try:

logger.info("No exact duplicates found in the dataset.")

if self.assign_id:
Expand All @@ -227,8 +245,18 @@ def run(
else None,
)
logger.info(f"Id generator written to {id_generator_path}")
end_time = time.time()
logger.info(f"Exact deduplication pipeline completed in {(end_time - start_time):.2f} seconds")
finally:
if self.assign_id:
kill_id_generator_actor()

total_end_time = time.time()
total_time = total_end_time - total_start_time
workflow_summary = {
"total_time": total_time,
"num_duplicates": num_duplicates_identified,
# paths
"id_generator_path": id_generator_path,
}
workflow_result.extend_metadata(workflow_summary)
Comment on lines +254 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable id_generator_path is referenced in the workflow_summary at line 257, but it's only defined inside the if self.assign_id: block at line 239. If self.assign_id is False, this will raise an UnboundLocalError when the workflow_summary dictionary tries to access the undefined variable.

To fix this, initialize id_generator_path before the try block:

Suggested change
workflow_summary = {
"total_time": total_time,
"num_duplicates": num_duplicates_identified,
# paths
"id_generator_path": id_generator_path,
}
workflow_result.extend_metadata(workflow_summary)
def run( # noqa: PLR0915
self, initial_tasks: list[FileGroupTask] | None = None, executor: RayActorPoolExecutor | None = None
) -> WorkflowRunResult:
"""Run the deduplication pipeline.
...
"""
self._validate_initial_tasks(initial_tasks)
workflow_result = WorkflowRunResult(workflow_name="exact_deduplication")
input_filegroups_time = 0.0
identification_time = 0.0
id_generator_path = None

Then update the workflow_summary to conditionally include it only when it's not None.

logger.info(f"Exact deduplication pipeline completed in {total_time:.2f} seconds")
return workflow_result
80 changes: 52 additions & 28 deletions nemo_curator/stages/deduplication/fuzzy/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nemo_curator.backends.experimental.ray_actor_pool import RayActorPoolExecutor
from nemo_curator.backends.utils import merge_executor_configs, warn_on_env_var_override
from nemo_curator.pipeline import Pipeline
from nemo_curator.pipeline.workflow import WorkflowBase, WorkflowRunResult
from nemo_curator.stages.deduplication.fuzzy.buckets_to_edges import BucketsToEdgesStage
from nemo_curator.stages.deduplication.fuzzy.connected_components import ConnectedComponentsStage
from nemo_curator.stages.deduplication.fuzzy.identify_duplicates import IdentifyDuplicatesStage
Expand All @@ -37,7 +38,7 @@
ID_GENERATOR_OUTPUT_FILENAME = "fuzzy_id_generator.json"


class FuzzyDeduplicationWorkflow:
class FuzzyDeduplicationWorkflow(WorkflowBase):
"""
A pipeline that performs fuzzy deduplication of a dataset.
It consists of the following stages:
Expand Down Expand Up @@ -267,9 +268,9 @@ def _validate_initial_tasks(self, initial_tasks: list[FileGroupTask] | None) ->
msg = "input_path to the dataset must be provided if initial_tasks are not provided manually."
raise ValueError(msg)

def run(
def run( # noqa: PLR0915
self, initial_tasks: list[FileGroupTask] | None = None, executor: RayActorPoolExecutor | None = None
) -> None:
) -> WorkflowRunResult:
"""Run the deduplication pipeline.

Args:
Expand All @@ -280,6 +281,11 @@ def run(

"""
self._validate_initial_tasks(initial_tasks)
workflow_result = WorkflowRunResult(workflow_name="fuzzy_deduplication")
minhash_time = 0.0
lsh_time = 0.0
connected_components_time = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we call this cc_pipeline_time or connected_components_pipeline time? The nit because the pipeline does much more than connected components (buckets to edges, connected components, shuffle on component etc).


if executor is None:
executor = RayActorPoolExecutor(config=self.executor_config)
else:
Expand All @@ -290,6 +296,8 @@ def run(
executor.config = merge_executor_configs(executor.config, self.executor_config)
warn_on_env_var_override(previous_config, executor.config)

total_start_time = time.time()

try:
create_id_generator_actor()
except ValueError:
Expand All @@ -300,50 +308,66 @@ def run(
"""
raise RuntimeError(err_msg) from None

id_generator_path = None
try:
# Step 1: Minhash
minhash_pipeline = self._create_minhash_pipeline(generate_input_filegroups=initial_tasks is None)
start_time = time.time()
minhash_pipeline.run(executor=executor, initial_tasks=initial_tasks)
minhash_start_time = time.time()
minhash_tasks = minhash_pipeline.run(executor=executor, initial_tasks=initial_tasks)
minhash_end_time = time.time()
logger.info(f"Minhash pipeline completed in {(minhash_end_time - start_time):.2f} seconds")
minhash_time = minhash_end_time - minhash_start_time
workflow_result.add_pipeline_tasks("minhash", minhash_tasks)
workflow_result.add_metadata("minhash_time", minhash_time)
logger.info(f"Minhash pipeline completed in {minhash_time:.2f} seconds")
output_fs = get_fs(
self.output_path,
self.write_kwargs.get("storage_options") if self.write_kwargs is not None else None,
)
id_generator_path = output_fs.sep.join([self.output_path, ID_GENERATOR_OUTPUT_FILENAME])
write_id_generator_to_disk(
id_generator_path,
storage_options=self.write_kwargs.get("storage_options") if self.write_kwargs is not None else None,
)
logger.info(f"Id generator written to {id_generator_path}")
workflow_result.add_metadata("id_generator_path", id_generator_path)

# Step 2: LSH
lsh_pipeline = self._create_lsh_pipeline()
lsh_start_time = time.time()
# LSH stage generates it's own input tasks from the minhash directory
lsh_tasks = lsh_pipeline.run(executor=executor, initial_tasks=None)
lsh_end_time = time.time()
logger.info(f"LSH pipeline completed in {(lsh_end_time - lsh_start_time):.2f} seconds")
lsh_time = lsh_end_time - lsh_start_time
workflow_result.add_pipeline_tasks("lsh", lsh_tasks)
workflow_result.add_metadata("lsh_time", lsh_time)
logger.info(f"LSH pipeline completed in {lsh_time:.2f} seconds")

valid_lsh_tasks = [task for task in lsh_tasks if task._metadata.get("num_docs", 0) > 0]
valid_lsh_tasks = [task for task in lsh_tasks or [] if task._metadata.get("num_docs", 0) > 0]
if len(valid_lsh_tasks) == 0:
logger.info("No potential duplicates found in the dataset. Skipping connected components pipeline.")
workflow_result.add_metadata("num_duplicates", 0)
else:
# Step 3: Connected components
connected_components_pipeline = self._create_connected_components_pipeline()
connected_components_start_time = time.time()
connected_components_tasks = connected_components_pipeline.run(
executor=executor, initial_tasks=valid_lsh_tasks
)
connected_components_end_time = time.time()
logger.info(
f"Connected components pipeline completed in {(connected_components_end_time - connected_components_start_time):.2f} seconds"
)
num_removed_documents = sum(
task._metadata.get("num_removal_ids", 0) for task in connected_components_tasks
connected_components_time = connected_components_end_time - connected_components_start_time
workflow_result.add_pipeline_tasks("connected_components", connected_components_tasks)
workflow_result.add_metadata("connected_components_time", connected_components_time)
logger.info(f"Connected components pipeline completed in {connected_components_time:.2f} seconds")
num_duplicates_identified = sum(
task._metadata.get("num_removal_ids", 0) for task in (connected_components_tasks or [])
)
logger.info(f"Number of documents removed: {num_removed_documents}")
output_fs = get_fs(
self.output_path,
self.write_kwargs.get("storage_options") if self.write_kwargs is not None else None,
)
id_generator_path = output_fs.sep.join([self.output_path, ID_GENERATOR_OUTPUT_FILENAME])
write_id_generator_to_disk(
id_generator_path,
storage_options=self.write_kwargs.get("storage_options")
if self.write_kwargs is not None
else None,
)
logger.info(f"Id generator written to {id_generator_path}")
end_time = time.time()
logger.info(f"Fuzzy deduplication pipeline completed in {(end_time - start_time):.2f} seconds")
workflow_result.add_metadata("num_duplicates", num_duplicates_identified)
logger.info(f"Number of documents removed: {num_duplicates_identified}")
finally:
kill_id_generator_actor()

total_end_time = time.time()
total_time = total_end_time - total_start_time
workflow_result.add_metadata("total_time", total_time)
logger.info(f"Fuzzy deduplication pipeline completed in {total_time:.2f} seconds")
return workflow_result
Loading
Loading