Skip to content

Latest commit

 

History

History
1121 lines (745 loc) · 42.2 KB

File metadata and controls

1121 lines (745 loc) · 42.2 KB

Chat-Haruhi-Suzumiya\kyon_generator\dataset.py 详细设计文档

该代码通过Hugging Face Hub加载两个角色对话数据集(Chat_Suzumiya_Fusion和Chat_Suzumiya_Fusion_B),使用concatenate_datasets函数合并后上传到新的Hub仓库,实现多数据集的融合与共享。

整体流程

graph TD
A[开始] --> B[登录Hugging Face Hub]
B --> C[从Hub加载数据集A: silk-road/Chat_Suzumiya_Fusion]
C --> D[从Hub加载数据集B: silk-road/Chat_Suzumiya_Fusion_B]
D --> E[提取train子集]
E --> F[使用concatenate_datasets合并数据集]
F --> G[上传合并后的数据集到Hub: silk-road/Chat-Haruhi-Fusion-A_B]
G --> H[结束]
Loading

类结构

模块级
├── read_jsonl_file (全局函数)
├── collate_fn (全局函数)
├── CharacterDataset (自定义Dataset类)
│   ├── __init__
│   ├── __len__
│   ├── __getitem__
│   ├── join_with_limit
│   ├── read_file_content
│   ├── read_jsonl_and_convert_to_tensor
│   ├── getSystemPrompt
│   └── getMemory

全局变量及字段

HF_TOKEN

HuggingFace API token for authentication

类型:str

CharacterDataset.data

List of JSON objects loaded from dialogue JSONL files

类型:List[dict]

CharacterDataset.character_path

Path to the character directory containing character-specific files

类型:str

CharacterDataset.memory_number

Number of top-k memory entries to retrieve based on similarity

类型:int

CharacterDataset.memory_path

Relative path to the memory embedding JSONL file (jsonl/title_text_embed.jsonl)

类型:str

CharacterDataset.system_prompt_name

Filename of the system prompt file (system_prompt.txt)

类型:str

CharacterDataset.system_prompt

Content of the character-specific system prompt

类型:str

CharacterDataset.memory_embed

Tensor containing pre-computed embeddings for character memories

类型:torch.Tensor

CharacterDataset.memory_text

List of text content corresponding to memory embeddings

类型:List[str]

CharacterDataset.memory_length

Maximum character length for concatenated retrieved memories

类型:int

全局函数及方法

read_jsonl_file

该函数是一个用于读取JSONL(JSON Lines)文件的工具函数,它逐行读取文件内容,解析每行的JSON数据,并将所有有效的JSON对象存储在列表中返回,同时具备基本的错误处理能力以跳过格式不正确的行。

参数:

  • file_pathstr,要读取的JSONL文件的路径

返回值:List[dict],包含文件中所有成功解析的JSON对象的列表

流程图

flowchart TD
    A[开始] --> B[初始化空列表 data]
    B --> C[打开文件 file_path]
    C --> D[读取下一行]
    D --> E{是否还有更多行?}
    E -->|否| F[返回 data 列表]
    E -->|是| G[去除行首尾空白]
    G --> H{行是否为空?}
    H -->|是| D
    H -->|否| I[尝试解析JSON]
    I --> J{解析是否成功?}
    J -->|是| K[将JSON对象添加到data列表]
    K --> D
    J -->|否| L[跳过该行]
    L --> D
    F --> M[结束]
Loading

带注释源码

def read_jsonl_file(file_path):
    """
    读取JSONL文件并解析其中的JSON对象
    
    参数:
        file_path: 要读取的JSONL文件路径
        
    返回值:
        包含文件中所有JSON对象的列表
    """
    # 初始化一个空列表用于存储解析后的JSON数据
    data = []
    
    # 以只读模式打开文件,使用UTF-8编码
    with open(file_path, 'r', encoding='utf-8') as f:
        # 遍历文件的每一行
        for line in f:
            # 去除行首尾的空白字符
            line = line.strip()
            
            # 只处理非空行
            if line:
                try:
                    # 尝试将行解析为JSON格式
                    json_data = json.loads(line)
                    # 解析成功,将JSON对象添加到列表中
                    data.append(json_data)
                except json.JSONDecodeError:
                    # 解析失败(JSON格式不正确),跳过该行
                    # 可以选择打印错误信息用于调试
                    # print(f"Failed to parse JSON: {line}")
                    continue
    
    # 返回包含所有JSON对象的列表
    return data

关键组件信息

名称 一句话描述
json 模块 Python标准库,用于JSON数据的解析和编码
jsonlines 模块 后续代码中使用的JSONL文件读取库,可逐行高效读取
CharacterDataset 自定义PyTorch数据集,用于加载角色对话数据

潜在的技术债务或优化空间

  1. 缺少文件存在性检查:函数未检查文件路径是否存在或是否为有效文件,可能导致FileNotFoundError
  2. 编码硬编码:UTF-8编码被硬编码,无法处理其他编码格式的JSONL文件
  3. 错误处理不够细致:仅捕获JSONDecodeError,对于其他可能的IO错误没有处理
  4. 大文件处理效率:对于非常大的JSONL文件,一次性加载所有数据到内存中可能导致内存问题,应考虑使用生成器模式或流式处理
  5. 返回空列表的歧义:当文件不存在或所有行都解析失败时,都返回空列表,调用者无法区分具体情况

