diff --git a/nemo/data-flywheel/e2e-llm-evaluation/README.md b/nemo/data-flywheel/e2e-llm-evaluation/README.md new file mode 100644 index 000000000..e31c472db --- /dev/null +++ b/nemo/data-flywheel/e2e-llm-evaluation/README.md @@ -0,0 +1,79 @@ +# Emergency Triage LLM Pipeline + +Synthetic data generation and model evaluation for ESI triage classification — powered by NeMo Microservices Platform. + +## Notebooks + +| Notebook | Description | +|----------------------------------|---------------------------------------------------------------| +| `clinical_triage_pipeline.ipynb` | Combined end-to-end NDD + Evaluator use case in one notebook | + +Run `clinical_triage_pipeline.ipynb` for the full pipeline in one place. It can be run locally with no GPU required. + +## Quick Start + +### 1. Create virtual environment + +```bash +python3 -m venv .venv +source .venv/bin/activate + +# Option A: with VPN / artifactory access +pip install 'nemo-platform[data-designer]==2.0.0.dev1+nightly20260309' \ + --index-url https://urm.nvidia.com/artifactory/api/pypi/nv-shared-pypi/simple + +# Option B: without VPN (e.g. on Brev) — download the wheel first via NGC CLI +pip install 'nemo-platform-python-sdk_v2.0.0.dev1+nightly20260309/nemo_platform-2.0.0.dev1+nightly20260309-py3-none-any.whl[data-designer]' + +# Additional dependencies +pip install -r requirements.txt +python -m ipykernel install --user --name nmp --display-name "NMP (venv)" +``` + +### 2. Start NMP + +```bash +nmp quickstart configure # NGC key → NVIDIA Build inference → save +nmp quickstart up --image nvcr.io/nvidian/nemo-llm/nmp-api:nightly-20260309 +nmp quickstart status # Wait for health: ready +``` + +> **Important: Two different API keys are needed.** +> - `nmp quickstart configure` requires the **NGC key** from the `nvidian/nemo-llm` org (https://org.ngc.nvidia.com/setup/api-key). This authenticates Docker image pulls. +> - Step 3 below requires a **build.nvidia.com key** (https://build.nvidia.com). This authenticates LLM inference calls. +> +> If `configure` fails with `NGC login failed: unauthorized`, the NGC key is wrong — make sure you're in the `nvidian/nemo-llm` org when generating it. + +> **Troubleshooting:** If batch jobs (`data_designer.create()`) get stuck at `created`, verify the jobs controller started by checking `nmp quickstart logs` for `Backend registry initialized with: docker`. If missing, reinstall with the exact nightly version: `pip install 'nemo-platform[data-designer]==2.0.0.dev1+nightly20260309'` and do a clean `nmp quickstart destroy` + `configure` + `up`. + +### 3. Register build.nvidia.com provider + +Get your API key at [build.nvidia.com](https://build.nvidia.com) (separate from the NGC key used in step 2). + +```bash +curl -s -X POST http://localhost:8080/apis/secrets/v2/workspaces/default/secrets \ + -H "Content-Type: application/json" \ + -d '{"name": "nvidia-build-api-key", "data": ""}' + +curl -s -X POST http://localhost:8080/apis/models/v2/workspaces/default/providers \ + -H "Content-Type: application/json" \ + -d '{"name": "nvidiabuild", "host_url": "https://integrate.api.nvidia.com/v1", "api_key_secret_name": "nvidia-build-api-key"}' +``` + +### 4. Run + +Open the notebook in Jupyter or Cursor. Select the `NMP (venv)` kernel, then run `clinical_triage_pipeline.ipynb` for the full end-to-end flow. + +## Pipeline + +Pipeline Diagram + +## NMP Commands + +```bash +nmp quickstart status # Check cluster health +nmp quickstart logs # View logs +nmp quickstart down # Stop cluster +nmp quickstart destroy # Stop and remove all data +nmp quickstart doctor # Diagnose issues +``` diff --git a/nemo/data-flywheel/e2e-llm-evaluation/clinical_triage_pipeline.ipynb b/nemo/data-flywheel/e2e-llm-evaluation/clinical_triage_pipeline.ipynb index b50343b3e..ec2ef3ce7 100644 --- a/nemo/data-flywheel/e2e-llm-evaluation/clinical_triage_pipeline.ipynb +++ b/nemo/data-flywheel/e2e-llm-evaluation/clinical_triage_pipeline.ipynb @@ -1,787 +1,721 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Emergency Triage LLM Evaluation 🏥\n", - "\n", - "**Problem:** Accurately labeling Emergency Severity Index (ESI) levels from nurse triage notes is crucial for clinical research and model development. However, real-world data is highly sensitive and access is often limited. Obtaining high-quality human annotations is also costly and slow, which makes it challenging to create large, diverse datasets for robust evaluation.\n", - "\n", - "**Solution:** Synthetic data provides a scalable and privacy-preserving approach. By simulating realistic triage notes and ESI labels, we can build rich datasets without exposing patient information or relying heavily on human annotators. This strategy accelerates iteration, benchmarking, and model improvement, addressing key bottlenecks caused by data scarcity.\n", - "\n", - "- **Use case:** Predict ESI levels from synthetic nurse triage notes using LLMs\n", - "- **Goal:** Evaluate model accuracy and the quality/complexity of generated notes across a range of clinical scenarios\n", - "- **Pipeline:** Synthetic data ➔ LLM-as-a-Judge scoring ➔ Filtering ➔ Evaluation\n", - "\n", - "```text\n", - " ┌───────────────────────────────┐ ┌─────────────────────────────┐\n", - " │ NeMo Data Designer │ │ NeMo Evaluator │\n", - " │ +------------------------+ │ │ +-----------------------+ │\n", - " │ | Nurse Triage Note 📝 |───┼───────▶| | LLM predicts ESI 🔍🤖 | │\n", - " │ +------------------------+ │ │ +-----------------------+ │\n", - " │ + │ │ | │\n", - " │ │ │ v │\n", - " │ +------------------------+ │ │ +-----------------------+ │\n", - " │ | Ground Truth (ESI) ✅ |───┼───────▶| | Predicted ESI 🏷️ | │\n", - " │ +------------------------+ │ │ +-----------------------+ │\n", - " └───────────────────────────────┘ │ | │\n", - " │ v │\n", - " │ +-----------------------+ │\n", - " │ | Metrics 📊 | │\n", - " │ | (Accuracy) | │\n", - " │ +-----------------------+ │\n", - " └─────────────────────────────┘\n", - "```\n", - "\n", - "**Workflow Overview:**\n", - "- 🏗️ Generate realistic, privacy-safe triage notes, evaluate their quality using LLM-as-a-Judge, and filter for high-value examples with Data Designer.\n", - "- ⬆️ Upload the resulting dataset to a compatible datastore (e.g., HuggingFace Datasets).\n", - "- 📈 Use the Evaluator to compute ESI classification accuracy and other relevant metrics.\n", - "\n", - "Tip: Run the cells below in order. You can re-run data preview/generation to explore different clinical scenarios and difficulty settings.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **Step 1**: 🎨 NeMo Data Designer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_microservices.data_designer.essentials import (\n", - " DataDesignerConfigBuilder,\n", - " NeMoDataDesignerClient,\n", - " ModelConfig,\n", - " InferenceParameters,\n", - " SamplerColumnConfig,\n", - " SamplerType,\n", - " CategorySamplerParams,\n", - " SubcategorySamplerParams,\n", - " PersonSamplerParams,\n", - " LLMTextColumnConfig,\n", - " LLMJudgeColumnConfig,\n", - " Score,\n", - ")\n", - "\n", - "data_designer_client = NeMoDataDesignerClient(\n", - " base_url=\"http://localhost:8080\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This name is set in the microservice deployment configuration.\n", - "MODEL_PROVIDER = \"nvidiabuild\"\n", - "\n", - "# The model ID is from build.nvidia.com.\n", - "MODEL_ID_GENERATOR = \"openai/gpt-oss-20b\"\n", - "MODEL_ID_JUDGE = \"openai/gpt-oss-120b\"\n", - "\n", - "# We choose these aliases to be descriptive for our use case.\n", - "MODEL_ALIAS_GENERATOR = \"content_generator\"\n", - "MODEL_ALIAS_JUDGE = \"judge\"\n", - "\n", - "model_configs = [\n", - " ModelConfig(\n", - " provider=MODEL_PROVIDER,\n", - " alias=MODEL_ALIAS_GENERATOR,\n", - " model=MODEL_ID_GENERATOR,\n", - " inference_parameters=InferenceParameters(\n", - " max_tokens=8000,\n", - " temperature=0.7,\n", - " top_p=0.95,\n", - " )\n", - " ),\n", - " ModelConfig(\n", - " provider=MODEL_PROVIDER,\n", - " alias=MODEL_ALIAS_JUDGE,\n", - " model=MODEL_ID_JUDGE,\n", - " inference_parameters=InferenceParameters(\n", - " max_tokens=4096,\n", - " temperature=0.1,\n", - " top_p=0.95,\n", - " )\n", - " )\n", - "]\n", - "\n", - "config_builder = DataDesignerConfigBuilder(model_configs=model_configs)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🎲 Sampler columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ESI levels\n", - "ESI_LEVELS = [\n", - " \"ESI 1: Resuscitation\",\n", - " \"ESI 2: Emergency\",\n", - " \"ESI 3: Urgent\",\n", - " \"ESI 4: Less Urgent\",\n", - " \"ESI 5: Non-urgent\",\n", - "]\n", - "\n", - "# Unique record ID\n", - "config_builder.add_column(\n", - " SamplerColumnConfig(\n", - " name=\"record_id\",\n", - " sampler_type=SamplerType.UUID,\n", - " params={\"short_form\": True, \"uppercase\": True}\n", - " )\n", - ")\n", - "\n", - "# ESI level (balanced sampling)\n", - "config_builder.add_column(\n", - " SamplerColumnConfig(\n", - " name=\"esi_level_description\",\n", - " sampler_type=SamplerType.CATEGORY,\n", - " params=CategorySamplerParams(\n", - " values=ESI_LEVELS,\n", - " ),\n", - " )\n", - ")\n", - "\n", - "# Clinical scenario (conditioned on ESI level)\n", - "config_builder.add_column(\n", - " SamplerColumnConfig(\n", - " name=\"clinical_scenario\",\n", - " sampler_type=SamplerType.SUBCATEGORY,\n", - " params=SubcategorySamplerParams(\n", - " category=\"esi_level_description\",\n", - " values={\n", - " ESI_LEVELS[0]: [\n", - " \"Cardiac arrest\",\n", - " \"Unresponsive with no pulse\",\n", - " \"Severe respiratory distress\",\n", - " \"Major trauma with signs of shock\",\n", - " \"Suspected narcotic overdose with shallow respirations\",\n", - " ],\n", - " ESI_LEVELS[1]: [\n", - " \"Crushing substernal chest pain radiating to the left arm\",\n", - " \"Sudden onset of facial droop and arm weakness\",\n", - " \"New onset confusion in an elderly patient\",\n", - " \"Active suicidal ideation with a plan\",\n", - " \"High-speed motor vehicle accident\",\n", - " \"Severe abdominal pain in a patient with a history of aortic aneurysm\",\n", - " ],\n", - " ESI_LEVELS[2]: [\n", - " \"Abdominal pain with fever and nausea\",\n", - " \"High fever with a productive cough and history of COPD\",\n", - " \"Displaced fracture with visible deformity\",\n", - " \"Asthma attack, responsive to initial treatment\",\n", - " \"Vaginal bleeding in a pregnant patient\",\n", - " \"Head injury with brief loss of consciousness\",\n", - " ],\n", - " ESI_LEVELS[3]: [\n", - " \"Simple laceration requiring sutures\",\n", - " \"Twisted ankle, unable to bear weight\",\n", - " \"Sore throat with fever\",\n", - " \"Symptoms of a urinary tract infection\",\n", - " \"Painful ear with fever in a child\",\n", - " ],\n", - " ESI_LEVELS[4]: [\n", - " \"Request for a prescription refill\",\n", - " \"Suture removal\",\n", - " \"Minor rash present for several days\",\n", - " \"Common cold symptoms\",\n", - " \"Follow-up for a minor wound check\",\n", - " ],\n", - " },\n", - " ),\n", - " )\n", - ")\n", - "\n", - "# Synthetic patient info\n", - "config_builder.add_column(\n", - " SamplerColumnConfig(\n", - " name=\"patient\",\n", - " sampler_type=SamplerType.PERSON,\n", - " params=PersonSamplerParams(age_range=[18, 70]),\n", - " )\n", - ")\n", - "\n", - "# Triage note writing style (captures range from poor to best quality notes)\n", - "config_builder.add_column(\n", - " SamplerColumnConfig(\n", - " name=\"writing_style\",\n", - " sampler_type=SamplerType.CATEGORY,\n", - " params=CategorySamplerParams(\n", - " values=[\"Draft\", \"Adequate\", \"Polished\"]\n", - " ),\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🦜 LLM-generated columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# LLM-generated triage note\n", - "config_builder.add_column(\n", - " LLMTextColumnConfig(\n", - " name=\"content\",\n", - " prompt=(\n", - " \"You are an experienced triage nurse in a busy Emergency Department writing a draft note. \"\n", - " \"Write a realistic, concise triage note in a telegraphic style using common medical abbreviations. \"\n", - " \"The note is for a {{ patient.age }} y/o {{ 'M' if patient.sex == 'Male' else 'F' }}. \"\n", - " \"Triage classification: '{{ esi_level_description }}'. \"\n", - " \"Reason for visit: '{{ clinical_scenario }}'. \"\n", - " \"Desired writing style: '{{ writing_style }}'. \"\n", - " \"Structure the note with 'CC:' and 'HPI:'. \"\n", - " \"Adjust the style and level of clinical detail based on the 'writing_style': \"\n", - " \"- Draft: Use minimal structure, brief statements, and omit some details; clinical indicators may be less clear. \"\n", - " \"- Adequate: Use complete sentences, include all relevant clinical indicators, but avoid excessive detail. \"\n", - " \"- Polished: Be thorough, precise, and clear; include nuanced or subtle signs and show strong clinical reasoning. \"\n", - " \"Also, adjust level of detail based on urgency (ESI 1 is always brief). \"\n", - " \"Respond with ONLY the note text, starting with 'CC:'.\"\n", - " ),\n", - " model_alias=MODEL_ALIAS_GENERATOR,\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### ⚖️ LLM-as-a-Judge Evaluation Step" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Rubric: clinical coherence\n", - "clinical_coherence_rubric = Score(\n", - " name=\"clinical coherence\",\n", - " description=\"Evaluates how well the clinical details in the triage note align with the assigned ESI level and scenario.\",\n", - " options={\n", - " \"5\": \"Note is perfectly aligned with the ESI level and scenario; details are clinically plausible and specific.\",\n", - " \"4\": \"Note is well-aligned, with only minor details that might be slightly inconsistent.\",\n", - " \"3\": \"Note is generally consistent, but some key clinical indicators are missing or don't fully match the ESI level.\",\n", - " \"2\": \"Note shows significant inconsistency between the clinical details and the assigned ESI level.\",\n", - " \"1\": \"Note is clinically incoherent and does not reflect the assigned ESI level or scenario at all.\"\n", - " }\n", - ")\n", - "\n", - "# Rubric: ESI level complexity (reduced to 3 levels: Simple, Moderate, Complex)\n", - "esi_level_complexity_rubric = Score(\n", - " name=\"esi level complexity\",\n", - " description=\"Evaluates how difficult it is to infer the correct ESI level from the note. Higher scores indicate greater complexity, which is desirable for creating a challenging dataset.\",\n", - " options={\n", - " \"Complex\": \"Note contains subtle or conflicting information, requiring clinical reasoning to distinguish between ESI levels.\",\n", - " \"Moderate\": \"Note requires some clinical inference; indicators are present but not always immediately obvious.\",\n", - " \"Simple\": \"Note uses clear, direct, or textbook indicators that make the ESI level obvious.\"\n", - " }\n", - ")\n", - "\n", - "# LLM judge: triage note quality\n", - "EVAL_TRIAGE_NOTE_PROMPT = \"\"\"\\\n", - "You are an expert ER physician responsible for quality control. Your task is to evaluate a synthetic triage note for its realism and complexity.\n", - "\n", - "**Triage Situation:**\n", - "- ESI Level: '{{ esi_level_description }}'\n", - "- Clinical Scenario: '{{ clinical_scenario }}'\n", - "- Desired Writing Style: '{{ writing_style }}'\n", - "- Patient: {{ patient.age }}-year-old {{ patient.sex }}\n", - "\n", - "**Generated Triage Note:**\n", - "\"{{ content }}\"\n", - "\n", - "Take a deep breath and carefully evaluate the \"Generated Triage Note\". Assess its clinical coherence with the situation and how well it matches the desired complexity. The goal is to create a challenging dataset, so higher complexity scores are desirable.\n", - "\"\"\"\n", - "\n", - "config_builder.add_column(\n", - " LLMJudgeColumnConfig(\n", - " name=\"triage_note_quality\",\n", - " model_alias=MODEL_ALIAS_JUDGE,\n", - " prompt=EVAL_TRIAGE_NOTE_PROMPT,\n", - " scores=[clinical_coherence_rubric, esi_level_complexity_rubric],\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🧪 Generate & Preview \n", - "\n", - "Tip: Re-run preview to cycle examples; adjust prompts, temperatures, or scenarios to tune realism and difficulty.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "preview = data_designer_client.preview(config_builder, num_records=10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run this cell multiple times to cycle through the 10 preview records.\n", - "preview.display_sample_record()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The preview dataset is available as a pandas DataFrame.\n", - "preview.dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🚀 Scale Up Generations\n", - "Once satisfied with the preview results, scale up to generate the full dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Submit batch job\n", - "job_results = data_designer_client.create(config_builder, num_records=100)\n", - "\n", - "job_results.wait_until_done()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = job_results.load_dataset()\n", - "print(\"\\nGenerated dataset shape:\", dataset.shape)\n", - "\n", - "dataset.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🧹 Refinement [Optional]\n", - "\n", - "Filter the generated dataset to retain only higher-quality triage notes:\n", - " \n", - "- Keeps only notes with **Clinical Coherence ≥ 2** (as judged by LLM).\n", - "- Retrieves ESI level complexity directly from the LLM judge column (`triage_note_quality`)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ast\n", - "from rich import print\n", - "\n", - "def filter_by_scores(df, min_coherence=1, samples_per_complexity=100):\n", - " indices = []\n", - " for idx, k in enumerate(df['triage_note_quality']):\n", - " # If k is a string, parse it to dict\n", - " if isinstance(k, str):\n", - " try:\n", - " k_dict = ast.literal_eval(k)\n", - " except Exception:\n", - " continue\n", - " else:\n", - " k_dict = k\n", - " try:\n", - " coherence_score = int(k_dict['clinical coherence']['score'])\n", - " if coherence_score >= min_coherence:\n", - " indices.append(idx)\n", - " except Exception:\n", - " continue\n", - " filtered_df = df.iloc[indices]\n", - " filtered_df = filtered_df[[\"esi_level_description\", \"content\", \"triage_note_quality\"]]\n", - " filtered_df['esi_level_complexity'] = filtered_df['triage_note_quality'].apply(\n", - " lambda k: (ast.literal_eval(k) if isinstance(k, str) else k).get('esi level complexity', {}).get('score')\n", - " )\n", - " filtered_df.drop(columns=['triage_note_quality'], inplace=True)\n", - " percent_filtered = 100 * len(filtered_df) / len(df) if len(df) > 0 else 0\n", - " print(f\"Filtered {len(filtered_df)} out of {len(df)} records ({percent_filtered:.1f}%)\")\n", - " # Sample up to N per complexity\n", - " sampled_df = (\n", - " filtered_df\n", - " .groupby('esi_level_complexity', group_keys=False)\n", - " .apply(lambda x: x.sample(min(len(x), samples_per_complexity), random_state=42))\n", - " .reset_index(drop=True)\n", - " )\n", - " print(f\"Sampled {len(sampled_df)} records total, {samples_per_complexity} (or less) per complexity level.\")\n", - " return sampled_df\n", - "\n", - "filtered_df = filter_by_scores(dataset, samples_per_complexity=100)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 👀 Inspect results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def show_example_triage_notes(filtered_df, num_examples=5):\n", - " from rich.console import Console\n", - " from rich.panel import Panel\n", - " from rich.text import Text\n", - "\n", - " console = Console()\n", - " examples = filtered_df.sample(num_examples)\n", - "\n", - " console.print(f\"[italic]Showing last {num_examples} filtered triage notes:[/italic]\\n\")\n", - " for idx, row in examples.iterrows():\n", - " esi_level = str(row.get(\"esi_level_description\", \"\"))\n", - " esi_level_complexity = str(row.get(\"esi_level_complexity\", \"\"))\n", - " content = str(row.get(\"content\", \"\"))\n", - " # Use blue for the complexity level\n", - " panel_title = f\"ESI Level: {esi_level} [bold][blue]({esi_level_complexity})[/blue][/bold]\"\n", - " panel = Panel(\n", - " Text(content, style=\"green\"),\n", - " title=panel_title,\n", - " border_style=\"cyan\",\n", - " expand=False,\n", - " padding=(1, 2),\n", - " )\n", - " console.print(panel)\n", - " console.print() # Extra newline for separation\n", - "\n", - "# Show some example records from the bottom using rich\n", - "show_example_triage_notes(filtered_df, num_examples=3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **Step 2**: 📊 Nemo Evaluator\n", - "\n", - "We evaluate the model on filtered triage notes to see if it predicts the correct ESI level.\n", - "\n", - "- **Dataset**: HF-compatible JSONL served by the datastore\n", - "- **Task**: Completion with structured output `{ \"esi_level_description\": \"...\" }`\n", - "- **Metric**: String containment check against ground-truth ESI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from huggingface_hub import HfApi\n", - "from huggingface_hub.utils import RepositoryNotFoundError\n", - "from nemo_microservices import NeMoMicroservices\n", - "\n", - "# Service endpoint for the NeMo Evaluator (change if running elsewhere)\n", - "BASE_URL = \"http://localhost:8080\"\n", - "\n", - "# Initialize NeMoMicroservices client (does not trigger any action yet)\n", - "client = NeMoMicroservices(base_url=BASE_URL)\n", - "\n", - "# Namespace for organizing datasets within Hugging Face Hub\n", - "NAMESPACE = \"triage-eval\"\n", - "\n", - "# Split the filtered dataframe into different complexity levels\n", - "df_complexities = {\n", - " \"simple\": filtered_df[filtered_df[\"esi_level_complexity\"] == \"Simple\"],\n", - " \"moderate\": filtered_df[filtered_df[\"esi_level_complexity\"] == \"Moderate\"],\n", - " \"complex\": filtered_df[filtered_df[\"esi_level_complexity\"] == \"Complex\"]\n", - "}\n", - "\n", - "# Hugging Face Hub endpoint for local server (set up in your datastore container)\n", - "HF_ENDPOINT = \"http://localhost:3000/v1/hf\" # Exposed from: 0.0.0.0:3000->3000/tcp\n", - "# Initialize the Hugging Face HF API client\n", - "hf_api = HfApi(endpoint=HF_ENDPOINT, token=os.environ[\"HF_TOKEN\"])\n", - "\n", - "# Create a dict to store files_url for each complexity level\n", - "files_url_dict = {}\n", - "\n", - "# Loop over each complexity level, preparing, saving, and uploading evaluation datasets\n", - "for level, df in df_complexities.items():\n", - " DATASET_NAME = f\"nurse-triage-notes-{level}\"\n", - " repo_id = f\"{NAMESPACE}/{DATASET_NAME}\"\n", - "\n", - " # Create (or get) the dataset repo for this complexity level\n", - " hf_api.create_repo(repo_id=repo_id, repo_type=\"dataset\", exist_ok=True)\n", - "\n", - " file_name = f\"dataset_{level}.jsonl\"\n", - " df.to_json(file_name, orient=\"records\", lines=True)\n", - " print(f\"Dataset prepared with {len(df)} samples for complexity '{level.capitalize()}'\")\n", - "\n", - " # Upload the dataset to the Hugging Face Hub\n", - " result = hf_api.upload_file(\n", - " path_or_fileobj=file_name,\n", - " path_in_repo=file_name,\n", - " repo_id=repo_id,\n", - " repo_type=\"dataset\",\n", - " revision=\"main\",\n", - " commit_message=f\"Eval dataset in {repo_id} for {level.capitalize()}\"\n", - " )\n", - "\n", - " print(f\"Dataset uploaded: {result}\") # Print result with the uploaded file URL/info\n", - "\n", - " # Construct files_url and store it for this complexity level\n", - " files_url = f\"hf://datasets/{NAMESPACE}/{DATASET_NAME}\"\n", - " files_url_dict[level] = files_url" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🧪 Evaluator Flow\n", - "This section defines the evaluation configuration used to assess model performance on triage note classification using a custom evaluator." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "EVALUATOR_CONFIG = {\n", - " \"eval_config\": {\n", - " \"type\": \"custom\",\n", - " \"params\": {\n", - " \"parallelism\": 8\n", - " },\n", - " \"tasks\": {\n", - " \"triage_classification\": {\n", - " \"type\": \"completion\",\n", - " \"params\": {\n", - " \"template\": {\n", - " \"messages\": [\n", - " {\n", - " \"role\": \"system\",\n", - " \"content\": (\n", - " \"You are an expert ER triage nurse. Your task is to classify the following triage note into one of the five Emergency Severity Index (ESI) levels.\"\n", - " f\" The possible levels are: {', '.join([repr(level) for level in ESI_LEVELS])}.\"\n", - " \" Carefully analyze the clinical details in the triage note, focusing on patient acuity, resource needs, and risk of rapid deterioration.\"\n", - " \" Respond with only the selected ESI level description, exactly matching one of the listed possibilities. Do not provide extra text or explanation.\"\n", - " )\n", - " },\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": (\n", - " \"Triage Note: {{item.content}}\\n\"\n", - " \"Classify the ESI level for this note based on the provided definitions.\"\n", - " \" Respond in JSON format only: { \\\"esi_level_description\\\": \\\"...\\\" }\"\n", - " )\n", - " }\n", - " ],\n", - " }\n", - " },\n", - " \"metrics\": {\n", - " \"accuracy\": {\n", - " \"type\": \"string-check\",\n", - " \"params\": {\n", - " \"check\": [\n", - " \"{{sample.output_text}}\",\n", - " \"contains\",\n", - " \"{{item.esi_level_description}}\"\n", - " ]\n", - " }\n", - " }\n", - " },\n", - " \"dataset\": {\n", - " \"files_url\": None\n", - " }\n", - " }\n", - " }\n", - " },\n", - " \"target_config\": {\n", - " \"type\": \"model\",\n", - " \"model\": {\n", - " \"api_endpoint\": {\n", - " \"url\": None,\n", - " \"model_id\": None\n", - " }\n", - " }\n", - " }\n", - "}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 🔍 Model evaluation loop and configuration\n", - "\n", - "This section compares multiple models (A/B testing) on the triage note classification task **across each complexity level** (Simple, Moderate, Complex).\n", - "\n", - "The models evaluated are:\n", - " - **Qwen3-8B** (`Qwen/Qwen3-8B`)\n", - " - **Nemotron Nano 9B v2** (`nvidia/nvidia-nemotron-nano-9b-v2`)\n", - "\n", - "For *each* complexity level, the accuracy score for each model is printed, allowing for side-by-side evaluation of how each model performs at every complexity." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "import copy\n", - "import pandas as pd\n", - "\n", - "# This code assumes EVALUATOR_CONFIG is available in the notebook scope\n", - "\n", - "MODEL_SPECS = [\n", - " {\n", - " \"name\": \"Qwen3-8B\",\n", - " \"model_id\": \"Qwen/Qwen3-8B\",\n", - " \"url\": \"https://your-model-endpoint-1/v1/completions\" # <-- Placeholder URL\n", - " },\n", - " {\n", - " \"name\": \"Nemotron Nano 9B v2\",\n", - " \"model_id\": \"nvidia/nvidia-nemotron-nano-9b-v2\",\n", - " \"url\": \"https://your-model-endpoint-2/v1/completions\" # <-- Placeholder URL\n", - " }\n", - "]\n", - "\n", - "COMPLEXITIES = [\"simple\", \"moderate\", \"complex\"]\n", - "\n", - "def run_evaluation(client, namespace, evaluator_config, model_spec, complexity, files_url_dict):\n", - " \"\"\"\n", - " Populates the evaluator_config, filling in the files_url and endpoint, then runs evaluation.\n", - " Returns accuracy for the given model+complexity.\n", - " \"\"\"\n", - " # Work with a deepcopy of the config for isolation\n", - " config = copy.deepcopy(evaluator_config)\n", - " # Set the dataset URL for the current complexity\n", - " config['eval_config']['tasks']['triage_classification']['dataset']['files_url'] = files_url_dict[complexity]\n", - " # Set the API endpoint and model_id for this model\n", - " config['target_config']['model']['api_endpoint']['url'] = model_spec['url']\n", - " config['target_config']['model']['api_endpoint']['model_id'] = model_spec['model_id']\n", - "\n", - " # Submit evaluation job\n", - " job = client.evaluation.jobs.create(\n", - " namespace=namespace,\n", - " # here, pass through the two parts\n", - " target=config['target_config'],\n", - " config=config['eval_config']\n", - " )\n", - " print(f\"Submitted evaluation job for model '{model_spec['name']}' on complexity '{complexity.capitalize()}' (job id: {job.id})\")\n", - " # Wait until complete\n", - " while True:\n", - " time.sleep(3)\n", - " progress = client.evaluation.jobs.status(job.id).progress\n", - " if progress >= 100: break\n", - " if progress % 20 == 0: print(f\" ⏳ Job {job.id} is {progress}% done\")\n", - " print(f\" ✔️ Job done for model '{model_spec['name']}' on complexity '{complexity.capitalize()}'\")\n", - "\n", - " # Fetch results and extract accuracy\n", - " results = client.evaluation.jobs.results(job.id)\n", - " accuracy_value = results.tasks['triage_classification'].metrics['accuracy'].scores['string-check'].value\n", - " return accuracy_value\n", - "\n", - "results_dict = {model_spec['name']: {} for model_spec in MODEL_SPECS}\n", - "\n", - "print(\"Starting evaluation jobs (per model, per complexity)...\")\n", - "for complexity in COMPLEXITIES:\n", - " for spec in MODEL_SPECS:\n", - " accuracy = run_evaluation(client, NAMESPACE, EVALUATOR_CONFIG, spec, complexity, files_url_dict)\n", - " results_dict[spec['name']][complexity.capitalize()] = 100 * accuracy # Store as percentage\n", - " print(f\" --> DONE: {spec['name']}, {complexity.capitalize()} (Accuracy: {100*accuracy:.2f}%)\\n\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 📊 Visualize Model Accuracies\n", - "The table below summarizes the accuracy (%) of each model for each complexity level." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "df_results = pd.DataFrame(results_dict).T\n", - "df_results = df_results[[c.capitalize() for c in COMPLEXITIES]]\n", - "\n", - "print(\"\\nModel Accuracy Table (%):\")\n", - "display(df_results.style.format(\"{:.2f}\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "nemo-data-designer-v25.11", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.5" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Emergency Triage LLM Pipeline 🏥\n", + "\n", + "**Problem:** Accurately labeling Emergency Severity Index (ESI) levels from nurse triage notes is crucial for clinical research and model development. Real-world data is highly sensitive and access is often limited; human annotations are costly and slow, making it hard to build large, diverse datasets for robust evaluation.\n", + "\n", + "**Solution:** Synthetic data provides a scalable, privacy-preserving approach. We simulate realistic triage notes and ESI labels with 🎨 NeMo Data Designer, score quality with LLM-as-a-Judge, then evaluate model accuracy with 📊 NeMo Evaluator—no patient data, minimal annotation burden.\n", + "\n", + "- **Use case:** Predict ESI levels from synthetic nurse triage notes using LLMs\n", + "- **Goal:** Evaluate model accuracy and note quality/complexity across clinical scenarios\n", + "- **Pipeline:** Part 1 (Data Designer) ➔ `triage_dataset.jsonl` ➔ Part 2 (Evaluator) ➔ accuracy by ESI level & complexity\n", + "\n", + "**Workflow:**\n", + "- 🎨 **Part 1 — NeMo Data Designer:** Generate triage notes, score with LLM-as-a-Judge, filter & save to `triage_dataset.jsonl`\n", + "- 📊 **Part 2 — NeMo Evaluator:** Load dataset → run inference (NMP Gateway) → score → compare accuracy by ESI level and complexity\n", + "\n", + "💡 **Tip:** Run cells in order. You can re-run Part 1 preview/generation to explore different scenarios and complexity settings.\n", + "\n", + "# \"Pipeline" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ⚙️ Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from data_designer.config import (\n", + " DataDesignerConfigBuilder,\n", + " ModelConfig,\n", + " ChatCompletionInferenceParams,\n", + " SamplerColumnConfig,\n", + " SamplerType,\n", + " CategorySamplerParams,\n", + " SubcategorySamplerParams,\n", + " PersonFromFakerSamplerParams,\n", + " LLMTextColumnConfig,\n", + " LLMJudgeColumnConfig,\n", + " Score,\n", + ")\n", + "from nemo_platform import NeMoPlatform\n", + "\n", + "BASE_URL = \"http://localhost:8080\"\n", + "client = NeMoPlatform(base_url=BASE_URL, workspace=\"default\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 1 — 🎨 NeMo Data Designer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🤖 Model configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PROVIDER = \"nvidiabuild\"\n", + "\n", + "# 🤖 Generator: Nano; Judge: Super\n", + "MODEL_ID_GENERATOR = \"nvidia/nemotron-3-nano-30b-a3b\"\n", + "MODEL_ID_JUDGE = \"nvidia/nemotron-3-super-120b-a12b\"\n", + "\n", + "MODEL_ALIAS_GENERATOR = \"content_generator\"\n", + "MODEL_ALIAS_JUDGE = \"judge\"\n", + "\n", + "model_configs = [\n", + " ModelConfig(\n", + " provider=MODEL_PROVIDER,\n", + " alias=MODEL_ALIAS_GENERATOR,\n", + " model=MODEL_ID_GENERATOR,\n", + " inference_parameters=ChatCompletionInferenceParams(\n", + " max_tokens=8000, temperature=0.7, top_p=0.95,\n", + " )\n", + " ),\n", + " ModelConfig(\n", + " provider=MODEL_PROVIDER,\n", + " alias=MODEL_ALIAS_JUDGE,\n", + " model=MODEL_ID_JUDGE,\n", + " inference_parameters=ChatCompletionInferenceParams(\n", + " max_tokens=4096, temperature=0.1, top_p=0.95,\n", + " )\n", + " )\n", + "]\n", + "\n", + "config_builder = DataDesignerConfigBuilder(model_configs=model_configs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🎲 Sampler columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ESI_LEVELS = [\n", + " \"ESI 1: Resuscitation\",\n", + " \"ESI 2: Emergency\",\n", + " \"ESI 3: Urgent\",\n", + " \"ESI 4: Less Urgent\",\n", + " \"ESI 5: Non-urgent\",\n", + "]\n", + "\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"record_id\",\n", + " sampler_type=SamplerType.UUID,\n", + " params={\"short_form\": True, \"uppercase\": True}\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"esi_level_description\",\n", + " sampler_type=SamplerType.CATEGORY,\n", + " params=CategorySamplerParams(values=ESI_LEVELS),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"clinical_scenario\",\n", + " sampler_type=SamplerType.SUBCATEGORY,\n", + " params=SubcategorySamplerParams(\n", + " category=\"esi_level_description\",\n", + " values={\n", + " ESI_LEVELS[0]: [\n", + " \"Cardiac arrest\",\n", + " \"Unresponsive with no pulse\",\n", + " \"Severe respiratory distress\",\n", + " \"Major trauma with signs of shock\",\n", + " \"Suspected narcotic overdose with shallow respirations\",\n", + " ],\n", + " ESI_LEVELS[1]: [\n", + " \"Crushing substernal chest pain radiating to the left arm\",\n", + " \"Sudden onset of facial droop and arm weakness\",\n", + " \"New onset confusion in an elderly patient\",\n", + " \"Active suicidal ideation with a plan\",\n", + " \"High-speed motor vehicle accident\",\n", + " \"Severe abdominal pain in a patient with a history of aortic aneurysm\",\n", + " ],\n", + " ESI_LEVELS[2]: [\n", + " \"Abdominal pain with fever and nausea\",\n", + " \"High fever with a productive cough and history of COPD\",\n", + " \"Displaced fracture with visible deformity\",\n", + " \"Asthma attack, responsive to initial treatment\",\n", + " \"Vaginal bleeding in a pregnant patient\",\n", + " \"Head injury with brief loss of consciousness\",\n", + " ],\n", + " ESI_LEVELS[3]: [\n", + " \"Simple laceration requiring sutures\",\n", + " \"Twisted ankle, unable to bear weight\",\n", + " \"Sore throat with fever\",\n", + " \"Symptoms of a urinary tract infection\",\n", + " \"Painful ear with fever in a child\",\n", + " ],\n", + " ESI_LEVELS[4]: [\n", + " \"Request for a prescription refill\",\n", + " \"Suture removal\",\n", + " \"Minor rash present for several days\",\n", + " \"Common cold symptoms\",\n", + " \"Follow-up for a minor wound check\",\n", + " ],\n", + " },\n", + " ),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"patient\",\n", + " sampler_type=SamplerType.PERSON_FROM_FAKER,\n", + " params=PersonFromFakerSamplerParams(age_range=[18, 70]),\n", + " )\n", + ")\n", + "\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"writing_style\",\n", + " sampler_type=SamplerType.CATEGORY,\n", + " params=CategorySamplerParams(values=[\"Draft\", \"Adequate\", \"Polished\"]),\n", + " )\n", + ")\n", + "\n", + "# 📐 Target complexity (bias Moderate/Complex for balanced judge distribution)\n", + "TARGET_COMPLEXITIES = [\"Simple\", \"Moderate\", \"Complex\"]\n", + "COMPLEXITY_WEIGHTS = [1, 4, 4] # fewer Simple, more Moderate/Complex\n", + "config_builder.add_column(\n", + " SamplerColumnConfig(\n", + " name=\"target_complexity\",\n", + " sampler_type=SamplerType.CATEGORY,\n", + " params=CategorySamplerParams(values=TARGET_COMPLEXITIES, weights=COMPLEXITY_WEIGHTS),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📝 LLM-generated columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_builder.add_column(\n", + " LLMTextColumnConfig(\n", + " name=\"content\",\n", + " prompt=(\n", + " \"You are an experienced triage nurse in a busy Emergency Department writing a draft note. \"\n", + " \"Write a realistic, concise triage note in a telegraphic style using common medical abbreviations. \"\n", + " \"The note is for a {{ patient.age }} y/o {{ 'M' if patient.sex == 'Male' else 'F' }}. \"\n", + " \"Triage classification: '{{ esi_level_description }}'. \"\n", + " \"Reason for visit: '{{ clinical_scenario }}'. \"\n", + " \"Desired writing style: '{{ writing_style }}'. \"\n", + " \"Never state the ESI level or use 'ESI 1'/'ESI 2' etc. in the note. \"\n", + " \"Target complexity for inferring ESI level from the note: '{{ target_complexity }}'. \"\n", + " \"Match this complexity: \"\n", + " \"- Simple: Include indicators that need some clinical inference; present but not immediately obvious. \"\n", + " \"- Moderate: Include subtle or conflicting information so that real clinical reasoning is needed to decide the ESI level. \"\n", + " \"- Complex: Note should be deliberately ambiguous or borderline—could reasonably support more than one ESI level. Include red herrings, conflicting cues, or omit key information so that definitive classification requires expert judgment and weighing of competing factors. \"\n", + " \"Structure the note with 'CC:' and 'HPI:'. \"\n", + " \"Adjust the style and level of clinical detail based on the 'writing_style': \"\n", + " \"- Draft: Use minimal structure, brief statements, and omit some details; clinical indicators may be less clear. \"\n", + " \"- Adequate: Use complete sentences, include all relevant clinical indicators, but avoid excessive detail. \"\n", + " \"- Polished: Be thorough, precise, and clear; include nuanced or subtle signs and show strong clinical reasoning. \"\n", + " \"Also, adjust level of detail based on urgency (ESI 1 is always brief). \"\n", + " \"Respond with ONLY the note text, starting with 'CC:'.\"\n", + " ),\n", + " model_alias=MODEL_ALIAS_GENERATOR,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ⚖️ LLM-as-a-Judge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clinical_coherence_rubric = Score(\n", + " name=\"clinical coherence\",\n", + " description=\"Evaluates how well the clinical details in the triage note align with the assigned ESI level and scenario.\",\n", + " options={\n", + " \"5\": \"Note is perfectly aligned with the ESI level and scenario; details are clinically plausible and specific.\",\n", + " \"4\": \"Note is well-aligned, with only minor details that might be slightly inconsistent.\",\n", + " \"3\": \"Note is generally consistent, but some key clinical indicators are missing or don't fully match the ESI level.\",\n", + " \"2\": \"Note shows significant inconsistency between the clinical details and the assigned ESI level.\",\n", + " \"1\": \"Note is clinically incoherent and does not reflect the assigned ESI level or scenario at all.\"\n", + " }\n", + ")\n", + "\n", + "esi_level_complexity_rubric = Score(\n", + " name=\"esi level complexity\",\n", + " description=\"Evaluates how difficult it is to infer the correct ESI level from the note.\",\n", + " options={\n", + " \"Simple\": \"Note includes indicators that need some clinical inference; present but not immediately obvious.\",\n", + " \"Moderate\": \"Note contains subtle or conflicting information, requiring clinical reasoning to decide the ESI level.\",\n", + " \"Complex\": \"Note is deliberately ambiguous or borderline; could reasonably support more than one ESI level. Contains red herrings, conflicting cues, or omits key information; definitive classification requires expert judgment and weighing of competing factors.\"\n", + " }\n", + ")\n", + "\n", + "EVAL_TRIAGE_NOTE_PROMPT = \"\"\"\\\n", + "You are an expert ER physician responsible for quality control. Your task is to evaluate a synthetic triage note for its realism and complexity.\n", + "\n", + "**Triage Situation:**\n", + "- ESI Level: '{{ esi_level_description }}'\n", + "- Clinical Scenario: '{{ clinical_scenario }}'\n", + "- Desired Writing Style: '{{ writing_style }}'\n", + "- Patient: {{ patient.age }}-year-old {{ patient.sex }}\n", + "- Intended complexity: '{{ target_complexity }}' (score based on what the note actually conveys; if the note achieves its intended complexity, say so).\n", + "\n", + "**Generated Triage Note:**\n", + "\\\"{{ content }}\\\"\n", + "\n", + "Take a deep breath and carefully evaluate the \\\"Generated Triage Note\\\". Assess its clinical coherence with the situation and how well it matches the desired complexity. The goal is to create a challenging dataset, so higher complexity scores are desirable.\n", + "\"\"\"\n", + "\n", + "config_builder.add_column(\n", + " LLMJudgeColumnConfig(\n", + " name=\"triage_note_quality\",\n", + " model_alias=MODEL_ALIAS_JUDGE,\n", + " prompt=EVAL_TRIAGE_NOTE_PROMPT,\n", + " scores=[clinical_coherence_rubric, esi_level_complexity_rubric],\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 👀 Preview\n", + "\n", + "Quick preview to validate the pipeline (runs synchronously, up to 10 records)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preview = client.data_designer.preview(config_builder, num_records=5)\n", + "\n", + "if preview.dataset is not None and len(preview.dataset) > 0:\n", + " preview.display_sample_record()\n", + " # Use preview dataset if skipping the batch job below\n", + " dataset = preview.dataset\n", + "else:\n", + " print(\"Preview returned no data. Check the logs above for errors.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preview.dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📦 Generate full dataset\n", + "\n", + "Scale up with batch `create` (requires Docker socket and jobs controller). For local quickstart without Docker socket, use the preview dataset from above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "job = client.data_designer.create(config_builder, num_records=100, wait_until_done=True)\n", + "\n", + "results = job.download_artifacts()\n", + "dataset = results.load_dataset()\n", + "print(f\"Generated dataset shape: {dataset.shape}\")\n", + "dataset.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🔍 Filter & save\n", + "\n", + "Filter by clinical coherence, extract complexity levels, and save to `triage_dataset.jsonl`.\n", + "\n", + "Use `preview.dataset` if running preview only, or `dataset` from the batch job above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ast\n", + "\n", + "def filter_by_scores(df, min_coherence=1, samples_per_complexity=100):\n", + " indices = []\n", + " for idx, k in enumerate(df['triage_note_quality']):\n", + " if isinstance(k, str):\n", + " try:\n", + " k_dict = ast.literal_eval(k)\n", + " except Exception:\n", + " continue\n", + " else:\n", + " k_dict = k\n", + " try:\n", + " coherence_score = int(k_dict['clinical coherence']['score'])\n", + " if coherence_score >= min_coherence:\n", + " indices.append(idx)\n", + " except Exception:\n", + " continue\n", + " filtered_df = df.iloc[indices]\n", + " filtered_df = filtered_df[[\"esi_level_description\", \"content\", \"triage_note_quality\"]].copy()\n", + " filtered_df['esi_level_complexity'] = filtered_df['triage_note_quality'].apply(\n", + " lambda k: (ast.literal_eval(k) if isinstance(k, str) else k).get('esi level complexity', {}).get('score')\n", + " )\n", + " filtered_df.drop(columns=['triage_note_quality'], inplace=True)\n", + " percent_filtered = 100 * len(filtered_df) / len(df) if len(df) > 0 else 0\n", + " print(f\"Filtered {len(filtered_df)} out of {len(df)} records ({percent_filtered:.1f}%)\")\n", + " sampled_df = (\n", + " filtered_df\n", + " .groupby('esi_level_complexity', group_keys=False)\n", + " .apply(lambda x: x.sample(min(len(x), samples_per_complexity), random_state=42))\n", + " .reset_index(drop=True)\n", + " )\n", + " print(f\"Sampled {len(sampled_df)} records total, {samples_per_complexity} (or less) per complexity level.\")\n", + " return sampled_df\n", + "\n", + "filtered_df = filter_by_scores(dataset, samples_per_complexity=100)\n", + "filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 💾 Save to disk — input for Part 2\n", + "OUTPUT_FILE = \"triage_dataset.jsonl\"\n", + "filtered_df.to_json(OUTPUT_FILE, orient=\"records\", lines=True)\n", + "print(f\"Saved {len(filtered_df)} records to {OUTPUT_FILE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 2 — 📊 NeMo Evaluator\n", + "\n", + "Evaluate model predictions on filtered triage notes by checking if the predicted ESI level description is contained in the ground-truth label.\n", + "\n", + "- Dataset: JSONL file (\"triage_dataset.jsonl\") loaded from disk (HF format compatible)\n", + "- Task: Model receives a triage note and must generate only the ESI level description as output\n", + "- Metric: Consider a prediction correct if the model's output string is found within the true \"esi_level_description\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📂 Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "INPUT_FILE = \"triage_dataset.jsonl\"\n", + "df = pd.read_json(INPUT_FILE, lines=True)\n", + "print(f\"Loaded {len(df)} records from {INPUT_FILE}\")\n", + "print(f\"\\nESI distribution:\\n{df['esi_level_description'].value_counts().to_string()}\")\n", + "print(f\"\\nComplexity distribution:\\n{df['esi_level_complexity'].value_counts().to_string()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ⚙️ Evaluator config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ESI_LEVELS = [\n", + " \"ESI 1: Resuscitation\",\n", + " \"ESI 2: Emergency\",\n", + " \"ESI 3: Urgent\",\n", + " \"ESI 4: Less Urgent\",\n", + " \"ESI 5: Non-urgent\",\n", + "]\n", + "\n", + "MODELS = {\n", + " \"GPT-OSS 20B\": \"openai-gpt-oss-20b\",\n", + " \"GPT-OSS 120B\": \"openai-gpt-oss-120b\",\n", + "}\n", + "\n", + "ESI_SYSTEM_PROMPT = (\n", + " \"You are an expert ER triage nurse. Classify the following triage note \"\n", + " \"into one of the five Emergency Severity Index (ESI) levels: \"\n", + " f\"{', '.join([repr(level) for level in ESI_LEVELS])}. \"\n", + " \"Analyze patient acuity, resource needs, and risk of rapid deterioration. \"\n", + " \"Respond with ONLY the ESI level description, exactly matching one of the listed levels.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🚀 Run inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def classify_batch(client, model_id, df):\n", + " \"\"\"Classify all triage notes with a given model via NMP gateway.\"\"\"\n", + " predictions = []\n", + " for i, (_, row) in enumerate(df.iterrows()):\n", + " response = client.inference.gateway.openai.post(\n", + " trailing_uri=\"v1/chat/completions\",\n", + " body={\n", + " \"model\": model_id,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": ESI_SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": f\"Triage Note: {row['content']}\"},\n", + " ],\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0.1,\n", + " },\n", + " )\n", + " try:\n", + " output = response[\"choices\"][0][\"message\"][\"content\"] or \"\"\n", + " except (KeyError, IndexError, TypeError):\n", + " output = \"\"\n", + " predictions.append(output)\n", + " if (i + 1) % 10 == 0 or i == len(df) - 1:\n", + " print(f\"\\r Progress: {i+1}/{len(df)}\", end=\"\", flush=True)\n", + " print()\n", + " return predictions\n", + "\n", + "all_predictions = {}\n", + "for label, model_id in MODELS.items():\n", + " print(f\"Running: {label} ({model_id})\")\n", + " all_predictions[label] = classify_batch(client, model_id, df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ✅ Score with NMP string-check metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_platform.types.evaluation import StringCheckMetric\n", + "\n", + "ACCURACY_METRIC = StringCheckMetric(\n", + " type=\"string-check\",\n", + " left_template=\"{{output_text}}\",\n", + " operation=\"contains\",\n", + " right_template=\"{{esi_level_description}}\",\n", + " description=\"Model output contains the correct ESI level\",\n", + ")\n", + "\n", + "def score_predictions(client, df, predictions, batch_size=10):\n", + " \"\"\"Score predictions using NMP inline metric-evaluate. Returns list of bools.\"\"\"\n", + " eval_rows = [\n", + " {\"esi_level_description\": gt, \"output_text\": pred}\n", + " for gt, pred in zip(df[\"esi_level_description\"], predictions)\n", + " ]\n", + " all_scores = []\n", + " for start in range(0, len(eval_rows), batch_size):\n", + " batch = eval_rows[start : start + batch_size]\n", + " result = client.evaluation.metrics.evaluate(\n", + " dataset={\"rows\": batch},\n", + " metric=ACCURACY_METRIC,\n", + " )\n", + " all_scores.extend(result.row_scores)\n", + " return [\n", + " (rs.scores.get(\"string-check\", 0) == 1.0 if rs.scores else False)\n", + " for rs in all_scores\n", + " ]\n", + "\n", + "model_scores = {}\n", + "for label, preds in all_predictions.items():\n", + " correct = score_predictions(client, df, preds)\n", + " acc = sum(correct) / len(correct) * 100\n", + " model_scores[label] = correct\n", + " print(f\"{label}: {acc:.1f}% ({sum(correct)}/{len(correct)})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📊 Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "from matplotlib.patches import Patch\n", + "\n", + "results = df[[\"esi_level_description\", \"esi_level_complexity\"]].copy()\n", + "for label, correct in model_scores.items():\n", + " results[label] = correct\n", + "\n", + "model_names = list(model_scores.keys())\n", + "\n", + "print(\"=== Accuracy by ESI Level (%) ===\")\n", + "by_level = (results.groupby(\"esi_level_description\")[model_names].mean() * 100).round(1)\n", + "display(by_level)\n", + "\n", + "print(\"\\n=== Accuracy by Complexity (%) ===\")\n", + "by_complexity = (results.groupby(\"esi_level_complexity\")[model_names].mean() * 100).round(1)\n", + "display(by_complexity)\n", + "\n", + "results.to_json(\"triage_eval_results.jsonl\", orient=\"records\", lines=True)\n", + "print(f\"\\nSaved to triage_eval_results.jsonl\")\n", + "\n", + "def plot_accuracy_by_esi_level(by_level, model_names, LEVEL_ORDER, model_palette=None):\n", + " \"\"\"Plot model accuracy by ESI level as a grouped bar chart.\"\"\"\n", + " sns.set(style=\"whitegrid\")\n", + " by_level_ordered = by_level.reindex(LEVEL_ORDER)\n", + " data_plot = by_level_ordered.reset_index().melt(\n", + " id_vars=\"esi_level_description\", var_name=\"Model\", value_name=\"Accuracy\"\n", + " )\n", + " if model_palette is None:\n", + " # fallback palette if None provided\n", + " palette_colors = [\"#76b900\", \"#255c2e\"] + [\"#b5b5b5\"] * (len(model_names)-2)\n", + " model_palette = {m: c for m, c in zip(model_names, palette_colors)}\n", + " fig, ax = plt.subplots(figsize=(7, 3))\n", + " for i, esi_level in enumerate(LEVEL_ORDER):\n", + " row = data_plot[data_plot[\"esi_level_description\"] == esi_level]\n", + " y_base = i * (len(model_names) + 1)\n", + " for j, model in enumerate(model_names):\n", + " acc = row[row[\"Model\"] == model][\"Accuracy\"].values\n", + " if len(acc) == 0:\n", + " continue\n", + " color = model_palette.get(model, \"#b5b5b5\")\n", + " ax.barh(\n", + " y_base + j, acc[0],\n", + " color=color, height=0.8, edgecolor=\"none\",\n", + " label=model if (i == 0) else None,\n", + " )\n", + " ax.set_yticks([(i * (len(model_names) + 1)) + 0.5 for i in range(len(LEVEL_ORDER))])\n", + " ax.set_yticklabels(LEVEL_ORDER)\n", + " ax.invert_yaxis()\n", + " ax.set_xlim(left=0, right=105)\n", + " ax.set_xlabel(\"Accuracy (%)\")\n", + " ax.set_title(\"Model Accuracy by ESI Level\", fontsize=12, pad=10, loc=\"left\", weight=\"bold\")\n", + " sns.despine(left=True, bottom=True)\n", + " ax.xaxis.grid(True, linestyle=\"--\", linewidth=0.8, alpha=0.2)\n", + " handles = [Patch(facecolor=model_palette[m], label=m) for m in model_names]\n", + " ax.legend(handles=handles, title=\"Model\", loc=\"upper left\", bbox_to_anchor=(1, 1), borderaxespad=0.0)\n", + " plt.tight_layout(rect=[0, 0, 0.83, 1])\n", + " plt.show()\n", + "\n", + "# Define LEVEL_ORDER and model_palette as before\n", + "LEVEL_ORDER = ESI_LEVELS\n", + "model_palette = {model_names[0]: \"#76b900\", model_names[1]: \"#255c2e\"}\n", + "\n", + "# Call the plotting function\n", + "plot_accuracy_by_esi_level(by_level, model_names, LEVEL_ORDER, model_palette)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/nemo/data-flywheel/e2e-llm-evaluation/requirements.txt b/nemo/data-flywheel/e2e-llm-evaluation/requirements.txt new file mode 100644 index 000000000..5736c3e2d --- /dev/null +++ b/nemo/data-flywheel/e2e-llm-evaluation/requirements.txt @@ -0,0 +1,4 @@ +python-json-logger>=3,<4 +matplotlib +ipykernel +seaborn diff --git a/nemo/data-flywheel/e2e-llm-evaluation/triage_pipeline_diagram.png b/nemo/data-flywheel/e2e-llm-evaluation/triage_pipeline_diagram.png new file mode 100644 index 000000000..b79d59d3e Binary files /dev/null and b/nemo/data-flywheel/e2e-llm-evaluation/triage_pipeline_diagram.png differ