该代码通过Hugging Face Hub加载两个角色对话数据集(Chat_Suzumiya_Fusion和Chat_Suzumiya_Fusion_B),使用concatenate_datasets函数合并后上传到新的Hub仓库,实现多数据集的融合与共享。
graph TD
A[开始] --> B[登录Hugging Face Hub]
B --> C[从Hub加载数据集A: silk-road/Chat_Suzumiya_Fusion]
C --> D[从Hub加载数据集B: silk-road/Chat_Suzumiya_Fusion_B]
D --> E[提取train子集]
E --> F[使用concatenate_datasets合并数据集]
F --> G[上传合并后的数据集到Hub: silk-road/Chat-Haruhi-Fusion-A_B]
G --> H[结束]
模块级
├── read_jsonl_file (全局函数)
├── collate_fn (全局函数)
├── CharacterDataset (自定义Dataset类)
│ ├── __init__
│ ├── __len__
│ ├── __getitem__
│ ├── join_with_limit
│ ├── read_file_content
│ ├── read_jsonl_and_convert_to_tensor
│ ├── getSystemPrompt
│ └── getMemoryHuggingFace API token for authentication
类型:str
List of JSON objects loaded from dialogue JSONL files
类型:List[dict]
Path to the character directory containing character-specific files
类型:str
Number of top-k memory entries to retrieve based on similarity
类型:int
Relative path to the memory embedding JSONL file (jsonl/title_text_embed.jsonl)
类型:str
Filename of the system prompt file (system_prompt.txt)
类型:str
Content of the character-specific system prompt
类型:str
Tensor containing pre-computed embeddings for character memories
类型:torch.Tensor
List of text content corresponding to memory embeddings
类型:List[str]
Maximum character length for concatenated retrieved memories
类型:int
该函数是一个用于读取JSONL(JSON Lines)文件的工具函数,它逐行读取文件内容,解析每行的JSON数据,并将所有有效的JSON对象存储在列表中返回,同时具备基本的错误处理能力以跳过格式不正确的行。
参数:
file_path:str,要读取的JSONL文件的路径
返回值:List[dict],包含文件中所有成功解析的JSON对象的列表
flowchart TD
A[开始] --> B[初始化空列表 data]
B --> C[打开文件 file_path]
C --> D[读取下一行]
D --> E{是否还有更多行?}
E -->|否| F[返回 data 列表]
E -->|是| G[去除行首尾空白]
G --> H{行是否为空?}
H -->|是| D
H -->|否| I[尝试解析JSON]
I --> J{解析是否成功?}
J -->|是| K[将JSON对象添加到data列表]
K --> D
J -->|否| L[跳过该行]
L --> D
F --> M[结束]
def read_jsonl_file(file_path):
"""
读取JSONL文件并解析其中的JSON对象
参数:
file_path: 要读取的JSONL文件路径
返回值:
包含文件中所有JSON对象的列表
"""
# 初始化一个空列表用于存储解析后的JSON数据
data = []
# 以只读模式打开文件,使用UTF-8编码
with open(file_path, 'r', encoding='utf-8') as f:
# 遍历文件的每一行
for line in f:
# 去除行首尾的空白字符
line = line.strip()
# 只处理非空行
if line:
try:
# 尝试将行解析为JSON格式
json_data = json.loads(line)
# 解析成功,将JSON对象添加到列表中
data.append(json_data)
except json.JSONDecodeError:
# 解析失败(JSON格式不正确),跳过该行
# 可以选择打印错误信息用于调试
# print(f"Failed to parse JSON: {line}")
continue
# 返回包含所有JSON对象的列表
return data| 名称 | 一句话描述 |
|---|---|
json 模块 |
Python标准库,用于JSON数据的解析和编码 |
jsonlines 模块 |
后续代码中使用的JSONL文件读取库,可逐行高效读取 |
CharacterDataset 类 |
自定义PyTorch数据集,用于加载角色对话数据 |
- 缺少文件存在性检查:函数未检查文件路径是否存在或是否为有效文件,可能导致
FileNotFoundError - 编码硬编码:UTF-8编码被硬编码,无法处理其他编码格式的JSONL文件
- 错误处理不够细致:仅捕获
JSONDecodeError,对于其他可能的IO错误没有处理 - 大文件处理效率:对于非常大的JSONL文件,一次性加载所有数据到内存中可能导致内存问题,应考虑使用生成器模式或流式处理
- 返回空列表的歧义:当文件不存在或所有行都解析失败时,都返回空列表,调用者无法区分具体情况
- 设计目标:提供一个简单可靠的工具函数,用于将JSONL文件转换为Python列表数据结构
- 约束:仅处理标准的JSONL格式(每行一个独立的JSON对象)
- 当前仅捕获
json.JSONDecodeError异常,跳过格式错误的行 - 建议增强:可添加文件读取异常的处理,如
FileNotFoundError、PermissionError等 - 建议增强:可提供错误回调机制,让调用者决定如何处理解析失败的行
- 依赖:Python标准库
json和builtins.open - 接口契约:调用方需传入有效的文件路径,返回值为列表类型
该函数是PyTorch DataLoader的回调函数,负责将数据集__getitem__返回的单个样本整理成一个批次。它从批次中提取input和answer字段,分别组成输入列表和目标列表,供后续模型训练使用。
参数:
batch:List[Dict],从DataLoader获取的批次数据,每个元素是包含input和answer字段的字典
返回值:Tuple[List[str], List[str]],返回两个字符串列表——第一个是输入列表,第二个是目标答案列表
flowchart TD
A[接收batch参数] --> B{检查batch是否为空}
B -->|是| C[返回空列表元组]
B -->|否| D[遍历batch中的每个sample]
D --> E[提取sample的input字段]
D --> F[提取sample的answer字段]
E --> G[构建inputs列表]
F --> H[构建targets列表]
G --> I[返回inputs和targets元组]
H --> I
def collate_fn(batch):
"""
DataLoader的整理函数,将单个样本整理成批次
参数:
batch: DataLoader返回的样本列表,每个样本是包含
'input'和'answer'字段的字典
返回:
(inputs, targets): 输入列表和目标答案列表的元组
"""
# 从每个样本中提取input字段,构建输入列表
inputs = [sample["input"] for sample in batch]
# 从每个样本中提取answer字段,构建目标列表
targets = [sample["answer"] for sample in batch]
# 注释: 以下是使用torch.stack的原始实现,但需要tensor类型
# 当前实现直接返回列表,兼容字符串类型的数据
# batch_inputs = torch.stack(inputs)
# batch_targets = torch.stack(targets)
# 返回输入和目标列表元组
return inputs, targets该函数是 Hugging Face Hub 库提供的身份验证函数,用于将本地的 Hugging Face 访问令牌(Token)进行登录认证,以便后续能够访问私有模型、数据集或其他需要认证的资源。
参数:
token:str,Hugging Face 的访问令牌(Personal Access Token),用于身份验证add_to_git_credential:Optional[bool],可选参数,是否将令牌添加到 git 凭证中,默认为 Falsequiet:Optional[bool],可选参数,是否静默执行,默认为 False
返回值:None,该函数无返回值,执行成功后直接完成登录状态
graph TD
A[开始 login] --> B{检查 token 是否有效}
B -->|无效| C[抛出异常: 无效的 token]
B -->|有效| D{add_to_git_credential 为 True?}
D -->|是| E[将 token 添加到 git credential]
D -->|否| F{quiet 为 True?}
F -->|是| G[静默完成登录]
F -->|否| H[输出登录成功信息]
E --> I[完成登录状态]
G --> I
H --> I
I[结束 login]
# 从 huggingface_hub 库导入 login 函数
from huggingface_hub import login
# 定义 Hugging Face 访问令牌
# 该令牌用于身份验证,可从 https://huggingface.co/settings/tokens 获取
HF_TOKEN = "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"
# 调用 login 函数进行身份验证
# 参数 token: 传入访问令牌进行认证
# 作用: 将本地环境与 Hugging Face 账号绑定,后续可访问私有模型/数据集
login(token=HF_TOKEN)注意:该 login 函数来源于外部库 huggingface_hub,并非在本项目代码中定义。以上信息基于 Hugging Face Hub 官方文档和常见用法。如需查看完整源码实现,请参考 huggingface_hub 官方仓库。
经过分析,代码中 load_dataset 是从 Hugging Face datasets 库导入的外部函数,并非在该代码文件中定义。以下是在代码中的使用情况:
从 Hugging Face datasets 库导入的函数,用于从 Hub 加载数据集。
参数:
path:str,Hugging Face Hub 上的数据集名称或本地路径
返回值:datasets.DatasetDict,返回包含数据集的字典,通常包含 'train'、'test' 等拆分
graph TD
A[调用 load_dataset] --> B{检查本地缓存}
B -->|本地无缓存| C[从 Hugging Face Hub 下载数据集]
B -->|本地有缓存| D[从本地加载数据集]
C --> E[返回 DatasetDict 对象]
D --> E
# 从 Hugging Face datasets 库导入 load_dataset 函数
# 这是一个外部依赖函数,非本代码文件定义
from datasets import load_dataset, Dataset
# 加载两个预训练数据集
# 参数:数据集在 Hugging Face Hub 上的路径
train_dataset_dict_A = load_dataset('silk-road/Chat_Suzumiya_Fusion')
train_dataset_dict_B = load_dataset('silk-road/Chat_Suzumiya_Fusion_B')
# 从 DatasetDict 中提取 'train' 拆分
train_dataset_A = train_dataset_dict_A['train']
train_dataset_B = train_dataset_dict_B['train']
# 合并两个数据集
train_dataset = concatenate_datasets([train_dataset_A, train_dataset_B])
# 将合并后的数据集推送到 Hub
train_dataset.push_to_hub("silk-road/Chat-Haruhi-Fusion-A_B")注意:代码中未定义
load_dataset函数,而是从datasets库导入的外部函数。如需了解该函数的完整文档,建议参考 Hugging Face datasets 库官方文档。
该函数用于将多个 Hugging Face Dataset 对象沿第一个轴(行方向)拼接成一个数据集,是 Hugging Face datasets 库提供的工具函数。
参数:
datasets:List[Dataset],要拼接的数据集列表
返回值:Dataset,拼接后的数据集对象
flowchart TD
A[开始] --> B[加载数据集A silk-road/Chat_Suzumiya_Fusion]
B --> C[加载数据集B silk-road/Chat_Suzumiya_Fusion_B]
C --> D[提取train数据集]
D --> E[调用concatenate_datasets合并]
E --> F[推送合并后的数据集到Hub]
F --> G[结束]
# 从HuggingFace datasets库导入concatenate_datasets函数
from datasets import load_dataset, concatenate_datasets
# 加载两个预训练数据集
# 数据集A: 包含角色对话数据
train_dataset_dict_A = load_dataset('silk-road/Chat_Suzumiya_Fusion')
# 数据集B: 包含另一个角色对话数据集合
train_dataset_dict_B = load_dataset('silk-road/Chat_Suzumiya_Fusion_B')
# 从DatasetDict中提取训练集
train_dataset_A = train_dataset_dict_A['train']
train_dataset_B = train_dataset_dict_B['train']
# 使用concatenate_datasets函数沿行方向拼接两个数据集
# 输入: [dataset_A, dataset_B] - 需要拼接的数据集列表
# 输出: 合并后的单一数据集对象
train_dataset = concatenate_datasets([train_dataset_A, train_dataset_B])
# 将合并后的数据集推送到HuggingFace Hub
train_dataset.push_to_hub("silk-road/Chat-Haruhi-Fusion-A_B")该函数是从 utils 模块导入的,但在当前提供的代码文件中仅作为导入语句出现,没有实际的实现代码。根据代码上下文推断,该函数可能用于获取文本的嵌入向量(embedding),但在当前代码中未被调用。
参数:
- 无法从当前代码中确定参数信息(函数未在此文件中定义)
返回值:
- 无法从当前代码中确定返回值信息(函数未在此文件中定义)
graph TD
A[开始] --> B[调用 utils.get_embedding]
B --> C[函数定义未知]
C --> D[结束]
# 当前代码文件中仅包含导入语句
from utils import get_embedding, download_models
# get_embedding 函数的实际实现在 utils 模块中
# 但该模块未在当前代码文件中提供
# 根据函数名推测,该函数可能用于:
# 1. 接收文本输入
# 2. 使用预训练模型(如 BERT, Sentence-BERT 等)生成文本嵌入向量
# 3. 返回嵌入向量供后续相似度计算或检索使用注意:当前提供的代码文件中,get_embedding 函数仅作为导入语句出现,没有实际的函数定义。该函数的具体参数、返回值和实现逻辑需要查看 utils 模块的源代码。
该函数为从 utils 模块导入的外部函数,用于下载所需的预训练模型。根据代码上下文推断,该函数应负责下载 Hugging Face 上的预训练语言模型(如 LLM、embedding 模型等),以支持后续的推理或微调任务。
注意:该函数的完整实现未在当前代码文件中定义,而是通过 from utils import download_models 导入。以下信息基于代码上下文推断得出。
参数:
- (无法从当前代码中确定具体参数,推测可能包含模型名称、存储路径等参数)
返回值:
- (无法从当前代码中确定具体返回值,推测可能返回下载的模型对象或模型路径)
由于该函数未在当前代码中实现,无法绘制详细的流程图。基于函数名称和项目上下文,推测其流程如下:
flowchart TD
A[开始下载模型] --> B{检查模型是否已存在}
B -->|是| C[返回本地模型路径]
B -->|否| D[从Hugging Face Hub下载模型]
D --> E{下载是否成功}
E -->|成功| F[保存到指定目录]
E -->|失败| G[抛出异常或重试]
F --> H[返回模型路径/对象]
# 该函数定义在 utils 模块中,当前代码文件仅导入该函数
from utils import get_embedding, download_models
# 函数调用示例(在当前文件中未实际调用)
# download_models(...)
# 推测该函数可能的使用方式:
# - 参数:model_name, save_path, tokenizer 等
# - 返回:model, tokenizer 或模型路径字符串- utils 模块:包含
get_embedding和download_models两个工具函数 - huggingface_hub:用于模型下载和数据集访问
- 代码组织问题:
download_models函数的实现未被包含在当前代码中,导致无法进行详细分析 - 硬编码凭证:HuggingFace Token 直接以字符串形式硬编码在代码中(
HF_TOKEN),存在安全风险 - 缺乏错误处理:未看到对模型下载失败情况的处理逻辑
若需要完整的函数分析,建议提供 utils.py 文件的内容,或在项目中查看该函数的实际实现代码。
这是 CharacterDataset 类的构造函数,用于初始化角色对话数据集。该方法接收训练数据、角色路径、记忆数量和记忆长度等参数,加载系统提示词和角色的记忆嵌入向量,为后续的数据检索和模型训练做好准备。
参数:
self:隐式参数,表示实例本身json_data:List[Dict],从 JSONL 文件读取的对话数据列表,每个元素包含 query、answer、chat_history、embedding、source 等字段character_path:str,角色文件夹的路径,用于定位系统提示词和记忆向量文件memory_number:int,从记忆库中检索的最相似记忆数量,用于构建上下文memory_length:int,检索记忆的最大字符长度限制,用于控制上下文长度
返回值:None,构造函数无返回值
graph TD
A[开始 __init__] --> B[接收参数: json_data, character_path, memory_number, memory_length]
B --> C[设置 self.data = json_data]
C --> D[设置 self.character_path = character_path]
D --> E[设置 self.memory_number = memory_number]
E --> F[设置 self.memory_path = 'jsonl/title_text_embed.jsonl']
F --> G[设置 self.system_prompt_name = 'system_prompt.txt']
G --> H[调用 getSystemPrompt 读取系统提示词]
H --> I[设置 self.system_prompt]
I --> J[调用 read_jsonl_and_convert_to_tensor 加载记忆嵌入]
J --> K[设置 self.memory_embed 和 self.memory_text]
K --> L[设置 self.memory_length = memory_length]
L --> M[结束 __init__]
def __init__(self, json_data, character_path, memory_number, memory_length):
"""
初始化 CharacterDataset 实例
参数:
json_data: 从 JSONL 文件解析的对话数据列表
character_path: 角色相关文件的根目录路径
memory_number: 每次检索返回的记忆条数
memory_length: 记忆文本的最大字符长度
"""
pass # 占位符,可能是之前预留的代码位置
# 存储原始对话数据
self.data = json_data
# 存储角色路径,后续用于加载系统提示词和记忆文件
self.character_path = character_path
# 存储要检索的记忆数量,用于 getMemory 方法
self.memory_number = memory_number
# 记忆文件的相对路径(相对于 character_path)
self.memory_path = "jsonl/title_text_embed.jsonl"
# 系统提示词文件名
self.system_prompt_name = "system_prompt.txt"
# 读取并存储系统提示词内容
# 系统提示词定义了角色的行为方式和对话风格
self.system_prompt = self.getSystemPrompt()
# 加载角色记忆的嵌入向量和原始文本
# memory_embed: 用于计算相似度的向量矩阵
# memory_text: 记忆的原始文本内容
self.memory_embed, self.memory_text = self.read_jsonl_and_convert_to_tensor(
os.path.join(self.character_path, self.memory_path)
)
# 存储记忆文本的最大长度限制
# 用于在 join_with_limit 方法中截断过长的记忆
self.memory_length = memory_length该方法为 PyTorch Dataset 类的标准长度方法,返回数据集中包含的样本总数,供 DataLoader 等工具在迭代时获取数据集大小。
参数:无(self 为隐式参数,表示类实例本身)
返回值:int,返回数据集 self.data 中包含的样本数量
flowchart TD
A[开始 __len__] --> B[返回 self.data 的长度]
B --> C[结束]
def __len__(self):
"""
返回数据集的样本数量
Returns:
int: 数据集包含的样本数量
"""
return len(self.data) # 返回存储在 self.data 中的数据列表的长度该方法是PyTorch数据集类的核心接口,根据给定索引从数据集中获取样本,并通过检索增强生成模型输入。它首先根据查询的embedding向量从记忆库中检索最相关的记忆,然后拼接系统提示、检索记忆、聊天历史和当前查询,形成完整的模型输入上下文。
参数:
idx:int,要获取的样本索引,用于从数据集中定位特定的训练样本
返回值:dict,包含预处理后训练数据的字典,包含query、system_prompt、retrieved_memory、chat_history、answer、embedding、source和input等字段
flowchart TD
A[开始 __getitem__] --> B[根据 idx 从 self.data 获取样本]
B --> C[将样本的 embedding 转换为 PyTorch Tensor]
C --> D[调用 getMemory 检索相关记忆]
D --> E[从 memory_text 中提取 top_k 对应的记忆文本]
E --> F[使用 join_with_limit 限制记忆长度]
F --> G[拼接 chat_history 为字符串]
G --> H[拼接 system_prompt + retrieved_memory + chat_history + query]
H --> I[构建返回字典 data]
I --> J[返回 data 字典]
def __getitem__(self, idx):
# 根据索引从数据列表中获取原始样本
sample = self.data[idx]
# 将样本中的 embedding 转换为 PyTorch 张量
query_embed = torch.tensor(sample["embedding"])
# 根据查询向量从记忆库中检索最相似的记忆
top_k = self.getMemory(query_embed)
# 从记忆文本列表中提取 top_k 索引对应的记忆内容
# 记忆文本格式为 "key|||value",这里取 value 部分
retrieved_memory = [self.memory_text[i].split("|||")[1] for i in top_k]
# 使用 join_with_limit 方法限制记忆的总长度
retrieved_memory = self.join_with_limit(retrieved_memory, self.memory_length)
# 获取聊天历史并使用 "###" 作为分隔符拼接
chat_history = sample["chat_history"]
chat_history = '###'.join(chat_history)
chat_history += '###'
# 获取当前查询
query = sample["query"]
# 拼接完整的输入:系统提示 + 检索记忆 + 聊天历史 + 查询
input = self.system_prompt + "###" + retrieved_memory + chat_history + query
# 构建返回的数据字典,包含模型训练所需的全部信息
data = {
"query": sample["query"], # 原始查询文本
"system_prompt": self.system_prompt, # 系统提示词
"retrieved_memory": retrieved_memory, # 检索到的记忆内容
"chat_history": chat_history, # 历史对话记录
"answer": sample["answer"], # 期望的模型回答
"embedding": query_embed, # 查询的向量表示
"source": sample["source"], # 数据来源标识
"input": input # 完整的模型输入
}
return data该方法用于将字符串列表中的元素逐个连接到一个结果字符串中,直到达到指定的最大长度限制为止。如果加入当前元素会导致超过最大长度,则停止添加剩余元素。
参数:
self:CharacterDataset,类的实例本身items:List[str],需要连接的字符串列表max_length:int,连接后的最大长度限制separator:str,连接元素时使用的分隔符,默认为"###"
返回值:str,连接后的字符串,当总长度不超过 max_length 时返回所有items的连接结果,否则返回满足长度限制的前缀部分
flowchart TD
A[开始 join_with_limit] --> B[初始化 result = ""]
B --> C{遍历 items 中的每个 item}
C --> D{检查加入当前item是否超过max_length}
D -->|是| E{result是否为空}
D -->|否| F[将separator加到result]
F --> G[将item加到result]
G --> C
E -->|否| H[直接添加item到result]
H --> G
E -->|是| I[break 退出循环]
I --> J[返回 result]
C -->|遍历完成| J
def join_with_limit(self, items, max_length, separator="###"):
"""
将字符串列表中的元素逐个连接,直到达到最大长度限制
参数:
items: 需要连接的字符串列表
max_length: 最大长度限制
separator: 分隔符,默认为"###"
返回:
连接后的字符串,不超过max_length
"""
result = "" # 初始化结果字符串为空
for item in items: # 遍历每个元素
# 如果加入当前元素不会导致超过上限,就将其添加到结果字符串中
if len(result) + len(item) + len(separator) <= max_length:
if result: # 如果结果非空,先添加分隔符
result += separator
result += item # 添加当前元素
else:
break # 如果已经超过上限,就停止添加新元素
return result # 返回连接后的结果该方法用于读取指定路径的文本文件内容,并返回完整的字符串形式。
参数:
file_path:str,需要读取的文件的路径(包含文件名)
返回值:str,返回文件的全部内容
flowchart TD
A[开始] --> B[打开文件<br/>file_path with 'r' mode<br/>encoding='utf-8']
B --> C[读取文件全部内容<br/>f.read()]
C --> D[返回文件内容<br/>content]
D --> E[结束]
def read_file_content(self, file_path):
"""
读取指定路径的文件内容
参数:
file_path: str, 需要读取的文件的路径
返回:
str, 文件的完整内容
"""
# 使用 utf-8 编码打开文件,'r' 模式表示只读
with open(file_path, 'r', encoding='utf-8') as f:
# 读取文件的全部内容并存储到 content 变量
content = f.read()
# 返回读取到的文件内容
return content该方法用于读取指定路径的 JSONL 格式记忆文件,解析其中的嵌入向量数据和文本数据,将嵌入向量列表转换为 PyTorch 张量,并与原始文本数据一起返回。
参数:
file_path:str,要读取的 JSONL 文件路径,该文件包含记忆的嵌入向量和对应文本
返回值:tuple,返回一个元组,包含 torch.Tensor 类型的嵌入向量张量和 list 类型的文本数据列表
flowchart TD
A[开始] --> B[初始化空列表 embed_list 和 data_list]
B --> C[使用 jsonlines 打开文件]
C --> D{遍历 reader 中的每一项}
D -->|是| E[获取 item 的值并添加到 embed_list]
E --> F[获取 item 的键并添加到 data_list]
F --> D
D -->|否| G[将 embed_list 转换为 torch.Tensor]
G --> H[返回 data_tensor 和 data_list]
H --> I[结束]
def read_jsonl_and_convert_to_tensor(self, file_path):
"""
读取 JSONL 文件并转换为 PyTorch 张量
:param file_path: str, JSONL 文件路径
:return: tuple, (torch.Tensor, list) 嵌入向量张量和文本列表
"""
# 初始化用于存储嵌入向量的列表
embed_list = []
# 初始化用于存储文本数据的列表
data_list = []
# 使用 jsonlines 库打开文件进行逐行读取
with jsonlines.open(file_path) as reader:
# 遍历文件中的每一项
for item in reader:
# 获取字典中的第一个值(嵌入向量)并添加到列表
embed_list.append(next(iter(item.values())))
# 获取字典中的第一个键(文本数据)并添加到列表
data_list.append(next(iter(item.keys())))
# 将嵌入向量列表转换为 PyTorch 张量
data_tensor = torch.tensor(embed_list)
# 返回张量形式的嵌入向量和文本列表
return data_tensor, data_list该方法用于获取角色扮演系统的系统提示词。它通过拼接角色路径和系统提示文件名,然后读取该文件的内容并返回。
参数:
- 该方法无显式参数(隐式参数
self为类的实例)
返回值:str,返回系统提示文件的内容字符串
flowchart TD
A[开始 getSystemPrompt] --> B[构建文件路径]
B --> C{调用 read_file_content}
C --> D{打开并读取文件}
D --> E[返回文件内容字符串]
E --> F[结束]
def getSystemPrompt(self):
"""
获取系统提示词内容
该方法通过读取角色目录下预定义的 system_prompt.txt 文件来获取
角色的系统提示词,用于在对话开始时设定角色的行为特征和背景
Returns:
str: 系统提示词文件的内容
"""
# 拼接完整的文件路径:角色目录路径 + 系统提示文件名
file_path = os.path.join(self.character_path, self.system_prompt_name)
# 调用文件读取方法获取文件内容
file_content = self.read_file_content(file_path)
# 返回读取到的系统提示词内容
return file_content该方法通过计算输入向量与预存记忆嵌入之间的相似度,利用矩阵乘积找出最相似的记忆条目,并返回其索引,用于检索最相关的记忆内容。
参数:
vector:torch.Tensor,输入查询的嵌入向量,用于与记忆库中的嵌入向量进行相似度计算
返回值:torch.Tensor,返回最相似的记忆条目的索引,索引数量由 self.memory_number 决定
flowchart TD
A[开始 getMemory] --> B[输入: vector 查询向量]
B --> C[执行矩阵乘法: similarity = torch.matmul self.memory_embed, vector]
C --> D[使用 torch.topk 找出最高相似度的 top_indices]
D --> E[返回 top_indices]
E --> F[结束]
subgraph 内部状态
C -.->|存储在| G[similarity: 相似度张量]
D -.->|存储在| H[top_indices: 最高相似度的索引]
end
def getMemory(self, vector):
"""
根据输入的查询向量,从记忆库中检索最相似的记忆索引
参数:
vector: torch.Tensor,查询向量,通常是当前对话或问题的embedding
返回:
torch.Tensor,包含最相似记忆的索引
"""
pass # 占位符,无实际作用
# 计算输入向量与所有记忆嵌入的相似度(点积)
# self.memory_embed: 所有记忆的嵌入向量 [num_memories, embed_dim]
# vector: 查询向量 [embed_dim] 或 [batch, embed_dim]
similarity = torch.matmul(self.memory_embed, vector)
# 找出最大的几个元素的索引
# self.memory_number: 要检索的记忆数量
# 使用 torch.topk 获取相似度最高的 top_memory_number 个记忆的索引
top_indices = torch.topk(similarity, self.memory_number).indices
return top_indices # 返回最相似记忆的索引列表
# print("Top indices:", top_indices) # 调试打印(已注释)该代码实现了一个基于角色记忆检索的对话数据集处理系统,支持从多个JSONL文件加载角色对话数据,通过嵌入向量相似度检索相关记忆,并将处理后的数据合并推送至HuggingFace Hub。
代码中通过torch.topk对记忆嵌入向量进行相似度检索,返回top_k个最相似的记忆索引,实现高效的向量相似度匹配。CharacterDataset类在初始化时预加载所有记忆嵌入,在__getitem__时仅按需检索相关记忆,实现惰性加载模式。
虽然代码中未直接使用量化,但通过torch.tensor直接处理嵌入向量,支持浮点型张量运算,为后续量化模型推理提供数据兼容性基础。
当前实现未包含量化策略,所有嵌入和特征均以float32存储,可后续通过动态量化或静态量化优化推理性能。
代码构建了完整的数据管道:read_jsonl_file读取原始数据 -> CharacterDataset处理转换 -> DataLoader批处理 -> concatenate_datasets合并多数据集 -> push_to_hub推送至云端。
getMemory方法通过矩阵乘法计算查询向量与记忆库的相似度,使用torch.topk获取最相关的记忆条目,支持自定义检索数量配置。
CharacterDataset通过getSystemPrompt方法从文件系统动态加载角色专属的系统提示词,实现角色性格和行为模式的定制化配置。
join_with_limit方法实现了带长度约束的记忆文本拼接,确保输入长度不超过指定阈值,避免模型输入溢出。
- 硬编码的敏感信息:HF_TOKEN直接明文写在代码中,存在严重的安全风险
- 硬编码的路径和配置:memory_path="jsonl/title_text_embed.jsonl"、system_prompt_name="system_prompt.txt"、分隔符"|||"等均硬编码在类中,降低了灵活性
- 大量未清理的注释代码:代码中存在大量被注释掉的旧逻辑和测试代码(如file_names列表、数据处理代码块),严重影响代码可读性和维护性
- 未使用或冗余的导入:导入了login但实际未使用;导入了ConcatDataset、nn等但未使用;被注释掉的load_dataset等导入仍然保留
- 代码风格不一致:混合使用下划线命名(snake_case)和驼峰命名(camelCase),如join_with_limit vs read_jsonl_and_convert_to_tensor
- 缺乏类型提示:整个代码没有任何类型注解(type hints),降低了代码可维护性和IDE支持
- 不完整的函数实现:getMemory方法中开头有pass语句但后面还有代码,语法冗余且风格不佳
- 文件操作缺乏异常处理:read_file_content、read_jsonl_and_convert_to_tensor、getSystemPrompt等文件读写操作没有try-except保护
- 资源未正确管理:read_jsonl_and_convert_to_tensor使用jsonlines.open()但未使用with语句,可能导致资源泄漏
- 内存效率问题:read_jsonl_and_convert_to_tensor一次性将所有embedding加载到内存,对于大规模数据集可能导致OOM
- 数据加载效率低下:collate_fn返回嵌套列表而非 tensors,训练时需要额外转换;embedding在__getitem__中每次重新转换为tensor,未做缓存
- 缺失的数据校验:__init__和__getitem__中未对json_data的关键字段进行校验,可能导致运行时KeyError
- 魔法数字:memory_number=8、memory_length=2000等参数直接传入,缺乏配置说明
- 安全改进:将HF_TOKEN移除代码,使用环境变量或配置中心管理敏感信息
- 配置外部化:将路径、分隔符等配置抽取到配置文件或环境变量,支持不同角色的数据源
- 代码清理:删除所有注释掉的无效代码块,或移至独立的废弃代码目录
- 精简导入:移除未使用的导入语句,保持导入列表整洁
- 统一命名规范:全项目统一使用snake_case命名风格
- 添加类型提示:为所有函数参数、返回值添加类型注解,提升代码可读性
- 移除冗余语法:删除getMemory方法开头的pass语句
- 增强异常处理:为所有文件IO操作添加try-except,优雅处理文件不存在、格式错误等异常
- 使用上下文管理器:确保所有文件操作使用with语句管理资源
- 优化内存使用:考虑使用内存映射或分批加载embedding;将embedding在__init__中预先转换为tensor并缓存
- 优化数据加载:collate_fn中直接返回tensor;使用pin_memory加速GPU数据传输
- 数据校验:在__init__中校验json_data必需字段,在__getitem__中添加默认值或错误处理
- 配置参数化:将magic number改为具名常量或配置参数,增加代码可读性
本代码旨在构建一个用于训练角色扮演对话模型的PyTorch数据集处理流水线。核心目标是将多个角色的对话数据(JSONL格式)进行处理,生成包含系统提示、角色记忆、聊天历史和查询的完整训练样本。主要约束包括:1) 内存嵌入向量维度必须与预训练模型一致;2) 检索的记忆数量(memory_number)默认为8条;3) 记忆文本最大长度(memory_length)限制为2000字符;4) 数据来源为HuggingFace Hub上的两个数据集(silk-road/Chat_Suzumiya_Fusion 和 silk-road/Chat_Suzumiya_Fusion_B)。
代码中的错误处理主要包括:1) JSON解析错误处理,在read_jsonl_file函数中使用try-except捕获json.JSONDecodeError并跳过无效行;2) 文件读取错误处理,read_file_content和read_jsonl_and_convert_to_tensor函数中使用with open确保文件正确关闭;3) 嵌入向量维度不匹配时,getMemory函数中的矩阵乘法运算会自动处理,但需要在实际使用前验证维度一致性。当前代码缺少对以下异常的处理:文件路径不存在时的FileNotFoundError、数据集为空时的空数据集检查、内存嵌入为空时的除零错误、HuggingFace Token无效时的认证失败处理。
整体数据流如下:1) 首先通过read_jsonl_file读取原始JSONL对话数据;2) 创建CharacterDataset实例,加载角色系统提示和记忆嵌入;3) 在__getitem__中,对每个样本进行以下处理:根据查询嵌入检索top-k记忆片段→拼接记忆文本→格式化聊天历史→组装完整输入文本;4) 使用DataLoader进行批处理,通过collate_fn整理输入输出;5) 最后将处理后的数据集上传至HuggingFace Hub。状态转换包括:JSONL文件→JSON对象列表→Dataset对象→DataLoader迭代器→字典格式的训练样本。
主要外部依赖包括:1) torch - 张量运算和神经网络基础;2) json/jsonlines - JSON格式解析;3) transformers - AutoTokenizer和AutoModel用于嵌入计算;4) huggingface_hub - 数据集上传和登录认证;5) datasets - HuggingFace数据集库;6) os - 文件系统操作。接口契约方面:read_jsonl_file接收文件路径字符串返回Python列表;CharacterDataset接收json_data列表、character_path字符串、memory_number整数、memory_length整数;collate_fn接收batch列表返回(inputs, targets)元组;get_embedding和download_models函数来自utils模块需预先定义。
核心配置参数包括:HF_TOKEN - HuggingFace访问令牌用于数据集上传;batch_size - 训练批次大小(代码中设为1);memory_number - 每次检索的记忆数量(CharacterDataset初始化时传入8);memory_length - 记忆文本最大字符数(初始化时传入2000);memory_path - 记忆文件相对路径(固定为"jsonl/title_text_embed.jsonl");system_prompt_name - 系统提示文件名(固定为"system_prompt.txt")。数据集配置包括:train_dataset_dict_A和train_dataset_dict_B分别从Hub加载两个原始数据集,最终通过concatenate_datasets合并为统一训练集。
当前实现存在以下性能瓶颈:1) 每次__getitem__调用都会执行torch.matmul进行相似度计算,可预先计算所有查询与记忆的相似度矩阵;2) 字符串拼接操作(join_with_limit)使用循环拼接而非列表+join,效率较低;3) 读取jsonlines时逐行处理,可使用批量读取;4) 未使用多进程加载数据,DataLoader的num_workers参数未设置;5) 嵌入向量存储在内存中,大规模数据时需考虑内存优化。优化建议:使用Faiss或Annoy等近似最近邻库加速检索;将相似度矩阵预计算并缓存;使用__init__中的批量加载替代逐行读取;启用DataLoader的pin_memory和num_workers参数。
代码涉及的安全问题包括:1) HF_TOKEN硬编码在代码中,应使用环境变量或配置文件管理敏感信息;2) 文件路径拼接使用os.path.join但未验证路径有效性,存在路径遍历风险;3) 读取JSONL文件时未限制单行最大长度,可能遭受恶意大文件攻击;4) 上传数据集至Hub前未进行敏感信息扫描。建议措施:将Token移至环境变量;添加路径验证逻辑;设置单行读取长度上限;添加数据脱敏流程。
建议添加以下测试用例:1) 单元测试 - 测试read_jsonl_file对空文件、格式错误文件的处理;测试CharacterDataset的__len__返回值正确性;测试getMemory返回索引在有效范围内;测试join_with_limit的边界条件。2) 集成测试 - 测试完整的数据加载流程;测试与DataLoader的兼容性;测试数据集拼接正确性。3) 性能测试 - 测试大规模数据加载时间;测试内存占用情况。4) 回归测试 - 验证数据集上传Hub后的可加载性。
部署时需注意:1) 环境配置 - 确保Python版本≥3.8,PyTorch版本与CUDA版本兼容;2) 依赖管理 - 建议使用requirements.txt或pyproject.toml管理依赖版本;3) 资源预估 - 需预估内存使用量(原始数据+嵌入+中间变量);4) 路径配置 - character_path、jsonl_file_path等路径需根据部署环境调整;5) 数据隐私 - 上传Hub的数据集需确认不包含敏感信息;6) 监控告警 - 建议添加数据加载失败和异常处理的日志记录。