其它项目

设计目标与约束

  • 设计目标:提供一个简单可靠的工具函数,用于将JSONL文件转换为Python列表数据结构
  • 约束:仅处理标准的JSONL格式(每行一个独立的JSON对象)

错误处理与异常设计

  • 当前仅捕获json.JSONDecodeError异常,跳过格式错误的行
  • 建议增强:可添加文件读取异常的处理,如FileNotFoundErrorPermissionError
  • 建议增强:可提供错误回调机制,让调用者决定如何处理解析失败的行

外部依赖与接口契约

  • 依赖:Python标准库jsonbuiltins.open
  • 接口契约:调用方需传入有效的文件路径,返回值为列表类型

collate_fn

该函数是PyTorch DataLoader的回调函数,负责将数据集__getitem__返回的单个样本整理成一个批次。它从批次中提取inputanswer字段,分别组成输入列表和目标列表,供后续模型训练使用。

参数:

  • batchList[Dict],从DataLoader获取的批次数据,每个元素是包含inputanswer字段的字典

返回值:Tuple[List[str], List[str]],返回两个字符串列表——第一个是输入列表,第二个是目标答案列表

流程图

flowchart TD
    A[接收batch参数] --> B{检查batch是否为空}
    B -->|是| C[返回空列表元组]
    B -->|否| D[遍历batch中的每个sample]
    D --> E[提取sample的input字段]
    D --> F[提取sample的answer字段]
    E --> G[构建inputs列表]
    F --> H[构建targets列表]
    G --> I[返回inputs和targets元组]
    H --> I
Loading

带注释源码

def collate_fn(batch):
    """
    DataLoader的整理函数,将单个样本整理成批次
    
    参数:
        batch: DataLoader返回的样本列表,每个样本是包含
               'input'和'answer'字段的字典
    
    返回:
        (inputs, targets): 输入列表和目标答案列表的元组
    """
    # 从每个样本中提取input字段,构建输入列表
    inputs = [sample["input"] for sample in batch]
    
    # 从每个样本中提取answer字段,构建目标列表
    targets = [sample["answer"] for sample in batch]

    # 注释: 以下是使用torch.stack的原始实现,但需要tensor类型
    # 当前实现直接返回列表,兼容字符串类型的数据
    # batch_inputs = torch.stack(inputs)
    # batch_targets = torch.stack(targets)

    # 返回输入和目标列表元组
    return inputs, targets

login

该函数是 Hugging Face Hub 库提供的身份验证函数,用于将本地的 Hugging Face 访问令牌(Token)进行登录认证,以便后续能够访问私有模型、数据集或其他需要认证的资源。

参数:

  • tokenstr,Hugging Face 的访问令牌(Personal Access Token),用于身份验证
  • add_to_git_credentialOptional[bool],可选参数,是否将令牌添加到 git 凭证中,默认为 False
  • quietOptional[bool],可选参数,是否静默执行,默认为 False

返回值:None,该函数无返回值,执行成功后直接完成登录状态

流程图

graph TD
    A[开始 login] --> B{检查 token 是否有效}
    B -->|无效| C[抛出异常: 无效的 token]
    B -->|有效| D{add_to_git_credential 为 True?}
    D -->|是| E[将 token 添加到 git credential]
    D -->|否| F{quiet 为 True?}
    F -->|是| G[静默完成登录]
    F -->|否| H[输出登录成功信息]
    E --> I[完成登录状态]
    G --> I
    H --> I
    I[结束 login]
Loading

带注释源码

# 从 huggingface_hub 库导入 login 函数
from huggingface_hub import login

# 定义 Hugging Face 访问令牌
# 该令牌用于身份验证,可从 https://huggingface.co/settings/tokens 获取
HF_TOKEN = "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"

# 调用 login 函数进行身份验证
# 参数 token: 传入访问令牌进行认证
# 作用: 将本地环境与 Hugging Face 账号绑定,后续可访问私有模型/数据集
login(token=HF_TOKEN)

注意:该 login 函数来源于外部库 huggingface_hub,并非在本项目代码中定义。以上信息基于 Hugging Face Hub 官方文档和常见用法。如需查看完整源码实现,请参考 huggingface_hub 官方仓库

load_dataset 函数提取结果

经过分析,代码中 load_dataset 是从 Hugging Face datasets 库导入的外部函数,并非在该代码文件中定义。以下是在代码中的使用情况:

load_dataset

从 Hugging Face datasets 库导入的函数,用于从 Hub 加载数据集。

参数:

  • pathstr,Hugging Face Hub 上的数据集名称或本地路径

返回值:datasets.DatasetDict,返回包含数据集的字典,通常包含 'train'、'test' 等拆分

流程图

graph TD
    A[调用 load_dataset] --> B{检查本地缓存}
    B -->|本地无缓存| C[从 Hugging Face Hub 下载数据集]
    B -->|本地有缓存| D[从本地加载数据集]
    C --> E[返回 DatasetDict 对象]
    D --> E
Loading

带注释源码

# 从 Hugging Face datasets 库导入 load_dataset 函数
# 这是一个外部依赖函数,非本代码文件定义
from datasets import load_dataset, Dataset

# 加载两个预训练数据集
# 参数:数据集在 Hugging Face Hub 上的路径
train_dataset_dict_A = load_dataset('silk-road/Chat_Suzumiya_Fusion')
train_dataset_dict_B = load_dataset('silk-road/Chat_Suzumiya_Fusion_B')

# 从 DatasetDict 中提取 'train' 拆分
train_dataset_A = train_dataset_dict_A['train']
train_dataset_B = train_dataset_dict_B['train']

# 合并两个数据集
train_dataset = concatenate_datasets([train_dataset_A, train_dataset_B])

# 将合并后的数据集推送到 Hub
train_dataset.push_to_hub("silk-road/Chat-Haruhi-Fusion-A_B")

注意:代码中未定义 load_dataset 函数,而是从 datasets 库导入的外部函数。如需了解该函数的完整文档,建议参考 Hugging Face datasets 库官方文档

concatenate_datasets

该函数用于将多个 Hugging Face Dataset 对象沿第一个轴(行方向)拼接成一个数据集,是 Hugging Face datasets 库提供的工具函数。

参数:

  • datasetsList[Dataset],要拼接的数据集列表

返回值:Dataset,拼接后的数据集对象

流程图

flowchart TD
    A[开始] --> B[加载数据集A silk-road/Chat_Suzumiya_Fusion]
    B --> C[加载数据集B silk-road/Chat_Suzumiya_Fusion_B]
    C --> D[提取train数据集]
    D --> E[调用concatenate_datasets合并]
    E --> F[推送合并后的数据集到Hub]
    F --> G[结束]
Loading

带注释源码

# 从HuggingFace datasets库导入concatenate_datasets函数
from datasets import load_dataset, concatenate_datasets

# 加载两个预训练数据集
# 数据集A: 包含角色对话数据
train_dataset_dict_A = load_dataset('silk-road/Chat_Suzumiya_Fusion')
# 数据集B: 包含另一个角色对话数据集合
train_dataset_dict_B = load_dataset('silk-road/Chat_Suzumiya_Fusion_B')

# 从DatasetDict中提取训练集
train_dataset_A = train_dataset_dict_A['train']
train_dataset_B = train_dataset_dict_B['train']

# 使用concatenate_datasets函数沿行方向拼接两个数据集
# 输入: [dataset_A, dataset_B] - 需要拼接的数据集列表
# 输出: 合并后的单一数据集对象
train_dataset = concatenate_datasets([train_dataset_A, train_dataset_B])

# 将合并后的数据集推送到HuggingFace Hub
train_dataset.push_to_hub("silk-road/Chat-Haruhi-Fusion-A_B")

get_embedding

该函数是从 utils 模块导入的,但在当前提供的代码文件中仅作为导入语句出现,没有实际的实现代码。根据代码上下文推断,该函数可能用于获取文本的嵌入向量(embedding),但在当前代码中未被调用。

参数:

  • 无法从当前代码中确定参数信息(函数未在此文件中定义)

返回值:

  • 无法从当前代码中确定返回值信息(函数未在此文件中定义)

流程图

graph TD
    A[开始] --> B[调用 utils.get_embedding]
    B --> C[函数定义未知]
    C --> D[结束]
Loading

带注释源码

# 当前代码文件中仅包含导入语句
from utils import get_embedding, download_models

# get_embedding 函数的实际实现在 utils 模块中
# 但该模块未在当前代码文件中提供
# 根据函数名推测,该函数可能用于:
# 1. 接收文本输入
# 2. 使用预训练模型(如 BERT, Sentence-BERT 等)生成文本嵌入向量
# 3. 返回嵌入向量供后续相似度计算或检索使用

注意:当前提供的代码文件中,get_embedding 函数仅作为导入语句出现,没有实际的函数定义。该函数的具体参数、返回值和实现逻辑需要查看 utils 模块的源代码。

函数提取结果

download_models

该函数为从 utils 模块导入的外部函数,用于下载所需的预训练模型。根据代码上下文推断,该函数应负责下载 Hugging Face 上的预训练语言模型(如 LLM、embedding 模型等),以支持后续的推理或微调任务。

注意:该函数的完整实现未在当前代码文件中定义,而是通过 from utils import download_models 导入。以下信息基于代码上下文推断得出。


参数:

  • (无法从当前代码中确定具体参数,推测可能包含模型名称、存储路径等参数)

返回值:

  • (无法从当前代码中确定具体返回值,推测可能返回下载的模型对象或模型路径)

流程图

由于该函数未在当前代码中实现,无法绘制详细的流程图。基于函数名称和项目上下文,推测其流程如下:

flowchart TD
    A[开始下载模型] --> B{检查模型是否已存在}
    B -->|是| C[返回本地模型路径]
    B -->|否| D[从Hugging Face Hub下载模型]
    D --> E{下载是否成功}
    E -->|成功| F[保存到指定目录]
    E -->|失败| G[抛出异常或重试]
    F --> H[返回模型路径/对象]
Loading

带注释源码

# 该函数定义在 utils 模块中,当前代码文件仅导入该函数
from utils import get_embedding, download_models

# 函数调用示例(在当前文件中未实际调用)
# download_models(...)  
# 推测该函数可能的使用方式:
# - 参数:model_name, save_path, tokenizer 等
# - 返回:model, tokenizer 或模型路径字符串

补充说明

外部依赖

  • utils 模块:包含 get_embeddingdownload_models 两个工具函数
  • huggingface_hub:用于模型下载和数据集访问

技术债务/优化空间

  1. 代码组织问题download_models 函数的实现未被包含在当前代码中,导致无法进行详细分析
  2. 硬编码凭证:HuggingFace Token 直接以字符串形式硬编码在代码中(HF_TOKEN),存在安全风险
  3. 缺乏错误处理:未看到对模型下载失败情况的处理逻辑

建议

若需要完整的函数分析,建议提供 utils.py 文件的内容,或在项目中查看该函数的实际实现代码。

CharacterDataset.__init__

这是 CharacterDataset 类的构造函数,用于初始化角色对话数据集。该方法接收训练数据、角色路径、记忆数量和记忆长度等参数,加载系统提示词和角色的记忆嵌入向量,为后续的数据检索和模型训练做好准备。

参数:

  • self:隐式参数,表示实例本身
  • json_dataList[Dict],从 JSONL 文件读取的对话数据列表,每个元素包含 query、answer、chat_history、embedding、source 等字段
  • character_pathstr,角色文件夹的路径,用于定位系统提示词和记忆向量文件
  • memory_numberint,从记忆库中检索的最相似记忆数量,用于构建上下文
  • memory_lengthint,检索记忆的最大字符长度限制,用于控制上下文长度

返回值:None,构造函数无返回值

流程图

graph TD
    A[开始 __init__] --> B[接收参数: json_data, character_path, memory_number, memory_length]
    B --> C[设置 self.data = json_data]
    C --> D[设置 self.character_path = character_path]
    D --> E[设置 self.memory_number = memory_number]
    E --> F[设置 self.memory_path = 'jsonl/title_text_embed.jsonl']
    F --> G[设置 self.system_prompt_name = 'system_prompt.txt']
    G --> H[调用 getSystemPrompt 读取系统提示词]
    H --> I[设置 self.system_prompt]
    I --> J[调用 read_jsonl_and_convert_to_tensor 加载记忆嵌入]
    J --> K[设置 self.memory_embed 和 self.memory_text]
    K --> L[设置 self.memory_length = memory_length]
    L --> M[结束 __init__]
Loading

带注释源码

def __init__(self, json_data, character_path, memory_number, memory_length):
    """
    初始化 CharacterDataset 实例
    
    参数:
        json_data: 从 JSONL 文件解析的对话数据列表
        character_path: 角色相关文件的根目录路径
        memory_number: 每次检索返回的记忆条数
        memory_length: 记忆文本的最大字符长度
    """
    pass  # 占位符,可能是之前预留的代码位置
    
    # 存储原始对话数据
    self.data = json_data
    
    # 存储角色路径,后续用于加载系统提示词和记忆文件
    self.character_path = character_path
    
    # 存储要检索的记忆数量,用于 getMemory 方法
    self.memory_number = memory_number
    
    # 记忆文件的相对路径(相对于 character_path)
    self.memory_path = "jsonl/title_text_embed.jsonl"
    
    # 系统提示词文件名
    self.system_prompt_name = "system_prompt.txt"
    
    # 读取并存储系统提示词内容
    # 系统提示词定义了角色的行为方式和对话风格
    self.system_prompt = self.getSystemPrompt()
    
    # 加载角色记忆的嵌入向量和原始文本
    # memory_embed: 用于计算相似度的向量矩阵
    # memory_text: 记忆的原始文本内容
    self.memory_embed, self.memory_text = self.read_jsonl_and_convert_to_tensor(
        os.path.join(self.character_path, self.memory_path)
    )
    
    # 存储记忆文本的最大长度限制
    # 用于在 join_with_limit 方法中截断过长的记忆
    self.memory_length = memory_length

CharacterDataset.__len__

该方法为 PyTorch Dataset 类的标准长度方法,返回数据集中包含的样本总数,供 DataLoader 等工具在迭代时获取数据集大小。

参数:无(self 为隐式参数,表示类实例本身)

返回值:int,返回数据集 self.data 中包含的样本数量

流程图

flowchart TD
    A[开始 __len__] --> B[返回 self.data 的长度]
    B --> C[结束]
Loading

带注释源码

def __len__(self):
    """
    返回数据集的样本数量
    
    Returns:
        int: 数据集包含的样本数量
    """
    return len(self.data)  # 返回存储在 self.data 中的数据列表的长度

CharacterDataset.__getitem__

该方法是PyTorch数据集类的核心接口,根据给定索引从数据集中获取样本,并通过检索增强生成模型输入。它首先根据查询的embedding向量从记忆库中检索最相关的记忆,然后拼接系统提示、检索记忆、聊天历史和当前查询,形成完整的模型输入上下文。

参数:

  • idxint,要获取的样本索引,用于从数据集中定位特定的训练样本

返回值:dict,包含预处理后训练数据的字典,包含query、system_prompt、retrieved_memory、chat_history、answer、embedding、source和input等字段

流程图

flowchart TD
    A[开始 __getitem__] --> B[根据 idx 从 self.data 获取样本]
    B --> C[将样本的 embedding 转换为 PyTorch Tensor]
    C --> D[调用 getMemory 检索相关记忆]
    D --> E[从 memory_text 中提取 top_k 对应的记忆文本]
    E --> F[使用 join_with_limit 限制记忆长度]
    F --> G[拼接 chat_history 为字符串]
    G --> H[拼接 system_prompt + retrieved_memory + chat_history + query]
    H --> I[构建返回字典 data]
    I --> J[返回 data 字典]
Loading

带注释源码

def __getitem__(self, idx):
    # 根据索引从数据列表中获取原始样本
    sample = self.data[idx]
    
    # 将样本中的 embedding 转换为 PyTorch 张量
    query_embed = torch.tensor(sample["embedding"])
    
    # 根据查询向量从记忆库中检索最相似的记忆
    top_k = self.getMemory(query_embed)
    
    # 从记忆文本列表中提取 top_k 索引对应的记忆内容
    # 记忆文本格式为 "key|||value",这里取 value 部分
    retrieved_memory = [self.memory_text[i].split("|||")[1] for i in top_k]
    
    # 使用 join_with_limit 方法限制记忆的总长度
    retrieved_memory = self.join_with_limit(retrieved_memory, self.memory_length)
    
    # 获取聊天历史并使用 "###" 作为分隔符拼接
    chat_history = sample["chat_history"]
    chat_history = '###'.join(chat_history)
    chat_history += '###'
    
    # 获取当前查询
    query = sample["query"]
    
    # 拼接完整的输入:系统提示 + 检索记忆 + 聊天历史 + 查询
    input = self.system_prompt + "###" + retrieved_memory + chat_history + query
    
    # 构建返回的数据字典,包含模型训练所需的全部信息
    data = {
        "query": sample["query"],                    # 原始查询文本
        "system_prompt": self.system_prompt,        # 系统提示词
        "retrieved_memory": retrieved_memory,        # 检索到的记忆内容
        "chat_history": chat_history,                # 历史对话记录
        "answer": sample["answer"],                 # 期望的模型回答
        "embedding": query_embed,                    # 查询的向量表示
        "source": sample["source"],                  # 数据来源标识
        "input": input                               # 完整的模型输入
    }
    return data

CharacterDataset.join_with_limit

该方法用于将字符串列表中的元素逐个连接到一个结果字符串中,直到达到指定的最大长度限制为止。如果加入当前元素会导致超过最大长度,则停止添加剩余元素。

参数:

  • self:CharacterDataset,类的实例本身
  • itemsList[str],需要连接的字符串列表
  • max_lengthint,连接后的最大长度限制
  • separatorstr,连接元素时使用的分隔符,默认为"###"

返回值:str,连接后的字符串,当总长度不超过 max_length 时返回所有items的连接结果,否则返回满足长度限制的前缀部分

流程图

flowchart TD
    A[开始 join_with_limit] --> B[初始化 result = ""]
    B --> C{遍历 items 中的每个 item}
    C --> D{检查加入当前item是否超过max_length}
    D -->|是| E{result是否为空}
    D -->|否| F[将separator加到result]
    F --> G[将item加到result]
    G --> C
    E -->|否| H[直接添加item到result]
    H --> G
    E -->|是| I[break 退出循环]
    I --> J[返回 result]
    C -->|遍历完成| J
Loading

带注释源码

def join_with_limit(self, items, max_length, separator="###"):
    """
    将字符串列表中的元素逐个连接,直到达到最大长度限制
    
    参数:
        items: 需要连接的字符串列表
        max_length: 最大长度限制
        separator: 分隔符,默认为"###"
    
    返回:
        连接后的字符串,不超过max_length
    """
    result = ""  # 初始化结果字符串为空
    for item in items:  # 遍历每个元素
        # 如果加入当前元素不会导致超过上限,就将其添加到结果字符串中
        if len(result) + len(item) + len(separator) <= max_length:
            if result:  # 如果结果非空,先添加分隔符
                result += separator
            result += item  # 添加当前元素
        else:
            break  # 如果已经超过上限,就停止添加新元素
    return result  # 返回连接后的结果

CharacterDataset.read_file_content

该方法用于读取指定路径的文本文件内容,并返回完整的字符串形式。

参数:

  • file_pathstr,需要读取的文件的路径(包含文件名)

返回值:str,返回文件的全部内容

流程图

flowchart TD
    A[开始] --> B[打开文件<br/>file_path with 'r' mode<br/>encoding='utf-8']
    B --> C[读取文件全部内容<br/>f.read()]
    C --> D[返回文件内容<br/>content]
    D --> E[结束]
Loading

带注释源码

def read_file_content(self, file_path):
    """
    读取指定路径的文件内容
    
    参数:
        file_path: str, 需要读取的文件的路径
    
    返回:
        str, 文件的完整内容
    """
    # 使用 utf-8 编码打开文件,'r' 模式表示只读
    with open(file_path, 'r', encoding='utf-8') as f:
        # 读取文件的全部内容并存储到 content 变量
        content = f.read()
    # 返回读取到的文件内容
    return content

CharacterDataset.read_jsonl_and_convert_to_tensor

该方法用于读取指定路径的 JSONL 格式记忆文件,解析其中的嵌入向量数据和文本数据,将嵌入向量列表转换为 PyTorch 张量,并与原始文本数据一起返回。

参数:

  • file_pathstr,要读取的 JSONL 文件路径,该文件包含记忆的嵌入向量和对应文本

返回值:tuple,返回一个元组,包含 torch.Tensor 类型的嵌入向量张量和 list 类型的文本数据列表

流程图

flowchart TD
    A[开始] --> B[初始化空列表 embed_list 和 data_list]
    B --> C[使用 jsonlines 打开文件]
    C --> D{遍历 reader 中的每一项}
    D -->|是| E[获取 item 的值并添加到 embed_list]
    E --> F[获取 item 的键并添加到 data_list]
    F --> D
    D -->|否| G[将 embed_list 转换为 torch.Tensor]
    G --> H[返回 data_tensor 和 data_list]
    H --> I[结束]
Loading

带注释源码

def read_jsonl_and_convert_to_tensor(self, file_path):
    """
    读取 JSONL 文件并转换为 PyTorch 张量
    :param file_path: str, JSONL 文件路径
    :return: tuple, (torch.Tensor, list) 嵌入向量张量和文本列表
    """
    # 初始化用于存储嵌入向量的列表
    embed_list = []
    # 初始化用于存储文本数据的列表
    data_list = []
    
    # 使用 jsonlines 库打开文件进行逐行读取
    with jsonlines.open(file_path) as reader:
        # 遍历文件中的每一项
        for item in reader:
            # 获取字典中的第一个值(嵌入向量)并添加到列表
            embed_list.append(next(iter(item.values())))
            # 获取字典中的第一个键(文本数据)并添加到列表
            data_list.append(next(iter(item.keys())))
    
    # 将嵌入向量列表转换为 PyTorch 张量
    data_tensor = torch.tensor(embed_list)
    
    # 返回张量形式的嵌入向量和文本列表
    return data_tensor, data_list

CharacterDataset.getSystemPrompt

该方法用于获取角色扮演系统的系统提示词。它通过拼接角色路径和系统提示文件名,然后读取该文件的内容并返回。

参数:

  • 该方法无显式参数(隐式参数 self 为类的实例)

返回值:str,返回系统提示文件的内容字符串

流程图

flowchart TD
    A[开始 getSystemPrompt] --> B[构建文件路径]
    B --> C{调用 read_file_content}
    C --> D{打开并读取文件}
    D --> E[返回文件内容字符串]
    E --> F[结束]
Loading

带注释源码

def getSystemPrompt(self):
    """
    获取系统提示词内容
    
    该方法通过读取角色目录下预定义的 system_prompt.txt 文件来获取
    角色的系统提示词,用于在对话开始时设定角色的行为特征和背景
    
    Returns:
        str: 系统提示词文件的内容
    """
    # 拼接完整的文件路径:角色目录路径 + 系统提示文件名
    file_path = os.path.join(self.character_path, self.system_prompt_name)
    
    # 调用文件读取方法获取文件内容
    file_content = self.read_file_content(file_path)
    
    # 返回读取到的系统提示词内容
    return file_content

CharacterDataset.getMemory

该方法通过计算输入向量与预存记忆嵌入之间的相似度,利用矩阵乘积找出最相似的记忆条目,并返回其索引,用于检索最相关的记忆内容。

参数:

  • vectortorch.Tensor,输入查询的嵌入向量,用于与记忆库中的嵌入向量进行相似度计算

返回值:torch.Tensor,返回最相似的记忆条目的索引,索引数量由 self.memory_number 决定

流程图

flowchart TD
    A[开始 getMemory] --> B[输入: vector 查询向量]
    B --> C[执行矩阵乘法: similarity = torch.matmul self.memory_embed, vector]
    C --> D[使用 torch.topk 找出最高相似度的 top_indices]
    D --> E[返回 top_indices]
    E --> F[结束]
    
    subgraph 内部状态
        C -.->|存储在| G[similarity: 相似度张量]
        D -.->|存储在| H[top_indices: 最高相似度的索引]
    end
Loading

带注释源码

def getMemory(self, vector):
    """
    根据输入的查询向量,从记忆库中检索最相似的记忆索引
    
    参数:
        vector: torch.Tensor,查询向量,通常是当前对话或问题的embedding
    
    返回:
        torch.Tensor,包含最相似记忆的索引
    """
    pass  # 占位符,无实际作用
    # 计算输入向量与所有记忆嵌入的相似度(点积)
    # self.memory_embed: 所有记忆的嵌入向量 [num_memories, embed_dim]
    # vector: 查询向量 [embed_dim] 或 [batch, embed_dim]
    similarity = torch.matmul(self.memory_embed, vector)

    # 找出最大的几个元素的索引
    # self.memory_number: 要检索的记忆数量
    # 使用 torch.topk 获取相似度最高的 top_memory_number 个记忆的索引
    top_indices = torch.topk(similarity, self.memory_number).indices

    return top_indices  # 返回最相似记忆的索引列表
    # print("Top indices:", top_indices)  # 调试打印(已注释)

关键组件

角色对话数据集加载与处理

该代码实现了一个基于角色记忆检索的对话数据集处理系统,支持从多个JSONL文件加载角色对话数据,通过嵌入向量相似度检索相关记忆,并将处理后的数据合并推送至HuggingFace Hub。

张量索引与惰性加载

代码中通过torch.topk对记忆嵌入向量进行相似度检索,返回top_k个最相似的记忆索引,实现高效的向量相似度匹配。CharacterDataset类在初始化时预加载所有记忆嵌入,在__getitem__时仅按需检索相关记忆,实现惰性加载模式。

反量化支持

虽然代码中未直接使用量化,但通过torch.tensor直接处理嵌入向量,支持浮点型张量运算,为后续量化模型推理提供数据兼容性基础。

量化策略

当前实现未包含量化策略,所有嵌入和特征均以float32存储,可后续通过动态量化或静态量化优化推理性能。

数据流程管道

代码构建了完整的数据管道:read_jsonl_file读取原始数据 -> CharacterDataset处理转换 -> DataLoader批处理 -> concatenate_datasets合并多数据集 -> push_to_hub推送至云端。

记忆检索机制

getMemory方法通过矩阵乘法计算查询向量与记忆库的相似度,使用torch.topk获取最相关的记忆条目,支持自定义检索数量配置。

系统提示词管理

CharacterDataset通过getSystemPrompt方法从文件系统动态加载角色专属的系统提示词,实现角色性格和行为模式的定制化配置。

记忆文本拼接与长度控制

join_with_limit方法实现了带长度约束的记忆文本拼接,确保输入长度不超过指定阈值,避免模型输入溢出。

问题及建议

已知问题

  • 硬编码的敏感信息:HF_TOKEN直接明文写在代码中,存在严重的安全风险
  • 硬编码的路径和配置:memory_path="jsonl/title_text_embed.jsonl"、system_prompt_name="system_prompt.txt"、分隔符"|||"等均硬编码在类中,降低了灵活性
  • 大量未清理的注释代码:代码中存在大量被注释掉的旧逻辑和测试代码(如file_names列表、数据处理代码块),严重影响代码可读性和维护性
  • 未使用或冗余的导入:导入了login但实际未使用;导入了ConcatDataset、nn等但未使用;被注释掉的load_dataset等导入仍然保留
  • 代码风格不一致:混合使用下划线命名(snake_case)和驼峰命名(camelCase),如join_with_limit vs read_jsonl_and_convert_to_tensor
  • 缺乏类型提示:整个代码没有任何类型注解(type hints),降低了代码可维护性和IDE支持
  • 不完整的函数实现:getMemory方法中开头有pass语句但后面还有代码,语法冗余且风格不佳
  • 文件操作缺乏异常处理:read_file_content、read_jsonl_and_convert_to_tensor、getSystemPrompt等文件读写操作没有try-except保护
  • 资源未正确管理:read_jsonl_and_convert_to_tensor使用jsonlines.open()但未使用with语句,可能导致资源泄漏
  • 内存效率问题:read_jsonl_and_convert_to_tensor一次性将所有embedding加载到内存,对于大规模数据集可能导致OOM
  • 数据加载效率低下:collate_fn返回嵌套列表而非 tensors,训练时需要额外转换;embedding在__getitem__中每次重新转换为tensor,未做缓存
  • 缺失的数据校验:__init__和__getitem__中未对json_data的关键字段进行校验,可能导致运行时KeyError
  • 魔法数字:memory_number=8、memory_length=2000等参数直接传入,缺乏配置说明

优化建议

  • 安全改进:将HF_TOKEN移除代码,使用环境变量或配置中心管理敏感信息
  • 配置外部化:将路径、分隔符等配置抽取到配置文件或环境变量,支持不同角色的数据源
  • 代码清理:删除所有注释掉的无效代码块,或移至独立的废弃代码目录
  • 精简导入:移除未使用的导入语句,保持导入列表整洁
  • 统一命名规范:全项目统一使用snake_case命名风格
  • 添加类型提示:为所有函数参数、返回值添加类型注解,提升代码可读性
  • 移除冗余语法:删除getMemory方法开头的pass语句
  • 增强异常处理:为所有文件IO操作添加try-except,优雅处理文件不存在、格式错误等异常
  • 使用上下文管理器:确保所有文件操作使用with语句管理资源
  • 优化内存使用:考虑使用内存映射或分批加载embedding;将embedding在__init__中预先转换为tensor并缓存
  • 优化数据加载:collate_fn中直接返回tensor;使用pin_memory加速GPU数据传输
  • 数据校验:在__init__中校验json_data必需字段,在__getitem__中添加默认值或错误处理
  • 配置参数化:将magic number改为具名常量或配置参数,增加代码可读性

其它

设计目标与约束

本代码旨在构建一个用于训练角色扮演对话模型的PyTorch数据集处理流水线。核心目标是将多个角色的对话数据(JSONL格式)进行处理,生成包含系统提示、角色记忆、聊天历史和查询的完整训练样本。主要约束包括:1) 内存嵌入向量维度必须与预训练模型一致;2) 检索的记忆数量(memory_number)默认为8条;3) 记忆文本最大长度(memory_length)限制为2000字符;4) 数据来源为HuggingFace Hub上的两个数据集(silk-road/Chat_Suzumiya_Fusion 和 silk-road/Chat_Suzumiya_Fusion_B)。

错误处理与异常设计

代码中的错误处理主要包括:1) JSON解析错误处理,在read_jsonl_file函数中使用try-except捕获json.JSONDecodeError并跳过无效行;2) 文件读取错误处理,read_file_content和read_jsonl_and_convert_to_tensor函数中使用with open确保文件正确关闭;3) 嵌入向量维度不匹配时,getMemory函数中的矩阵乘法运算会自动处理,但需要在实际使用前验证维度一致性。当前代码缺少对以下异常的处理:文件路径不存在时的FileNotFoundError、数据集为空时的空数据集检查、内存嵌入为空时的除零错误、HuggingFace Token无效时的认证失败处理。

数据流与状态机

整体数据流如下:1) 首先通过read_jsonl_file读取原始JSONL对话数据;2) 创建CharacterDataset实例,加载角色系统提示和记忆嵌入;3) 在__getitem__中,对每个样本进行以下处理:根据查询嵌入检索top-k记忆片段→拼接记忆文本→格式化聊天历史→组装完整输入文本;4) 使用DataLoader进行批处理,通过collate_fn整理输入输出;5) 最后将处理后的数据集上传至HuggingFace Hub。状态转换包括:JSONL文件→JSON对象列表→Dataset对象→DataLoader迭代器→字典格式的训练样本。

外部依赖与接口契约

主要外部依赖包括:1) torch - 张量运算和神经网络基础;2) json/jsonlines - JSON格式解析;3) transformers - AutoTokenizer和AutoModel用于嵌入计算;4) huggingface_hub - 数据集上传和登录认证;5) datasets - HuggingFace数据集库;6) os - 文件系统操作。接口契约方面:read_jsonl_file接收文件路径字符串返回Python列表;CharacterDataset接收json_data列表、character_path字符串、memory_number整数、memory_length整数;collate_fn接收batch列表返回(inputs, targets)元组;get_embedding和download_models函数来自utils模块需预先定义。

配置与参数说明

核心配置参数包括:HF_TOKEN - HuggingFace访问令牌用于数据集上传;batch_size - 训练批次大小(代码中设为1);memory_number - 每次检索的记忆数量(CharacterDataset初始化时传入8);memory_length - 记忆文本最大字符数(初始化时传入2000);memory_path - 记忆文件相对路径(固定为"jsonl/title_text_embed.jsonl");system_prompt_name - 系统提示文件名(固定为"system_prompt.txt")。数据集配置包括:train_dataset_dict_A和train_dataset_dict_B分别从Hub加载两个原始数据集,最终通过concatenate_datasets合并为统一训练集。

性能考虑与优化空间

当前实现存在以下性能瓶颈:1) 每次__getitem__调用都会执行torch.matmul进行相似度计算,可预先计算所有查询与记忆的相似度矩阵;2) 字符串拼接操作(join_with_limit)使用循环拼接而非列表+join,效率较低;3) 读取jsonlines时逐行处理,可使用批量读取;4) 未使用多进程加载数据,DataLoader的num_workers参数未设置;5) 嵌入向量存储在内存中,大规模数据时需考虑内存优化。优化建议:使用Faiss或Annoy等近似最近邻库加速检索;将相似度矩阵预计算并缓存;使用__init__中的批量加载替代逐行读取;启用DataLoader的pin_memory和num_workers参数。

安全考虑

代码涉及的安全问题包括:1) HF_TOKEN硬编码在代码中,应使用环境变量或配置文件管理敏感信息;2) 文件路径拼接使用os.path.join但未验证路径有效性,存在路径遍历风险;3) 读取JSONL文件时未限制单行最大长度,可能遭受恶意大文件攻击;4) 上传数据集至Hub前未进行敏感信息扫描。建议措施:将Token移至环境变量;添加路径验证逻辑;设置单行读取长度上限;添加数据脱敏流程。

测试策略建议

建议添加以下测试用例:1) 单元测试 - 测试read_jsonl_file对空文件、格式错误文件的处理;测试CharacterDataset的__len__返回值正确性;测试getMemory返回索引在有效范围内;测试join_with_limit的边界条件。2) 集成测试 - 测试完整的数据加载流程;测试与DataLoader的兼容性;测试数据集拼接正确性。3) 性能测试 - 测试大规模数据加载时间;测试内存占用情况。4) 回归测试 - 验证数据集上传Hub后的可加载性。

部署注意事项

部署时需注意:1) 环境配置 - 确保Python版本≥3.8,PyTorch版本与CUDA版本兼容;2) 依赖管理 - 建议使用requirements.txt或pyproject.toml管理依赖版本;3) 资源预估 - 需预估内存使用量(原始数据+嵌入+中间变量);4) 路径配置 - character_path、jsonl_file_path等路径需根据部署环境调整;5) 数据隐私 - 上传Hub的数据集需确认不包含敏感信息;6) 监控告警 - 建议添加数据加载失败和异常处理的日志记录。