diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3a7a4ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +# 忽略特定文件 +screenshot.jpg +list.error +tasks.json + +# 忽略特定目录 +__pycache__/ +data/ +data_log/ +weights/ \ No newline at end of file diff --git a/MobiFlow/CHANGELOG.md b/MobiFlow/CHANGELOG.md new file mode 100644 index 0000000..46a2c4f --- /dev/null +++ b/MobiFlow/CHANGELOG.md @@ -0,0 +1,59 @@ +# 更新日志 + +## 最新改进(2025年8月) + +### 新增功能 + +#### 1. 动态筛选条件检查器 (`dynamic_match`) +- **功能描述**:支持根据任务描述动态提取筛选条件并验证相应操作 +- **应用场景**:特别适用于电商类任务、选票类,如"销量最高"、"价格最低"等条件的验证 +- **配置示例**: + ```yaml + dynamic_match: + extract_from: task_description + condition_patterns: + sales_highest: + trigger_keywords: ["销量最高", "销量最多", "销量"] + verify_keywords: ["销量", "最高", "最多"] + llm_prompt: "该步骤是否执行了按销量排序的操作?" + ``` + +#### 2. 完善的escalate策略 +- **优先检查顺序**:`text` → `regex` → `action` → `ui/icons` → `ocr` → `llm` + +#### 3. 增强的轨迹数据支持 +- **目录格式**:支持包含截图、XML、动作记录的完整轨迹目录 +- **多模态验证**:LLM验证时支持同时传入当前帧和下一帧的截图 +- **上下文丰富**:每帧包含任务描述、推理过程、动作信息等完整上下文 + +### 测试改进 + + +#### 测试覆盖 +- **淘宝trace 1、2**: 淘宝搜索商品 +- **淘宝trace 3**: 价格最低的Type-C数据线筛选验证 +- **淘宝trace 4**: 销量最高的苹果充电线筛选验证 +- **淘宝trace 5**: 销量最高的Type-C数据线筛选验证 +- **淘宝trace 6 7 8** + +### 框架改进 + +#### 通用性提升 +- **避免硬编码**:动态检查器不包含特定业务逻辑,完全依赖配置驱动 +- **模式扩展**:支持任意数量的条件模式,易于扩展新的筛选类型 +- **灵活验证**:支持多种验证字段组合,适应不同数据格式 + +#### 配置增强 +- **后备策略**:当基础匹配失败时自动使用LLM后备验证 + +### 兼容性 +- **向后兼容**:所有原有功能保持完全兼容 +- **渐进采用**:新功能可以逐步集成,不影响现有配置 +- **扩展友好**:框架设计支持未来的功能扩展 + +## 2025念8月14日 + +- **OCR引擎**: 加入了OCR引擎,支持对图片文字处理 +- **xml辅助**: 在ocr不可用时,若对应的trace中有xml文件,可直接提取xml中所有的文本,用于关键词匹配 +- **条件模式映射**:支持复杂的条件-验证关键词映射 +- **待增加**: 下一步将增加图标元素的辅助识别(ui),实现更准确、迅速的本地关键步骤识别 \ No newline at end of file diff --git a/MobiFlow/README.md b/MobiFlow/README.md new file mode 100644 index 0000000..efebefb --- /dev/null +++ b/MobiFlow/README.md @@ -0,0 +1,458 @@ +# MobiFlow: 基于dag的移动代理基准测试框架 + +一个离线验证框架:在收集到某次任务执行的完整轨迹(关键帧/事件)后,读取任务的 DAG 配置,检查是否存在满足依赖与顺序约束的"满足路径",从而判断该次执行是否达到任务目标。 + +## 功能特性 + +### 核心验证能力 +- **多层级条件检查**:支持文本匹配、正则表达式、UI状态、XML解析、图标检测、OCR识别、LLM推理等多种判定方法 +- **图标检测识别**:基于OpenCV模板匹配的图标检测,支持多尺度匹配和相似度阈值控制,快速识别UI界面元素 +- **逐级升级策略**:当前一种判定方法不可用时,自动升级到更高级的判定方式 +- **多路径DAG验证**:基于有向无环图验证任务节点间的依赖关系,支持AND和OR两种语义的路径分支 +- **动态条件匹配**:支持根据任务描述动态提取条件并进行相应验证 + +### 高级验证功能 +- **双语义依赖系统**: + - `deps` 字段:AND语义,所有前置节点必须完成 + - `next` 字段:OR语义,支持多分支路径选择 +- **路径感知验证**:智能帧分配机制,避免跨路径的帧冲突 +- **约束检查**:自动检测deps/next配置冲突并发出警告 +- **路径分析日志**:执行前显示所有可能的成功路径,便于调试和配置验证 +- **界面文字识别**:识别当前界面文字内容,通过文章元素匹配关键节点 +- **智能图标识别**:`icons_match` 检查器支持基于模板匹配的图标检测,提供快速的UI状态验证能力 +- **多模态LLM验证**:支持结合截图和上下文信息进行LLM推理验证 +- **灵活配置系统**:通过YAML配置文件灵活定义各种验证条件和模式 + +## 目录结构 + +- `avdag/` 核心库 + - `conditions.py` 条件检查器与注册表(含动态匹配检查器、图标检测检查器) + - `dag.py` DAG 定义、拓扑排序、路径分析和约束检查 + - `verifier.py` 核心验证逻辑(路径感知的帧分配和多路径验证) + - `loader.py` 从 YAML/JSON 读取任务配置(支持deps和next字段) + - `trace_loader.py` 从目录结构读取轨迹数据 + - `types.py` 基本类型(Frame/Result/VerifierOptions 等,NodeSpec支持next字段) +- `tools/` 辅助工具 + - `Icon_detection/` 图标检测工具包 + - `icon_detector.py` 基于OpenCV的多尺度模板匹配核心检测器 + - `icon_detection.py` 高级图标检测服务接口 + - `config.py` 图标检测配置管理 + - `test_*.py` 图标检测功能测试脚本 +- `task_configs/icons/` 图标资源库 + - `weixin/` 微信应用图标模板 + - `bilibili/` B站应用图标模板 + - `xiecheng/` 携程应用图标模板 +- `task_configs/` 任务配置与图标资源(本目录中的可用示例配置位于此目录) + - `task_configs/*.json` JSON 格式的任务配置示例 + - `task_configs/icons/` 图标资源模板 +- `docs/` 使用说明与设计文档(包含多路径、检查器模式与OCR/LLM 使用说明) +- `tests/` 单元测试 + - `test_dependency_validation.py` 依赖约束检查测试 + - `test_next_paths.py` next节点OR语义测试 + - `test_path_analysis.py` 路径分析日志测试 +- 测试脚本 + - `test_dynamic_filter.py` 动态筛选条件验证测试 + - `test_filter_verification.py` 筛选功能验证测试 + - `test_image_verification.py` 图像验证测试 + +## 安装依赖 + +安装项目根目录的相关库即可 + +```bash +pip install -r requirements.txt +``` + +可选额外安装一个OCR辅助工具,配合Paddle进行检测: + +```bash +# 安装Tesseract OCR +sudo apt-get install tesseract-ocr + +# 安装中文语言包 +sudo apt-get install tesseract-ocr-chi-sim + +# 检查是否正确安装 +tesseract --version +``` + +### 图标检测功能额外依赖 + +```bash +pip install opencv-python numpy +``` + +## 快速开始 + +1) 查看本目录内的示例配置:`task_configs/` 中包含若干任务配置及图标资源。 + +2) 使用最小演示脚本运行(示例使用 `task_configs` 中的配置): + +```bash +python -m avdag.verifier task_configs/taobao.json trace_folder/ +``` + +输出包括: +- 路径分析日志(所有可能的成功路径) +- 是否成功(存在一条满足约束的满足路径) +- 被满足的节点与对应匹配到的帧索引 +- 一个按时间排序的满足序列(可视为线性化的 trace) + +示例输出: +``` +[INFO] === DAG 路径分析 === +[INFO] 发现 2 条可能的成功路径: + 路径 1: activate_search -> input_keyword -> results_page -> open_profile -> follow_author + 路径 2: activate_search -> input_keyword -> results_page -> follow_author +[INFO] === 路径分析结束 === +``` + +## 配置格式(YAML/JSON) + +可以使用`MobiFlow/auto_rules`中自动工具由LLM分析生成任务配置,也可手动按照如下规则编写。 + +### 基础配置示例 + +```yaml +task_id: shop_search +nodes: + - id: open_app + name: 打开购物 App + condition: + type: text_match + params: + any: ["打开了淘宝", "商城首页"] + - id: search_page + deps: [open_app] # AND语义:必须等待open_app完成 + condition: + type: ui_flag + params: + key: screen + equals: search + - id: search_keyword + deps: [search_page] + condition: + type: regex_match + params: + pattern: ".*iPhone 15.*" + - id: result_list + deps: [search_keyword] + condition: + type: text_match + params: + any: ["结果", "共", "商品"] +# 成功条件(可选)。若省略,则默认任一"汇点"节点被满足即成功。 +success: + any_of: [result_list] +``` + +### 多路径配置示例 + +```yaml +task_id: bilibili_search_follow +nodes: + - id: activate_search + name: 激活搜索功能 + condition: + type: text_match + params: + any: ["搜索", "search"] + next: [input_keyword] # OR语义:可以进入input_keyword + + - id: input_keyword + name: 输入搜索关键词 + condition: + type: text_match + params: + any: ["关键词", "搜索词"] + next: [results_page] + + - id: results_page + name: 搜索结果页面 + condition: + type: text_match + params: + any: ["搜索结果", "结果页"] + next: [follow_author, open_profile] # OR语义:两条可选路径 + + - id: open_profile + name: 打开用户主页 + condition: + type: text_match + params: + any: ["用户主页", "个人页面"] + next: [follow_author] + + - id: follow_author + name: 关注作者 + condition: + type: text_match + params: + any: ["关注", "已关注"] + +success: + any_of: [follow_author] +``` + + +### 图标检测配置示例 + +```yaml +task_id: wechat_send_message +app_id: com.tencent.mm +description: 在微信中给指定联系人或群聊发送消息 +nodes: + - id: find_contact_entry + name: 查找联系人或群聊 + condition: + type: escalate + params: + icons: + all: ["icon_001_通讯录", "icon_002_微信", "icon_000_我"] # 必须检测到所有图标 + ocr: + all: ["微信", "通讯录", "发现", "我"] + llm: + prompt: 当前页面是否为微信主界面、通讯录或搜索页面? + expected_true: true + next: [send_message_success] + + - id: send_message_success + name: 成功发送消息 + condition: + type: juxtaposition # 要求图标和OCR都成功 + params: + icons: + any: ["icon_001_回车", "icon_002_发送"] # 匹配任意发送相关图标 + threshold: 0.85 # 自定义相似度阈值 + ocr: + all: ["发送"] + +success: + any_of: [send_message_success] +``` + +### 配置说明 + +- `nodes[].deps`:该节点的前置依赖(AND 关系)- 所有依赖节点必须完成 +- `nodes[].next`:该节点的后继节点(OR 关系)- 任一后继节点可以执行 +- `condition`:由 `type` 指定检查器,`params` 为该检查器参数 +- `success`: + - `any_of: [node_id...]` 任一节点满足即判成功 + - `all_of: [node_id...]` 列表中全部节点满足才判成功 + - 若不配置,默认检查"汇点"节点(无出边)中是否存在满足的节点 + +#### 依赖系统语义 + +- **deps(AND语义)**:严格的前置依赖,所有listed节点必须先完成 +- **next(OR语义)**:灵活的后继选择,可以进入任一listed节点 +- **约束检查**:当节点同时定义deps和作为其他节点的next目标时,系统会发出警告,deps优先 + +#### 路径分析 + +验证执行前会自动分析并输出所有可能的成功路径: +``` +[INFO] 发现 2 条可能的成功路径: + 路径 1: activate_search -> input_keyword -> results_page -> open_profile -> follow_author + 路径 2: activate_search -> input_keyword -> results_page -> follow_author +``` + +### 支持的检查器类型 + +#### 基础检查器 +- `text_match`: 文本匹配 +- `regex_match`: 正则表达式匹配 +- `ui_flag`: UI状态检查 +- `xml_text_match`: XML内容匹配 +- `action_match`: 动作类型匹配 + +#### 图像识别检查器 +- `icons_match`: 图标检测匹配,基于OpenCV模板匹配技术快速识别UI界面图标 + +#### 高级检查器 +- `escalate`: 按策略升级顺序尝试多种检查方法 +- `juxtaposition`: 并列检查器,要求所有配置的检查器都必须通过 +- `dynamic_match`: 动态条件匹配,根据任务描述提取条件并验证相应操作 + +## 轨迹数据(Frames) + +框架支持两种轨迹数据格式: + +### 1. JSON格式(简单) + +把执行过程中的关键帧/事件整理为按时间排序的数组: + +```json +{ + "timestamp": 1723456789.123, + "text": "打开了淘宝,进入搜索页", + "ui": {"screen": "search"}, + "payload": {"extra": "自由扩展"} +} +``` + +### 2. 目录格式(移动端自动化) + +支持包含截图、XML、动作记录的目录结构: + +``` +trace_folder/ +├── 1.jpg # 截图 +├── 1.xml # UI布局信息 +├── 2.jpg +├── 2.xml +├── ... +├── actions.json # 动作序列 +└── react.json # 推理记录 +``` + +每帧包含的字段: +- `image`: 截图文件路径 +- `screenshot`: 截图数据(numpy数组或文件路径,用于图标检测) +- `xml_text`: UI布局的XML文本 +- `reasoning`: 推理过程描述 +- `action`: 执行的动作信息 +- `task_description`: 任务描述(用于动态匹配) +- `text`: 综合文本信息(用于简单匹配) +- `app_id`: 应用包名(用于图标路径解析) + +内置检查器会在上述字段中查找信息,你也可以注册自定义检查器(见下)。 + +## 自定义检查器 + +```python +from avdag.conditions import register_condition, ConditionChecker + +@register_condition("my_checker") +class MyChecker(ConditionChecker): + def check(self, frame: dict, params: dict, options=None) -> bool: + # 读取 frame / params 做任意判断 + return frame.get("payload", {}).get("flag") == params.get("flag") +``` + +注册后即可在配置中使用: + +```yaml +condition: + type: my_checker + params: + flag: true +``` + +## 运行测试 + +请使用下面的通用方式运行测试或直接运行工具自带的测试脚本: + +### 自定义验证选项 + +```python +from avdag.verifier import make_llm_options, verify_task_folder + +# 配置LLM验证 +opts = make_llm_options( + api_key="your-api-key", + base_url="https://api.openai.com/v1", + model="gpt-4o", + force_llm=True # 强制使用LLM验证 +) + +# 运行验证 +result = verify_task_folder("task.yaml", "trace_folder", opts) +``` + +## 设计说明 + +### 核心算法 +- **多路径DAG验证**:支持AND(deps)和OR(next)两种语义的路径分支 +- **路径感知验证**:智能帧分配,避免跨路径的帧冲突,确保验证准确性 +- **约束检查**:检查顺序严格:某节点匹配到的帧索引必须晚于其所有依赖节点 +- **验证流程**: + - 路径分析:输出所有可能的成功路径 + - 候选收集:为每个节点收集"候选帧索引集合"(基于可达性) + - 拓扑验证:计算每个节点的"最小可行索引" + - 成功判定:若 `success` 中的任一/全部目标节点存在可行索引,则判成功 + - 结果输出:给出按匹配索引排序的线性化满足序列 + +### 支持功能特性 +- **多语义依赖系统**: + - deps:AND语义,严格的前置依赖关系 + - next:OR语义,灵活的分支路径选择 +- **路径感知帧分配**:基于动态可达性分析的智能帧分配机制 +- **约束冲突检测**:自动检测deps/next配置冲突并发出警告 +- **路径分析日志**:执行前输出所有可能路径,便于调试和配置验证 +- **智能图标检测**: + - 基于OpenCV模板匹配的多尺度图标检测 + - 支持any/all匹配模式和自定义相似度阈值 + - 智能路径解析,根据应用ID自动查找图标资源 + - 非极大值抑制去重,提高检测准确性 +- **动态条件匹配**:`dynamic_match` 检查器支持根据任务描述动态提取条件并验证相应操作 +- **多模态验证**:支持结合截图、XML、推理文本、图标检测进行LLM验证 +- **升级策略**:`escalate` 检查器支持从简单到复杂的逐级验证策略(text → regex → ui → action → dynamic_match → icons → ocr → llm) +- **灵活配置**:通过YAML配置文件可以灵活定义各种复杂的验证条件 + +### 适用场景 +- **移动端多路径任务**:特别适合B站、淘宝等存在多种操作路径的应用验证 +- **UI界面状态检测**:通过图标检测快速识别应用界面状态,如微信主界面、聊天窗口等 +- **复杂分支逻辑**:支持用户可以选择不同操作路径的任务验证 +- **条件筛选验证**:支持根据任务要求动态判断是否执行了正确的筛选操作 +- **多模态验证链**:结合图标检测、OCR识别、LLM推理的逐级验证策略 +- **配置调试**:通过路径分析日志快速定位配置问题 +- **人工复核标记**:当自动验证不确定时,自动标记需要人工复核 + +## 图标检测功能详细说明 + +### 技术原理 +图标检测功能基于OpenCV模板匹配技术,采用多尺度检测和相似度阈值控制,能够在移动应用截图中快速准确地识别UI图标元素。 + +### 核心特性 +- **多尺度模板匹配**:支持0.5x到2.0x的缩放范围,适应不同分辨率的设备 +- **智能相似度控制**:可配置相似度阈值,平衡检测精度和召回率 +- **非极大值抑制**:自动去除重复检测结果,提高检测准确性 +- **路径智能解析**:根据应用包名自动查找对应图标资源 +- **批量检测优化**:支持同时检测多个图标,提高验证效率 + +### 配置参数说明 + +#### icons检查器参数 +```yaml +icons: + any: ["icon_001_通讯录", "icon_002_微信"] # 匹配任意一个图标 + all: ["icon_001_回车", "icon_002_发送"] # 必须匹配所有图标 + threshold: 0.85 # 可选:自定义相似度阈值 +``` + +#### 图标资源组织 +``` +task_configs/icons/ +├── weixin/ # 微信应用图标 +│ ├── icon_001_通讯录.jpg +│ ├── icon_002_微信.jpg +│ └── icon_000_我.jpg +├── bilibili/ # B站应用图标 +└── taobao/ # 淘宝应用图标 +``` + +### 升级策略中的位置 +在escalate检查器中,图标检测位于OCR之前,LLM之后: +``` +text → regex → action → icons → ocr → llm +``` +这样的设计确保: +1. 优先使用快速的文本和UI检查 +2. 图标检测提供视觉验证能力 +3. OCR处理复杂文本识别 +4. LLM作为最终的智能判断 + +### 性能优化 +- **图标缓存机制**:已加载的图标模板会被缓存,避免重复读取 +- **早期终止**:escalate模式下图标检测成功即返回,无需后续检查 +- **尺寸预检查**:避免处理过大的缩放模板,提高检测速度 + +### 使用建议 +1. **图标质量**:使用清晰、特征明显的图标模板 +2. **阈值调优**:根据实际效果调整相似度阈值,通常0.8-0.9为佳 +3. **命名规范**:采用统一的图标命名规则,便于管理和配置 +4. **组合使用**:结合其他检查器使用,提高验证的可靠性 + +--- + +若你需要适配真实移动端各类设备采集(OCR、UI dump、强 LLM 审核回调等),可将其加工为上述 `frames` 数组再进行离线验证;也可通过自定义检查器接入更复杂的判断逻辑。 diff --git a/MobiFlow/auto_rules/README.md b/MobiFlow/auto_rules/README.md new file mode 100644 index 0000000..bb72de4 --- /dev/null +++ b/MobiFlow/auto_rules/README.md @@ -0,0 +1,106 @@ +# Auto Rules 自动化任务配置生成 + +基于 LLM 的任务验证配置自动生成系统。从用户操作轨迹中提取任务描述,结合模版通过大语言模型生成验证配置文件。 + +## 功能特性 + +- **自动提取**: 从目录结构中提取actions.json中的任务描述 +- **模版结合**: 结合现有的示例模版(orders/modes) +- **统一提示词**: 集中管理LLM提示词,避免重复 +- **LLM生成**: 使用大语言模型生成验证配置 +- **简化流程**: 一键从任务描述到完整配置文件 + +## 安装依赖 + +```bash +pip install PyYAML openai +``` + +## 快速开始 + +### 基础使用 + +```bash +# 使用orders模版生成配置 +python main.py /path/to/taobao_test YOUR_API_KEY + +# 指定输出文件 +python main.py /path/to/taobao_test YOUR_API_KEY --output-file config.yaml + +# 使用modes模版 +python main.py /path/to/taobao_test YOUR_API_KEY --template-type modes +``` + +ordes和modes只是最基础的模版,可以自定义后作为LLM的参考使用。 + +### 试运行模式 + +```bash +# 仅提取任务描述,不调用LLM +python main.py /path/to/taobao_test dummy_key --dry-run + +# 实际使用 +python main.py ../data/taobao/type2 --output-file ./test_output.yaml +``` + +## 命令行参数 + +### 必需参数 + +- `target_dir`: 包含actions.json文件的目标目录路径 +- `api_key`: OpenAI API密钥 + +### 可选参数 + +- `--output-file, -o`: 输出配置文件路径 +- `--template-type, -t`: 模版类型,可选 `orders` 或 `modes`(默认:orders) +- `--app-name`: 应用名称(默认:从任务描述中提取) +- `--dry-run`: 仅提取任务描述,不调用LLM + +## 输入数据格式 + +系统会自动搜索目标目录及其子目录中的 `actions.json` 文件,提取其中的 `task_description` 字段。 + +### actions.json 示例 + +```json +{ + "app_name": "淘宝", + "task_description": "在淘宝搜一下苹果数据线,然后挑一款合适的", + "action_count": 5, + "actions": [...] +} +``` + +## 输出配置示例 + +生成的配置文件将基于选择的模版类型,包含完整的任务验证节点和条件检查。 + +```yaml +task_id: taobao_search_and_select +app_id: com.taobao.app +description: 在淘宝搜索商品并选择的任务验证配置 + +nodes: + - id: launch_app + name: 启动淘宝应用 + condition: + type: escalate + params: + ocr: + any: ["淘宝", "启动"] + llm: + prompt: "该步是否成功启动了淘宝应用?" + expected_true: true + next: [search_entry] + # ... 更多节点 + +success: + any_of: [complete_action] +``` + +## 注意事项 + +1. **API费用**: LLM调用会产生费用,建议先用--dry-run测试 +2. **网络连接**: 需要稳定的网络连接访问LLM API +3. **输出验证**: 建议人工审核生成的配置文件 diff --git a/MobiFlow/auto_rules/__init__.py b/MobiFlow/auto_rules/__init__.py new file mode 100644 index 0000000..6b9aecc --- /dev/null +++ b/MobiFlow/auto_rules/__init__.py @@ -0,0 +1,14 @@ +""" +Auto Rules - 自动化任务配置生成系统 + +该模块用于分析任务轨迹数据,自动生成验证配置文件。 + +主要功能: +1. 提取和分析task_description +2. 汇总同类任务描述 +3. 基于示例模版生成新的验证配置 +4. 支持多种任务类型的自动化配置生成 +""" + +__version__ = "1.0.0" +__author__ = "Auto Verify System" diff --git a/MobiFlow/auto_rules/llm_generator.py b/MobiFlow/auto_rules/llm_generator.py new file mode 100644 index 0000000..6290d33 --- /dev/null +++ b/MobiFlow/auto_rules/llm_generator.py @@ -0,0 +1,144 @@ +import openai +import yaml +import logging +from typing import Dict, Any, Optional + +from prompts import SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +class LLMConfigGenerator: + """基于LLM的配置生成器""" + + def __init__(self, + api_key: str = None, + base_url: str = "https://api.openai.com/v1", + model: str = "gpt-4"): + """ + 初始化LLM配置生成器 + + Args: + api_key: OpenAI API密钥 + base_url: API基础URL + model: 使用的模型名称 + """ + self.api_key = api_key + self.base_url = base_url + self.model = model + + # 创建OpenAI客户端 + if api_key: + self.client = openai.OpenAI( + api_key=api_key, + base_url=base_url + ) + else: + self.client = None + + def generate_config_from_prompt(self, prompt: str) -> Optional[Dict[str, Any]]: + """ + 根据完整的提示词生成配置 + + Args: + prompt: 完整的LLM提示词 + + Returns: + 生成的配置字典或原始文本 + """ + if not self.client: + logger.warning("未提供API密钥,无法生成配置") + return None + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "system", + "content": SYSTEM_PROMPT + }, + { + "role": "user", + "content": prompt + } + ], + temperature=0.3, + max_tokens=4000, + timeout=40, # 40秒超时 + ) + + response_text = response.choices[0].message.content + + # 尝试解析为YAML + config = self._parse_llm_response(response_text) + + if config: + logger.info("LLM成功生成了配置") + return config + else: + # 如果解析失败,返回原始文本 + logger.warning("LLM生成的内容无法解析为YAML,返回原始文本") + return response_text + + except Exception as e: + logger.error(f"LLM生成配置失败: {e}") + return None + + def _parse_llm_response(self, response_text: str) -> Optional[Dict[str, Any]]: + """解析LLM响应""" + try: + # 提取YAML部分 + if "```yaml" in response_text: + yaml_start = response_text.find("```yaml") + 7 + yaml_end = response_text.find("```", yaml_start) + yaml_text = response_text[yaml_start:yaml_end].strip() + elif "```" in response_text: + yaml_start = response_text.find("```") + 3 + yaml_end = response_text.find("```", yaml_start) + yaml_text = response_text[yaml_start:yaml_end].strip() + else: + yaml_text = response_text.strip() + + # 解析YAML + config = yaml.safe_load(yaml_text) + + # 简单验证配置格式 + if isinstance(config, dict) and 'task_id' in config and 'nodes' in config: + # 格式化配置以确保数组字段使用方括号格式 + formatted_config = self._format_config_arrays(config) + return formatted_config + else: + logger.warning("LLM生成的配置格式无效") + return None + + except Exception as e: + logger.error(f"解析LLM响应失败: {e}") + return None + + def _format_config_arrays(self, config: Dict[str, Any]) -> Dict[str, Any]: + """ + 格式化配置中的数组字段,确保使用方括号格式 + + Args: + config: 原始配置字典 + + Returns: + 格式化后的配置字典 + """ + def format_recursive(obj): + if isinstance(obj, dict): + formatted = {} + for key, value in obj.items(): + # 特定字段需要格式化为数组 + if key in ['deps', 'next', 'all', 'any', 'any_of'] and isinstance(value, list): + formatted[key] = value + else: + formatted[key] = format_recursive(value) + return formatted + elif isinstance(obj, list): + return [format_recursive(item) for item in obj] + else: + return obj + + return format_recursive(config) diff --git a/MobiFlow/auto_rules/main.py b/MobiFlow/auto_rules/main.py new file mode 100644 index 0000000..f68af1e --- /dev/null +++ b/MobiFlow/auto_rules/main.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +Auto Rules 主程序 +基于 LLM 的任务配置自动生成系统 + +使用方法: + python main.py [目标目录] [API密钥] [选项] + +示例: + # 分析目录并生成配置 + python main.py /path/to/taobao_test YOUR_API_KEY --output-file config.yaml + + # 仅提取任务描述不调用LLM + python main.py /path/to/taobao_test dummy --dry-run +""" + +import argparse +import logging +import sys +import json +import yaml +from pathlib import Path + +from task_extractor import TaskDescriptionExtractor +from llm_generator import LLMConfigGenerator +from prompts import generate_user_prompt + +# 添加上级目录到路径以导入llmconfig +sys.path.append(str(Path(__file__).parent.parent)) +import llmconfig + + +def format_yaml_with_brackets(config: dict) -> str: + """ + 格式化YAML输出,确保特定数组字段使用方括号格式 + + Args: + config: 配置字典 + + Returns: + 格式化后的YAML字符串 + """ + import re + + # 首先使用标准YAML格式化 + yaml_str = yaml.dump(config, default_flow_style=False, + allow_unicode=True, sort_keys=False, indent=2) + + # 逐行处理,修复数组格式 + lines = yaml_str.split('\n') + result_lines = [] + i = 0 + + while i < len(lines): + line = lines[i] + + # 检查是否是需要格式化的数组字段 + if re.match(r'(\s+)(deps|next|any_of):\s*$', line): + # 处理单项数组 + indent = re.match(r'(\s+)', line).group(1) if re.match(r'(\s+)', line) else '' + field_name = re.search(r'(deps|next|any_of):', line).group(1) + + # 查找下一行的数组项 + if i + 1 < len(lines) and re.match(r'\s*-\s*(.+)', lines[i + 1]): + item = re.match(r'\s*-\s*(.+)', lines[i + 1]).group(1) + result_lines.append(f'{indent}{field_name}: [{item}]') + i += 2 # 跳过数组项行 + continue + + elif re.match(r'(\s+)(all|any):\s*$', line): + # 处理多项数组 + indent = re.match(r'(\s+)', line).group(1) if re.match(r'(\s+)', line) else '' + field_name = re.search(r'(all|any):', line).group(1) + + # 收集所有数组项 + items = [] + j = i + 1 + while j < len(lines) and re.match(r'\s*-\s*(.+)', lines[j]): + item = re.match(r'\s*-\s*(.+)', lines[j]).group(1) + # 确保引号正确 + if not (item.startswith('"') and item.endswith('"')): + item = f'"{item}"' + items.append(item) + j += 1 + + if items: + formatted_items = ', '.join(items) + result_lines.append(f'{indent}{field_name}: [{formatted_items}]') + i = j # 跳过所有数组项行 + continue + + # 普通行直接添加 + result_lines.append(line) + i += 1 + + return '\n'.join(result_lines) + + +def load_template_yaml(template_type: str) -> str: + """加载模版 YAML 文件内容""" + base_dir = Path(__file__).parent.parent / "task_rules" + + if template_type == "orders": + template_file = base_dir / "example_checker_ordes.yaml" + elif template_type == "modes": + template_file = base_dir / "example_checker_modes.yaml" + else: + raise ValueError(f"不支持的模版类型: {template_type}") + + if not template_file.exists(): + raise FileNotFoundError(f"模版文件不存在: {template_file}") + + with open(template_file, 'r', encoding='utf-8') as f: + return f.read() + + +def generate_llm_prompt(task_descriptions: list, template_yaml: str, app_name: str = "unknown") -> str: + """生成发送给 LLM 的提示词(使用共享的提示词模块)""" + return generate_user_prompt(task_descriptions, template_yaml, app_name) + + +def main(): + """主程序入口""" + parser = argparse.ArgumentParser( + description="基于 LLM 的任务配置自动生成系统", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + + # 必需参数 + parser.add_argument("target_dir", + help="包含actions.json文件的目标目录路径") + + # parser.add_argument("api_key", + # help="OpenAI API密钥") + + # 可选参数 + parser.add_argument("--output-file", "-o", + help="输出配置文件路径(默认:基于目录名生成)") + + parser.add_argument("--template-type", "-t", + choices=["orders", "modes"], + default="orders", + help="使用的模版类型 (默认: orders)") + + parser.add_argument("--app-name", + help="应用名称(默认从任务描述中提取)") + + parser.add_argument("--dry-run", + action="store_true", + help="仅提取和显示任务描述,不调用LLM") + + args = parser.parse_args() + + # 设置日志 + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + logger = logging.getLogger(__name__) + + try: + # 验证参数 + target_dir = Path(args.target_dir) + if not target_dir.exists(): + logger.error(f"目标目录不存在: {args.target_dir}") + sys.exit(1) + + # 步骤1: 提取任务描述 + logger.info(f"从 {args.target_dir} 提取任务描述...") + extractor = TaskDescriptionExtractor() + task_data = extractor.extract_from_directory(args.target_dir) + + if not task_data: + logger.warning("未找到任何有效的actions.json文件") + sys.exit(0) + + # 提取task_description列表 + task_descriptions = [task['task_description'] for task in task_data if task['task_description']] + + if not task_descriptions: + logger.warning("未找到任何有效的task_description") + sys.exit(0) + + logger.info(f"成功提取 {len(task_descriptions)} 个任务描述") + + # 显示提取的任务描述 + print(f"\n提取到的任务描述:") + for i, desc in enumerate(task_descriptions, 1): + print(f"{i}. {desc}") + + # 确定应用名称 + app_name = args.app_name + if not app_name: + app_names = set(task.get('app_name', '') for task in task_data if task.get('app_name')) + if app_names: + app_name = list(app_names)[0] + else: + app_name = target_dir.name + + logger.info(f"使用应用名称: {app_name}") + + if args.dry_run: + logger.info("试运行模式,不调用LLM生成配置") + return + + # 步骤2: 加载模版 + logger.info(f"加载 {args.template_type} 模版...") + template_yaml = load_template_yaml(args.template_type) + + # 步骤3: 生成LLM提示词 + logger.info("生成LLM提示词...") + prompt = generate_llm_prompt(task_descriptions, template_yaml, app_name) + + # 步骤4: 调用LLM生成配置 + logger.info("调用LLM生成配置...") + + api_key = llmconfig.API_KEY + base_url = llmconfig.BASE_URL + model = llmconfig.MODEL + llm_generator = LLMConfigGenerator(api_key, base_url,model) + + generated_config = llm_generator.generate_config_from_prompt(prompt) + + if not generated_config: + logger.error("LLM生成配置失败") + sys.exit(1) + + # 步骤5: 保存配置文件 + if args.output_file: + output_path = Path(args.output_file) + else: + # 基于目录名生成输出文件名 + safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in target_dir.name) + output_path = Path(f"{safe_name}_{args.template_type}_config.yaml") + + # 确保输出目录存在 + output_path.parent.mkdir(parents=True, exist_ok=True) + + logger.info(f"保存配置到: {output_path}") + + with open(output_path, 'w', encoding='utf-8') as f: + if isinstance(generated_config, dict): + # 使用自定义格式化保存YAML,确保数组使用方括号格式 + yaml_content = format_yaml_with_brackets(generated_config) + f.write(yaml_content) + else: + f.write(generated_config) + + logger.info("任务配置生成完成!") + print(f"\n✓ 配置文件已生成: {output_path}") + + except KeyboardInterrupt: + logger.info("用户中断程序") + sys.exit(0) + except Exception as e: + logger.error(f"程序执行失败: {e}") + sys.exit(1) + + +def print_analysis_summary(analysis, group_name): + """打印分析结果摘要""" + print(f"\n=== {group_name} 分析结果 ===") + print(f"总任务数: {analysis['total_tasks']}") + print(f"应用名称: {', '.join(analysis['app_names'])}") + + if analysis['common_actions']: + print("常见动作:") + for action in analysis['common_actions'][:5]: # 显示前5个 + print(f" - {action['action']}: {action['count']}次 ({action['frequency']:.1%})") + + if analysis['task_patterns']: + print("识别的任务模式:") + for pattern in analysis['task_patterns']: + print(f" - {pattern['name']}: 置信度 {pattern['confidence']:.1%}") + + complexity = analysis.get('complexity_analysis', {}) + if complexity: + level = complexity.get('complexity_level', 'unknown') + avg_actions = complexity.get('avg_actions', 0) + print(f"复杂度: {level} (平均 {avg_actions:.1f} 步操作)") + + print() + + +if __name__ == "__main__": + main() diff --git a/MobiFlow/auto_rules/prompts.py b/MobiFlow/auto_rules/prompts.py new file mode 100644 index 0000000..c112e02 --- /dev/null +++ b/MobiFlow/auto_rules/prompts.py @@ -0,0 +1,125 @@ +""" +LLM提示词模块 +统一管理所有LLM相关的提示词内容 +""" + +import json +from typing import List + + +# 系统提示词 +SYSTEM_PROMPT = """你是一个专业的任务验证配置生成专家,擅长根据用户任务描述和参考模版生成精确的YAML验证配置文件,用于检测判定任务的关键节点。 + +## 满足以下要求: + +1. **优化OCR关键词**: 根据实际任务描述,为每个节点的OCR检查添加更准确和全面的关键词列表,关键词不和具体任务的具体内容相关,而是针对任务类型和常见操作的通用关键词。 +2. **改进LLM提示词**: 让LLM验证提示更具体、更符合实际场景,且能准确判断任务节点是否完成。 +3. **优化节点路径**: 根据任务模式调整节点之间的依赖关系和路径选择,节点和路径针对该类任务的共性进行优化。 +4. **调整检查器类型**: 根据验证需求选择合适的escalate或juxtaposition类型(严格但耗时)。 +5. **针对关键节点**: 配置文件应重点关注任务中的关键节点,确保这些节点的验证准确有效,最终完成状态能代表整个任务完成,不一定要覆盖所有步骤 + +## 配置规则约束: +- 只能使用type: escalate 或 type: juxtaposition +- params只能包含ocr和llm字段 +- ocr可以使用any或all,使用[]格式 +- llm必须包含prompt和expected_true字段 +- 保持deps (AND语义) 和 next (OR语义) 的正确使用,且使用[,]格式 + +请直接返回优化后的完整YAML配置,不要包含其他解释文字。""" + + +# 配置要求模板 +CONFIG_REQUIREMENTS = """## 配置要求: + +1. **基本信息**: + - 生成合适的task_id、app_id、task_type和description + - 基于实际的应用名称和任务类型 + +2. **节点设计**: + - 根据任务流程设计合理的节点序列 + - 使用deps (AND语义) 表示强制依赖 + - 使用next (OR语义) 表示可选路径 + - 确保节点间的逻辑关系正确 + +3. **条件检查**: + - type只能是escalate或juxtaposition + - params只能包含ocr和llm字段 + - ocr使用any表示任意匹配,all表示全部匹配 + - llm包含prompt和expected_true字段 + - 根据任务特点选择合适的关键词 + +4. **成功条件**: + - 设置合理的success条件 + - 使用any_of或all_of""" + + +# 配置示例格式 +CONFIG_EXAMPLE = """## 参考配置示例格式: + +```yaml +task_id: example_task +app_id: com.example.app +task_type: demo +description: 示例任务描述 + +nodes: + - id: step1 + name: 第一步 + next: [step2] + condition: + type: escalate + params: + ocr: + any: ["关键词1", "关键词2"] + llm: + prompt: "该步是否完成了XXX操作?" + expected_true: true + + + - id: step2 + name: 第二步 + condition: + type: juxtaposition + params: + ocr: + all: ["确认", "完成"] + llm: + prompt: "该步是否显示了完成状态?" + expected_true: true + +success: + any_of: [step2] +```""" + + +def generate_user_prompt(task_descriptions: List[str], template_yaml: str, app_name: str = "unknown") -> str: + """ + 生成用户提示词 + + Args: + task_descriptions: 任务描述列表 + template_yaml: 参考模版YAML内容 + app_name: 应用名称 + + Returns: + 完整的用户提示词 + """ + prompt = f"""请根据以下任务描述和参考模版,生成一个完整的任务验证DAG配置: + +## 任务描述列表 +{json.dumps(task_descriptions, ensure_ascii=False, indent=2)} + +## 参考模版配置 +以下是参考模版的结构和格式: + +```yaml +{template_yaml} +``` + +{CONFIG_REQUIREMENTS} + +{CONFIG_EXAMPLE} + +请直接返回完整的YAML配置,不要包含其他解释文字。""" + + return prompt.strip() diff --git a/MobiFlow/auto_rules/run_all.sh b/MobiFlow/auto_rules/run_all.sh new file mode 100644 index 0000000..9a5f0dd --- /dev/null +++ b/MobiFlow/auto_rules/run_all.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env zsh +# run_all.sh 放在项目根目录里直接 ./run_all.sh 即可 + +set -euo pipefail + +DATA_ROOT="../data" +RULE_ROOT="../task_rules" + +# 可选:自动创建 task_rules 目录 +# mkdir -p "$RULE_ROOT" + +# 遍历 data 下的一级目录(排除 .7z 文件) +for app_dir in "$DATA_ROOT"/*(/); do + app_name="${app_dir:t}" # bilibili、cloudmusic … + mkdir -p "$RULE_ROOT/$app_name" # 确保输出目录存在 + + # 遍历该目录下的 type* 子目录 + for type_dir in "$app_dir"/type*(/); do + type_name="${type_dir:t}" # type1、type2 … + out_file="$RULE_ROOT/$app_name/${app_name}-${type_name}.yaml" + + echo "→ Processing $type_dir → $out_file" + python main.py "$type_dir" --output-file "$out_file" + done +done + +echo "All done!" \ No newline at end of file diff --git a/MobiFlow/auto_rules/task_extractor.py b/MobiFlow/auto_rules/task_extractor.py new file mode 100644 index 0000000..cbfdbe0 --- /dev/null +++ b/MobiFlow/auto_rules/task_extractor.py @@ -0,0 +1,142 @@ +import json +import os +from pathlib import Path +from typing import List, Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) + + +class TaskDescriptionExtractor: + """任务描述提取器""" + + def __init__(self): + self.supported_files = ['actions.json'] + + def extract_from_directory(self, directory_path: str) -> List[Dict[str, Any]]: + """ + 从指定目录及其子目录中提取所有任务描述 + + Args: + directory_path: 目标目录路径 + + Returns: + 包含任务描述信息的字典列表 + """ + task_descriptions = [] + directory = Path(directory_path) + + if not directory.exists(): + logger.error(f"目录不存在: {directory_path}") + return task_descriptions + + # 遍历目录及子目录 + for root, dirs, files in os.walk(directory): + for file in files: + if file in self.supported_files: + file_path = Path(root) / file + task_info = self._extract_from_file(file_path) + if task_info: + task_info['source_path'] = str(file_path) + task_info['source_dir'] = str(Path(root)) + task_descriptions.append(task_info) + + logger.info(f"从 {directory_path} 提取到 {len(task_descriptions)} 个任务描述") + return task_descriptions + + def _extract_from_file(self, file_path: Path) -> Optional[Dict[str, Any]]: + """ + 从单个文件中提取任务描述 + + Args: + file_path: 文件路径 + + Returns: + 任务信息字典或None + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 提取关键信息 + task_info = { + 'task_description': data.get('task_description', ''), + 'old_task_description': data.get('old_task_description', ''), + 'app_name': data.get('app_name', ''), + 'task_type': data.get('task_type'), + 'action_count': data.get('action_count', 0), + 'actions': data.get('actions', []) + } + + # 验证必要字段 + if not task_info['task_description']: + logger.warning(f"文件 {file_path} 中缺少task_description字段") + return None + + return task_info + + except json.JSONDecodeError as e: + logger.error(f"解析JSON文件失败 {file_path}: {e}") + return None + except Exception as e: + logger.error(f"读取文件失败 {file_path}: {e}") + return None + + def group_by_app(self, task_descriptions: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """ + 按应用名称分组任务描述 + + Args: + task_descriptions: 任务描述列表 + + Returns: + 按app_name分组的字典 + """ + grouped = {} + for task in task_descriptions: + app_name = task.get('app_name', 'unknown') + if app_name not in grouped: + grouped[app_name] = [] + grouped[app_name].append(task) + + return grouped + + def group_by_task_type(self, task_descriptions: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """ + 按任务类型分组任务描述 + + Args: + task_descriptions: 任务描述列表 + + Returns: + 按task_type分组的字典 + """ + grouped = {} + for task in task_descriptions: + task_type = task.get('task_type', 'general') + if task_type not in grouped: + grouped[task_type] = [] + grouped[task_type].append(task) + + return grouped + + +if __name__ == "__main__": + # 测试代码 + logging.basicConfig(level=logging.INFO) + + extractor = TaskDescriptionExtractor() + + # 测试提取 + test_dir = "./data" + task_descriptions = extractor.extract_from_directory(test_dir) + + print(f"提取到 {len(task_descriptions)} 个任务描述:") + for task in task_descriptions: + print(f"- {task['app_name']}: {task['task_description']}") + + # 测试分组 + grouped_by_app = extractor.group_by_app(task_descriptions) + print(f"\n按应用分组:") + for app, tasks in grouped_by_app.items(): + print(f"- {app}: {len(tasks)} 个任务") diff --git a/MobiFlow/auto_rules/test_output_examples.yaml b/MobiFlow/auto_rules/test_output_examples.yaml new file mode 100644 index 0000000..cd8c1a1 --- /dev/null +++ b/MobiFlow/auto_rules/test_output_examples.yaml @@ -0,0 +1,72 @@ +task_id: taobao_search_and_add_to_cart +app_id: com.taobao.taobao +task_type: shopping +description: 在淘宝应用中,根据指令搜索指定商品并将其成功加入购物车的任务验证配置。 +nodes: +- id: launch_app + name: 启动淘宝应用 + condition: + type: escalate + params: + ocr: + any: ["淘宝", "首页", "推荐", "消息", "购物车", "我的淘宝"] + llm: + prompt: 当前界面是否为淘宝应用的首页或主界面? + expected_true: true + next: [search_entry] +- id: search_entry + name: 进入搜索界面 + condition: + type: escalate + params: + ocr: + any: ["搜索", "发现", "扫一扫", "拍照搜", "search", "🔍"] + llm: + prompt: 用户是否已经点击了顶部的搜索框或通过其他方式进入了搜索输入界面? + expected_true: true + next: [input_search_keyword] +- id: input_search_keyword + name: 输入搜索关键词 + condition: + type: escalate + params: + ocr: + any: ["搜索", "取消", "清空", "键盘", "输入法"] + llm: + prompt: 用户是否在搜索框中输入了与任务描述(task_description)中指定的商品相关的关键词? + expected_true: true + next: [view_search_results] +- id: view_search_results + name: 查看搜索结果列表 + condition: + type: escalate + params: + ocr: + any: ["综合", "销量", "筛选", "店铺", "价格", "¥", "广告"] + llm: + prompt: 当前界面是否展示了与搜索关键词相关的商品结果列表? + expected_true: true + next: [select_product] +- id: select_product + name: 选择商品进入详情页 + condition: + type: escalate + params: + ocr: + any: ["商品详情", "宝贝详情", "评价", "参数", "客服", "店铺", "购物车", "加入购物车", "立即购买"] + llm: + prompt: 用户是否从商品列表中选择了一个目标商品并成功进入了其详情页面? + expected_true: true + next: [add_to_cart] +- id: add_to_cart + name: 将商品加入购物车 + condition: + type: juxtaposition + params: + ocr: + any: ["加入购物车", "添加成功", "已加入", "规格", "颜色", "尺码", "数量", "确定"] + llm: + prompt: 根据当前界面信息判断,用户是否成功将任务描述(task_description)中指定的商品加入了购物车?请注意检查是否有确认提示或规格选择步骤。 + expected_true: true +success: + any_of: [add_to_cart] diff --git a/MobiFlow/avdag/__init__.py b/MobiFlow/avdag/__init__.py new file mode 100644 index 0000000..0729834 --- /dev/null +++ b/MobiFlow/avdag/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +# 空包,用于暴露公共 API +from .types import Frame, VerifyResult +from .loader import load_task +from .verifier import verify_task diff --git a/MobiFlow/avdag/conditions.py b/MobiFlow/avdag/conditions.py new file mode 100644 index 0000000..7bb70d4 --- /dev/null +++ b/MobiFlow/avdag/conditions.py @@ -0,0 +1,816 @@ +from __future__ import annotations +from typing import Dict, Callable, Any, List, Optional +import re +import numpy as np + +from .types import Frame, VerifierOptions +from .logger import get_condition_logger + +class ConditionChecker: + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + raise NotImplementedError + +_REGISTRY: Dict[str, ConditionChecker] = {} + +def register_condition(name: str): + def _wrap(cls): + _REGISTRY[name] = cls() + return cls + return _wrap + + +def get_checker(name: str) -> ConditionChecker: + if name not in _REGISTRY: + raise KeyError(f"Unknown condition type: {name}") + return _REGISTRY[name] + + +@register_condition("text_match") +class TextMatch(ConditionChecker): + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + text = str(frame.get("text", "")) + any_words: List[str] = params.get("any", []) + all_words: List[str] = params.get("all", []) + if any_words and not any(w in text for w in any_words): + return False + if all_words and not all(w in text for w in all_words): + return False + return bool(any_words or all_words) + + +@register_condition("regex_match") +class RegexMatch(ConditionChecker): + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + text = str(frame.get("text", "")) + pattern = params.get("pattern") + if not pattern: + return False + flags = 0 + if params.get("ignore_case"): + flags |= re.IGNORECASE + return re.search(pattern, text, flags) is not None + + +@register_condition("ui_flag") +class UIFlag(ConditionChecker): + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + ui = frame.get("ui", {}) or {} + key = params.get("key") + if key is None: + return False + value = ui.get(key) + if "equals" in params: + return value == params["equals"] + if "in" in params and isinstance(params["in"], list): + return value in params["in"] + return value is not None + + +@register_condition("xml_text_match") +class XmlTextMatch(ConditionChecker): + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + xml_text = str((frame.get("xml_text") or frame.get("xml") or "")) + any_words: List[str] = params.get("any", []) + all_words: List[str] = params.get("all", []) + if any_words and not any(w in xml_text for w in any_words): + return False + if all_words and not all(w in xml_text for w in all_words): + return False + return bool(any_words or all_words) + + +@register_condition("escalate") +class EscalateChecker(ConditionChecker): + """按策略升级顺序尝试检查器,当任意一个检查器返回True时立即结束。 + + 严格按照 escalation_order 顺序执行:["text", "regex", "ui", "action", "dynamic_match", "ocr", "llm"] + + params 可包含各种子条件配置: + - text: 文本匹配参数 + - regex: 正则表达式匹配参数 + - ui: UI状态检查参数 + - action: 动作匹配参数 + - dynamic_match: 动态匹配参数 + - ocr: OCR识别参数 + - llm: LLM验证参数 + + 按照 escalation_order 顺序依次尝试,任意一个检查器返回 True 则立即返回 True。 + """ + + def __init__(self): + self._logger = get_condition_logger() + + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + options = options or VerifierOptions() + order = options.escalation_order + + # 如果强制LLM验证且配置了LLM和LLM条件,则只检查LLM + if (options.force_llm_verification and + options.llm is not None and + params.get("llm") is not None): + self._logger.debug(f"强制LLM验证模式,frame索引: {frame.get('_index', '?')}") + ctx = { + "frame": frame, + "params": params["llm"], + "options": options, # 传递options给LLM函数 + } + res = options.llm(ctx) + self._logger.debug(f"LLM验证结果: {res}") + return res is True + + # 严格按照 escalation_order 顺序执行检查器 + # self._logger.debug(f"升级顺序: {order}") + self._logger.debug(f"配置的检查器: {list(params.keys())}") + + for checker_name in order: + # 检查当前检查器是否在params中配置 + if params.get(checker_name) is not None: + self._logger.debug(f"尝试检查器: {checker_name}") + + try: + result = False + + if checker_name == "text": + result = get_checker("text_match").check(frame, params["text"], options) + + elif checker_name == "regex": + result = get_checker("regex_match").check(frame, params["regex"], options) + + elif checker_name == "ui": + result = get_checker("ui_flag").check(frame, params["ui"], options) + + elif checker_name == "action": + # 处理 action 的两种配置方式 + action_params = params["action"] + if isinstance(action_params, dict) and action_params.get("type") == "action_match": + result = get_checker("action_match").check(frame, action_params.get("params") or {}, options) + else: + result = get_checker("action_match").check(frame, action_params, options) + + # elif checker_name == "xml": + # result = get_checker("xml_text_match").check(frame, params["xml"], options) + + elif checker_name == "dynamic_match": + result = get_checker("dynamic_match").check(frame, params["dynamic_match"], options) + + elif checker_name == "icons": + # 使用专门的图标检查器 + self._logger.debug(f"调用图标检查器,frame索引: {frame.get('_index', '未知')}") + result = get_checker("icons_match").check(frame, params["icons"], options) + self._logger.debug(f"图标检查结果: {result}") + # 如果图标检查失败且未配置LLM,则直接返回结果 + if not result and options.llm is None: + self._logger.debug(f"图标检查失败,未配置LLM,frame索引: {frame.get('_index', '未知')}") + return False + return result + elif checker_name == "ocr" and options.ocr is not None: + # 使用专门的OCR检查器 + self._logger.debug(f"调用OCR检查器,frame索引: {frame.get('_index', '未知')}") + result = get_checker("ocr_match").check(frame, params["ocr"], options) + self._logger.debug(f"OCR检查结果: {result}") + # TODO: 当前暂时避免ocr检测任务不满足时,总是调用llm检测 + # 若注释,则不管OCR一旦检测为不满足,都会继续尝试LLM验证 + # return result + + elif checker_name == "llm" and options.llm is not None: + ctx = { + "frame": frame, + "params": params["llm"], + "options": options, # 传递options给LLM函数 + } + llm_result = options.llm(ctx) + result = llm_result is True + + self._logger.debug(f"{checker_name} 检查结果: {result}") + + # 如果当前检查器返回True,立即返回True(escalate的核心逻辑) + if result: + self._logger.debug(f"{checker_name} 检查成功,立即返回True") + return True + + except Exception as e: + self._logger.warning(f"{checker_name} 检查器执行失败: {e}") + continue + else: + if checker_name in ["ocr", "icons"]: + self._logger.debug(f"跳过未配置的检查器: {checker_name}") + + # 所有配置的检查器都失败 + self._logger.debug("所有检查器都失败,返回False") + return False + + +@register_condition("juxtaposition") +class JuxtapositionChecker(ConditionChecker): + """并列检查器:要求所有配置的检查器都必须通过且结果一致。 + + params 可包含多个子条件配置: + - text / regex / ui / xml: 与对应基础检查器兼容的参数 + - action: 动作匹配参数 + - dynamic_match: 动态匹配参数 + - ocr: OCR识别参数 + - llm: LLM验证参数 + + 所有配置的检查器都必须返回 True,才认为该节点验证成功。 + """ + + def __init__(self): + self._logger = get_condition_logger() + + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + options = options or VerifierOptions() + + # 收集所有配置的检查器及其结果 + configured_checkers = [] + results = [] + + # 1) text 检查 + if params.get("text") is not None: + configured_checkers.append("text_match") + result = get_checker("text_match").check(frame, params["text"], options) + results.append(result) + self._logger.debug(f"text_match 结果: {result}") + if not result: + self._logger.debug("text_match 检查失败,跳过后续检查") + return False + + # 2) regex 检查 + if params.get("regex") is not None: + configured_checkers.append("regex_match") + result = get_checker("regex_match").check(frame, params["regex"], options) + results.append(result) + self._logger.debug(f"regex_match 结果: {result}") + if not result: + self._logger.debug("regex_match 检查失败,跳过后续检查") + return False + + # 3) ui 检查 + if params.get("ui") is not None: + configured_checkers.append("ui_flag") + result = get_checker("ui_flag").check(frame, params["ui"], options) + results.append(result) + self._logger.debug(f"ui_flag 结果: {result}") + if not result: + self._logger.debug("ui_flag 检查失败,跳过后续检查") + return False + + # 4) action 检查 + if params.get("action") is not None: + configured_checkers.append("action_match") + # 处理嵌套的action配置 + action_params = params["action"] + if isinstance(action_params, dict) and action_params.get("type") == "action_match": + result = get_checker("action_match").check(frame, action_params.get("params") or {}, options) + else: + result = get_checker("action_match").check(frame, action_params, options) + results.append(result) + self._logger.debug(f"action_match 结果: {result}") + if not result: + self._logger.debug("action_match 检查失败,跳过后续检查") + return False + + # 5) xml 检查 + if params.get("xml") is not None: + configured_checkers.append("xml_text_match") + result = get_checker("xml_text_match").check(frame, params["xml"], options) + results.append(result) + self._logger.debug(f"xml_text_match 结果: {result}") + if not result: + self._logger.debug("xml_text_match 检查失败,跳过后续检查") + return False + + # 6) dynamic_match 检查 + if params.get("dynamic_match") is not None: + configured_checkers.append("dynamic_match") + result = get_checker("dynamic_match").check(frame, params["dynamic_match"], options) + results.append(result) + self._logger.debug(f"dynamic_match 结果: {result}") + if not result: + self._logger.debug("dynamic_match 检查失败,跳过后续检查") + return False + + # 7) icons 检查 + if params.get("icons") is not None: + configured_checkers.append("icons") + # 使用专门的图标检查器 + icons_result = get_checker("icons_match").check(frame, params["icons"], options) + results.append(icons_result) + self._logger.debug(f"图标检测最终结果: {icons_result}") + if not icons_result: + self._logger.debug("图标检测失败,跳过后续检查") + return False + + # 8) ocr 检查(需要 options.ocr 支持) + if params.get("ocr") is not None and options.ocr is not None: + configured_checkers.append("ocr") + # 使用专门的OCR检查器 + ocr_result = get_checker("ocr_match").check(frame, params["ocr"], options) + results.append(ocr_result) + self._logger.debug(f"OCR最终结果: {ocr_result}") + if not ocr_result: + self._logger.debug("OCR 检查失败,跳过后续检查") + return False + + # 9) llm 检查(需要 options.llm 支持) + if params.get("llm") is not None and options.llm is not None: + configured_checkers.append("llm") + ctx = { + "frame": frame, + "params": params["llm"], + "options": options, # 传递options给LLM函数 + } + llm_result = options.llm(ctx) + results.append(llm_result is True) + self._logger.debug(f"llm 结果: {llm_result is True}") + + # 检查是否至少配置了一个检查器 + if not configured_checkers: + self._logger.warning("没有配置任何检查器") + return False + + # 所有配置的检查器都必须返回 True + final_result = all(results) + self._logger.debug(f"配置的检查器: {configured_checkers}") + self._logger.debug(f"各检查器结果: {results}") + self._logger.debug(f"最终结果: {final_result}") + + return final_result + + +@register_condition("ocr_match") +class OCRMatch(ConditionChecker): + """OCR匹配检查器,使用增强的文本处理功能""" + + def __init__(self): + self._logger = get_condition_logger() + + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + self._logger.debug("=====开始OCR匹配检查=====") + self._logger.debug(f"frame索引: {frame.get('_index', '未知')}") + self._logger.debug(f"检查params: {params}") + self._logger.debug(f"options存在: {options is not None}") + self._logger.debug(f"options.ocr存在: {options.ocr is not None if options else False}") + + # 初始化结果记录 + matched_keywords = [] + unmatched_keywords = [] + checker_result = "" + + if not options or not options.ocr: + self._logger.warning("OCR选项不可用,返回False") + checker_result = "OCR选项不可用" + # 在frame中记录检查结果 + frame['_last_ocr_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return False + + # 获取OCR文本 + ocr_text = options.ocr(frame) or "" + if not ocr_text.strip(): + checker_result = "OCR识别文本为空" + frame['_last_ocr_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return False + + # self._logger.debug(f"OCR原始返回: {ocr_text[:100]}...") + self._logger.debug(f"OCR原始返回: {ocr_text}") + + # 检查是否有缓存的处理结果 + processed_result = frame.get('_ocr_processed') or frame.get('_xml_processed') + + # 文本包含匹配 - any条件 + if "any" in params: + any_keywords = params["any"] + matched_any = [] + + # 方式1:原始OCR文本检查 + for w in any_keywords: + if w in ocr_text: + matched_any.append(w) + + if matched_any: + self._logger.debug(f"any匹配(原始): {matched_any}") + matched_keywords.extend(matched_any) + checker_result = f"OCR识别成功,匹配关键词: {matched_any}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + + # 方式2:智能匹配 + if processed_result: + try: + from .ocr_processor import get_ocr_processor + processor = get_ocr_processor() + for keyword in any_keywords: + if processor.smart_text_contains(processed_result, keyword): + self._logger.debug(f"any匹配(智能): {keyword}") + matched_keywords.append(keyword) + checker_result = f"OCR智能匹配成功,匹配关键词: {keyword}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + except ImportError: + self._logger.warning("OCRProcessor不可用,使用基础匹配") + + # 记录所有未匹配的any关键词 + unmatched_keywords.extend([w for w in any_keywords if w not in matched_any]) + + # 文本包含匹配 - all条件 + if "all" in params: + all_keywords = params["all"] + matched_all = [] + + # 方式1:原始OCR文本检查 + for w in all_keywords: + if w in ocr_text: + matched_all.append(w) + + if len(matched_all) == len(all_keywords): + self._logger.debug(f"all匹配(原始): {all_keywords}") + matched_keywords.extend(matched_all) + checker_result = f"OCR识别成功,匹配所有关键词: {all_keywords}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + # 方式2:智能匹配 + elif processed_result: + try: + from .ocr_processor import get_ocr_processor + processor = get_ocr_processor() + smart_matched = [] + for keyword in all_keywords: + if processor.smart_text_contains(processed_result, keyword): + smart_matched.append(keyword) + + if len(smart_matched) == len(all_keywords): + self._logger.debug(f"all匹配(智能): {all_keywords}") + matched_keywords.extend(smart_matched) + checker_result = f"OCR智能匹配成功,匹配所有关键词: {all_keywords}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + else: + # 记录智能匹配下未匹配的关键词 + unmatched_keywords.extend([w for w in all_keywords if w not in smart_matched]) + matched_keywords.extend(smart_matched) + except ImportError: + self._logger.warning("OCRProcessor不可用,使用基础匹配") + # 记录原始匹配下未匹配的关键词 + unmatched_keywords.extend([w for w in all_keywords if w not in matched_all]) + matched_keywords.extend(matched_all) + else: + # 记录原始匹配下未匹配的关键词 + unmatched_keywords.extend([w for w in all_keywords if w not in matched_all]) + matched_keywords.extend(matched_all) + + # 正则匹配 + if "pattern" in params: + pattern = params["pattern"] + flags = re.IGNORECASE if params.get("ignore_case") else 0 + + # 方式1:对原始文本应用正则 + if re.search(pattern, ocr_text, flags): + self._logger.debug(f"正则匹配(原始): {pattern}") + checker_result = f"OCR正则匹配成功,模式: {pattern}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + # 方式2:对处理后的文本格式应用正则 + elif processed_result: + for text_format in [processed_result.cleaned, processed_result.no_spaces, ' '.join(processed_result.words)]: + if text_format and re.search(pattern, text_format, flags): + self._logger.debug(f"正则匹配(处理): {pattern} -> {text_format[:50]}...") + checker_result = f"OCR智能正则匹配成功,模式: {pattern}" + frame['_last_ocr_result'] = { + 'success': True, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + return True + + # 正则匹配失败,记录模式 + unmatched_keywords.append(f"pattern: {pattern}") + + # 构建失败原因 + if unmatched_keywords: + checker_result = f"OCR识别失败,未匹配关键词: {unmatched_keywords}" + self._logger.debug(f"未匹配的关键词: {unmatched_keywords}") + else: + checker_result = "OCR识别失败,无匹配条件" + + if processed_result: + self._logger.debug(f"check keywords: any: {params.get('any', [])} / all: {params.get('all', [])}") + # self._logger.debug(f"处理文本格式 - 清理: {processed_result.cleaned[:50]}...") + # self._logger.debug(f"处理文本格式 - 无空格: {processed_result.no_spaces[:50]}...") + self._logger.debug(f"处理文本格式 - 清理: {processed_result.cleaned}") + self._logger.debug(f"处理文本格式 - 无空格: {processed_result.no_spaces}") + self._logger.debug(f"处理文本格式 - 词语数: {len(processed_result.words)}") + + # 记录失败结果 + frame['_last_ocr_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_keywords': matched_keywords, + 'unmatched_keywords': unmatched_keywords + } + + return False + +__all__ = ["ConditionChecker", "register_condition", "get_checker"] + + +@register_condition("action_match") +class ActionMatch(ConditionChecker): + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + act = frame.get("action") or {} + if not isinstance(act, dict): + return False + t = params.get("type") + if t and act.get("type") != t: + return False + contains: Dict[str, Any] = params.get("contains") or {} + for k, v in contains.items(): + if act.get(k) != v: + return False + return True if (t or contains) else False + + +@register_condition("dynamic_match") +class DynamicMatchChecker(ConditionChecker): + """基于动态配置的通用匹配检查器,支持从任务描述中提取关键信息进行匹配""" + + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + """ + 动态匹配检查器,支持多种匹配策略: + + params 支持的配置: + - extract_from: 指定从哪个字段提取信息 (task_description, reasoning等) + - condition_patterns: 条件模式映射,每个模式包含匹配关键词和对应的验证关键词 + - verification_fields: 验证字段列表,指定在哪些字段中查找验证关键词 + - fallback_llm: 当基础匹配不确定时是否使用LLM验证 + """ + extract_from = params.get("extract_from", "task_description") + condition_patterns = params.get("condition_patterns", {}) + verification_fields = params.get("verification_fields", ["reasoning", "text"]) + + # 提取源文本 + source_text = frame.get(extract_from, "").lower() + if not source_text: + return False + + # 找到匹配的条件模式 + matched_condition = None + for condition_name, pattern_config in condition_patterns.items(): + trigger_keywords = pattern_config.get("trigger_keywords", []) + if any(keyword.lower() in source_text for keyword in trigger_keywords): + matched_condition = condition_name + break + + if not matched_condition: + return False + + # 获取对应的验证关键词 + pattern_config = condition_patterns[matched_condition] + verify_keywords = pattern_config.get("verify_keywords", []) + + # 在验证字段中查找关键词 + for field in verification_fields: + field_text = frame.get(field, "").lower() + if any(keyword.lower() in field_text for keyword in verify_keywords): + return True + + # 如果基础匹配失败,使用LLM作为后备验证 + if params.get("fallback_llm") and options and options.llm: + llm_prompt = pattern_config.get("llm_prompt") or f"该步骤是否执行了与'{matched_condition}'相关的操作?" + ctx = { + "frame": frame, + "params": { + "prompt": llm_prompt, + "expected_true": True + }, + "options": options, # 传递options给LLM函数 + } + return options.llm(ctx) is True + + return False + + +@register_condition("icons_match") +class IconsMatch(ConditionChecker): + """图标匹配检查器,使用图像模板匹配检测图标是否存在""" + + def __init__(self): + self._logger = get_condition_logger() + self._detection_service = None + + def _get_detection_service(self): + """延迟导入图标检测服务""" + if self._detection_service is None: + try: + # 延迟导入避免循环依赖 + import sys + from pathlib import Path + project_root = Path(__file__).parent.parent + sys.path.insert(0, str(project_root)) + + from tools.Icon_detection import get_icon_detection_service + self._detection_service = get_icon_detection_service() + self._logger.debug("图标检测服务初始化成功") + except Exception as e: + self._logger.error(f"初始化图标检测服务失败: {e}") + self._detection_service = None + return self._detection_service + + def _extract_image_from_frame(self, frame: Frame) -> Optional[np.ndarray]: + """从frame中提取图像数据""" + # 检查frame中可能的图像字段,当前是使用frame(字典)中的img存储图像文件的完整路径 + image_fields = ['img', 'screenshot', 'image', 'frame_image', 'screen'] + + for field in image_fields: + if field in frame and frame[field] is not None: + image_data = frame[field] + + # 如果是文件路径 + if isinstance(image_data, str): + try: + import cv2 + img = cv2.imread(image_data) + if img is not None: + self._logger.debug(f"从路径加载图像: {image_data}") + return img + except Exception as e: + self._logger.warning(f"从路径加载图像失败 {image_data}: {e}") + continue + + # 如果是numpy数组 + elif isinstance(image_data, np.ndarray): + self._logger.debug(f"从字段 {field} 获取图像数据,形状: {image_data.shape}") + return image_data + + # 如果是字节数据 + elif isinstance(image_data, (bytes, bytearray)): + try: + import cv2 + nparr = np.frombuffer(image_data, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is not None: + self._logger.debug(f"从字节数据解码图像成功") + return img + except Exception as e: + self._logger.warning(f"从字节数据解码图像失败: {e}") + continue + + self._logger.warning("未在frame中找到有效的图像数据") + return None + + def check(self, frame: Frame, params: Dict[str, Any], options: Optional[VerifierOptions] = None) -> bool: + self._logger.debug("=====开始图标匹配检查=====") + self._logger.debug(f"frame索引: {frame.get('_index', '未知')}") + self._logger.debug(f"检查params: {params}") + + # 初始化结果记录 + matched_icons = [] + unmatched_icons = [] + checker_result = "" + + # 获取图标检测服务 + detection_service = self._get_detection_service() + if detection_service is None: + checker_result = "图标检测服务不可用" + frame['_last_icons_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_icons': matched_icons, + 'unmatched_icons': unmatched_icons + } + self._logger.error(checker_result) + return False + + # 提取图像数据 + image = self._extract_image_from_frame(frame) + if image is None: + checker_result = "无法从frame中提取图像数据" + frame['_last_icons_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_icons': matched_icons, + 'unmatched_icons': unmatched_icons + } + self._logger.error(checker_result) + return False + + # 获取应用ID(用于确定图标搜索路径) + app_id = frame.get('app_id') or frame.get('package_name') or frame.get('app_name') + + # 获取相似度阈值 + threshold = params.get('threshold') + + # 处理any条件 + if "any" in params: + any_icons = params["any"] + if not isinstance(any_icons, list): + any_icons = [any_icons] + + self._logger.debug(f"检查any图标: {any_icons}") + + result = detection_service.detect_icons( + image, + any_icons, + app_id, + threshold, + match_mode='any' + ) + + if result['success']: + matched_icons.extend(result['matched_icons']) + self._logger.debug(f"any图标匹配成功: {result['matched_icons']}") + # 记录成功结果 + frame['_last_icons_result'] = { + 'success': True, + 'reason': f"成功匹配图标: {result['matched_icons']}", + 'matched_icons': result['matched_icons'], + 'unmatched_icons': result['unmatched_icons'], + 'details': result['details'] + } + return True + else: + unmatched_icons.extend(result['unmatched_icons']) + self._logger.debug(f"any图标匹配失败: {result['unmatched_icons']}") + + # 处理all条件 + if "all" in params: + all_icons = params["all"] + if not isinstance(all_icons, list): + all_icons = [all_icons] + + self._logger.debug(f"检查all图标: {all_icons}") + + result = detection_service.detect_icons( + image, + all_icons, + app_id, + threshold, + match_mode='all' + ) + + if result['success']: + matched_icons.extend(result['matched_icons']) + self._logger.debug(f"all图标匹配成功: {result['matched_icons']}") + # 记录成功结果 + frame['_last_icons_result'] = { + 'success': True, + 'reason': f"成功匹配所有图标: {result['matched_icons']}", + 'matched_icons': result['matched_icons'], + 'unmatched_icons': result['unmatched_icons'], + 'details': result['details'] + } + return True + else: + unmatched_icons.extend(result['unmatched_icons']) + self._logger.debug(f"all图标匹配失败: {result['unmatched_icons']}") + + # 构建失败原因 + if unmatched_icons: + checker_result = f"图标检测失败,未匹配图标: {unmatched_icons}" + else: + checker_result = "图标检测失败,无匹配条件" + + # 记录失败结果 + frame['_last_icons_result'] = { + 'success': False, + 'reason': checker_result, + 'matched_icons': matched_icons, + 'unmatched_icons': unmatched_icons + } + + self._logger.debug(checker_result) + return False diff --git a/MobiFlow/avdag/dag.py b/MobiFlow/avdag/dag.py new file mode 100644 index 0000000..b3b8347 --- /dev/null +++ b/MobiFlow/avdag/dag.py @@ -0,0 +1,187 @@ +from __future__ import annotations +from typing import Dict, List, Set, Tuple +from collections import defaultdict, deque + +from .types import NodeSpec, TaskSpec + +class DAG: + def __init__(self, nodes: List[NodeSpec]): + self.nodes: Dict[str, NodeSpec] = {n.id: n for n in nodes} + # 统一的邻接关系(用于拓扑排序与无环校验) + self.children: Dict[str, List[str]] = defaultdict(list) + self.parents: Dict[str, List[str]] = defaultdict(list) + # 来源细分:便于验证阶段分别按 AND/OR 语义处理 + self.parents_from_deps: Dict[str, List[str]] = defaultdict(list) + self.parents_from_next: Dict[str, List[str]] = defaultdict(list) + + # 执行约束检查 + self._validate_dependencies_consistency(nodes) + + for n in nodes: + for p in (n.deps or []): + self.children[p].append(n.id) + self.parents[n.id].append(p) + self.parents_from_deps[n.id].append(p) + # 根据 next 定义增加边:n -> succ + for succ in (n.next or []): + self.children[n.id].append(succ) + self.parents[succ].append(n.id) + self.parents_from_next[succ].append(n.id) + # 验证无环 + self._assert_acyclic() + + def _assert_acyclic(self): + indeg: Dict[str, int] = {nid: 0 for nid in self.nodes} + for nid, ps in self.parents.items(): + indeg[nid] = len(ps) + q = deque([nid for nid, d in indeg.items() if d == 0]) + seen = 0 + while q: + cur = q.popleft() + seen += 1 + for ch in self.children.get(cur, []): + indeg[ch] -= 1 + if indeg[ch] == 0: + q.append(ch) + if seen != len(self.nodes): + raise ValueError("Graph contains a cycle") + + def _validate_dependencies_consistency(self, nodes: List[NodeSpec]): + """验证 deps 和 next 定义的一致性,避免混淆和冗余""" + import warnings + + # 构建 next 关系映射 + next_targets: Dict[str, List[str]] = defaultdict(list) + for node in nodes: + for succ in (node.next or []): + next_targets[succ].append(node.id) + + issues = [] + + for node in nodes: + node_id = node.id + has_deps = bool(node.deps) + has_next_parents = bool(next_targets.get(node_id)) + + # 检查1: 同时定义了 deps 和通过 next 被其他节点指向 + if has_deps and has_next_parents: + next_parents = next_targets[node_id] + deps_set = set(node.deps or []) + next_set = set(next_parents) + + # 如果 deps 和 next 路径完全重复,这是冗余的 + if deps_set == next_set: + issues.append(f"节点 '{node_id}': deps {list(deps_set)} 与 next 路径来源 {list(next_set)} 完全重复,建议只使用 deps") + # 如果 deps 和 next 路径部分重叠,可能导致混淆 + elif deps_set & next_set: + overlap = deps_set & next_set + issues.append(f"节点 '{node_id}': deps {list(deps_set)} 与 next 路径来源 {list(next_set)} 存在重叠 {list(overlap)},可能导致语义混淆") + # 如果完全不重叠,提示优先级 + else: + issues.append(f"节点 '{node_id}': 同时定义了 deps {list(deps_set)} 和 next 路径来源 {list(next_set)},将优先使用 deps (AND 语义)") + + # 检查2: 节点同时定义了 deps 和 next(虽然技术上可行,但可能混淆) + if has_deps and node.next: + issues.append(f"节点 '{node_id}': 同时定义了 deps {node.deps} 和 next {node.next},next 声明仅用于构建图结构") + + # 输出警告信息 + if issues: + warning_msg = "DAG 依赖定义一致性警告:\n" + "\n".join(f" - {issue}" for issue in issues) + warnings.warn(warning_msg, UserWarning, stacklevel=3) + + # 检查3: 验证所有引用的节点都存在 + node_ids = {n.id for n in nodes} + for node in nodes: + # 检查 deps 引用 + for dep in (node.deps or []): + if dep not in node_ids: + raise ValueError(f"节点 '{node.id}' 的 deps 引用了不存在的节点 '{dep}'") + # 检查 next 引用 + for succ in (node.next or []): + if succ not in node_ids: + raise ValueError(f"节点 '{node.id}' 的 next 引用了不存在的节点 '{succ}'") + + def topo_order(self) -> List[str]: + indeg: Dict[str, int] = {nid: 0 for nid in self.nodes} + for nid, ps in self.parents.items(): + indeg[nid] = len(ps) + q = deque([nid for nid, d in indeg.items() if d == 0]) + order: List[str] = [] + while q: + cur = q.popleft() + order.append(cur) + for ch in self.children.get(cur, []): + indeg[ch] -= 1 + if indeg[ch] == 0: + q.append(ch) + return order + + def sinks(self) -> List[str]: + return [nid for nid in self.nodes if len(self.children.get(nid, [])) == 0] + + def get_all_paths_to_targets(self, target_nodes: List[str]) -> List[List[str]]: + """获取从根节点到目标节点的所有可能路径""" + all_paths = [] + + # 找到所有根节点(无父节点的节点) + root_nodes = [nid for nid in self.nodes if len(self.parents.get(nid, [])) == 0] + + def dfs_paths(current: str, path: List[str], visited: set): + if current in visited: + return # 避免环路 + + visited.add(current) + path.append(current) + + # 如果当前节点是目标节点之一,记录路径 + if current in target_nodes: + all_paths.append(path.copy()) + + # 继续向子节点探索 + for child in self.children.get(current, []): + dfs_paths(child, path, visited.copy()) + + path.pop() + + # 从每个根节点开始探索 + for root in root_nodes: + dfs_paths(root, [], set()) + + return all_paths + + def log_possible_paths(self, success_nodes: List[str], logger): + """输出配置中存在的可能路径到日志""" + logger.info("=== DAG 路径分析 ===") + + # 输出节点依赖关系概览 + logger.debug("节点依赖关系:") + for nid in self.topo_order(): + node = self.nodes[nid] + deps_info = f"deps={node.deps}" if node.deps else "deps=None" + next_info = f"next={node.next}" if node.next else "next=None" + logger.debug(f" {nid}: {deps_info}, {next_info}") + + # 输出父子关系 + logger.debug("父子关系:") + for nid in self.nodes: + deps_parents = self.parents_from_deps.get(nid, []) + next_parents = self.parents_from_next.get(nid, []) + if deps_parents: + logger.debug(f" {nid} <- {deps_parents} (deps, AND语义)") + if next_parents: + logger.debug(f" {nid} <- {next_parents} (next, OR语义)") + + # 获取并输出所有可能路径 + all_paths = self.get_all_paths_to_targets(success_nodes) + + if all_paths: + logger.info(f"发现 {len(all_paths)} 条可能的成功路径:") + for i, path in enumerate(all_paths, 1): + path_str = " -> ".join(path) + logger.info(f" 路径 {i}: {path_str}") + else: + logger.info("未发现任何可能的成功路径") + + logger.info("=== 路径分析结束 ===\n") + +__all__ = ["DAG"] diff --git a/MobiFlow/avdag/loader.py b/MobiFlow/avdag/loader.py new file mode 100644 index 0000000..c072fef --- /dev/null +++ b/MobiFlow/avdag/loader.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import json +from typing import Any, Dict + +import yaml + +from .types import ConditionSpec, NodeSpec, SuccessSpec, TaskSpec + + +def _parse_node(d: Dict[str, Any]) -> NodeSpec: + cond = d.get("condition") + condition = None + if cond: + condition = ConditionSpec(type=cond.get("type"), params=cond.get("params", {})) + return NodeSpec( + id=d["id"], + name=d.get("name"), + deps=d.get("deps"), + next=d.get("next"), + condition=condition, + score=d.get("score", 10), # 默认分数为10分 + ) + + +def load_task(path: str) -> TaskSpec: + if path.endswith(".json"): + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + else: + with open(path, "r", encoding="utf-8") as f: + raw = yaml.safe_load(f) + nodes = [_parse_node(n) for n in raw.get("nodes", [])] + succ_raw = raw.get("success") or {} + success = SuccessSpec(any_of=succ_raw.get("any_of"), all_of=succ_raw.get("all_of")) if succ_raw else None + return TaskSpec(task_id=raw.get("task_id", "task"), nodes=nodes, success=success) + +__all__ = ["load_task"] diff --git a/MobiFlow/avdag/logger.py b/MobiFlow/avdag/logger.py new file mode 100644 index 0000000..74ddd60 --- /dev/null +++ b/MobiFlow/avdag/logger.py @@ -0,0 +1,483 @@ +""" +统一的日志系统 - 为 avdag 框架和相关工具提供灵活的日志输出配置 + +支持多种日志级别和输出方式: +- CRITICAL: 关键错误,会导致程序无法继续执行 +- ERROR: 普通错误,不影响程序继续运行 +- WARNING: 警告信息 +- INFO: 一般信息,默认显示 +- DEBUG: 调试信息,详细的执行过程 +- TRACE: 最详细的跟踪信息 + +配置方式: +1. 环境变量: AVDAG_LOG_LEVEL=DEBUG +2. 代码配置: set_log_level('DEBUG') +3. 配置文件: 通过 configure_logging() 函数加载 + +使用方式: +```python +from avdag.logger import get_logger + +logger = get_logger(__name__) +logger.info("这是一般信息") +logger.debug("这是调试信息") +logger.error("这是错误信息") +``` +""" + +import os +import sys +import json +import logging +from typing import Optional, Dict, Any, Union, List +from enum import Enum +from pathlib import Path + +class LogLevel(Enum): + """日志级别枚举""" + CRITICAL = 50 + ERROR = 40 + WARNING = 30 + INFO = 20 + DEBUG = 10 + TRACE = 5 + + @classmethod + def from_string(cls, level_str: str) -> 'LogLevel': + """从字符串获取日志级别""" + level_map = { + 'CRITICAL': cls.CRITICAL, + 'FATAL': cls.CRITICAL, + 'ERROR': cls.ERROR, + 'WARNING': cls.WARNING, + 'WARN': cls.WARNING, + 'INFO': cls.INFO, + 'DEBUG': cls.DEBUG, + 'TRACE': cls.TRACE, + } + return level_map.get(level_str.upper(), cls.INFO) + + +class AVDAGLogger: + """AVDAG 专用日志器""" + + def __init__(self, name: str): + self.name = name + self._logger = logging.getLogger(f"avdag.{name}") + + # 添加 TRACE 级别支持 + if not hasattr(logging, 'TRACE'): + logging.addLevelName(LogLevel.TRACE.value, 'TRACE') + def trace(self, msg, *args, **kwargs): + if self.isEnabledFor(LogLevel.TRACE.value): + self._log(LogLevel.TRACE.value, msg, args, **kwargs) + logging.Logger.trace = trace + + def critical(self, msg: str, *args, **kwargs): + """记录关键错误信息""" + self._logger.critical(msg, *args, **kwargs) + + def error(self, msg: str, *args, **kwargs): + """记录错误信息""" + self._logger.error(msg, *args, **kwargs) + + def warning(self, msg: str, *args, **kwargs): + """记录警告信息""" + self._logger.warning(msg, *args, **kwargs) + + def info(self, msg: str, *args, **kwargs): + """记录一般信息""" + self._logger.info(msg, *args, **kwargs) + + def debug(self, msg: str, *args, **kwargs): + """记录调试信息""" + self._logger.debug(msg, *args, **kwargs) + + def trace(self, msg: str, *args, **kwargs): + """记录最详细的跟踪信息""" + if hasattr(self._logger, 'trace'): + self._logger.trace(msg, *args, **kwargs) + else: + self._logger.log(LogLevel.TRACE.value, msg, *args, **kwargs) + + def is_enabled_for(self, level: Union[str, LogLevel]) -> bool: + """检查是否启用了指定级别的日志""" + if isinstance(level, str): + level = LogLevel.from_string(level) + return self._logger.isEnabledFor(level.value) + + +class ColoredFormatter(logging.Formatter): + """带颜色的日志格式化器""" + + # ANSI 颜色代码 + COLORS = { + 'CRITICAL': '\033[95m', # 紫色 + 'ERROR': '\033[91m', # 红色 + 'WARNING': '\033[93m', # 黄色 + 'INFO': '\033[92m', # 绿色 + 'DEBUG': '\033[96m', # 青色 + 'TRACE': '\033[90m', # 灰色 + } + RESET = '\033[0m' + + def _supports_color(self) -> bool: + """检查是否支持颜色输出(跨平台)""" + # 检查是否为TTY + if not sys.stderr.isatty(): + return False + + # 检查环境变量 + if os.getenv('NO_COLOR'): + return False + + if os.getenv('FORCE_COLOR'): + return True + + # 平台特定检查 + if sys.platform == 'win32': + # Windows: 检查是否支持ANSI转义序列 + try: + # Windows 10及以上版本通常支持ANSI颜色 + import platform + version = platform.version() + # Windows 10的版本号通常是10.0.x + if version.startswith('10.0.') or version.startswith('11.'): + return True + except: + pass + + # 检查TERM环境变量 + term = os.getenv('TERM', '').lower() + if 'color' in term or 'ansi' in term: + return True + + # 默认Windows支持(现代终端如Windows Terminal、VS Code等) + return True + + else: + # Unix/Linux: 检查TERM环境变量 + term = os.getenv('TERM', '').lower() + if term in ['dumb', 'unknown']: + return False + + # 大多数Unix终端支持颜色 + return 'color' in term or 'ansi' in term or 'xterm' in term or term in [ + 'screen', 'tmux', 'rxvt', 'konsole', 'gnome-terminal' + ] + + def __init__(self, use_colors: bool = True, show_time: bool = True, show_module: bool = True): + # 增强的颜色支持检测 + self.use_colors = use_colors and self._supports_color() + self.show_time = show_time + self.show_module = show_module + + # 构建格式字符串 + fmt_parts = [] + if show_time: + fmt_parts.append('%(asctime)s') + fmt_parts.append('[%(levelname)s]') + if show_module: + fmt_parts.append('%(name)s') + fmt_parts.append('%(message)s') + + fmt = ' '.join(fmt_parts) + super().__init__(fmt, datefmt='%H:%M:%S') + + def format(self, record): + # 创建record的副本,避免修改原始record影响其他处理器 + if self.use_colors: + # 复制record以避免修改原始对象 + import copy + record_copy = copy.copy(record) + levelname = record_copy.levelname + if levelname in self.COLORS: + record_copy.levelname = f"{self.COLORS[levelname]}{levelname}{self.RESET}" + return super().format(record_copy) + else: + return super().format(record) + + +class LoggingConfig: + """日志配置管理""" + + def __init__(self): + self._configured = False + self._loggers: Dict[str, AVDAGLogger] = {} + self._default_level = LogLevel.INFO + self._handlers: List[logging.Handler] = [] + + def configure(self, + level: Union[str, LogLevel] = LogLevel.DEBUG, + use_colors: bool = True, + show_time: bool = True, + show_module: bool = True, + output_file: Optional[str] = None, + config_file: Optional[str] = None) -> None: + """配置日志系统 + + Args: + level: 日志级别 + use_colors: 是否使用颜色输出 + show_time: 是否显示时间 + show_module: 是否显示模块名 + output_file: 输出到文件(可选) + config_file: 从配置文件加载(可选) + """ + + # 从配置文件加载设置 + if config_file and Path(config_file).exists(): + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + level = level or config.get('level', 'INFO') + use_colors = config.get('use_colors', use_colors) + show_time = config.get('show_time', show_time) + show_module = config.get('show_module', show_module) + output_file = output_file or config.get('output_file') + + # 从环境变量获取级别 + if level is None: + level = os.getenv('AVDAG_LOG_LEVEL', 'INFO') + + if isinstance(level, str): + level = LogLevel.from_string(level) + + self._default_level = level + + # 配置根日志器 + root_logger = logging.getLogger('avdag') + root_logger.setLevel(level.value) + + # 移除现有处理器(更彻底的清理) + for handler in self._handlers: + root_logger.removeHandler(handler) + self._handlers.clear() + + # 清除根日志器上的所有处理器(防止重复配置导致的问题) + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # 总是创建控制台处理器 + console_handler = logging.StreamHandler(sys.stderr) + console_formatter = ColoredFormatter(use_colors, show_time, show_module) + console_handler.setFormatter(console_formatter) + root_logger.addHandler(console_handler) + self._handlers.append(console_handler) + + # 如果指定了输出文件,同时创建文件处理器(不使用颜色) + if output_file: + try: + # 规范化文件路径(跨平台兼容) + output_path = Path(output_file).resolve() + # 确保目录存在 + output_path.parent.mkdir(parents=True, exist_ok=True) + + # 创建文件处理器,强制使用UTF-8编码 + file_handler = logging.FileHandler( + str(output_path), + encoding='utf-8', + mode='a' # 追加模式,避免覆盖现有日志 + ) + + # 构建文件格式字符串(与控制台格式一致,但不包含颜色) + fmt_parts = [] + if show_time: + fmt_parts.append('%(asctime)s') + fmt_parts.append('[%(levelname)s]') + if show_module: + fmt_parts.append('%(name)s') + fmt_parts.append('%(message)s') + fmt = ' '.join(fmt_parts) + + file_formatter = logging.Formatter(fmt, datefmt='%H:%M:%S') + file_handler.setFormatter(file_formatter) + root_logger.addHandler(file_handler) + self._handlers.append(file_handler) + + except Exception as e: + # 如果文件创建失败,打印警告但继续执行(只使用控制台输出) + print(f"警告: 无法创建日志文件 {output_file}: {e}", file=sys.stderr) + + # 防止传播到根日志器(避免重复输出) + root_logger.propagate = False + + self._configured = True + + def get_logger(self, name: str) -> AVDAGLogger: + """获取或创建日志器""" + if name not in self._loggers: + self._loggers[name] = AVDAGLogger(name) + return self._loggers[name] + + def set_level(self, level: Union[str, LogLevel]) -> None: + """设置全局日志级别""" + if isinstance(level, str): + level = LogLevel.from_string(level) + + self._default_level = level + root_logger = logging.getLogger('avdag') + root_logger.setLevel(level.value) + + def get_level(self) -> LogLevel: + """获取当前日志级别""" + return self._default_level + + def is_configured(self) -> bool: + """检查日志系统是否已配置""" + return self._configured + + +# 全局配置实例 +_config = LoggingConfig() + +def configure_logging(**kwargs) -> None: + """配置日志系统的便捷函数""" + _config.configure(**kwargs) + +def get_logger(name: str) -> AVDAGLogger: + """获取日志器的便捷函数""" + if not _config.is_configured(): + # 自动配置:使用默认设置 + configure_logging() + return _config.get_logger(name) + +def set_log_level(level: Union[str, LogLevel]) -> None: + """设置日志级别的便捷函数""" + _config.set_level(level) + +def get_log_level() -> LogLevel: + """获取当前日志级别""" + return _config.get_level() + +def is_debug_enabled() -> bool: + """检查是否启用调试模式""" + return _config.get_level().value <= LogLevel.DEBUG.value + +def is_trace_enabled() -> bool: + """检查是否启用跟踪模式""" + return _config.get_level().value <= LogLevel.TRACE.value + + +# 兼容性函数:保持与现有代码的兼容性 +def debug_print(msg: str, category: str = "DEBUG") -> None: + """兼容性函数:替代原有的 print 调用""" + logger = get_logger(category.lower()) + logger.debug(msg) + +def info_print(msg: str, category: str = "INFO") -> None: + """兼容性函数:替代原有的 print 调用""" + logger = get_logger(category.lower()) + logger.info(msg) + +def error_print(msg: str, category: str = "ERROR") -> None: + """兼容性函数:替代原有的 print 调用""" + logger = get_logger(category.lower()) + logger.error(msg) + +def warning_print(msg: str, category: str = "WARNING") -> None: + """兼容性函数:替代原有的 print 调用""" + logger = get_logger(category.lower()) + logger.warning(msg) + + +# 预定义的常用日志器 +def get_verifier_logger() -> AVDAGLogger: + """获取验证器日志器""" + return get_logger("verifier") + +def get_ocr_logger() -> AVDAGLogger: + """获取OCR处理日志器""" + return get_logger("ocr") + +def get_llm_logger() -> AVDAGLogger: + """获取LLM调用日志器""" + return get_logger("llm") + +def get_frame_logger() -> AVDAGLogger: + """获取帧处理日志器""" + return get_logger("frame") + +def get_condition_logger() -> AVDAGLogger: + """获取条件检查日志器""" + return get_logger("condition") + + +# 模块级别的便捷日志器 +logger = get_logger(__name__) + +def test_logging_compatibility(): + """测试日志系统的跨平台兼容性""" + import tempfile + + print("=== 日志系统兼容性测试 ===") + + # 测试平台信息 + print(f"平台: {sys.platform}") + print(f"TTY支持: {sys.stderr.isatty()}") + print(f"编码: {sys.stderr.encoding}") + + # 测试颜色支持 + formatter = ColoredFormatter() + print(f"颜色支持: {formatter.use_colors}") + + # 测试环境变量 + print(f"TERM: {os.getenv('TERM', 'N/A')}") + print(f"NO_COLOR: {os.getenv('NO_COLOR', 'N/A')}") + print(f"FORCE_COLOR: {os.getenv('FORCE_COLOR', 'N/A')}") + + # 测试文件输出 + with tempfile.NamedTemporaryFile(mode='w', suffix='.log', delete=False) as tmp: + temp_log_file = tmp.name + + try: + # 配置日志系统 + configure_logging( + level='DEBUG', + use_colors=True, + show_time=True, + show_module=True, + output_file=temp_log_file + ) + + # 测试日志输出 + test_logger = get_logger('test') + test_logger.debug('测试DEBUG') + test_logger.info('测试INFO') + test_logger.warning('测试WARNING') + test_logger.error('测试ERROR') + + # 读取文件内容 + with open(temp_log_file, 'r', encoding='utf-8') as f: + file_content = f.read() + + print(f"\n文件日志内容预览:") + for i, line in enumerate(file_content.splitlines()[:2], 1): + print(f" {i}: {line}") + + # 检查是否包含颜色代码 + has_color_codes = any(code in file_content for code in ['\033[', '[96m', '[92m']) + print(f"文件包含颜色代码: {'是' if has_color_codes else '否'}") + + print("✅ 兼容性测试完成") + + except Exception as e: + print(f"❌ 测试失败: {e}") + finally: + # 清理临时文件 + try: + os.unlink(temp_log_file) + except: + pass + +# 模块级别的便捷日志器 +logger = get_logger(__name__) + +# 自动配置检查 +if not _config.is_configured(): + # 检查是否有环境变量或配置文件 + config_file = os.getenv('AVDAG_LOG_CONFIG') + if config_file and Path(config_file).exists(): + configure_logging(config_file=config_file) + else: + # 使用默认配置 + configure_logging() diff --git a/MobiFlow/avdag/ocr_processor.py b/MobiFlow/avdag/ocr_processor.py new file mode 100644 index 0000000..819bf53 --- /dev/null +++ b/MobiFlow/avdag/ocr_processor.py @@ -0,0 +1,745 @@ +""" +OCR处理模块 - 为验证系统提供图像文字识别功能 + +此模块封装了app_trajectory_analyzer的OCR引擎,提供统一的文字识别接口。 +支持PaddleOCR和Tesseract两种引擎,可根据需要选择或自动切换。 +""" + +from __future__ import annotations +import os +import sys +import re +import xml.etree.ElementTree as ET +from typing import Dict, List, Optional, Any, Union, Tuple +from dataclasses import dataclass +from PIL import Image, ImageOps, ImageFilter, ImageEnhance +import threading + +from .logger import get_ocr_logger + +# 动态添加app_trajectory_analyzer路径 +def _add_ocr_path(): + """添加OCR引擎路径到Python搜索路径""" + current_dir = os.path.dirname(__file__) + tools_dir = os.path.join(current_dir, "..", "tools", "app_trajectory_analyzer", "src") + if os.path.exists(tools_dir) and tools_dir not in sys.path: + sys.path.insert(0, tools_dir) + +# 尝试导入OCR引擎 +_add_ocr_path() + +try: + from analyzer.ocr_engine import OCREngine + _ocr_available = True +except ImportError: + _ocr_available = False + + +@dataclass +class ProcessedText: + """OCR处理后的文本结果""" + original: str # 原始OCR文本 + cleaned: str # 清理后的文本(移除特殊符号) + no_spaces: str # 无空格版本(用于连续匹配) + words: List[str] # 分词结果 + chars: List[str] # 字符列表 + + +class OCRProcessor: + """OCR处理器,提供图像文字识别和文本处理功能""" + + def __init__(self, use_paddle: bool = True, lang: str = "chi_sim+eng"): + """初始化OCR处理器""" + self.use_paddle = use_paddle + self.lang = lang + self._engine: Optional[Any] = None + self._engine_paddle: Optional[Any] = None + self._engine_tess: Optional[Any] = None + self._available = _ocr_available + self._lock = threading.Lock() + self._cache_words: Dict[str, List[str]] = {} + self._logger = get_ocr_logger() + + if self._available: + try: + paddle_success = False + tesseract_success = False + + # 尝试初始化 PaddleOCR + try: + self._engine_paddle = OCREngine(use_paddle=True, lang=self.lang) + paddle_success = True + self._logger.info("PaddleOCR引擎初始化成功") + except Exception as e: + self._logger.warning(f"PaddleOCR引擎初始化失败: {e}") + self._engine_paddle = None + + # 尝试初始化 Tesseract + try: + self._engine_tess = OCREngine(use_paddle=False, lang=self.lang) + tesseract_success = True + self._logger.info("Tesseract引擎初始化成功") + except Exception as e: + self._logger.warning(f"Tesseract引擎初始化失败: {e}") + self._engine_tess = None + + # 设置主引擎:优先使用PaddleOCR,如果失败则使用Tesseract + if self.use_paddle and paddle_success: + self._engine = self._engine_paddle + self._logger.info("使用PaddleOCR作为主引擎") + elif tesseract_success: + self._engine = self._engine_tess + self._logger.info("使用Tesseract作为主引擎") + elif paddle_success: + self._engine = self._engine_paddle + self._logger.info("回退到PaddleOCR作为主引擎") + else: + self._logger.error("所有OCR引擎初始化失败") + self._available = False + self._engine = None + + except Exception as e: + self._logger.error(f"初始化OCR引擎失败: {e}") + self._available = False + + def is_available(self) -> bool: + """检查OCR功能是否可用""" + return self._available and (self._engine_paddle is not None or self._engine_tess is not None) + + def process_text(self, raw_text: str) -> ProcessedText: + """将原始文本标准化并生成多视图便于匹配""" + if not raw_text or not raw_text.strip(): + return ProcessedText(original='', cleaned='', no_spaces='', words=[], chars=[]) + + # 正规化:全半角、大小写、常见混淆字符 + def to_half_width(s: str) -> str: + res = [] + for ch in s: + code = ord(ch) + if code == 0x3000: + code = 32 + elif 0xFF01 <= code <= 0xFF5E: + code -= 0xFEE0 + res.append(chr(code)) + return ''.join(res) + + def normalize_confusions(s: str) -> str: + mapping = { + 'I': 'I', 'L': 'L', 'O': 'O', 'S': 'S', 'B': 'B', + '0': '0', '1': '1', '2': '2', '5': '5', '6': '6', '8': '8', '9': '9', + # 常见 OCR 易混: + 'O': '0', 'o': '0', 'l': '1', 'I': '1', '丨': '1', '|': '1', + 'Z': '2', 'S': '5', 'B': '8', + } + return ''.join(mapping.get(c, c) for c in s) + + # 1. 清理文本:保留中文、字母、数字、空格 + raw_text_norm = normalize_confusions(to_half_width(raw_text)).casefold() + cleaned = re.sub(r'[^\u4e00-\u9fff\w\s]', ' ', raw_text_norm) + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + + # 2. 无空格版本:用于连续匹配 + no_spaces = re.sub(r'\s+', '', cleaned) + + # 3. 提取词语:按空格分割 + words = [w.strip() for w in cleaned.split() if w.strip()] + + # 4. 提取字符 + chars = list(no_spaces) + return ProcessedText(original=raw_text, cleaned=cleaned, no_spaces=no_spaces, words=words, chars=chars) + + def smart_text_contains(self, processed_text: ProcessedText, keyword: str) -> bool: + """ + 智能文本匹配,支持多种匹配策略 + + Args: + processed_text: 处理后的文本 + keyword: 要搜索的关键词 + + Returns: + bool: 是否匹配 + """ + if not keyword or not processed_text: + return False + + def to_half_width(s: str) -> str: + res = [] + for ch in s: + code = ord(ch) + if code == 0x3000: + code = 32 + elif 0xFF01 <= code <= 0xFF5E: + code -= 0xFEE0 + res.append(chr(code)) + return ''.join(res) + + def normalize_confusions(s: str) -> str: + mapping = { + 'O': 'O', 'o': 'o', 'I': 'I', 'L': 'L', 'S': 'S', 'B': 'B', + '0': '0', '1': '1', '2': '2', '5': '5', '6': '6', '8': '8', '9': '9', + 'O': '0', 'o': '0', 'l': '1', 'I': '1', '丨': '1', '|': '1', 'Z': '2', 'S': '5', 'B': '8', + } + return ''.join(mapping.get(c, c) for c in s) + + # 处理关键词,应用相同的正规化 + keyword_norm = normalize_confusions(to_half_width(keyword)).casefold() + keyword_clean = re.sub(r'[^\u4e00-\u9fff\w\s]', ' ', keyword_norm) + keyword_clean = re.sub(r'\s+', ' ', keyword_clean).strip() + keyword_no_spaces = re.sub(r'\s+', '', keyword_clean) + + # 匹配策略 1: 精确匹配(带空格) + if keyword_clean in processed_text.cleaned: + return True + + # 匹配策略 2: 连续匹配(无空格) + if keyword_no_spaces in processed_text.no_spaces: + return True + + # 匹配策略 3: 分词匹配 + keyword_words = [w.strip() for w in keyword_clean.split() if w.strip()] + if keyword_words and all(any(kw in word for word in processed_text.words) for kw in keyword_words): + return True + + # 匹配策略 4: 模糊匹配(80%相似度) + try: + from difflib import SequenceMatcher + similarity = SequenceMatcher(None, keyword_no_spaces, processed_text.no_spaces).ratio() + if similarity >= 0.8: + return True + except ImportError: + pass + + return False + + def extract_text_from_image(self, image_path: str, enable_hybrid: bool = True) -> Tuple[str, Optional[str]]: + """ + 从图像中提取文字 + + Args: + image_path: 图像文件路径 + enable_hybrid: 是否启用混合识别(Paddle + Tesseract) + + Returns: + Tuple[str, Optional[str]]: (主识别结果, 备用识别结果) + """ + if not self.is_available(): + self._logger.warning("OCR引擎不可用") + return "", None + + if not os.path.exists(image_path): + self._logger.error(f"图像文件不存在: {image_path}") + return "", None + + try: + # 验证图像文件 + with Image.open(image_path) as img: + img.verify() + except Exception as e: + self._logger.error(f"打开图片失败: {image_path}: {e}") + return "", None + + primary_text = "" + secondary_text = None + + # 使用主引擎识别 + if self._engine: + try: + result = self._engine.run(image_path) + primary_text = result.get_text() if result else "" + self._logger.debug(f"主引擎识别结果: {len(primary_text)} 字符") + except Exception as e: + self._logger.error(f"主引擎识别失败: {e}") + + # 混合识别:使用备用引擎 + if enable_hybrid and self._engine_paddle and self._engine_tess: + backup_engine = self._engine_tess if self._engine == self._engine_paddle else self._engine_paddle + try: + result = backup_engine.run(image_path) + secondary_text = result.get_text() if result else "" + self._logger.debug(f"备用引擎识别结果: {len(secondary_text)} 字符") + + # 如果主引擎失败但备用引擎成功,使用备用结果 + if not primary_text and secondary_text: + primary_text = secondary_text + self._logger.info("使用备用引擎结果作为主结果") + + except Exception as e: + self._logger.warning(f"备用引擎识别失败: {e}") + + return primary_text, secondary_text + + def recognize_image(self, image_path: str) -> Optional[ProcessedText]: + """ + 识别图像并返回处理后的文本结果(兼容性方法) + + Args: + image_path: 图像文件路径 + + Returns: + ProcessedText: 处理后的文本结果,失败时返回None + """ + text, _ = self.extract_text_from_image(image_path) + if text: + return self.process_text(text) + return None + + def get_word_list(self, image_path: str) -> List[str]: + """ + 从图像中获取词语列表(兼容性方法) + + Args: + image_path: 图像文件路径 + + Returns: + List[str]: 词语列表 + """ + text, backup_text = self.extract_text_from_image(image_path) + words = [] + + if text: + processed = self.process_text(text) + words.extend(processed.words) + if processed.cleaned: + words.append(processed.cleaned) + if processed.no_spaces: + words.append(processed.no_spaces) + + if backup_text and backup_text != text: + processed_backup = self.process_text(backup_text) + words.extend(processed_backup.words) + if processed_backup.cleaned: + words.append(processed_backup.cleaned) + if processed_backup.no_spaces: + words.append(processed_backup.no_spaces) + + return list(set(words)) + + def extract_xml_text(self, xml_content: str) -> str: + """从XML内容中提取可视文本""" + if not xml_content: + return "" + + try: + root = ET.fromstring(xml_content) + except ET.ParseError as e: + self._logger.warning(f"XML解析失败: {e}") + return "" + + texts = [] + + def extract_text_recursive(element): + """递归提取元素文本""" + # 获取元素的text属性 + text_attr = element.get('text', '').strip() + if text_attr: + texts.append(text_attr) + + # 获取元素的content-desc属性(Android无障碍描述) + content_desc = element.get('content-desc', '').strip() + if content_desc and content_desc != text_attr: + texts.append(content_desc) + + # 递归处理子元素 + for child in element: + extract_text_recursive(child) + + extract_text_recursive(root) + return ' '.join(texts) + + def process_frame_text(self, frame: Dict[str, Any]) -> ProcessedText: + """ + 处理帧中的所有文本信息 + + Args: + frame: 包含图像、XML等信息的帧数据 + + Returns: + ProcessedText: 处理后的综合文本 + """ + all_texts = [] + + # 1. 直接文本信息 + if 'text' in frame and frame['text']: + all_texts.append(str(frame['text'])) + + # 2. XML文本提取 + if 'xml_text' in frame and frame['xml_text']: + xml_text = self.extract_xml_text(frame['xml_text']) + if xml_text: + all_texts.append(xml_text) + + # 3. OCR文本提取 + if 'image' in frame and frame['image'] and os.path.exists(frame['image']): + ocr_text, _ = self.extract_text_from_image(frame['image']) + if ocr_text: + all_texts.append(ocr_text) + + # 4. 任务描述和推理文本 + for field in ['task_description', 'reasoning', 'action']: + if field in frame and frame[field]: + all_texts.append(str(frame[field])) + + # 合并所有文本 + combined_text = ' '.join(all_texts) + return self.process_text(combined_text) + + def match_keyword_in_frame(self, frame: Dict[str, Any], keyword: str, enable_ocr: bool = True) -> bool: + """ + 在帧中搜索关键词 + + Args: + frame: 帧数据 + keyword: 要搜索的关键词 + enable_ocr: 是否启用OCR识别 + + Returns: + bool: 是否找到关键词 + """ + if not keyword: + return False + + # 首先在现有文本字段中搜索 + text_fields = ['text', 'task_description', 'reasoning'] + for field in text_fields: + if field in frame and frame[field]: + processed = self.process_text(str(frame[field])) + if self.smart_text_contains(processed, keyword): + self._logger.debug(f"在字段 {field} 中找到关键词: {keyword}") + return True + + # 在XML文本中搜索 + if 'xml_text' in frame and frame['xml_text']: + xml_text = self.extract_xml_text(frame['xml_text']) + if xml_text: + processed = self.process_text(xml_text) + if self.smart_text_contains(processed, keyword): + self._logger.debug(f"在XML文本中找到关键词: {keyword}") + return True + + # 使用OCR在图像中搜索 + if enable_ocr and 'image' in frame and frame['image']: + if os.path.exists(frame['image']): + ocr_text, backup_text = self.extract_text_from_image(frame['image']) + + # 在主OCR结果中搜索 + if ocr_text: + processed = self.process_text(ocr_text) + if self.smart_text_contains(processed, keyword): + self._logger.debug(f"在OCR文本中找到关键词: {keyword}") + return True + + # 在备用OCR结果中搜索 + if backup_text and backup_text != ocr_text: + processed = self.process_text(backup_text) + if self.smart_text_contains(processed, keyword): + self._logger.debug(f"在备用OCR文本中找到关键词: {keyword}") + return True + + return False + + def get_text_similarity(self, text1: str, text2: str) -> float: + """计算两个文本的相似度""" + if not text1 or not text2: + return 0.0 + + processed1 = self.process_text(text1) + processed2 = self.process_text(text2) + + try: + from difflib import SequenceMatcher + return SequenceMatcher(None, processed1.no_spaces, processed2.no_spaces).ratio() + except ImportError: + # 简单的字符重叠度计算 + chars1 = set(processed1.no_spaces) + chars2 = set(processed2.no_spaces) + if not chars1 and not chars2: + return 1.0 + if not chars1 or not chars2: + return 0.0 + return len(chars1 & chars2) / len(chars1 | chars2) + + +# 全局OCR处理器实例(单例模式) +_global_ocr_processor: Optional[OCRProcessor] = None +_ocr_lock = threading.Lock() + +def get_ocr_processor(use_paddle: bool = True, lang: str = "chi_sim+eng") -> OCRProcessor: + """获取全局OCR处理器实例""" + global _global_ocr_processor + + with _ocr_lock: + if _global_ocr_processor is None: + _global_ocr_processor = OCRProcessor(use_paddle=use_paddle, lang=lang) + + return _global_ocr_processor + + +def get_global_ocr_processor() -> OCRProcessor: + """获取全局OCR处理器实例(兼容性函数)""" + return get_ocr_processor() + + +def extract_text_from_xml(xml_content: str) -> ProcessedText: + """ + 从XML内容中提取文字(兼容性函数) + + Args: + xml_content: XML文件内容字符串 + + Returns: + ProcessedText: 处理后的文本结果 + """ + processor = get_ocr_processor() + xml_text = processor.extract_xml_text(xml_content) + return processor.process_text(xml_text) + + +def extract_text_from_xml_simple(xml_content: str) -> ProcessedText: + """ + 从XML内容中使用正则表达式简单提取文字(备用方案) + + Args: + xml_content: XML文件内容字符串 + + Returns: + ProcessedText: 处理后的文本结果 + """ + if not xml_content or not xml_content.strip(): + return ProcessedText( + original='', + cleaned='', + no_spaces='', + words=[], + chars=[] + ) + + text_contents = [] + + # 提取text属性 + text_pattern = r'text="([^"]*[a-zA-Z\u4e00-\u9fff]+[^"]*)"' + text_matches = re.findall(text_pattern, xml_content) + text_contents.extend([t.strip() for t in text_matches if t.strip()]) + + # 提取content-desc属性 + desc_pattern = r'content-desc="([^"]*[a-zA-Z\u4e00-\u9fff]+[^"]*)"' + desc_matches = re.findall(desc_pattern, xml_content) + text_contents.extend([d.strip() for d in desc_matches if d.strip()]) + + # 提取hint属性 + hint_pattern = r'hint="([^"]*[a-zA-Z\u4e00-\u9fff]+[^"]*)"' + hint_matches = re.findall(hint_pattern, xml_content) + text_contents.extend([h.strip() for h in hint_matches if h.strip()]) + + # 去重 + unique_texts = list(set([t for t in text_contents if t])) + combined_text = ' '.join(unique_texts) + + # 使用OCRProcessor的文本处理功能 + processor = get_ocr_processor() + return processor.process_text(combined_text) + + +def create_frame_ocr_function(processor: OCRProcessor) -> callable: + """ + 创建适用于Frame的OCR函数,用于集成到VerifierOptions中 + + Args: + processor: OCR处理器实例 + + Returns: + callable: 接受Frame参数的OCR函数 + """ + def frame_ocr(frame: Dict[str, Any]) -> Optional[str]: + """ + 从Frame中提取并识别图像文字 + + Args: + frame: 包含图像路径的Frame字典 + + Returns: + str: 识别的文字,失败时返回XML文本或None + """ + # 获取图像路径 + image_path = frame.get("image") + if not image_path or not os.path.exists(image_path): + # 退化到改进的XML提取 + xml_text = frame.get("xml_text", "") + if xml_text: + xml_processed = extract_text_from_xml(xml_text) + if xml_processed.cleaned: + frame['_xml_processed'] = xml_processed + processor._logger.info(f"图像不可用,使用改进XML提取: {xml_processed.cleaned[:100]}...") + processor._logger.debug(f"XML提取词语数: {len(xml_processed.words)}") + return f"{xml_processed.cleaned} {xml_processed.no_spaces} {' '.join(xml_processed.words)}" + else: + processor._logger.warning("图像不可用且XML提取失败") + return xml_text[:200] if xml_text else None + return None + + # 使用OCR识别 + ocr_text, backup_text = processor.extract_text_from_image(image_path) + xml_text = frame.get("xml_text", "") + merged_parts: List[str] = [] + + if ocr_text: + # 处理OCR文本 + processed = processor.process_text(ocr_text) + frame['_ocr_processed'] = processed + merged_parts.extend([processed.cleaned, processed.no_spaces] + processed.words) + processor._logger.debug(f"识别图像 {os.path.basename(image_path)} -> 词语数: {len(processed.words)}") + + # 融合XML文本 + if xml_text: + xml_processed = extract_text_from_xml(xml_text) + if xml_processed.cleaned: + frame['_xml_processed'] = xml_processed + merged_parts.extend([xml_processed.cleaned, xml_processed.no_spaces] + xml_processed.words) + processor._logger.debug(f"融合XML文本 -> 词语数: {len(xml_processed.words)}") + + if merged_parts: + return ' '.join(list(dict.fromkeys([p for p in merged_parts if p]))) + + # 全部失败 + processor._logger.warning(f"图像 {os.path.basename(image_path)} 未识别到文字且无可用XML") + return None + + return frame_ocr + + +def create_frame_texts_function(processor: OCRProcessor) -> callable: + """ + 创建适用于Frame的文本列表提取函数 + + Args: + processor: OCR处理器实例 + + Returns: + callable: 接受Frame参数,返回文本列表的函数 + """ + def frame_texts(frame: Dict[str, Any]) -> List[str]: + """ + 从Frame中提取文本列表,优先使用OCR识别 + + Args: + frame: 包含图像路径的Frame字典 + + Returns: + List[str]: 文本列表 + """ + # 获取图像路径 + image_path = frame.get("image") + if not image_path or not os.path.exists(image_path): + # 退化到改进的XML提取 + xml_text = frame.get("xml_text", "") + if xml_text: + xml_processed = extract_text_from_xml(xml_text) + if xml_processed.words: + frame['_xml_processed'] = xml_processed + # 返回多种格式的文本 + result_texts = xml_processed.words.copy() + if xml_processed.cleaned: + result_texts.append(xml_processed.cleaned) + if xml_processed.no_spaces: + result_texts.append(xml_processed.no_spaces) + processor._logger.info(f"图像不可用,使用改进XML提取,得到 {len(result_texts)} 个文本片段") + return list(set(result_texts)) + else: + # 降级到传统方式 + texts = frame.get("xml_texts", []) + if not texts and xml_text: + texts = [xml_text] + processor._logger.warning("XML解析失败,使用原始XML文本") + return texts + else: + texts = frame.get("xml_texts", []) + processor._logger.warning("无图像也无XML文本") + return texts + + # 使用OCR获取文本 + ocr_text, backup_text = processor.extract_text_from_image(image_path) + words = [] + + if ocr_text: + processed = processor.process_text(ocr_text) + words.extend(processed.words) + if processed.cleaned: + words.append(processed.cleaned) + if processed.no_spaces: + words.append(processed.no_spaces) + + # 融合 XML + xml_text = frame.get("xml_text", "") + if xml_text: + xml_processed = extract_text_from_xml(xml_text) + if xml_processed.words: + frame['_xml_processed'] = xml_processed + words.extend(xml_processed.words) + if xml_processed.cleaned: + words.append(xml_processed.cleaned) + if xml_processed.no_spaces: + words.append(xml_processed.no_spaces) + + if words: + processor._logger.debug(f"从图像 {os.path.basename(image_path)} (含XML融合) 提取 {len(set(words))} 个文本片段") + return list(set(words)) + + # 兜底 + texts = frame.get("xml_texts", []) + if not texts and xml_text: + texts = [xml_text] + if texts: + processor._logger.info(f"使用兜底XML文本 {len(texts)} 段") + else: + processor._logger.warning("无可用文本") + return texts + + return frame_texts + + +def create_standard_ocr_functions() -> tuple[callable, callable]: + """ + 创建标准的OCR函数对,用于快速集成 + + Returns: + tuple: (frame_ocr_function, frame_texts_function) + """ + processor = get_global_ocr_processor() + return ( + create_frame_ocr_function(processor), + create_frame_texts_function(processor) + ) + + +# 便捷函数 +def extract_text_from_image(image_path: str) -> str: + """便捷函数:从图像提取文本""" + processor = get_ocr_processor() + text, _ = processor.extract_text_from_image(image_path) + return text + +def match_text_in_frame(frame: Dict[str, Any], keyword: str) -> bool: + """便捷函数:在帧中匹配文本""" + processor = get_ocr_processor() + return processor.match_keyword_in_frame(frame, keyword) + +def process_frame_text(frame: Dict[str, Any]) -> ProcessedText: + """便捷函数:处理帧文本""" + processor = get_ocr_processor() + return processor.process_frame_text(frame) + + +# 导出的公共接口 +__all__ = [ + "ProcessedText", + "OCRProcessor", + "create_frame_ocr_function", + "create_frame_texts_function", + "get_global_ocr_processor", + "get_ocr_processor", + "create_standard_ocr_functions", + "extract_text_from_xml", + "extract_text_from_xml_simple", + "extract_text_from_image", + "match_text_in_frame", + "process_frame_text" +] diff --git a/MobiFlow/avdag/trace_loader.py b/MobiFlow/avdag/trace_loader.py new file mode 100644 index 0000000..037e9d8 --- /dev/null +++ b/MobiFlow/avdag/trace_loader.py @@ -0,0 +1,132 @@ +from __future__ import annotations +import json +import os +from typing import Any, Dict, List +import re + +from .types import Frame + + +def _read_file(path: str) -> str: + try: + with open(path, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + return "" + + +def load_frames_from_dir(folder: str) -> List[Frame]: + """将包含 images/xml/actions/react 的目录转换为帧序列。 + + 规则: + - 帧索引从 1 开始(按文件名 N.xml/N.jpg 推断),依次排序。 + - 每帧字段: + - image: `/.jpg`(存在时) + - xml_text: `/.xml` 的原始文本(存在时) + - reasoning: react.json[i-1].reasoning(存在时) + - action: actions.json.actions[i-1](存在时) + - text: 合并 reasoning + action 文本,用于简易匹配 + """ + frames: List[Frame] = [] + # 添加一个空的帧,作为起始点 + frames.append({ + "image": None, + "xml_text": "", + "reasoning": None, + "react_action": None, + "action": None, + "text": "", + "ui": {}, + "task_description": "", + "app_name": "" + }) + if not os.path.isdir(folder): + raise FileNotFoundError(folder) + + actions_path = os.path.join(folder, "actions.json") + react_path = os.path.join(folder, "react.json") + actions = [] + reacts = [] + act_meta: Dict[str, Any] = {} + if os.path.exists(actions_path): + with open(actions_path, "r", encoding="utf-8") as f: + act_meta = json.load(f) or {} + actions = (act_meta or {}).get("actions", []) + if os.path.exists(react_path): + with open(react_path, "r", encoding="utf-8") as f: + reacts = json.load(f) or [] + + # 找到所有形如 N.xml 或 N.jpg 的索引 + indices: List[int] = [] + for name in os.listdir(folder): + if name.endswith(".xml"): + try: + idx = int(os.path.splitext(name)[0]) + indices.append(idx) + except ValueError: + pass + elif name.endswith(".jpg"): + try: + idx = int(os.path.splitext(name)[0]) + indices.append(idx) + except ValueError: + pass + indices = sorted(sorted(set(indices))) + + for i in indices: + fr: Frame = {} + xml_path = os.path.join(folder, f"{i}.xml") + img_path = os.path.join(folder, f"{i}.jpg") + fr["image"] = img_path if os.path.exists(img_path) else None + fr["xml_text"] = _read_file(xml_path) if os.path.exists(xml_path) else "" + # 从 xml 中提取包名等元信息,可以按需要补充,构建更完善的UI上下文 + # 例如:fr["ui"] = {"package": "com.example.app"} 等 + # 这里仅示例提取 package 名称 + ui: Dict[str, Any] = {} + if fr["xml_text"]: + m = re.search(r'package="([^"]+)"', fr["xml_text"]) # 简单提取第一个 package + if m: + ui["package"] = m.group(1) + if ui: + fr["ui"] = ui + + r = reacts[i - 1] if 0 <= (i - 1) < len(reacts) else None + a = actions[i - 1] if 0 <= (i - 1) < len(actions) else None + fr["reasoning"] = r.get("reasoning") if isinstance(r, dict) else None + fr["react_action"] = r.get("action") if isinstance(r, dict) else None + fr["action"] = a if isinstance(a, dict) else None + # 添加顶层任务元信息,便于 LLM 判断相关性 + if act_meta: + fr["task_description"] = act_meta.get("task_description") or act_meta.get("old_task_description") + fr["app_name"] = act_meta.get("app_name") + + # 组装便于简单文本匹配的 text 字段 + pieces: List[str] = [] + if fr.get("reasoning"): + pieces.append(str(fr["reasoning"])) + if a: + if a.get("type"): + pieces.append(str(a["type"])) + if a.get("text"): + pieces.append(str(a["text"])) + if r and isinstance(r, dict): + params = r.get("parameters") or {} + for v in params.values(): + try: + pieces.append(str(v)) + except Exception: + pass + fr["text"] = " \n".join(pieces) + + frames.append(fr) + + # 增加邻接上下文引用(只读) + for idx, fr in enumerate(frames): + fr["_index"] = idx + fr["_prev"] = frames[idx - 1] if idx > 0 else None + fr["_next"] = frames[idx + 1] if idx + 1 < len(frames) else None + + return frames + + +__all__ = ["load_frames_from_dir"] diff --git a/MobiFlow/avdag/types.py b/MobiFlow/avdag/types.py new file mode 100644 index 0000000..bec4958 --- /dev/null +++ b/MobiFlow/avdag/types.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Callable + +Frame = Dict[str, Any] # 简化:每帧就是一个字典 + +@dataclass +class ConditionSpec: + type: str + params: Dict[str, Any] + +@dataclass +class NodeSpec: + id: str + name: Optional[str] = None + deps: Optional[List[str]] = None + # next: 声明后继节点(可选)。用于定义可选路径(OR语义)。 + # 注意:当同一节点同时定义了 deps 与 next 指向它的父节点时,deps 仍然采用 AND 语义优先; + # 未定义 deps 时,来自 next 的父子边在验证阶段按 OR 语义处理。 + next: Optional[List[str]] = None + condition: Optional[ConditionSpec] = None + # score: 节点分数,默认为10分 + score: int = 10 + +@dataclass +class SuccessSpec: + any_of: Optional[List[str]] = None + all_of: Optional[List[str]] = None + +@dataclass +class TaskSpec: + task_id: str + nodes: List[NodeSpec] + success: Optional[SuccessSpec] = None + +@dataclass +class NodeMatch: + node_id: str + frame_index: int + +@dataclass +class VerifyResult: + ok: bool + matched: List[NodeMatch] + reason: Optional[str] = None + # 新增:判定过程日志与人工复核标记(兼容旧测试) + logs: List["DecisionLog"] = field(default_factory=list) + manual_review_needed: bool = False + # 新增:任务总分(成功匹配的节点分数之和) + total_score: int = 0 + + +@dataclass +class DecisionLog: + frame_index: int + node_id: str + strategy: str + decision: str # hit | miss | inconclusive + details: Optional[str] = None + # 新增字段:记录检查器的详细结果 + checker_type: Optional[str] = None # ocr | llm | text | regex | ui | action | dynamic_match + checker_result: Optional[str] = None # 检查器的详细结果或原因 + matched_keywords: Optional[List[str]] = None # 匹配成功的关键词 + unmatched_keywords: Optional[List[str]] = None # 未匹配的关键词 + + +@dataclass +class VerifierOptions: + """可注入的判定能力与策略顺序。 + + - ocr: 函数接受 Frame,返回识别出的文本(字符串)或 None 表示不支持/不确定。 + - llm: 函数接受上下文字典,返回 True/False/None(None 表示不确定)。 + - escalation_order: 策略升级顺序,默认 [text, regex, ui, action, dynamic_match, icons, ocr, llm] + - log_decisions: 是否记录详细日志。 + - force_llm_verification: 是否强制使用LLM验证,即使其他策略已经匹配。 + - prevent_frame_backtrack: 是否防止帧回退(默认True),一旦某个帧被OCR/LLM使用,之前的帧也标记为已使用。 + - ocr_frame_exclusive: OCR验证时是否独占使用帧(默认True),防止同一帧被多个OCR节点重复使用。 + - llm_frame_exclusive: LLM验证时是否独占使用帧(默认True),防止同一帧被多个LLM节点重复使用。 + """ + ocr: Optional[Callable[[Frame], Optional[str]]] = None + llm: Optional[Callable[[Dict[str, Any]], Optional[bool]]] = None + escalation_order: List[str] = field(default_factory=lambda: [ + "text", "regex", "ui", "action", "icons", "ocr", "llm" + ]) + log_decisions: bool = True + force_llm_verification: bool = False + prevent_frame_backtrack: bool = True + ocr_frame_exclusive: bool = True + llm_frame_exclusive: bool = True + max_llm_retries: int = 3 # LLM请求的最大重试次数 + llm_retry_delay: float = 1.0 # LLM重试之间的延迟(秒) + +__all__ = [ + "Frame", + "ConditionSpec", + "NodeSpec", + "SuccessSpec", + "TaskSpec", + "NodeMatch", + "VerifyResult", + "DecisionLog", + "VerifierOptions", +] diff --git a/MobiFlow/avdag/verifier.py b/MobiFlow/avdag/verifier.py new file mode 100644 index 0000000..eba0e78 --- /dev/null +++ b/MobiFlow/avdag/verifier.py @@ -0,0 +1,882 @@ +from __future__ import annotations +import json +import os +import base64 +import time +import re +from typing import Dict, List, Tuple, Optional + +from .types import Frame, NodeMatch, TaskSpec, VerifyResult, VerifierOptions, DecisionLog +from .dag import DAG +from .conditions import get_checker +from .loader import load_task +from .trace_loader import load_frames_from_dir +from .logger import get_logger, get_frame_logger, get_verifier_logger, get_llm_logger + + +def _collect_candidates_path_aware(frames: List[Frame], task: TaskSpec, dag: DAG, options: Optional[VerifierOptions], logs: List[DecisionLog]) -> Dict[str, List[int]]: + """路径感知的候选帧收集:按拓扑顺序检查节点,支持路径剪枝和动态帧分配。 + + - 按拓扑顺序逐节点检查; + - 路径剪枝:只为当前可达路径上的节点分配帧资源; + - 动态帧分配:根据已成功的路径动态决定下一个要检查的节点; + - 帧独占优化:针对每个路径分支独立管理帧使用,避免跨路径冲突。 + """ + cand: Dict[str, List[int]] = {} + # 路径级别的帧使用追踪:key为路径标识,value为该路径已使用的帧集合 + path_used_frames: Dict[str, set] = {} + # 节点可达性:记录每个节点当前是否可达(基于其依赖/前驱的满足情况) + node_reachable: Dict[str, bool] = {} + # 节点的最小可行帧索引:用于路径约束 + min_frame_idx: Dict[str, int] = {} + # 节点实际匹配的帧索引:用于正确计算后继节点的时序约束 + matched_frame_idx: Dict[str, int] = {} + + frame_logger = get_frame_logger() + frame_logger.debug("开始路径感知的候选帧收集") + + # 按拓扑顺序处理节点 + topo_order = dag.topo_order() + + # 初始化:根节点(无依赖)总是可达 + for nid in dag.nodes: + node = dag.nodes[nid] + deps = node.deps or [] + or_parents = dag.parents_from_next.get(nid, []) + + if not deps and not or_parents: + # 根节点 + node_reachable[nid] = True + min_frame_idx[nid] = 0 + else: + node_reachable[nid] = False + min_frame_idx[nid] = 0 + + for nid in topo_order: + node = dag.nodes[nid] + + # 路径剪枝:只处理可达节点 + if not node_reachable.get(nid, False): + cand[nid] = [] + frame_logger.debug(f"节点 {nid}: 不可达,跳过检查") + continue + + hits: List[int] = [] + if not node.condition: + cand[nid] = hits + # 无条件节点视为总是成功,更新后继节点可达性 + _update_successor_reachability(nid, 0, dag, node_reachable, min_frame_idx, matched_frame_idx, frame_logger) + continue + + checker = get_checker(node.condition.type) + params = node.condition.params or {} + + # 判断当前节点是否需要OCR或LLM验证 + needs_ocr = _node_needs_ocr(node, params) + needs_llm = _node_needs_llm(node, params) + needs_frame_exclusive = needs_ocr or needs_llm + + # 确定该节点的搜索起始帧(基于路径约束) + start_frame = min_frame_idx.get(nid, 0) + + # 获取当前路径上已使用的帧(仅针对帧独占节点) + current_path_used = set() + if needs_frame_exclusive: + # 通过回溯当前节点的路径来确定已使用的帧 + current_path_used = _get_path_used_frames(nid, dag, cand, needs_frame_exclusive) + + frame_logger.debug(f"节点 {nid}: 需要独占={needs_frame_exclusive}, 起始帧={start_frame}, 路径已使用帧={sorted(current_path_used)}") + + for i in range(start_frame, len(frames)): + # 对于需要独占帧的节点,跳过当前路径上已使用的帧 + if needs_frame_exclusive and i in current_path_used: + frame_logger.trace(f"节点 {nid}: 跳过路径已使用帧 {i}") + continue + + fr = frames[i] + try: + ok = checker.check(fr, params, options) + except Exception as e: + ok = False + if options and options.log_decisions: + logs.append(DecisionLog( + frame_index=i, + node_id=nid, + strategy=node.condition.type, + decision="inconclusive", + details=str(e), + checker_type="exception", + checker_result=str(e) + )) + frame_logger.error(f"节点 {nid}: 帧 {i} 检查异常 - {e}") + else: + if options and options.log_decisions: + # 获取检查器的详细结果 + checker_result = "" + matched_keywords = [] + unmatched_keywords = [] + checker_type = node.condition.type + + # 从frame中获取OCR或LLM的详细结果 + if "_last_ocr_result" in fr: + ocr_result = fr["_last_ocr_result"] + checker_type = "ocr" + checker_result = ocr_result.get("reason", "") + matched_keywords = ocr_result.get("matched_keywords", []) + unmatched_keywords = ocr_result.get("unmatched_keywords", []) + elif "_last_llm_result" in fr: + llm_result = fr["_last_llm_result"] + checker_type = "llm" + checker_result = llm_result.get("reason", "") + + logs.append(DecisionLog( + frame_index=i, + node_id=nid, + strategy=node.condition.type, + decision=("hit" if ok else "miss"), + checker_type=checker_type, + checker_result=checker_result, + matched_keywords=matched_keywords if matched_keywords else None, + unmatched_keywords=unmatched_keywords if unmatched_keywords else None + )) + frame_logger.trace(f"节点 {nid}: 帧 {i} 检查结果 {'成功' if ok else '失败'}") + + if ok: + hits.append(i) + frame_logger.debug(f"节点 {nid} 在帧 {i} 匹配成功") + + # 动态更新后继节点的可达性 + _update_successor_reachability(nid, i, dag, node_reachable, min_frame_idx, matched_frame_idx, frame_logger) + + # 对于帧独占节点,找到匹配后立即停止(早停机制) + if needs_frame_exclusive: + frame_logger.debug(f"节点 {nid}: 找到匹配帧 {i},启用早停机制") + break + else: + # 对于不需要独占帧的节点,继续搜索更多匹配 + frame_logger.trace(f"节点 {nid}: 帧 {i} 匹配成功,继续搜索更多候选帧") + + frame_logger.debug(f"节点 {nid}: 找到 {len(hits)} 个候选帧 {hits}") + cand[nid] = hits + + return cand + + +def _get_path_used_frames(node_id: str, dag: DAG, cand: Dict[str, List[int]], is_exclusive: bool) -> set: + """获取到达当前节点的路径上已使用的帧集合(用于帧独占逻辑) + + 重要:当一个帧被使用时,该帧之前的所有帧也被标记为已使用,确保时序的线性关系 + """ + used_frames = set() + + # 回溯路径上的前驱节点 + def backtrack_path(nid: str, visited: set): + if nid in visited: + return + visited.add(nid) + + node = dag.nodes.get(nid) + if not node: + return + + # 检查该节点是否使用了帧独占 + if node.condition: + params = node.condition.params or {} + node_needs_exclusive = _node_needs_ocr(node, params) or _node_needs_llm(node, params) + + # 如果该节点需要帧独占且有匹配的帧,将其及之前的所有帧加入已使用集合 + if node_needs_exclusive and cand.get(nid): + for matched_frame in cand[nid]: + # 当某个帧被使用时,该帧及之前的所有帧都被标记为已使用 + for frame_idx in range(matched_frame + 1): + used_frames.add(frame_idx) + + # 递归检查依赖节点 + for dep in (node.deps or []): + backtrack_path(dep, visited) + + # 递归检查next路径的父节点 + for parent in dag.parents_from_next.get(nid, []): + backtrack_path(parent, visited) + + # 从当前节点开始回溯(不包含当前节点本身) + visited = {node_id} + node = dag.nodes.get(node_id) + if node: + for dep in (node.deps or []): + backtrack_path(dep, visited) + for parent in dag.parents_from_next.get(node_id, []): + backtrack_path(parent, visited) + + return used_frames + + +def _update_successor_reachability(node_id: str, frame_idx: int, dag: DAG, node_reachable: Dict[str, bool], min_frame_idx: Dict[str, int], matched_frame_idx: Dict[str, int], logger): + """动态更新后继节点的可达性""" + # 记录当前节点的实际匹配帧索引 + matched_frame_idx[node_id] = frame_idx + + # 更新通过deps依赖当前节点的后继节点 + for child_id in dag.children.get(node_id, []): + child_node = dag.nodes[child_id] + + # 检查deps依赖 + if child_node.deps and node_id in child_node.deps: + # 检查该子节点的所有deps依赖是否都已满足 + all_deps_satisfied = True + max_dep_frame = -1 + + for dep in child_node.deps: + if not node_reachable.get(dep, False): + all_deps_satisfied = False + break + # 使用实际匹配的帧索引而不是最小帧索引 + dep_matched_frame = matched_frame_idx.get(dep, -1) + if dep_matched_frame >= 0: + max_dep_frame = max(max_dep_frame, dep_matched_frame) + + if all_deps_satisfied and max_dep_frame >= 0: + node_reachable[child_id] = True + min_frame_idx[child_id] = max_dep_frame + 1 # 严格时序:后继必须发生在依赖之后 + logger.debug(f"节点 {child_id}: 通过deps依赖变为可达,最小帧索引={min_frame_idx[child_id]} (基于依赖节点 {node_id} 的匹配帧 {frame_idx})") + + # 检查next路径依赖(OR语义) + elif not child_node.deps and node_id in dag.parents_from_next.get(child_id, []): + node_reachable[child_id] = True + min_frame_idx[child_id] = frame_idx + 1 # 后继节点必须在当前节点之后 + logger.debug(f"节点 {child_id}: 通过next路径变为可达,最小帧索引={min_frame_idx[child_id]} (基于父节点 {node_id} 的匹配帧 {frame_idx})") + + +def _collect_candidates(frames: List[Frame], task: TaskSpec, options: Optional[VerifierOptions], logs: List[DecisionLog]) -> Dict[str, List[int]]: + """兼容性包装:调用路径感知的候选帧收集函数""" + dag = DAG(task.nodes) + return _collect_candidates_path_aware(frames, task, dag, options, logs) + + +def _node_needs_ocr(node, params: Dict[str, any]) -> bool: + """判断节点是否需要OCR验证""" + if node.condition and node.condition.type in ("escalate", "juxtaposition"): + checker_params = params or {} + return "ocr" in checker_params + return False + + +def _node_needs_llm(node, params: Dict[str, any]) -> bool: + """判断节点是否需要LLM验证""" + if node.condition and node.condition.type in ("escalate", "juxtaposition"): + checker_params = params or {} + return "llm" in checker_params + return False + + +def _min_feasible_index(cands: List[int], min_required: int) -> Optional[int]: + """在 cands 中找出第一个 >= min_required 的索引。若无则返回 None。""" + # 二分也可,这里线性即可(帧通常不多) + for x in cands: + if x >= min_required: + return x + return None + + +def _calculate_total_score(matched_nodes: List[str], task: TaskSpec) -> int: + """计算匹配节点的总分数""" + total_score = 0 + node_dict = {node.id: node for node in task.nodes} + + for node_id in matched_nodes: + if node_id in node_dict: + total_score += node_dict[node_id].score + + return total_score + + +def verify(frames: List[Frame], task: TaskSpec, options: Optional[VerifierOptions] = None) -> VerifyResult: + """给定帧与任务,判断是否存在符合拓扑依赖的满足路径。 + + 算法:拓扑序 DP,记录每个节点的最小可行帧索引(满足依赖且该节点匹配)。 + """ + dag = DAG(task.nodes) + + # 确定成功节点集合(用于路径分析) + succ_nodes: List[str] + if task.success: + if task.success.any_of: + succ_nodes = task.success.any_of + elif task.success.all_of: + succ_nodes = task.success.all_of + else: + # 空 success 定义:视作 sinks + succ_nodes = dag.sinks() + else: + succ_nodes = dag.sinks() + + # 输出可能的路径到日志 + verifier_logger = get_verifier_logger() + dag.log_possible_paths(succ_nodes, verifier_logger) + + logs: List[DecisionLog] = [] + cands = _collect_candidates(frames, task, options, logs) + + topo = dag.topo_order() + min_idx: Dict[str, Optional[int]] = {nid: None for nid in dag.nodes} + prev: Dict[str, Optional[str]] = {nid: None for nid in dag.nodes} # 用于回溯路径(记录选择的父节点) + + for nid in topo: + node = dag.nodes[nid] + deps = node.deps or [] + # 依赖与路径父节点: + # - 若声明了 deps,则采用 AND 语义(保持兼容); + # - 否则,若存在由 next 形成的父节点集合,则采用 OR 语义(任一父成功即可)。 + dep_idx = 0 + chosen_parent: Optional[str] = None + if deps: + latest_dep = -1 + latest_src = None + for d in deps: + if min_idx[d] is None: + latest_dep = None + break + if min_idx[d] > latest_dep: + latest_dep = min_idx[d] # type: ignore + latest_src = d + if latest_dep is None: + min_idx[nid] = None + continue + dep_idx = latest_dep + chosen_parent = latest_src + else: + # OR 语义父节点(来自 next) + or_parents = dag.parents_from_next.get(nid, []) + if or_parents: + # 取最早完成的父节点作为起点 + available = [(p, min_idx[p]) for p in or_parents if min_idx.get(p) is not None] + if not available: + min_idx[nid] = None + continue + chosen_parent, parent_idx = min(available, key=lambda x: x[1]) # type: ignore + dep_idx = int(parent_idx) # type: ignore + else: + # 无依赖与无路径父节点,视作根节点 + dep_idx = 0 + chosen_parent = None + + # 从候选中找第一个满足顺序的帧 + hit = _min_feasible_index(cands.get(nid, []), dep_idx) + if hit is not None: + min_idx[nid] = hit + # 记录采用的父节点(用于路径回溯) + prev[nid] = chosen_parent + + # 决定成功节点集合 + succ_nodes: List[str] + if task.success: + if task.success.any_of: + succ_nodes = task.success.any_of + ok = any(min_idx.get(n) is not None for n in succ_nodes) + elif task.success.all_of: + succ_nodes = task.success.all_of + ok = all(min_idx.get(n) is not None for n in succ_nodes) + else: + # 空 success 定义:视作 sinks + succ_nodes = dag.sinks() + ok = any(min_idx.get(n) is not None for n in succ_nodes) + else: + succ_nodes = dag.sinks() + ok = any(min_idx.get(n) is not None for n in succ_nodes) + + if not ok: + # 即使最终成功条件不满足,也回溯已经匹配的节点 + matched: List[Tuple[int, str]] = [] # (frame_idx, node_id) + + # 找出所有成功匹配的节点(不论是否达到最终成功条件) + successful_nodes = [] + for nid in dag.nodes: + if min_idx.get(nid) is not None: + matched.append((min_idx[nid], nid)) # type: ignore + successful_nodes.append(nid) + + verifier_logger = get_verifier_logger() + verifier_logger.debug(f"成功匹配的节点: {successful_nodes}") + verifier_logger.debug(f"成功节点要求: {succ_nodes}") + verifier_logger.debug(f"min_idx状态: {min_idx}") + + matched.sort(key=lambda x: x[0]) + + # 计算已匹配节点的总分 + matched_node_ids = [nid for idx, nid in matched] + total_score = _calculate_total_score(matched_node_ids, task) + + # 构建详细的失败原因,包含最后一个检查的节点信息 + detailed_reason = "no feasible success path" + if logs: + # 找到最后一次检查的日志 + last_log = logs[-1] + if last_log.checker_result: + detailed_reason += f" (最后检查节点 {last_log.node_id}: {last_log.checker_result})" + elif last_log.unmatched_keywords: + detailed_reason += f" (最后检查节点 {last_log.node_id}: 未匹配关键词 {last_log.unmatched_keywords})" + + # 若存在某些节点完全无法判断(无命中且存在 escalation/高阶策略未配置),标记人工复核 + manual = any( + (n.condition and n.condition.type in ("escalate",)) for n in task.nodes + ) and (options is None or (options.llm is None and options.ocr is None)) + + return VerifyResult( + ok=False, + matched=[NodeMatch(node_id=nid, frame_index=idx) for idx, nid in matched], + reason=detailed_reason, + logs=logs, + manual_review_needed=manual, + total_score=total_score + ) + + # 回溯出一条可行路径: + # 若 any_of,则取 min_idx 最小的那个;若 all_of,则回溯每个并合并。 + matched: List[Tuple[int, str]] = [] # (frame_idx, node_id) + + def backtrack(start: str): + chain = [] + cur = start + while cur is not None and min_idx.get(cur) is not None: + chain.append((min_idx[cur], cur)) # type: ignore + cur = prev.get(cur) + # 依赖方向回溯得到自底向上链条,反转 + chain.reverse() + return chain + + if task.success and task.success.all_of: + added = set() + for n in succ_nodes: + if min_idx.get(n) is not None: + for item in backtrack(n): + if item[1] not in added: + matched.append(item) + added.add(item[1]) + else: + # any_of 或默认 sinks:选择最早完成的一个成功节点 + candidate_succ = [n for n in succ_nodes if min_idx.get(n) is not None] + target = min(candidate_succ, key=lambda n: min_idx[n]) + matched = backtrack(target) + + matched.sort(key=lambda x: x[0]) + + # 计算匹配节点的总分 + matched_node_ids = [nid for idx, nid in matched] + total_score = _calculate_total_score(matched_node_ids, task) + + # 构建成功的详细原因,包含最后一个成功节点的信息 + detailed_reason = None + if logs: + # 找到最后一次成功的检查日志 + success_logs = [log for log in logs if log.decision == "hit"] + if success_logs: + last_success_log = success_logs[-1] + if last_success_log.checker_result: + detailed_reason = f"任务验证成功 (最后成功节点 {last_success_log.node_id}: {last_success_log.checker_result})" + elif last_success_log.matched_keywords: + detailed_reason = f"任务验证成功 (最后成功节点 {last_success_log.node_id}: 匹配关键词 {last_success_log.matched_keywords})" + + return VerifyResult( + ok=True, + matched=[NodeMatch(node_id=nid, frame_index=idx) for idx, nid in matched], + reason=detailed_reason, + logs=logs, + manual_review_needed=False, + total_score=total_score + ) + + +def verify_task(task_path: str, trace_path: str) -> VerifyResult: + task = load_task(task_path) + with open(trace_path, "r", encoding="utf-8") as f: + frames = json.load(f) + assert isinstance(frames, list) + return verify(frames, task) + + +def verify_task_folder(task_path: str, trace_folder: str, options: Optional[VerifierOptions] = None) -> VerifyResult: + task = load_task(task_path) + frames = load_frames_from_dir(trace_folder) + return verify(frames, task, options) + + +def make_llm_options(api_key: str, base_url: str, model: str = "google/gemini-2.5-flash", force_llm: bool = False, max_retries: int = 3, retry_delay: float = 1.0) -> VerifierOptions: + """构造带 LLM 回调的 VerifierOptions,使用 LangChain OpenAI 兼容接口。 + + 注意:不在库内硬编码 key;由调用方传入。 + + Args: + api_key: API密钥 + base_url: API基础URL + model: 模型名称 + force_llm: 是否强制使用LLM验证 + max_retries: LLM请求的最大重试次数 + retry_delay: 重试间隔(秒) + """ + try: + from langchain_openai import ChatOpenAI # type: ignore + except Exception: # pragma: no cover - 可选依赖 + def _llm(_ctx): + return None + return VerifierOptions(llm=_llm, force_llm_verification=force_llm) + + # 创建 LangChain LLM 客户端 + client = ChatOpenAI( + model=model, + api_key=api_key, + base_url=base_url, + temperature=0.2, + max_tokens=3000, + timeout=40 + ) + + def _llm(ctx: Dict[str, any]) -> Optional[bool]: # type: ignore + params = (ctx.get("params") or {}) + prompt = params.get("prompt") or "请判断该步骤是否达成预期。" + frame = ctx.get("frame") or {} + reasoning = frame.get("reasoning") or "" + action = frame.get("action") or {} + task_desc = frame.get("task_description") or "" + + # 获取当前帧和下一帧的图片 + current_image = frame.get("image") + next_frame = frame.get("_next") or {} + next_image = next_frame.get("image") + + # TODO: 必要时考虑增加上一帧图片 + prev_frame = frame.get("_prev") or {} + prev_image = prev_frame.get("image") + + # 构建图片内容列表 + image_contents = [] + current_and_next = False + prev_and_current = False + + if current_image and next_image: + current_and_next = True + if current_image and os.path.exists(current_image): + with open(current_image, "rb") as f: + import base64 + current_image_b64 = base64.b64encode(f.read()).decode() + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{current_image_b64}" + } + }) + + if next_image and os.path.exists(next_image): + with open(next_image, "rb") as f: + import base64 + next_image_b64 = base64.b64encode(f.read()).decode() + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{next_image_b64}" + } + }) + elif prev_image and current_image: + if prev_image and os.path.exists(prev_image): + prev_and_current = True + with open(prev_image, "rb") as f: + import base64 + prev_image_b64 = base64.b64encode(f.read()).decode() + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{prev_image_b64}" + } + }) + if current_image and os.path.exists(current_image): + with open(current_image, "rb") as f: + import base64 + current_image_b64 = base64.b64encode(f.read()).decode() + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{current_image_b64}" + } + }) + else: + llm_logger.error("状态截图均不可用,无法进行判断。") + return None + + # sys = "你是一个针对移动端操作的自动化验证助手,请主要基于提供的截图和上下文辅助谨慎判断任务是否正确完成,并返回JSON格式的结果。" + sys = """ + 作为移动端操作自动化验证专家,您需要严格基于视觉证据分析截图状态,结合上下文判断任务节点是否达成。 + 核心原则: + 1. 截图是主要判定依据,文本信息仅作辅助参考(可能包含错误) + 2. 需识别界面元素的视觉变化及状态转移 + 3. 对模糊场景保持保守判断 + 4. 如无必要,不要展开冗长推理 + """ + + # 构建文本内容 + if current_and_next: + text_content = ( + "## 关键节点验证任务\n" + f"**全局任务task_description**: {task_desc}\n" + f"**当前节点要求**: {prompt}\n\n" + "移动端执行参考(注意:可能包含错误)\n" + f"- 操作意图: {reasoning}\n" + f"- 执行动作: {action}\n" + "## 分析指令\n" + "1. **视觉对比分析**:\n" + " - 当前帧截图:识别准备状态是否满足任务节点要求(按钮/文本/图标等)\n" + " - 下一帧截图:检测操作后界面变化(状态变更/新元素/提示消失等)是否满足任务节点要求\n" + " - 对比两帧差异\n" + "2. **容错机制**:\n" + " - 若文本描述与截图状态冲突 → 以截图为准\n" + " - 若未观测到满足节点要求的状态或变化 → 返回no\n" + "3. **判定标准**:\n" + " ❌ 以下情况视为失败:\n" + " - 关键元素未变化(如按钮仍可点击)\n" + " - 出现错误提示/异常状态\n" + " - 界面变化不符合任务逻辑链条\n\n" + "## 输出要求\n" + "请严格按照以下JSON格式返回结果:\n" + '{"result": "yes", "reason": "简要说明判断原因"} 或 {"result": "no", "reason": "简要说明失败原因"}' + ) + elif prev_and_current: + text_content = ( + "## 关键节点验证任务\n" + f"**全局任务**: {task_desc}\n" + f"**当前节点要求**: {prompt}\n\n" + "移动端执行参考(注意:可能包含错误)\n" + f"- 操作意图: {reasoning}\n" + f"- 执行动作: {action}\n" + "## 分析指令\n" + "1. **视觉对比分析**:\n" + " - 上一帧截图:上一步操作前识别准备状态内容状态(按钮/文本/图标等)\n" + " - 当前帧截图:检测上步操作后、当前界面变化(状态变更/新元素/提示消失等)\n" + " - 对比两帧差异是否满足任务节点要求\n" + "2. **容错机制**:\n" + " - 若文本描述与截图状态冲突 → 以截图为准\n" + " - 若未观测到满足节点要求的状态或变化 → 返回no\n" + "3. **判定标准**:\n" + " ❌ 以下情况视为失败:\n" + " - 关键元素未变化(如按钮仍可点击)\n" + " - 出现错误提示/异常状态\n" + " - 界面变化不符合任务逻辑链条\n\n" + "## 输出要求\n" + "请严格按照以下JSON格式返回结果:\n" + '{"result": "yes", "reason": "简要说明判断原因"} 或 {"result": "no", "reason": "简要说明失败原因"}' + ) + ## 精简版 + # f"当前节点判断任务: {prompt}\n" + # f"总任务描述: {task_desc}\n" + # f"移动端推理(仅参考): {reasoning}\n" + # f"移动端动作(仅参考): {action}\n" + # "请主要基于提供的前一帧和当前帧状态截图,判断该步骤是否达成预期。\n" + # "“推理”和“动作”由移动端操作提供,仅参考,不一定正确,需实际按截图进一步分析。\n" + # "请严格按照以下JSON格式返回结果:\n" + # '{"result": "yes", "reason": "简要说明判断原因"} 或 {"result": "no", "reason": "简要说明失败原因"}' + ### + + # 构建消息内容 + message_content = [{"type": "text", "text": text_content}] + message_content.extend(image_contents) + + # LangChain 消息格式 + from langchain_core.messages import SystemMessage, HumanMessage + + messages = [ + SystemMessage(content=sys), + HumanMessage(content=message_content) + ] + + llm_logger = get_llm_logger() + llm_logger.debug(f"prompt with {len(image_contents)} images: {text_content}") + + # 重试配置 + max_retries = 3 # 默认值 + retry_delay = 1.0 # 默认值(秒) + + # 从上下文中获取options配置(如果有的话) + if ctx and "options" in ctx: + options = ctx["options"] + if hasattr(options, 'max_llm_retries'): + max_retries = options.max_llm_retries + if hasattr(options, 'llm_retry_delay'): + retry_delay = options.llm_retry_delay + + for attempt in range(max_retries): + response_text = None + try: + # 使用 LangChain 客户端调用 + resp = client.invoke(messages) + + # 检查响应是否有效 + if not resp or not resp.content: + if attempt < max_retries - 1: + llm_logger.warning(f"received empty or invalid response from LLM (attempt {attempt + 1}/{max_retries}), retrying...") + import time + time.sleep(retry_delay) + continue + else: + llm_logger.error(f"received empty or invalid response from LLM after {max_retries} attempts") + return None + + response_text = resp.content.strip() + + # 检查响应内容是否为空 + if not response_text: + if attempt < max_retries - 1: + llm_logger.warning(f"received empty response content (attempt {attempt + 1}/{max_retries}), retrying...") + import time + time.sleep(retry_delay) + continue + else: + llm_logger.error(f"received empty response content after {max_retries} attempts") + return None + + llm_logger.debug(f"raw response (attempt {attempt + 1}): {response_text}") + + # 尝试解析JSON响应 + try: + import json + import re + + def extract_json_from_text(text): + """从文本中提取JSON内容,处理包含```json标记的情况""" + # 移除可能的markdown代码块标记 + cleaned_text = text.strip() + + # 尝试匹配 ```json ... ``` 格式 + json_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', cleaned_text, re.DOTALL) + if json_block_match: + return json_block_match.group(1).strip() + + # 尝试匹配纯JSON格式(查找第一个完整的JSON对象) + json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned_text, re.DOTALL) + if json_match: + return json_match.group(0).strip() + + # 如果都没找到,返回原文本 + return cleaned_text + + # 先尝试直接解析 + try: + result_json = json.loads(response_text) + except json.JSONDecodeError: + # 如果直接解析失败,尝试提取JSON内容 + extracted_json = extract_json_from_text(response_text) + llm_logger.debug(f"extracted json: {extracted_json}") + result_json = json.loads(extracted_json) + + result = result_json.get("result", "").lower() + reason = result_json.get("reason", "") + + llm_logger.debug(f"parsed result: {result}, reason: {reason}") + + # 在frame中记录LLM检查结果 + llm_result_record = { + 'success': result == "yes", + 'reason': reason + } + frame['_last_llm_result'] = llm_result_record + + if result == "yes": + return True + elif result == "no": + return False + elif result == "": + # 结果为空,可能需要重试 + if attempt < max_retries - 1: + llm_logger.warning(f"received empty result field (attempt {attempt + 1}/{max_retries}), retrying...") + import time + time.sleep(retry_delay) + continue + else: + llm_logger.warning(f"unexpected empty result value after {max_retries} attempts") + frame['_last_llm_result'] = { + 'success': False, + 'reason': 'LLM返回空结果' + } + return None + else: + # 结果值不是预期的,但不重试,直接返回None + llm_logger.warning(f"unexpected result value: {result}") + frame['_last_llm_result'] = { + 'success': False, + 'reason': f'LLM返回异常结果: {result}' + } + return None + + except (json.JSONDecodeError, AttributeError, TypeError) as e: + # JSON解析失败,尝试从文本中提取结果 + llm_logger.warning(f"JSON parsing failed ({str(e)}), trying text extraction from: {response_text}") + text_lower = response_text.lower() + if "yes" in text_lower and "no" not in text_lower: + llm_logger.debug("extracted result: yes (from text)") + frame['_last_llm_result'] = { + 'success': True, + 'reason': 'LLM返回yes (从文本提取)' + } + return True + elif "no" in text_lower and "yes" not in text_lower: + llm_logger.debug("extracted result: no (from text)") + frame['_last_llm_result'] = { + 'success': False, + 'reason': 'LLM返回no (从文本提取)' + } + return False + else: + # 如果文本提取也失败,且还有重试机会,则重试 + if attempt < max_retries - 1: + llm_logger.warning(f"unable to extract clear result from response (attempt {attempt + 1}/{max_retries}), retrying...") + import time + time.sleep(retry_delay) + continue + else: + llm_logger.warning(f"unable to extract clear result from response after {max_retries} attempts: {response_text}") + frame['_last_llm_result'] = { + 'success': False, + 'reason': 'LLM响应无法解析' + } + return None + + except Exception as e: + if attempt < max_retries - 1: + llm_logger.warning(f"LLM call failed (attempt {attempt + 1}/{max_retries}): {e}, retrying...") + import time + time.sleep(retry_delay) + continue + else: + llm_logger.error(f"LLM call failed after {max_retries} attempts: {e}") + frame['_last_llm_result'] = { + 'success': False, + 'reason': f'LLM调用失败: {str(e)}' + } + return None + + # 如果到这里说明所有重试都失败了 + llm_logger.error("All retry attempts exhausted") + frame['_last_llm_result'] = { + 'success': False, + 'reason': 'LLM所有重试尝试均失败' + } + return None + + return VerifierOptions( + llm=_llm, + force_llm_verification=force_llm, + max_llm_retries=max_retries, + llm_retry_delay=retry_delay + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Verify LLM mobile agent task by DAG") + parser.add_argument("task", help="path to task yaml/json") + parser.add_argument("trace", help="path to trace json frames") + args = parser.parse_args() + + res = verify_task(args.task, args.trace) + print(json.dumps({ + "ok": res.ok, + "matched": [{"node": m.node_id, "frame": m.frame_index} for m in res.matched], + "reason": res.reason, + }, ensure_ascii=False, indent=2)) diff --git a/MobiFlow/docs/CHECKER_MODES.md b/MobiFlow/docs/CHECKER_MODES.md new file mode 100644 index 0000000..37052b5 --- /dev/null +++ b/MobiFlow/docs/CHECKER_MODES.md @@ -0,0 +1,208 @@ +# 检查器模式详细说明 + +## 概述 + +验证框架现在支持两种主要的检查器组合模式: + +1. **escalate**: 升级模式 - 按顺序尝试,任意一个成功即返回成功 +2. **juxtaposition**: 并列模式 - 所有配置的检查器都必须成功 + +## 1. Escalate 模式(升级模式) + +### 工作原理 +- 按照 `escalation_order` 的顺序依次尝试配置的检查器 +- **任意一个检查器返回 True,立即返回 True**(短路求值) +- 适用于"宽松验证"场景,有多种验证方式但只需要其中一种成功 + +### 默认升级顺序 +```python +escalation_order: ["text", "regex", "action", "icons", "ocr", "llm"] +``` + +### 配置示例 +```yaml +condition: + type: escalate + params: + # 第1优先级:简单文本匹配 + text: + any: ["搜索", "查找"] + + # 第2优先级:正则表达式 + regex: + pattern: ".*(搜索|查询).*" + ignore_case: true + + # 第2优先级:动作验证 + action: + type: click + + # 第4优先级: icons图标识别 + icons: + any: ["购物车"] + # 第5优先级:OCR识别 + ocr: + any: ["搜索"] + + # 第6优先级:LLM验证 + llm: + prompt: "该步是否执行了搜索?" +``` + +### 执行流程 +1. 检查 `text` 配置,如果匹配成功 → 立即返回 True +2. 如果 `text` 失败,检查 `regex`,如果成功 → 立即返回 True +3. 如果 `regex` 失败,检查 `ui`,如果成功 → 立即返回 True +4. 依此类推... +5. 当只配置某种策略如ocr或llm时,则只使用该方式 +6. 如果所有配置的检查器都失败 → 返回 False + +### 适用场景 +- 搜索操作验证(可能通过文本、UI状态或LLM任一方式确认) +- 页面跳转验证(可能通过多种方式检测) +- 灵活的操作确认 + +--- + +## 2. Juxtaposition 模式(并列模式) + +### 工作原理 +- **所有配置的检查器都必须返回 True**,才认为验证成功 +- 任意一个检查器失败,整个验证失败 +- 适用于"严格验证"场景,需要多重确认 + +### 配置示例 +```yaml +condition: + type: juxtaposition + params: + # 必须满足:文本包含关键词 + text: + any: ["确认", "提交"] + + # 必须满足:是点击动作 + action: + type: click + + # 必须满足:UI状态正确 + ui: + key: screen + equals: confirm_page + + # 必须满足:OCR识别成功 + ocr: + all: ["确认", "提交"] + + # 必须满足:LLM确认 + llm: + prompt: "该步是否点击了确认按钮?" +``` + +### 执行流程 +1. 执行所有配置的检查器 +2. 收集所有检查器的结果 +3. 只有当**所有结果都是 True** 时,才返回 True +4. 任意一个结果为 False,返回 False + +### 适用场景 +- 关键操作的严格验证(如支付、确认订单) +- 需要多重确认的敏感操作 +- 复合条件验证(必须同时满足多个条件) + +--- + +## 3. 支持的检查器类型 + +两种模式都支持以下检查器: +优先推荐使用icons、ocr和llm检测,仅基于trace的完整截图就能够判断。 + +| 检查器 | 参数格式 | 说明 | +|--------|----------|------| +| `text` | `{any: [...], all: [...]}` | 文本包含匹配 | +| `regex` | `{pattern: "...", ignore_case: bool}` | 正则表达式匹配 | +| `ui` | `{key: "...", equals: "...", in: [...]}` | UI状态检查 | +| `action` | `{type: "...", contains: {...}}` | 动作类型验证 | +| `xml` | `{any: [...], all: [...]}` | XML文本匹配 | +| `ocr` | `{any: [...], all: [...], pattern: "..."}` | OCR图像识别 | +| `llm` | `{prompt: "...", expected_true: bool}` | LLM智能验证 | + +--- + + +## 4. 自定义升级顺序 + +可以通过 `VerifierOptions` 自定义 escalate 的执行顺序: + +```python +from avdag.types import VerifierOptions + +# 自定义顺序:优先UI检查,然后文本,最后LLM +custom_options = VerifierOptions( + escalation_order=["ui", "ocr", "llm"] +) +``` + +--- + +## 6. 实际使用建议 + +### 使用 Escalate 的场景 +- 操作验证有多种可能的确认方式 +- 需要从简单到复杂的渐进式验证 +- 性能要求较高,希望早期匹配成功 + +### 使用 Juxtaposition 的场景 +- 关键操作需要多重确认 +- 必须同时满足多个严格条件 +- 需要确保验证的可靠性和准确性 + +### 组合使用 +在复杂任务中,可以在不同节点使用不同模式: + +```yaml +nodes: + - id: search_input + condition: + type: escalate # 搜索输入验证相对宽松 + params: + text: {any: ["搜索"]} + action: {type: input} + + - id: payment_confirm + condition: + type: juxtaposition # 支付确认必须严格验证 + params: + text: {all: ["支付", "确认"]} + action: {type: click} + ocr: {all: ["支付", "确认"]} + llm: {prompt: "是否点击了支付确认?"} +``` + +--- + +## 7. 调试输出 + +两种模式都提供详细的调试输出: + +### Escalate 模式输出 +``` +[Escalate] 升级顺序: ['text', 'regex', 'ui', 'action', 'dynamic_match', 'ocr', 'llm'] +[Escalate] 配置的检查器: ['text', 'ui', 'llm'] +[Escalate] 尝试检查器: text +[Escalate] text 检查结果: False +[Escalate] 尝试检查器: ui +[Escalate] ui 检查结果: True +[Escalate] ui 检查成功,立即返回True +``` + +### Juxtaposition 模式输出 +``` +[Juxtaposition] text_match 结果: True +[Juxtaposition] ui_flag 结果: True +[Juxtaposition] action_match 结果: False +[Juxtaposition] 配置的检查器: ['text_match', 'ui_flag', 'action_match'] +[Juxtaposition] 各检查器结果: [True, True, False] +[Juxtaposition] 最终结果: False +``` + +这种详细的输出有助于理解验证过程和调试配置问题。 diff --git "a/MobiFlow/docs/OCR\350\257\206\345\210\253\346\224\271\350\277\233.md" "b/MobiFlow/docs/OCR\350\257\206\345\210\253\346\224\271\350\277\233.md" new file mode 100644 index 0000000..22cec2d --- /dev/null +++ "b/MobiFlow/docs/OCR\350\257\206\345\210\253\346\224\271\350\277\233.md" @@ -0,0 +1,83 @@ +# 增强版OCR识别辅助 + +把“识别不稳定导致漏匹配”的链路补强: + +实现多引擎+多预处理+滑窗补充+文本多视图匹配,并把 OCR 与 XML 文本融合,外加缓存与排序稳定化,尽量消除漏报与波动。 + +## 目标清单 +- 提高识别召回,尽可能覆盖图中真实文字元素。 +- 降低多次识别结果不一致的波动。 +- 匹配阶段更稳健,减少因轻微噪声/断词/空格导致的漏判。 +- 在 OCR 失败时自动回落到 XML 文本。 + +## 已做改进(核心方案) +- 多引擎联合识别 + - 同时准备 PaddleOCR 与 Tesseract,两套引擎都跑,合并结果(引擎差异互补)。 +- 多预处理变体 + - 对同一图像生成多种视图:原图、放大1.5x/2.0x、灰度、锐化、增强对比度、二值化;对每个视图分别做 OCR 并合并。 +- 轻量滑窗补充 + - 若全局结果词数偏少,自动启用2×2重叠的滑窗裁剪识别,补充小字或局部区域漏识别。 +- 文本标准化与多视图匹配 + - 统一全角转半角、大小写折叠、常见易混字符归一(O/0、l/I/1、S/5、B/8、中文竖线等)。 + - 生成多视图:cleaned(清洗)、no_spaces(去空格)、words(分词)、chars(逐字序列),并支持: + - 连续(no_spaces)匹配 + - 子词包含匹配 + - 中文逐字按序匹配 + - RapidFuzz 部分相似度匹配(可选,80分阈值) +- OCR + XML 融合、逐级降级 + - 若图像不可用或 OCR 为空,则提取 XML 文本(text/content-desc/hint + 资源 ID 可读化)。 + - 在 frame 层合并 OCR 与 XML 两侧结果用于匹配,尽量不丢任何候选文本。 +- 结果缓存与稳定排序 + - 基于“路径+mtime+size”的键缓存识别的词列表,多次调用返回稳定集合。 + - 最终词集合按字典序稳定排序,减少“总词数不一致”的波动。 +- 运行日志更可诊断 + - 打印合并前后词数、样例片段,便于定位具体漏项。 + +## 关键文件改动 +- ocr_processor.py + - 新增:多引擎、多预处理、滑窗补充、文本标准化与智能匹配逻辑、缓存。 + - create_frame_ocr_function / create_frame_texts_function 现会融合 XML 与 OCR 的文本,并把多视图结果回填进 frame(_ocr_processed/_xml_processed)。 + - 提供智能匹配 smart_text_contains,支持空格/断词鲁棒与模糊匹配。 +- 条件器无需改动即可使用,因为 OCRMatch 已读取 frame 的处理结果并做智能匹配;我保留了兼容路径并输出调试信息。 + +## 如何使用 +- 标准用法保持不变:通过 create_standard_ocr_functions() 注入 OCR。 + + +## 效果 +- 召回:多引擎+多预处理+滑窗让原先“有字但没识别”的情况显著减少。 +- 稳定:缓存+稳定排序消除多次调用总词数波动。 +- 匹配:无空格/分词/逐字/模糊等多策略配合,显著降低因 OCR 断词、空格、轻微错字导致的漏报。 + +## 质量门(快速) +- 语法/运行:已执行增强匹配小测,所有断言通过。 +- 依赖:Pillow 已在工程中使用;RapidFuzz 属可选(缺失时自动降级到基础匹配)。 + +## 后续可选增强 +- 利用 XML 坐标/控件层级对感兴趣区域做 ROI 识别(更精确的滑窗,成本更低)。 +- 小模型语言归一(中英文数词混排的词典纠错)。 +- 引入得分投票:多引擎/多视图为每个词累积分数,匹配时按词频/置信度加权。 +- 参数开关:通过环境变量控制滑窗启用阈值与模糊匹配阈值,便于性能—召回折中。 + +## 变更 +- “尽可能确保图中有的文字元素可匹配”:Done(多引擎、多预处理、滑窗、融合)。 +- “降低多次识别结果不一致”:Done(缓存+稳定排序)。 +- “匹配更鲁棒、避免漏报”:Done(标准化+多策略匹配+XML回退/融合)。 + +## 使用时的注意事项 + +当前的页面关键词OCR辅助识别,已经具备页面内元素识别的较高准确性。但由于页面的排版设计、元素UI样式各异等因素,例如对于“我的淘宝”,“马上抢红包”等关键词,不一定保证识别结果是联系的短句,如果采用 + + +![alt text](./images/识别示例.png) + +识别结果可能为: +``` +y999 400 m0u mnes 8er ease dem ee5 a 00 d0 视频 消息 购物 车 我 的 三 e g6 eg 一 5u0men 淘 +``` +因此,匹配时选取**针对词语**的组合,如 +``` +all:["我的","淘宝"] +all:[“马上","抢","红包”] +``` +等拆分模式匹配,能够显著提高匹配成功率。 diff --git a/MobiFlow/docs/UNIVERSAL_TEST_RUNNER.md b/MobiFlow/docs/UNIVERSAL_TEST_RUNNER.md new file mode 100644 index 0000000..d2ba334 --- /dev/null +++ b/MobiFlow/docs/UNIVERSAL_TEST_RUNNER.md @@ -0,0 +1,315 @@ +# 通用任务测试执行器 + +## 概述 + +通用任务测试执行器是一个灵活的测试框架,可以通过 JSON 配置文件轻松配置和执行不同任务的自动化测试。 + +## 核心特性 + +- **灵活配置**: 通过 JSON 配置文件定义任务类型、规则文件、数据目录等 +- **多种测试模式**: 支持测试所有类型、指定类型、指定 trace 等 +- **详细日志**: 自动记录测试过程中的所有输出到指定文件 +- **结果汇总**: 生成详细的测试报告,包括成功率、匹配节点、失败原因等 +- **易于扩展**: 新增任务只需添加对应的配置文件 + +## 使用方法 + +### 1. 基本用法 + +```bash +# 测试所有类型 +python universal_test_runner.py task_configs/taobao.json + +# 测试指定类型 +python universal_test_runner.py task_configs/taobao.json type3 + +# 测试指定类型的指定trace +python universal_test_runner.py task_configs/taobao.json type3:150 + +# 测试指定的trace编号 +python universal_test_runner.py task_configs/taobao.json 150,151,152 +``` + +### 2. 配置文件示例 + +#### 淘宝任务配置 (task_configs/taobao.json) + +```json +{ + "task_name": "taobao", + "description": "淘宝任务测试配置", + + "rules_base_dir": "task_rules/taobao", + "data_base_dir": "data", + + "task_types": { + "3": { + "name": "加购物车任务", + "rule_file": "type3-taobao_add_cart-new.yaml", + "data_traces": [150, 151, 152, 153, 154, 155], + "description": "添加商品到购物车" + }, + "4": { + "name": "排序/筛选任务", + "rule_file": "type4-taobao_add_cart.yaml", + "data_traces": [120, 121, 122, 123, 124, 125], + "description": "商品排序和筛选功能" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} +``` + +#### 小红书任务配置 (task_configs/xiaohongshu.json) + +```json +{ + "task_name": "xiaohongshu", + "description": "小红书任务测试配置", + + "rules_base_dir": "task_rules/xiaohongshu", + "data_base_dir": "data/xiaohongshu", + + "task_types": { + "2": { + "name": "type2任务", + "rule_file": "xiaohongshu-type2.yaml", + "data_traces": "type2", + "description": "小红书type2功能测试" + }, + "3": { + "name": "type3任务", + "rule_file": "xiaohongshu-type3.yaml", + "data_traces": "type3", + "description": "小红书type3功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} +``` + +## 配置文件说明 + +### 基本配置 + +- `task_name`: 任务名称,用于生成日志和结果文件名 +- `description`: 任务描述 +- `rules_base_dir`: 规则文件基础目录 +- `data_base_dir`: 数据文件基础目录 + +### 任务类型配置 + +每个任务类型包含: +- `name`: 类型显示名称 +- `rule_file`: 对应的规则文件名 +- `data_traces`: 测试数据配置,支持多种格式: + - **数字列表**: `[150, 151, 152]` - 明确指定trace编号 + - **字符串**: `"type2"` - 指定type目录名 + - **空列表**: `[]` - 自动发现可用traces + - **不配置**: 完全不包含`data_traces`字段 - 自动发现可用traces +- `description`: 类型描述 + +#### data_traces 配置详解 + +1. **明确指定trace编号列表** (最高优先级) + ```json + "data_traces": [150, 151, 152, 153, 154] + ``` + 系统会测试编号为 150、151、152、153、154 的trace。 + +2. **指定type目录** + ```json + "data_traces": "type2" + ``` + 系统会在 `data_base_dir/type2` 目录下查找测试数据。 + +3. **自动发现模式** + ```json + "data_traces": [] + ``` + 或者完全不配置 `data_traces` 字段,系统会自动扫描 `data_base_dir` 目录: + - 优先查找 `type{task_type}` 格式的目录 + - 然后查找数字编号的目录 + - 最后包含其他匹配的目录 + +4. **优先级规则** + - 如果配置了 `data_traces` 且不为空,优先使用配置的值 + - 如果配置为空或未配置,则启用自动发现模式 + - 如果配置的路径不存在,会回退到自动发现模式 + +### 测试选项 + +- `enable_ocr`: 启用 OCR 功能 +- `enable_llm`: 启用 LLM 功能 +- `force_llm`: 强制使用 LLM +- `ocr_frame_exclusive`: OCR 帧独占模式 +- `llm_frame_exclusive`: LLM 帧独占模式 +- `prevent_frame_backtrack`: 防止帧回退 + +### 日志配置 + +- `level`: 日志级别 (DEBUG, INFO, WARNING, ERROR) +- `use_colors`: 使用彩色输出 +- `show_time`: 显示时间戳 +- `show_module`: 显示模块名 +- `output_file`: 日志文件名模板 + +### 输出配置 + +- `summary_file`: 汇总文件名模板 +- `detailed_results_file`: 详细结果文件名模板 + +## 输出文件说明 + +### 日志文件 + +记录测试过程中的所有详细信息,包括: +- 系统初始化信息 +- 每个测试用例的执行过程 +- 验证结果和匹配详情 +- 错误信息和警告 + +### 汇总文件 (.txt) + +人类可读的测试结果汇总,包括: +- 测试基本信息 +- 总体成功率 +- 分类型结果统计 +- 每个测试用例的详细结果 + +### 详细结果文件 (.json) + +机器可读的详细结果数据,包括: +- 完整的测试配置 +- 每个测试用例的详细结果 +- 执行时间统计 +- 错误信息 + +## 新增任务步骤 + +1. **创建配置文件**: 在 `task_configs/` 目录下创建新的 JSON 配置文件 +2. **配置规则目录**: 在 `task_rules/` 下创建对应的规则文件目录 +3. **准备测试数据**: 在 `data/` 下准备对应的测试数据 +4. **运行测试**: 使用新配置文件运行测试 + +## 示例输出 + +### 控制台输出 + +``` +=== 通用任务测试执行器 === +任务名称: taobao +任务描述: 淘宝任务测试配置 +日志文件: test-taobao-20240822_143025.log +汇总文件: test-taobao-summary-20240822_143025.txt +详细结果: test-taobao-detailed-20240822_143025.json + +--- 测试 150 [加购物车任务] --- +规则文件: type3-taobao_add_cart-new.yaml +数据路径: /path/to/data/150 +验证结果: ✓ 成功 +匹配节点: ['search', 'add_to_cart'] +任务得分: 100.0分 +执行时间: 2.34秒 + +--- 类型 3 汇总 --- +trace 150: ✓ | score: 100.0 | nodes: ['search', 'add_to_cart'] | reason: +trace 151: ✗ | score: 60.0 | nodes: ['search'] | reason: 未找到加购物车操作 +成功率: 1/2 (50.0%) +``` + +### 汇总文件示例 + +``` +任务测试汇总报告 +============================================================ +任务名称: taobao +测试时间: 2024-08-22 14:30:25 +配置文件: task_configs/taobao.json +总测试数: 10 +总成功数: 8 +总成功率: 80.0% +总执行时间: 25.67秒 + +分类型结果: +---------------------------------------- +类型 3 (加购物车任务): + 测试数: 5 + 成功数: 4 + 成功率: 80.0% + +类型 4 (排序/筛选任务): + 测试数: 5 + 成功数: 4 + 成功率: 80.0% +``` + +## 扩展和定制 + +### 自定义验证选项 + +可以在配置文件的 `test_options` 中添加更多验证选项,然后在 `UniversalTestRunner._create_verifier_options()` 方法中处理。 + +### 自定义输出格式 + +可以修改 `save_results()` 方法来支持更多输出格式,如 CSV、XML 等。 + +### 添加新的测试模式 + +可以在 `main()` 函数中添加更多参数解析逻辑来支持新的测试模式。 + +## 常见问题 + +### Q: 如何添加新任务? +A: 创建新的配置文件,配置规则目录和数据目录,然后运行测试。 + +### Q: 如何修改日志级别? +A: 在配置文件的 `logging.level` 中设置,支持 DEBUG、INFO、WARNING、ERROR。 + +### Q: 如何禁用 LLM 或 OCR? +A: 在配置文件的 `test_options` 中设置 `enable_llm` 或 `enable_ocr` 为 false。 + +### Q: 测试结果保存在哪里? +A: 根据配置文件中的 `output` 部分设置,默认保存在当前目录下,文件名包含时间戳。 diff --git "a/MobiFlow/docs/images/\350\257\206\345\210\253\347\244\272\344\276\213.png" "b/MobiFlow/docs/images/\350\257\206\345\210\253\347\244\272\344\276\213.png" new file mode 100644 index 0000000..5d187e7 Binary files /dev/null and "b/MobiFlow/docs/images/\350\257\206\345\210\253\347\244\272\344\276\213.png" differ diff --git "a/MobiFlow/docs/\346\227\245\345\277\227\347\263\273\347\273\237\344\275\277\347\224\250.md" "b/MobiFlow/docs/\346\227\245\345\277\227\347\263\273\347\273\237\344\275\277\347\224\250.md" new file mode 100644 index 0000000..8897ffb --- /dev/null +++ "b/MobiFlow/docs/\346\227\245\345\277\227\347\263\273\347\273\237\344\275\277\347\224\250.md" @@ -0,0 +1,244 @@ +# 日志系统使用 + +## 概述 + +实现了统一的日志系统,用于替代原有的 `print` 语句,提供灵活的日志级别控制和输出格式配置。 + +## 日志级别 + +支持以下日志级别(从低到高): + +1. **TRACE** (5) - 最详细的跟踪信息,用于深度调试 +2. **DEBUG** (10) - 调试信息,开发时使用 +3. **INFO** (20) - 一般信息,默认显示(默认级别) +4. **WARNING** (30) - 警告信息,需要注意但不影响正常运行 +5. **ERROR** (40) - 错误信息,出现问题但程序可以继续 +6. **CRITICAL** (50) - 关键错误,可能导致程序无法继续 + +## 配置方式 + +### 1. 环境变量配置 + +设置环境变量 `AVDAG_LOG_LEVEL` 来控制日志级别: + +```bash +export AVDAG_LOG_LEVEL=DEBUG +python your_script.py +``` + +### 2. 代码配置 + +在代码中直接配置: + +```python +from avdag.logger import configure_logging, set_log_level + +# 方式1:完整配置 +configure_logging( + level="DEBUG", # 日志级别 + use_colors=True, # 使用颜色输出 + show_time=True, # 显示时间 + show_module=True, # 显示模块名 + output_file="debug.log" # 输出到文件(可选) +) + +# 方式2:仅设置级别 +set_log_level("DEBUG") +``` + +### 3. 配置文件 + +创建 `logging_config.json` 文件: + +```json +{ + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "logs/debug.log" +} +``` + +然后在代码中加载: + +```python +from avdag.logger import configure_logging + +configure_logging(config_file="logging_config.json") +``` + +或者通过环境变量指定配置文件: + +```bash +export AVDAG_LOG_CONFIG=logging_config.json +python your_script.py +``` + +## 使用方式 + +### 基本使用 + +```python +from avdag.logger import get_logger + +logger = get_logger(__name__) + +logger.trace("最详细的跟踪信息") +logger.debug("调试信息") +logger.info("一般信息") +logger.warning("警告信息") +logger.error("错误信息") +logger.critical("关键错误") +``` + +### 预定义日志器 + +项目提供了一些预定义的日志器: + +```python +from avdag.logger import ( + get_verifier_logger, # 验证器日志 + get_ocr_logger, # OCR处理日志 + get_llm_logger, # LLM调用日志 + get_frame_logger, # 帧处理日志 + get_condition_logger # 条件检查日志 +) + +verifier_logger = get_verifier_logger() +verifier_logger.info("验证开始") + +ocr_logger = get_ocr_logger() +ocr_logger.debug("OCR识别中...") +``` + +### 兼容性函数 + +为了与现有代码兼容,提供了一些便捷函数: + +```python +from avdag.logger import debug_print, info_print, error_print, warning_print + +debug_print("这是调试信息", "VERIFIER") +info_print("这是一般信息", "OCR") +error_print("这是错误信息", "LLM") +warning_print("这是警告信息", "FRAME") +``` + +## 输出格式 + +### 控制台输出(带颜色) + +``` +14:30:25 [INFO] avdag.verifier 验证任务开始 +14:30:25 [DEBUG] avdag.frame 节点 search_box: 找到 2 个候选帧 [1, 3] +14:30:26 [ERROR] avdag.ocr PaddleOCR识别失败: 模型文件不存在 +14:30:26 [WARNING] avdag.llm LLM调用超时,使用备用策略 +``` + +### 文件输出 + +``` +2024-08-17 14:30:25 [INFO] avdag.verifier 验证任务开始 +2024-08-17 14:30:25 [DEBUG] avdag.frame 节点 search_box: 找到 2 个候选帧 [1, 3] +2024-08-17 14:30:26 [ERROR] avdag.ocr PaddleOCR识别失败: 模型文件不存在 +2024-08-17 14:30:26 [WARNING] avdag.llm LLM调用超时,使用备用策略 +``` + +## 实际应用示例 + +### 开发阶段 + +```bash +# 显示所有调试信息 +export AVDAG_LOG_LEVEL=DEBUG +python -m avdag.verifier examples/tasks/shop_search.yaml examples/traces/sample_shop_trace.json +``` + +### 生产环境 + +```bash +# 只显示重要信息和错误 +export AVDAG_LOG_LEVEL=WARNING +python your_production_script.py +``` + +### 问题调试 + +```bash +# 最详细的跟踪信息,输出到文件 +export AVDAG_LOG_LEVEL=TRACE +python your_script.py 2>&1 | tee debug.log +``` + +或者在代码中: + +```python +from avdag.logger import configure_logging + +configure_logging( + level="TRACE", + output_file="debug.log", + show_time=True, + show_module=True +) +``` + +### OCR调试 + +```bash +# 专门调试OCR问题 +export AVDAG_LOG_LEVEL=DEBUG +python examples/run_taobao_verify.py +``` + +### LLM调用调试 + +```python +from avdag.logger import get_llm_logger, set_log_level + +set_log_level("DEBUG") +llm_logger = get_llm_logger() + +# 在LLM调用前后记录详细信息 +llm_logger.debug(f"准备调用LLM,prompt长度: {len(prompt)}") +result = llm_call(prompt) +llm_logger.debug(f"LLM返回结果: {result}") +``` + +## 最佳实践 + +1. **模块化日志器**:每个模块使用独立的日志器 + ```python + logger = get_logger(__name__) + ``` + +2. **合适的日志级别**: + - `TRACE`: 最详细的执行流程 + - `DEBUG`: 中间变量值、状态变化 + - `INFO`: 重要的业务逻辑节点 + - `WARNING`: 异常情况但不影响继续执行 + - `ERROR`: 错误情况 + - `CRITICAL`: 严重错误 + +3. **性能考虑**:使用级别检查避免不必要的字符串格式化 + ```python + if logger.is_enabled_for("DEBUG"): + logger.debug(f"复杂的调试信息: {expensive_operation()}") + ``` + +4. **结构化日志**:使用一致的格式便于分析 + ```python + logger.info(f"任务 {task_id} 完成,耗时 {duration:.2f}s,成功率 {success_rate:.1%}") + ``` + +## 迁移指南 + +原有的 `print` 语句已经自动替换为相应的日志调用: + +- `print(f"[OCR] ...")` → `ocr_logger.info(...)` +- `print(f"[Frame Search] ...")` → `frame_logger.debug(...)` +- `print(f"[LLM] ...")` → `llm_logger.debug(...)` +- `print(f"DEBUG: ...")` → `verifier_logger.debug(...)` + +如果需要自定义日志行为,可以根据需要调整日志级别和格式。 diff --git "a/MobiFlow/docs/\351\252\214\350\257\201\346\241\206\346\236\266.md" "b/MobiFlow/docs/\351\252\214\350\257\201\346\241\206\346\236\266.md" new file mode 100644 index 0000000..3132a30 --- /dev/null +++ "b/MobiFlow/docs/\351\252\214\350\257\201\346\241\206\346\236\266.md" @@ -0,0 +1,265 @@ +## 验证工作流完整分析 +以下是当前框架中完整的验证工作流: + +### **框架架构概览** + +这是一个**基于DAG(有向无环图)的移动应用任务验证框架**,专门用于验证LLM智能体在移动端执行复杂任务的完成情况。 + +### **核心组件与职责** + +#### 1. **数据层** (types.py) + +- **Frame**: 表示执行轨迹中的一个时间帧,包含截图、XML、推理、动作等信息 +- **TaskSpec**: 任务配置规范,定义节点、依赖关系和成功条件 +- **VerifierOptions**: 验证选项,配置OCR、LLM等能力 +- **VerifyResult**: 验证结果,包含成功状态、匹配路径、人工复核标记等 + +#### 2. **任务加载层** (loader.py + trace_loader.py) + +- **任务配置加载**: 从YAML/JSON文件加载任务DAG定义 +- **轨迹数据加载**: 从目录结构提取多模态数据(图片+XML+动作+推理) +- **数据增强**: 自动组装文本字段,添加邻接上下文引用 + +#### 3. **条件检查层** (conditions.py) + +- **基础检查器**: `text_match`、`regex_match`、`ui_flag`、`xml_text_match`、`action_match` +- **高级检查器**: + - `escalate`: 多策略升级验证(text→regex→ui→xml→ocr→llm) + - `dynamic_match`: 动态条件匹配,从任务描述提取条件并验证操作 +- **注册机制**: 支持自定义检查器扩展 + +#### 4. **DAG计算层** (dag.py) + +- **依赖关系管理**: 构建节点间的父子关系图 +- **拓扑排序**: 确保依赖顺序的正确性 +- **环检测**: 验证DAG的有效性 + +#### 5. **核心验证层** (verifier.py) + +verifier: + + - 若节点声明了 deps,则仍按 AND 语义(取所有依赖的最晚命中帧作为起点) + - 若未声明 deps 且存在 next 来源的父节点,则按 OR 语义(任一父可行即可,取最早完成的父节点作为起点) 回溯时记录选择的父节点,恢复具体路径 + + +- **候选帧收集**: 为每个节点收集满足条件的帧索引 +- **动态规划算法**: 基于依赖约束计算每个节点的最小可行索引 +- **路径回溯**: 构建最优满足路径 +- **帧使用优化**: 支持帧独占和防回退机制 + +#### 6. **多模态处理层** (ocr_processor.py) + +- **OCR文字识别**: 集成app_trajectory_analyzer的OCR引擎 +- **智能文本处理**: 多格式文本预处理和匹配策略 +- **XML文本提取**: 从Android UI XML中提取所有可见文本 +- **降级策略**: OCR失败时自动降级到XML文本提取 + +### **验证工作流详细流程** + +#### **阶段1: 数据预处理** + +```python +# 1. 加载任务配置 +task = load_task("task.yaml") # 解析DAG节点和依赖关系 + +# 2. 加载执行轨迹 +frames = load_frames_from_dir("trace_folder/") # 提取图片、XML、动作、推理数据 + +# 3. 数据增强 +# - 为每帧添加索引和上下文引用 +# - 组装综合文本字段 +# - 提取UI元信息(包名等) +``` + +#### **阶段2: DAG依赖分析** + +```python +# 1. 构建DAG图 +dag = DAG(task.nodes) # 建立节点依赖关系 + +# 2. 拓扑排序 +topo_order = dag.topo_order() # 确定验证顺序 + +# 3. 环检测 +dag._assert_acyclic() # 确保DAG有效性 +``` + +#### **阶段3: 帧匹配与候选收集** + +```python +def _collect_candidates(frames, task, options): + cand = {} + used_frames = set() # 已使用帧集合 + + for node in task.nodes: + hits = [] + checker = get_checker(node.condition.type) + + # 帧独占优化:避免OCR/LLM重复使用同一帧 + for i, frame in enumerate(frames): + if needs_exclusive_access(node) and i in used_frames: + continue + + # 多策略验证 + if checker.check(frame, params, options): + hits.append(i) + if needs_exclusive_access(node): + used_frames.add(i) + break # 早停机制 + + cand[node.id] = hits + return cand +``` + +#### **阶段4: 条件检查策略** + +**4.1 Escalate策略升级流程**: + +```yaml +escalation_order: [text, regex, ui, action, dynamic_match, ocr, llm] +``` + +**4.2 Dynamic Match动态条件检查**: + +```python +# 从任务描述提取条件 +if "价格最低" in task_description: + # 在推理文本中验证是否执行了价格排序 + verify_in_reasoning("价格", "低到高", "便宜") +``` + +**4.3 OCR多模态处理**: + +```python +def frame_ocr(frame): + # 1. 优先使用OCR识别图片 + if image_exists: + text = ocr_engine.recognize(image_path) + if text: return processed_text + + # 2. 降级到XML文本提取 + xml_text = extract_text_from_xml(xml_content) + return xml_text +``` + +#### **阶段5: 动态规划求解** + +```python +def verify(frames, task, options): + # 1. 收集候选帧 + candidates = _collect_candidates(frames, task, options) + + # 2. DP计算最小可行索引 + min_idx = {} + for node_id in topo_order: + deps = get_dependencies(node_id) + + # 依赖约束:必须晚于所有依赖节点 + min_required = max(min_idx[dep] for dep in deps) if deps else 0 + + # 找到第一个满足顺序约束的候选帧 + feasible = _min_feasible_index(candidates[node_id], min_required) + min_idx[node_id] = feasible + + # 3. 检查成功条件 + success_satisfied = check_success_condition(min_idx, task.success) + + # 4. 路径回溯 + if success_satisfied: + path = backtrack_optimal_path(min_idx, task.success) + return VerifyResult(ok=True, matched=path) + else: + return VerifyResult(ok=False, manual_review_needed=True) +``` + +#### **阶段6: 结果生成与路径回溯** + +```python +# 1. 成功路径回溯 +if task.success.any_of: + # 选择最早完成的成功节点 + target = min(success_nodes, key=lambda n: min_idx[n]) + path = backtrack(target) + +elif task.success.all_of: + # 回溯所有必需节点的路径 + paths = [backtrack(node) for node in success_nodes] + path = merge_unique_paths(paths) + +# 2. 构建验证结果 +return VerifyResult( + ok=True, + matched=[NodeMatch(node_id=nid, frame_index=idx) for idx, nid in path], + logs=decision_logs, + manual_review_needed=False +) +``` + +### **关键创新特性** + +#### 1. **智能帧管理机制** + +- **帧独占模式**: OCR/LLM验证时避免帧重复使用 +- **防回退机制**: 线性流程中防止回退到已使用的帧 +- **早停优化**: 找到匹配后立即停止,提高效率 + +#### 2. **多模态验证能力** + +- **OCR集成**: 支持PaddleOCR和Tesseract双引擎 +- **LLM推理**: 结合截图和上下文的多模态验证 +- **降级策略**: OCR失败时自动降级到XML文本提取 + +#### 3. **动态条件匹配** + +- **任务感知**: 从任务描述自动提取验证条件 +- **模式映射**: 支持复杂的条件模式配置 +- **灵活验证**: 在多个字段中查找验证关键词 + +#### 4. **策略升级机制** + +- **渐进式验证**: 从简单到复杂的多层级检查 +- **智能降级**: 高级策略失败时自动降级 +- **配置驱动**: 支持自定义升级顺序 + +### 📊 **验证结果示例** + +```python +VerifyResult( + ok=True, # 验证成功 + matched=[ + NodeMatch(node_id="open_app_home", frame_index=0), + NodeMatch(node_id="activate_search", frame_index=1), + NodeMatch(node_id="input_keyword", frame_index=2), + NodeMatch(node_id="submit_search", frame_index=3), + NodeMatch(node_id="apply_filter_condition", frame_index=4), + NodeMatch(node_id="add_to_cart", frame_index=6) + ], + logs=[...], # 详细决策日志 + manual_review_needed=False # 无需人工复核 +) +``` + +### 🔧 **扩展能力** + +#### 1. **自定义检查器** + +```python +@register_condition("custom_checker") +class CustomChecker(ConditionChecker): + def check(self, frame, params, options): + # 自定义验证逻辑 + return custom_logic(frame, params) +``` + +#### 2. **配置驱动的任务定义** + +- 支持复杂的依赖关系定义 +- 灵活的成功条件配置 +- 可扩展的条件参数 + +#### 3. **多种数据源支持** + +- 目录结构的复杂轨迹 +- 自定义数据格式扩展 + +这个框架特别适合验证包含复杂筛选条件app操作任务场景,能够智能地从任务描述中提取条件并验证是否正确执行了相应操作。 diff --git a/MobiFlow/example_checker_modes.yaml b/MobiFlow/example_checker_modes.yaml new file mode 100644 index 0000000..e27f9a0 --- /dev/null +++ b/MobiFlow/example_checker_modes.yaml @@ -0,0 +1,92 @@ +task_id: example_juxtaposition_demo +app_id: com.example.app +task_type: demo +description: 演示escalate和juxtaposition检查器的使用 + +nodes: + # 示例1: 使用escalate模式 - 按顺序尝试,任意一个成功即可 + - id: search_with_escalate + name: 使用escalate模式进行搜索验证 + condition: + type: escalate + params: + action: + type: click + # 如果动作检查失败,使用OCR识别 + icons: + all: ["购物车", "搜索"] + # 如果图标检查失败,检查动作类型 + ocr: + any: ["搜索", "查询"] + # 最后使用LLM验证 + llm: + prompt: "该步是否执行了搜索操作?" + expected_true: true + + # 示例2: 使用juxtaposition模式 - 所有配置的检查器都必须成功 + - id: confirm_with_juxtaposition + deps: [search_with_escalate] + name: 使用juxtaposition模式进行严格验证 + condition: + type: juxtaposition + params: + # 必须同时满足:是点击动作 + action: + type: click + # 必须同时满足:UI状态正确 + icons: + any: ["确认", "提交"] + # 必须同时满足:OCR识别到确认按钮 + ocr: + all: ["确认", "提交"] + # 必须同时满足:LLM确认操作正确 + llm: + prompt: "该步是否点击了确认/提交按钮?" + expected_true: true + + # 示例3: 复杂的escalate配置, + - id: complex_escalate_example + deps: [confirm_with_juxtaposition] + name: escalate 备用策略 + condition: + type: escalate + params: + # 第1级:基础文本匹配 + text: + all: ["购物车", "商品"] + # 第2级:正则表达式匹配 + regex: + pattern: "(购物车|cart).*(添加|加入)" + ignore_case: true + # 第3级:动作验证 + action: + type: click + # 第4级:UI状态检查 + icons: + all: ["购物车", "已添加"] + # 第5级:OCR图像识别 + ocr: + any: ["购物车", "加入购物车", "立即购买"] + pattern: "(购物车|cart)" + # 第6级:LLM最终验证 + llm: + prompt: "基于截图和上下文,该步是否成功将商品加入购物车?" + expected_true: true + + # 示例4: 只使用高级检查器的juxtaposition + - id: advanced_juxtaposition + deps: [complex_escalate_example] + name: 高级检查器的juxtaposition组合 + condition: + type: juxtaposition + params: + # 必须满足:OCR识别到相关元素 + ocr: + all: ["支付", "确认订单"] + # 必须满足:LLM确认操作 + llm: + prompt: "该步是否完成了购买/支付操作?" + expected_true: true + +success: + any_of: [advanced_juxtaposition] diff --git a/MobiFlow/logging_config.json b/MobiFlow/logging_config.json new file mode 100644 index 0000000..49efea2 --- /dev/null +++ b/MobiFlow/logging_config.json @@ -0,0 +1,7 @@ +{ + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "./output_file/debug.log" +} diff --git a/MobiFlow/pyproject.toml b/MobiFlow/pyproject.toml new file mode 100644 index 0000000..81260ed --- /dev/null +++ b/MobiFlow/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "avdag" +version = "0.0.1" +description = "A minimal DAG-based verifier for mobile LLM agent task traces" +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "PyYAML>=6.0.1", +] +authors = [{ name = "Auto Verify" }] + +[tool.setuptools.packages.find] +where = ["."] +include = ["avdag*"] diff --git a/MobiFlow/task_configs/bilibili.json b/MobiFlow/task_configs/bilibili.json new file mode 100644 index 0000000..4d8eedc --- /dev/null +++ b/MobiFlow/task_configs/bilibili.json @@ -0,0 +1,77 @@ +{ + "task_name": "bilibili_auto", + "description": "Bilibili任务测试配置 - 使用自动发现功能", + + "rules_base_dir": "task_rules/bilibili", + "data_base_dir": "data/bilibili", + + "task_types": { + "type1": { + "name": "type1任务", + "rule_file": "bilibili-type1.yaml", + "description": "Bilibili type1功能测试 - 明确指定type1目录" + }, + "type2": { + "name": "type2任务", + "rule_file": "bilibili-type2.yaml", + "description": "Bilibili type2功能测试 - 明确指定type2目录" + }, + "type3": { + "name": "type3任务", + "rule_file": "bilibili-type3.yaml", + "description": "Bilibili type3功能测试 - 明确指定type3目录" + }, + "type4": { + "name": "type4任务", + "rule_file": "bilibili-type4.yaml", + "description": "Bilibili type4功能测试 - 明确指定type4目录" + }, + "type5": { + "name": "type5任务", + "rule_file": "bilibili-type5.yaml", + "description": "Bilibili type5功能测试 - 明确指定type5目录" + }, + "type6": { + "name": "type6任务", + "rule_file": "bilibili-type6.yaml", + "description": "Bilibili type6功能测试 - 明确指定type6目录" + }, + "type7": { + "name": "type7任务", + "rule_file": "bilibili-type7.yaml", + "description": "Bilibili type7功能测试 - 明确指定type7目录" + }, + "type8": { + "name": "type8任务", + "rule_file": "bilibili-type8.yaml", + "description": "Bilibili type8功能测试 - 明确指定type8目录" + }, + "type9": { + "name": "type9任务", + "rule_file": "bilibili-type9.yaml", + "description": "Bilibili type9功能测试 - 明确指定type9目录" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.log", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/cloudmusic.json b/MobiFlow/task_configs/cloudmusic.json new file mode 100644 index 0000000..5856e87 --- /dev/null +++ b/MobiFlow/task_configs/cloudmusic.json @@ -0,0 +1,43 @@ +{ + "task_name": "cloudmusic", + "description": "网易云音乐任务测试配置", + + "rules_base_dir": "task_rules/cloudmusic", + "data_base_dir": "data/cloudmusic", + + "task_types": { + "type1": { + "name": "音乐播放任务", + "rule_file": "cloudmusic-type1.yaml", + "data_traces": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "description": "音乐播放相关功能测试" + }, + "type2": { + "name": "音乐搜索任务", + "rule_file": "cloudmusic-type2.yaml", + "description": "音乐搜索功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/ele.json b/MobiFlow/task_configs/ele.json new file mode 100644 index 0000000..c832ff1 --- /dev/null +++ b/MobiFlow/task_configs/ele.json @@ -0,0 +1,47 @@ +{ + "task_name": "ele", + "description": "饿了么任务测试配置", + + "rules_base_dir": "task_rules/ele", + "data_base_dir": "data/ele", + + "task_types": { + "type3": { + "name": "外卖订餐任务", + "rule_file": "ele-type3.yaml", + "description": "外卖订餐功能测试" + }, + "type4": { + "name": "商家筛选任务", + "rule_file": "ele-type4.yaml", + "description": "商家筛选和排序功能测试" + }, + "type5": { + "name": "菜品搜索任务", + "rule_file": "ele-type5.yaml", + "description": "菜品搜索功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/feizhu.json b/MobiFlow/task_configs/feizhu.json new file mode 100644 index 0000000..127c87e --- /dev/null +++ b/MobiFlow/task_configs/feizhu.json @@ -0,0 +1,73 @@ +{ + "task_name": "feizhu", + "description": "飞猪旅行任务测试配置", + + "rules_base_dir": "task_rules/feizhu", + "data_base_dir": "data/feizhu", + + "task_types": { + "type1": { + "name": "机票预订任务", + "rule_file": "feizhu-type1.yaml", + "data_traces": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 22], + "description": "机票搜索和预订功能测试" + }, + "type2": { + "name": "酒店预订任务", + "rule_file": "feizhu-type2.yaml", + "description": "酒店搜索和预订功能测试" + }, + "type3": { + "name": "火车票预订任务", + "rule_file": "feizhu-type3.yaml", + "description": "火车票搜索和预订功能测试" + }, + "type4": { + "name": "旅游产品搜索任务", + "rule_file": "feizhu-type4.yaml", + "description": "旅游产品搜索功能测试" + }, + "type5": { + "name": "筛选排序任务", + "rule_file": "feizhu-type5.yaml", + "description": "商品筛选和排序功能测试" + }, + "type6": { + "name": "用户中心任务", + "rule_file": "feizhu-type6.yaml", + "description": "用户中心相关功能测试" + }, + "type7": { + "name": "支付流程任务", + "rule_file": "feizhu-type7.yaml", + "description": "支付流程功能测试" + }, + "type8": { + "name": "综合功能任务", + "rule_file": "feizhu-type8.yaml", + "description": "综合功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/gaode.json b/MobiFlow/task_configs/gaode.json new file mode 100644 index 0000000..636699b --- /dev/null +++ b/MobiFlow/task_configs/gaode.json @@ -0,0 +1,47 @@ +{ + "task_name": "gaode", + "description": "高德地图任务测试配置", + + "rules_base_dir": "task_rules/gaode", + "data_base_dir": "data/gaode", + + "task_types": { + "type1": { + "name": "地图导航任务", + "rule_file": "gaode-type1.yaml", + "description": "地图导航功能测试" + }, + "type2": { + "name": "地点搜索任务", + "rule_file": "gaode-type2.yaml", + "description": "地点搜索功能测试" + }, + "type3": { + "name": "路线规划任务", + "rule_file": "gaode-type3.yaml", + "description": "路线规划功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_000_\345\217\221\345\274\271\345\271\225.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_000_\345\217\221\345\274\271\345\271\225.jpg" new file mode 100644 index 0000000..907c1ae Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_000_\345\217\221\345\274\271\345\271\225.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_000_\351\246\226\351\241\265.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_000_\351\246\226\351\241\265.jpg" new file mode 100644 index 0000000..f890bd7 Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_000_\351\246\226\351\241\265.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_001_\346\212\225\345\270\201.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_001_\346\212\225\345\270\201.jpg" new file mode 100644 index 0000000..b2e9ea0 Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_001_\346\212\225\345\270\201.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_001_\346\220\234\347\264\242.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_001_\346\220\234\347\264\242.jpg" new file mode 100644 index 0000000..5e6a50a Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_001_\346\220\234\347\264\242.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_001_\350\241\250\346\203\205.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_001_\350\241\250\346\203\205.jpg" new file mode 100644 index 0000000..c605b14 Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_001_\350\241\250\346\203\205.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_002_AI\346\220\234\347\264\242.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_002_AI\346\220\234\347\264\242.jpg" new file mode 100644 index 0000000..b7d0f9e Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_002_AI\346\220\234\347\264\242.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_002_\346\224\266\350\227\217.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_002_\346\224\266\350\227\217.jpg" new file mode 100644 index 0000000..9e21dae Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_002_\346\224\266\350\227\217.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_003_\345\217\221\345\270\203.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_003_\345\217\221\345\270\203.jpg" new file mode 100644 index 0000000..2aceb84 Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_003_\345\217\221\345\270\203.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_003_\346\270\205\351\231\244.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_003_\346\270\205\351\231\244.jpg" new file mode 100644 index 0000000..6aa480f Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_003_\346\270\205\351\231\244.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_003_\347\255\233\351\200\211.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_003_\347\255\233\351\200\211.jpg" new file mode 100644 index 0000000..c28518f Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_003_\347\255\233\351\200\211.jpg" differ diff --git "a/MobiFlow/task_configs/icons/bilibili/icon_004_\350\277\224\345\233\236.jpg" "b/MobiFlow/task_configs/icons/bilibili/icon_004_\350\277\224\345\233\236.jpg" new file mode 100644 index 0000000..a693767 Binary files /dev/null and "b/MobiFlow/task_configs/icons/bilibili/icon_004_\350\277\224\345\233\236.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_000_\346\210\221.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_000_\346\210\221.jpg" new file mode 100644 index 0000000..c46d938 Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_000_\346\210\221.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_000_\350\241\250\346\203\205.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_000_\350\241\250\346\203\205.jpg" new file mode 100644 index 0000000..285788f Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_000_\350\241\250\346\203\205.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_001_\345\233\236\350\275\246.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_001_\345\233\236\350\275\246.jpg" new file mode 100644 index 0000000..99f23c0 Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_001_\345\233\236\350\275\246.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_001_\346\267\273\345\212\240.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_001_\346\267\273\345\212\240.jpg" new file mode 100644 index 0000000..b0d3cf2 Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_001_\346\267\273\345\212\240.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_001_\351\200\232\350\256\257\345\275\225.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_001_\351\200\232\350\256\257\345\275\225.jpg" new file mode 100644 index 0000000..ebc0d6e Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_001_\351\200\232\350\256\257\345\275\225.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_002_\345\217\221\351\200\201.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_002_\345\217\221\351\200\201.jpg" new file mode 100644 index 0000000..3b29919 Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_002_\345\217\221\351\200\201.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_002_\345\276\256\344\277\241.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_002_\345\276\256\344\277\241.jpg" new file mode 100644 index 0000000..332828b Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_002_\345\276\256\344\277\241.jpg" differ diff --git "a/MobiFlow/task_configs/icons/weixin/icon_002_\350\257\255\351\237\263.jpg" "b/MobiFlow/task_configs/icons/weixin/icon_002_\350\257\255\351\237\263.jpg" new file mode 100644 index 0000000..ced9eb4 Binary files /dev/null and "b/MobiFlow/task_configs/icons/weixin/icon_002_\350\257\255\351\237\263.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_000_\346\267\273\345\212\240.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\346\267\273\345\212\240.jpg" new file mode 100644 index 0000000..a453997 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\346\267\273\345\212\240.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_000_\350\241\214\347\250\213.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\350\241\214\347\250\213.jpg" new file mode 100644 index 0000000..2d68917 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\350\241\214\347\250\213.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_000_\351\253\230\347\272\247\347\255\233\351\200\211.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\351\253\230\347\272\247\347\255\233\351\200\211.jpg" new file mode 100644 index 0000000..c23755f Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_000_\351\253\230\347\272\247\347\255\233\351\200\211.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_001_\346\234\272\347\245\250.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_001_\346\234\272\347\245\250.jpg" new file mode 100644 index 0000000..4352657 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_001_\346\234\272\347\245\250.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_001_\350\241\214\347\250\213.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_001_\350\241\214\347\250\213.jpg" new file mode 100644 index 0000000..5adef45 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_001_\350\241\214\347\250\213.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_002_\345\210\206\344\272\253.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\345\210\206\344\272\253.jpg" new file mode 100644 index 0000000..3b9d91a Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\345\210\206\344\272\253.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\212\242\347\245\250.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\212\242\347\245\250.jpg" new file mode 100644 index 0000000..e1c7040 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\212\242\347\245\250.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\234\272\347\245\250.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\234\272\347\245\250.jpg" new file mode 100644 index 0000000..1e881a0 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_002_\346\234\272\347\245\250.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_003_\345\207\272\345\217\221\346\234\200\346\227\251.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\345\207\272\345\217\221\346\234\200\346\227\251.jpg" new file mode 100644 index 0000000..b79fe52 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\345\207\272\345\217\221\346\234\200\346\227\251.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\210\221\347\232\204.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\210\221\347\232\204.jpg" new file mode 100644 index 0000000..22acf42 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\210\221\347\232\204.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\227\245\346\234\237.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\227\245\346\234\237.jpg" new file mode 100644 index 0000000..681f526 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\346\227\245\346\234\237.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_003_\350\200\227\346\227\266\346\234\200\347\237\255.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\350\200\227\346\227\266\346\234\200\347\237\255.jpg" new file mode 100644 index 0000000..e21330a Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_003_\350\200\227\346\227\266\346\234\200\347\237\255.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\270\252\344\272\272\344\270\255\345\277\203.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\270\252\344\272\272\344\270\255\345\277\203.jpg" new file mode 100644 index 0000000..efea8ba Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\270\252\344\272\272\344\270\255\345\277\203.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\273\267\346\240\274\346\234\200\344\275\216.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\273\267\346\240\274\346\234\200\344\275\216.jpg" new file mode 100644 index 0000000..a7f93fd Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\344\273\267\346\240\274\346\234\200\344\275\216.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_004_\346\212\242\347\245\250.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\346\212\242\347\245\250.jpg" new file mode 100644 index 0000000..fdeceda Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\346\212\242\347\245\250.jpg" differ diff --git "a/MobiFlow/task_configs/icons/xiecheng/icon_004_\351\246\226\351\241\265.jpg" "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\351\246\226\351\241\265.jpg" new file mode 100644 index 0000000..6f0c8e2 Binary files /dev/null and "b/MobiFlow/task_configs/icons/xiecheng/icon_004_\351\246\226\351\241\265.jpg" differ diff --git "a/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\346\234\272\347\245\250.png" "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\346\234\272\347\245\250.png" new file mode 100644 index 0000000..c47666d Binary files /dev/null and "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\346\234\272\347\245\250.png" differ diff --git "a/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\347\201\253\350\275\246\347\245\250.png" "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\347\201\253\350\275\246\347\245\250.png" new file mode 100644 index 0000000..f0a3717 Binary files /dev/null and "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\347\201\253\350\275\246\347\245\250.png" differ diff --git "a/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\205\222\345\272\227.png" "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\205\222\345\272\227.png" new file mode 100644 index 0000000..b405d46 Binary files /dev/null and "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\205\222\345\272\227.png" differ diff --git "a/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\246\226\351\241\265.png" "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\246\226\351\241\265.png" new file mode 100644 index 0000000..5d04523 Binary files /dev/null and "b/MobiFlow/task_configs/icons/\346\220\272\347\250\213/\351\246\226\351\241\265.png" differ diff --git a/MobiFlow/task_configs/taobao.json b/MobiFlow/task_configs/taobao.json new file mode 100644 index 0000000..8ff00b1 --- /dev/null +++ b/MobiFlow/task_configs/taobao.json @@ -0,0 +1,42 @@ +{ + "task_name": "taobao", + "description": "淘宝任务测试配置", + + "rules_base_dir": "task_rules/taobao", + "data_base_dir": "data/taobao", + + "task_types": { + "type3": { + "name": "加购物车任务", + "rule_file": "type3-taobao_add_cart-new.yaml", + "description": "添加商品到购物车" + }, + "type4": { + "name": "排序/筛选任务", + "rule_file": "type4-taobao_add_cart.yaml", + "description": "商品排序和筛选功能" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/task_config_template.json b/MobiFlow/task_configs/task_config_template.json new file mode 100644 index 0000000..36268e7 --- /dev/null +++ b/MobiFlow/task_configs/task_config_template.json @@ -0,0 +1,61 @@ +{ + "task_name": "示例任务配置", + "description": "任务配置模板,用于定义测试参数", + + "rules_base_dir": "task_rules/taobao", + "data_base_dir": "data/taobao", + + "task_types": { + "type1": { + "name": "搜索任务", + "rule_file": "type1-taobao-search.yaml", + "data_traces": [1, 2, 3, 4, 5], + "description": "基本搜索功能测试 - 明确指定trace编号列表" + }, + "type2": { + "name": "搜索+详情任务", + "rule_file": "type2-taobao-search-open-add_cart.yaml", + "data_traces": "type2", + "description": "搜索并查看商品详情 - 指定type目录" + }, + "type3": { + "name": "加购物车任务", + "rule_file": "type3-taobao_add_cart-new.yaml", + "data_traces": [150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196], + "description": "添加商品到购物车 - 明确指定trace编号列表" + }, + "type4": { + "name": "排序/筛选任务", + "rule_file": "type4-taobao_add_cart.yaml", + "data_traces": [], + "description": "商品排序和筛选功能 - 空配置,自动发现traces" + }, + "type5": { + "name": "自动发现任务", + "rule_file": "type5-auto-discover.yaml", + "description": "自动发现功能演示 - 不配置data_traces字段,完全自动发现" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/weixin.json b/MobiFlow/task_configs/weixin.json new file mode 100644 index 0000000..051c2e3 --- /dev/null +++ b/MobiFlow/task_configs/weixin.json @@ -0,0 +1,42 @@ +{ + "task_name": "weixin", + "description": "微信任务测试配置", + + "rules_base_dir": "task_rules/weixin", + "data_base_dir": "data/weixin", + + "task_types": { + "type1": { + "name": "type1任务", + "rule_file": "weixin-type1.yaml", + "description": "微信type1功能测试 - 明确指定type1目录" + }, + "type2": { + "name": "type2任务", + "rule_file": "weixin-type2.yaml", + "description": "微信type2功能测试 - 明确指定type2目录" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/xiaohongshu.json b/MobiFlow/task_configs/xiaohongshu.json new file mode 100644 index 0000000..cf57900 --- /dev/null +++ b/MobiFlow/task_configs/xiaohongshu.json @@ -0,0 +1,62 @@ +{ + "task_name": "xiaohongshu", + "description": "小红书任务测试配置", + + "rules_base_dir": "task_rules/xiaohongshu", + "data_base_dir": "data/xiaohongshu", + + "task_types": { + "2": { + "name": "type2任务", + "rule_file": "xiaohongshu-type2.yaml", + "data_traces": "type2", + "description": "小红书type2功能测试" + }, + "3": { + "name": "type3任务", + "rule_file": "xiaohongshu-type3.yaml", + "data_traces": "type3", + "description": "小红书type3功能测试" + }, + "4": { + "name": "type4任务", + "rule_file": "xiaohongshu-type4.yaml", + "data_traces": "type4", + "description": "小红书type4功能测试" + }, + "5": { + "name": "type5任务", + "rule_file": "xiaohongshu-type5.yaml", + "data_traces": "type5", + "description": "小红书type5功能测试" + }, + "6": { + "name": "type6任务", + "rule_file": "xiaohongshu-type6.yaml", + "data_traces": "type6", + "description": "小红书type6功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/xiaohongshu_auto.json b/MobiFlow/task_configs/xiaohongshu_auto.json new file mode 100644 index 0000000..d3293d7 --- /dev/null +++ b/MobiFlow/task_configs/xiaohongshu_auto.json @@ -0,0 +1,47 @@ +{ + "task_name": "xiaohongshu_auto", + "description": "小红书任务测试配置 - 使用自动发现功能", + + "rules_base_dir": "task_rules/xiaohongshu", + "data_base_dir": "data/xiaohongshu", + + "task_types": { + "type4": { + "name": "type4任务", + "rule_file": "xiaohongshu-type4.yaml", + "description": "小红书type4功能测试 - 明确指定type4目录" + }, + "type5": { + "name": "type5任务", + "rule_file": "xiaohongshu-type5.yaml", + "description": "小红书type5功能测试 - 明确指定type5目录" + }, + "type6": { + "name": "type6任务", + "rule_file": "xiaohongshu-type6.yaml", + "description": "小红书type6功能测试 - 明确指定type6目录" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.log", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_configs/xiechen.json b/MobiFlow/task_configs/xiechen.json new file mode 100644 index 0000000..f25f128 --- /dev/null +++ b/MobiFlow/task_configs/xiechen.json @@ -0,0 +1,57 @@ +{ + "task_name": "xiechen", + "description": "携程任务测试配置", + + "rules_base_dir": "task_rules/xiechen", + "data_base_dir": "data/xiechen", + + "task_types": { + "type1": { + "name": "机票预订任务", + "rule_file": "xiechen-type1.yaml", + "description": "机票搜索和预订功能测试" + }, + "type2": { + "name": "酒店预订任务", + "rule_file": "xiechen-type2.yaml", + "description": "酒店搜索和预订功能测试" + }, + "type3": { + "name": "火车票预订任务", + "rule_file": "xiechen-type3.yaml", + "description": "火车票搜索和预订功能测试" + }, + "type4": { + "name": "旅游产品任务", + "rule_file": "xiechen-type4.yaml", + "description": "旅游产品搜索和预订功能测试" + }, + "type5": { + "name": "综合服务任务", + "rule_file": "xiechen-type5.yaml", + "description": "综合旅行服务功能测试" + } + }, + + "test_options": { + "enable_ocr": true, + "enable_llm": true, + "force_llm": false, + "ocr_frame_exclusive": true, + "llm_frame_exclusive": true, + "prevent_frame_backtrack": true + }, + + "logging": { + "level": "DEBUG", + "use_colors": true, + "show_time": true, + "show_module": true, + "output_file": "test-{task_name}-{timestamp}.log" + }, + + "output": { + "summary_file": "test-{task_name}-summary-{timestamp}.txt", + "detailed_results_file": "test-{task_name}-detailed-{timestamp}.json" + } +} diff --git a/MobiFlow/task_rules/bilibili/bilibili-type1.yaml b/MobiFlow/task_rules/bilibili/bilibili-type1.yaml new file mode 100644 index 0000000..9e80fba --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type1.yaml @@ -0,0 +1,44 @@ +task_id: bilibili_search_content +app_id: tv.danmaku.bili +task_type: search +description: 在B站(bilibili)应用中,根据用户指令搜索指定内容并验证搜索结果页面的展示。 +nodes: +# - id: launch_bilibili +# name: 启动B站应用 +# condition: +# type: escalate +# params: +# ocr: +# all: ["直播", "推荐", "热门", "首页", "动态", "我的", "频道"] +# llm: +# prompt: 请判断当前界面是否为B站(bilibili)的应用首页或主界面? +# expected_true: true +# next: [enter_search_page] + +- id: enter_search_page + name: 进入搜索页面 + score: 40 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面,或者用户是否已点击了首页的搜索图标? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 60 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的视频、用户或内容列表? + expected_true: true + +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type2.yaml b/MobiFlow/task_rules/bilibili/bilibili-type2.yaml new file mode 100644 index 0000000..1f0b548 --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type2.yaml @@ -0,0 +1,88 @@ +task_id: bilibili_play_video +app_id: tv.danmaku.bili +task_type: video_playback +description: 在B站应用中,通过搜索或直接点击,找到并播放用户指定的视频。 +nodes: +# - id: launch_bilibili +# name: 启动B站应用 +# condition: +# type: escalate +# params: +# ocr: +# all: ["首页", "推荐", "动态", "会员购", "我的", "bilibili"] +# llm: +# prompt: 屏幕截图是否显示B站(bilibili)应用已成功打开并处于其主界面(如推荐、首页等)? +# expected_true: true +# next: [enter_search_page] + +- id: enter_search_page + name: 进入搜索页面 + score: 10 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面? + expected_true: true + next: [execute_search] + +- id: execute_search + name: 输入关键词并执行搜索 + score: 10 + condition: + type: escalate + params: + # ocr: + # all: ["搜索", "搜索历史"] + llm: + prompt: 请判断用户是否在搜索框中输入了与任务描述`{task_description}`相关的关键词并执行了搜索操作? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 10 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的视频、用户或内容列表? + expected_true: true + next: [click_video_from_feed_1, click_video_from_feed_2] + +- id: click_video_from_feed_1 + name: 从推荐流点击视频 + score: 20 + condition: + type: escalate + params: + ocr: + all: ["简介", "评论", "弹幕", "关注"] + +- id: click_video_from_feed_2 + name: 从推荐流点击视频 + score: 20 + condition: + type: escalate + llm: + prompt: 屏幕截图是否显示用户直接从推荐、首页或动态信息流中点击了一个视频进行播放? + expected_true: true + + # next: [video_is_playing] + +# - id: video_is_playing +# name: 视频正在播放 +# condition: +# type: juxtaposition +# params: +# ocr: +# all: ["弹幕", "点赞", "投币", "收藏", "分享", "全屏", "暂停", "倍速", "发弹幕", "弹幕列表"] +# llm: +# prompt: 屏幕截图是否明确显示一个视频正在播放(能看到播放器控件、进度条、弹幕等),并且视频内容与任务描述(task_description)高度相关? +# expected_true: true +success: + any_of: [click_video_from_feed_1, click_video_from_feed_2] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type3.yaml b/MobiFlow/task_rules/bilibili/bilibili-type3.yaml new file mode 100644 index 0000000..75a889d --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type3.yaml @@ -0,0 +1,33 @@ +task_id: bilibili_search_up +app_id: tv.danmaku.bili +task_type: search +description: 在B站应用中根据指令搜索指定UP主 +nodes: + +- id: enter_search_page + name: 进入搜索页面 + score: 40 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面,或者用户是否已点击了首页的搜索图标? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 60 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户", "粉丝" ,"关注"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的用户UP主? + expected_true: true + +success: + any_of: [view_search_results] \ No newline at end of file diff --git a/MobiFlow/task_rules/bilibili/bilibili-type4.yaml b/MobiFlow/task_rules/bilibili/bilibili-type4.yaml new file mode 100644 index 0000000..f512122 --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type4.yaml @@ -0,0 +1,46 @@ +task_id: bilibili_view_creator_homepage +app_id: tv.danmaku.bili +task_type: search_and_view +description: 在B站应用中,通过搜索或其他方式,进入指定UP主的个人主页。 +nodes: +- id: enter_search_page + name: 进入搜索页面 + score: 30 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面,或者用户是否已点击了首页的搜索图标? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 30 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户", "关注"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的用户UP主? + expected_true: true + next: [enter_creator_homepage] + +- id: enter_creator_homepage + name: 进入UP主个人主页 + score: 40 + condition: + type: juxtaposition + params: + ocr: + all: ["主页", "动态", "投稿"] + any: ["关注", "粉丝"] + llm: + prompt: 当前界面是否为任务要求进入的目标UP主的个人主页?页面上应清晰展示该UP主的头像、昵称、粉丝数,并包含'主页'、'动态'、'投稿'等分区。 + expected_true: true + +success: + any_of: [enter_creator_homepage] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type5.yaml b/MobiFlow/task_rules/bilibili/bilibili-type5.yaml new file mode 100644 index 0000000..f21cc90 --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type5.yaml @@ -0,0 +1,48 @@ +task_id: bilibili_search_video_by_creator +app_id: tv.danmaku.bili +task_type: search +description: 在哔哩哔哩(Bilibili)应用中,根据指定的UP主和视频名称搜索并找到目标视频。 +nodes: +- id: enter_search_page + name: 进入搜索页面 + score: 30 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面,或者用户是否已点击了首页的搜索图标? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 30 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的视频、用户或内容列表? + expected_true: true + next: [search_video_by_creator] + + +# 可以考虑增加UP主个人主页的检查 + +- id: search_video_by_creator + name: 搜索指定UP主的视频 + score: 40 + condition: + type: escalate + params: + ocr: + all: ["搜索", "的视频", "动态"] + llm: + prompt: 请判断用户是否在搜索框中输入了与任务描述`{task_description}`相关的视频标题,并执行了搜索操作? + expected_true: true + +success: + any_of: [search_video_by_creator] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type6.yaml b/MobiFlow/task_rules/bilibili/bilibili-type6.yaml new file mode 100644 index 0000000..2d9383f --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type6.yaml @@ -0,0 +1,62 @@ +task_id: bilibili_search_and_play_video +app_id: tv.danmaku.bili +task_type: video_playback +description: 该任务验证在Bilibili应用中,根据指定的UP主和视频/系列名称搜索并成功播放视频的流程。 +nodes: +- id: enter_search_page + name: 进入搜索页面 + score: 20 + condition: + type: escalate + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断当前界面是否为B站的搜索输入界面,或者用户是否已点击了首页的搜索图标? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + score: 20 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "番剧", "直播", "用户"] + llm: + prompt: 请判断当前界面是否成功展示了与任务描述`{task_description}`高度相关的视频、用户或内容列表? + expected_true: true + next: [search_video_by_creator] + + +# 可以考虑增加UP主个人主页的检查 + +- id: search_video_by_creator + name: 搜索指定UP主的视频 + score: 20 + condition: + type: escalate + params: + ocr: + all: ["搜索", "的视频", "动态"] + llm: + prompt: 请判断用户是否在搜索框中输入了与任务描述`{task_description}`相关的视频标题,并执行了搜索操作? + expected_true: true + next: [play_video] + +# 如果要测试最终状态,可以直接注释前面的路径节点 +- id: play_video + name: 播放目标视频 + score: 20 + condition: + type: juxtaposition + params: + ocr: + all: ["简介", "评论", "弹幕", "关注"] + llm: + prompt: 请判断当前界面是否正在播放视频?要求与任务描述`{task_description}`相关的视频。 + expected_true: true + +success: + any_of: [play_video] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type7.yaml b/MobiFlow/task_rules/bilibili/bilibili-type7.yaml new file mode 100644 index 0000000..12be35f --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type7.yaml @@ -0,0 +1,43 @@ +task_id: bilibili_follow_creator +app_id: tv.danmaku.bili +task_type: follow_action +description: 在B站(bilibili)应用中,通过搜索找到并成功关注一个指定的用户(UP主或官方账号)。 +nodes: +# - id: launch_app +# name: 启动B站应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["推荐", "首页", "动态", "我的", "热门", "会员购"] +# llm: +# prompt: 当前界面是否为B站应用的主界面或首页? +# expected_true: true +# next: [search_creator] + +- id: execute_search + name: 输入关键词并执行搜索 + score: 10 + condition: + type: juxtaposition + params: + ocr: + all: ["搜索", "搜索历史", "热搜"] + llm: + prompt: 请判断用户是否使用了与任务描述`{task_description}`相关的关键词搜索? + expected_true: true + next: [confirm_follow] + +- id: confirm_follow + name: 确认关注成功 + condition: + type: juxtaposition + params: + ocr: + any: ["已关注", "关注"] + all: ["粉丝"] + llm: + prompt: 界面上是否(如按钮文字变为'已关注')证明已经成功关注了任务描述中的目标创作者? + expected_true: true +success: + any_of: [confirm_follow] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type8.yaml b/MobiFlow/task_rules/bilibili/bilibili-type8.yaml new file mode 100644 index 0000000..bd47957 --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type8.yaml @@ -0,0 +1,90 @@ +task_id: bilibili_video_comment +app_id: tv.danmaku.bili +task_type: social_interaction +description: 在Bilibili应用中搜索指定视频并发表评论 +nodes: +# - id: launch_app +# name: 启动应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["推荐", "首页", "动态", "我的", "大会员"] +# llm: +# prompt: 当前界面是否为Bilibili应用的主界面? +# expected_true: true +# next: [enter_search] +# - id: enter_search +# name: 进入搜索界面 +# condition: +# type: escalate +# params: +# ocr: +# any: ["搜索", "🔍", "排行榜", "游戏中心"] +# llm: +# prompt: 用户是否已经点击搜索图标,进入了搜索输入界面? +# expected_true: true +# next: [input_keyword_and_search] +# - id: input_keyword_and_search +# name: 输入关键词并搜索 +# condition: +# type: escalate +# params: +# ocr: +# any: ["搜索历史", "搜索发现", "综合", "番剧", "视频", "用户"] +# llm: +# prompt: 用户是否根据任务要求输入了视频关键词并执行了搜索,界面上是否展示了搜索结果列表? +# expected_true: true +# next: [select_video_from_results] + +# 前面可以增加更多的节点,限制路径 + +- id: select_video_detail + name: 进入视频播放详情页 + condition: + type: juxtaposition + params: + ocr: + all: ["简介", "评论", "点我发弹幕"] + llm: + prompt: 当前是否成功进入了目标视频的播放页面? + expected_true: true + next: [enter_comment_section] + +- id: enter_comment_section + name: 进入评论区或点击评论框 + condition: + type: escalate + params: + ocr: + all: ["简介","评论","热门评论"] + llm: + prompt: 是否已经进入任务描述`{task_description}`的视频评论页面,并在页面内看到评论框? + expected_true: true + next: [input_and_publish_comment] + +- id: input_and_publish_comment + name: 输入并发表评论 + condition: + type: juxtaposition + params: + ocr: + all: ["发布", "转到动态"] + llm: + prompt: 用户是否在输入框中输入了与任务描述一致的评论内容,等待发送? + expected_true: true + # next: [verify_comment_published] + +# - id: verify_comment_published +# name: 验证评论发布成功 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["发布成功", "评论发送成功", "审核", "我的评论", "删除", "撤回"] +# llm: +# prompt: 界面上是否出现了评论发布成功的提示,或者在评论区能看到用户刚刚发布的、与任务描述内容一致的评论? +# expected_true: true + +success: + any_of: [input_and_publish_comment] diff --git a/MobiFlow/task_rules/bilibili/bilibili-type9.yaml b/MobiFlow/task_rules/bilibili/bilibili-type9.yaml new file mode 100644 index 0000000..853dc5b --- /dev/null +++ b/MobiFlow/task_rules/bilibili/bilibili-type9.yaml @@ -0,0 +1,42 @@ +task_id: bilibili_video_comment +app_id: tv.danmaku.bili +task_type: content_interaction +description: 在Bilibili应用中,搜索指定UP主或视频,并成功发表一条指定内容的评论。 +nodes: +- id: select_video_detail + name: 进入视频播放详情页 + condition: + type: juxtaposition + params: + ocr: + all: ["简介", "评论", "点我发弹幕"] + llm: + prompt: 当前是否成功进入了指定UP主对应视频的播放页面? + expected_true: true + next: [enter_comment_section] + +- id: enter_comment_section + name: 进入评论区或点击评论框 + condition: + type: escalate + params: + ocr: + all: ["简介","评论","热门评论"] + llm: + prompt: 是否已经进入任务描述`{task_description}`的视频评论页面,并在页面内看到评论框? + expected_true: true + next: [input_and_publish_comment] + +- id: input_and_publish_comment + name: 输入并发表评论 + condition: + type: juxtaposition + params: + ocr: + all: ["发布", "转到动态"] + llm: + prompt: 用户是否在输入框中输入了与任务描述一致的评论内容,等待发送? + expected_true: true + +success: + any_of: [input_and_publish_comment] \ No newline at end of file diff --git a/MobiFlow/task_rules/cloudmusic/cloudmusic-type1.yaml b/MobiFlow/task_rules/cloudmusic/cloudmusic-type1.yaml new file mode 100644 index 0000000..41e98a6 --- /dev/null +++ b/MobiFlow/task_rules/cloudmusic/cloudmusic-type1.yaml @@ -0,0 +1,30 @@ +task_id: general_search_validation +app_id: any +task_type: search +description: 通用搜索任务验证配置,适用于在应用内或浏览器中执行搜索操作并验证搜索结果是否正确展示。 +nodes: +- id: initiate_search + name: 进入搜索或输入关键词 + condition: + type: escalate + params: + ocr: + all: ["歌手", "搜索历史", "猜你喜欢", "热搜"] + llm: + prompt: 该步是否进入了搜索界面或激活了搜索框以准备输入关键词? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + condition: + type: juxtaposition + params: + ocr: + all: ["综合", "单曲","歌单","专辑"] + llm: + prompt: 该步是否成功展示了与任务描述中实体(如歌曲、歌手、短语等)相关的搜索结果列表? + expected_true: true + +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/cloudmusic/cloudmusic-type2.yaml b/MobiFlow/task_rules/cloudmusic/cloudmusic-type2.yaml new file mode 100644 index 0000000..cf77e43 --- /dev/null +++ b/MobiFlow/task_rules/cloudmusic/cloudmusic-type2.yaml @@ -0,0 +1,30 @@ +task_id: play_music +app_id: com.music.app +task_type: music +description: 验证在音乐应用中搜索并播放指定歌曲或歌手音乐的任务流程。 +nodes: +- id: view_search_results + name: 查看搜索结果 + condition: + type: escalate + params: + ocr: + all: ["综合", "单曲", "歌单","专辑"] + llm: + prompt: 该步是否成功展示了与任务描述中实体(如歌曲、歌手、短语等)相关的搜索结果列表? + expected_true: true + next: [playback_interface] + + +- id: playback_interface + name: 进入播放界面 + condition: + type: escalate + params: + # ocr: + # any: ["播放中", "暂停", "下一首", "歌词", "评论", "分享", "'||'", "▶", "'0:'", "'1:'", "'2:'", "'3:'", "'4:'", "'5:'"] + llm: + prompt: 该步是否成功进入了音乐播放界面,并且界面显示正在播放或已加载任务所指定的音乐? + expected_true: true +success: + any_of: [playback_interface] diff --git a/MobiFlow/task_rules/ele/ele-type3.yaml b/MobiFlow/task_rules/ele/ele-type3.yaml new file mode 100644 index 0000000..dc54dc4 --- /dev/null +++ b/MobiFlow/task_rules/ele/ele-type3.yaml @@ -0,0 +1,83 @@ +task_id: eleme_order_food +app_id: me.ele +task_type: shopping +description: 在饿了么应用中根据用户指令完成外卖下单任务,包括搜索店铺、选择商品。 +nodes: +- id: enter_store_page + name: 进入店铺页面 + condition: + type: escalate + params: + ocr: + all: ["综合排序","销量优先","速度优先"] + llm: + prompt: 根据任务描述`{{task_description}}`,该步骤是否展示了目标商品或商家店铺结果列表? + expected_true: true + next: [item_detail] + + +- id: item_detail + name: 商品详情 + deps: [enter_store_page] + condition: + type: juxtaposition + params: + ocr: + all: ["详情","评价","商品详情","商品描述"] + llm: + prompt: 当前节点是否正确选择了任务描述`{{task_description}}`中指定商品,进入商品详情页? + expected_true: true + next: [add_to_cart] + +# - id: select_item_and_specs +# name: 选择商品及规格 +# deps: [enter_store_page] +# condition: +# type: juxtaposition +# params: +# ocr: +# all: ["选好了"] +# llm: +# prompt: 该步骤是否正确选择了任务描述`{{task_description}}`中指定的商品和规格(如辣度、甜度、温度、加料等)? +# expected_true: true +# next: [add_to_cart] + +- id: add_to_cart + name: 加入购物车 + deps: [select_item_and_specs] + condition: + type: escalate + params: + ocr: + all: ["去结算"] + llm: + prompt: 该步骤是否成功将指定商品加入了购物车? + expected_true: true + # next: [go_to_checkout] + +# - id: go_to_checkout +# name: 进入结算页面 +# condition: +# type: juxtaposition +# params: +# ocr: +# all: ["确认订单", "地址"] +# any: ["提交订单", "在线支付", "餐具", "备注", "合计", "实付"] +# llm: +# prompt: 该步骤是否已进入订单确认或结算页面,且页面中包含了任务描述`{{task_description}}`中指定的商品信息? +# expected_true: true +# next: [submit_order] + +# - id: submit_order +# name: 提交订单并支付 +# deps: [go_to_checkout] +# condition: +# type: escalate +# params: +# ocr: +# any: ["支付成功", "下单成功", "订单已送出", "等待商家接单", "查看订单", "订单详情"] +# llm: +# prompt: 该步骤是否成功提交了订单并完成了支付,或者进入了订单成功的确认页面? +# expected_true: true +success: + any_of: [add_to_cart] diff --git a/MobiFlow/task_rules/ele/ele-type4.yaml b/MobiFlow/task_rules/ele/ele-type4.yaml new file mode 100644 index 0000000..a3bca52 --- /dev/null +++ b/MobiFlow/task_rules/ele/ele-type4.yaml @@ -0,0 +1,83 @@ +task_id: eleme_order_food +app_id: me.ele +task_type: shopping +description: 在饿了么应用中根据用户指令完成外卖下单任务,包括搜索店铺、选择商品、自定义规格。 +nodes: +- id: enter_store_page + name: 进入店铺页面 + condition: + type: escalate + params: + ocr: + all: ["综合排序","销量优先","速度优先"] + llm: + prompt: 根据任务描述`{{task_description}}`,该步骤是否展示了目标商品或商家店铺结果列表? + expected_true: true + next: [item_detail] + + +- id: item_detail + name: 商品详情 + deps: [enter_store_page] + condition: + type: juxtaposition + params: + ocr: + all: ["详情","评价","商品详情","商品描述"] + llm: + prompt: 当前节点是否正确选择了任务描述`{{task_description}}`中指定商品,进入商品详情页? + expected_true: true + next: [select_item_and_specs] + +- id: select_item_and_specs + name: 选择商品及规格 + deps: [enter_store_page] + condition: + type: juxtaposition + params: + ocr: + all: ["选好了"] + llm: + prompt: 该步骤是否正确选择了任务描述`{{task_description}}`中指定的商品和规格(如辣度、甜度、温度、加料等)? + expected_true: true + next: [add_to_cart] + +- id: add_to_cart + name: 加入购物车 + deps: [select_item_and_specs] + condition: + type: escalate + params: + ocr: + all: ["去结算"] + llm: + prompt: 该步骤是否成功将指定商品加入了购物车? + expected_true: true + # next: [go_to_checkout] + +# - id: go_to_checkout +# name: 进入结算页面 +# condition: +# type: juxtaposition +# params: +# ocr: +# all: ["确认订单", "地址"] +# any: ["提交订单", "在线支付", "餐具", "备注", "合计", "实付"] +# llm: +# prompt: 该步骤是否已进入订单确认或结算页面,且页面中包含了任务描述`{{task_description}}`中指定的商品信息? +# expected_true: true +# next: [submit_order] + +# - id: submit_order +# name: 提交订单并支付 +# deps: [go_to_checkout] +# condition: +# type: escalate +# params: +# ocr: +# any: ["支付成功", "下单成功", "订单已送出", "等待商家接单", "查看订单", "订单详情"] +# llm: +# prompt: 该步骤是否成功提交了订单并完成了支付,或者进入了订单成功的确认页面? +# expected_true: true +success: + any_of: [add_to_cart] diff --git a/MobiFlow/task_rules/ele/ele-type5.yaml b/MobiFlow/task_rules/ele/ele-type5.yaml new file mode 100644 index 0000000..49c4aa2 --- /dev/null +++ b/MobiFlow/task_rules/ele/ele-type5.yaml @@ -0,0 +1,83 @@ +task_id: waimai_order_food +app_id: me.ele +task_type: shopping +description: 在饿了么应用中根据用户指令完成外卖下单任务,包括搜索、选择商品、设置规格、设置数量等。 +nodes: +- id: enter_store_page + name: 进入店铺页面 + condition: + type: escalate + params: + ocr: + all: ["综合排序","销量优先","速度优先"] + llm: + prompt: 根据任务描述`{{task_description}}`,该步骤是否展示了目标商品或商家店铺结果列表? + expected_true: true + next: [item_detail] + + +- id: item_detail + name: 商品详情 + deps: [enter_store_page] + condition: + type: juxtaposition + params: + ocr: + all: ["详情","评价","商品详情","商品描述"] + llm: + prompt: 当前节点是否正确选择了任务描述`{{task_description}}`中指定商品,进入商品详情页? + expected_true: true + next: [select_item_and_specs] + +- id: select_item_and_specs + name: 选择商品及规格 + deps: [enter_store_page] + condition: + type: juxtaposition + params: + ocr: + all: ["选好了"] + llm: + prompt: 该步骤是否正确选择了任务描述`{{task_description}}`中指定的商品和规格(如辣度、甜度、温度、加料等)? + expected_true: true + next: [add_to_cart] + +- id: add_to_cart + name: 加入购物车 + deps: [select_item_and_specs] + condition: + type: juxtaposition + params: + ocr: + all: ["去结算"] + llm: + prompt: 该步骤是否成功将指定数量的商品加入了购物车? + expected_true: true + # next: [go_to_checkout] + +# - id: go_to_checkout +# name: 进入结算页面 +# condition: +# type: juxtaposition +# params: +# ocr: +# all: ["确认订单", "地址"] +# any: ["提交订单", "在线支付", "餐具", "备注", "合计", "实付"] +# llm: +# prompt: 该步骤是否已进入订单确认或结算页面,且页面中包含了任务描述`{{task_description}}`中指定的商品信息? +# expected_true: true +# next: [submit_order] + +# - id: submit_order +# name: 提交订单并支付 +# deps: [go_to_checkout] +# condition: +# type: escalate +# params: +# ocr: +# any: ["支付成功", "下单成功", "订单已送出", "等待商家接单", "查看订单", "订单详情"] +# llm: +# prompt: 该步骤是否成功提交了订单并完成了支付,或者进入了订单成功的确认页面? +# expected_true: true +success: + any_of: [add_to_cart] diff --git a/MobiFlow/task_rules/example_checker_modes.yaml b/MobiFlow/task_rules/example_checker_modes.yaml new file mode 100644 index 0000000..821707d --- /dev/null +++ b/MobiFlow/task_rules/example_checker_modes.yaml @@ -0,0 +1,77 @@ +task_id: example_juxtaposition_demo +app_id: com.example.app +task_type: demo +description: 演示escalate和juxtaposition检查器的使用 + +#### +# 配置说明: +# 示例1: 使用escalate模式 - 按顺序尝试,任意一个成功即可 +# 示例2: 使用juxtaposition模式 - 所有配置的检查器都必须成功 +# ocr支持使用any和all来控制匹配逻辑 +# 1. 使用any表示任意一个匹配即可 +# 2. 使用all表示必须全部匹配 +#### + +nodes: + # 示例1: 使用escalate模式 - 按顺序尝试,任意一个成功即可 + - id: search_with_escalate + name: 使用escalate模式进行搜索验证 + condition: + type: escalate + params: + # 任意内容匹配到即可 + ocr: + any: ["搜索", "查询"] + # 最后使用LLM验证 + llm: + prompt: "该步是否执行了搜索操作?" + expected_true: true + + # 示例2: 使用juxtaposition模式 - 所有配置的检查器都必须成功 + - id: confirm_with_juxtaposition + deps: [search_with_escalate] + name: 使用juxtaposition模式进行严格验证 + condition: + type: juxtaposition + params: + # 必须同时满足:检测到所有关键词 + ocr: + all: ["确认", "提交"] + # 必须同时满足:LLM确认操作正确 + llm: + prompt: "该步是否点击了确认/提交按钮?" + expected_true: true + + # 示例3: 复杂的escalate配置,展示完整的升级链 + - id: complex_escalate_example + deps: [confirm_with_juxtaposition] + name: 复杂的escalate升级链示例 + condition: + type: escalate + params: + # 第6级:OCR图像识别 + ocr: + any: ["购物车", "加入购物车", "立即购买"] + pattern: "(购物车|cart)" + # 第7级:LLM最终验证 + llm: + prompt: "基于截图和上下文,该步是否成功将商品加入购物车?" + expected_true: true + + # 示例4: 只使用高级检查器的juxtaposition + - id: advanced_juxtaposition + deps: [complex_escalate_example] + name: 高级检查器的juxtaposition组合 + condition: + type: juxtaposition + params: + # 必须满足:OCR识别到相关元素 + ocr: + all: ["支付", "确认订单"] + # 必须满足:LLM确认操作 + llm: + prompt: "该步是否完成了购买/支付操作?" + expected_true: true + +success: + any_of: [advanced_juxtaposition] diff --git a/MobiFlow/task_rules/example_checker_ordes.yaml b/MobiFlow/task_rules/example_checker_ordes.yaml new file mode 100644 index 0000000..3d0eef6 --- /dev/null +++ b/MobiFlow/task_rules/example_checker_ordes.yaml @@ -0,0 +1,197 @@ +task_id: example_order_demo +app_id: com.example.app +task_type: demo +description: 演示多路径任务验证配置的完整示例,展示deps和next的使用 + +#### +# 配置说明: +# deps: AND语义 - 严格的前置依赖,所有listed节点必须先完成 +# next: OR语义 - 灵活的后继选择,可以进入任一listed节点 +# 本示例展示一个电商购物任务的多种路径配置 +#### + +nodes: + # 第1步:启动应用(入口节点) + - id: launch_app + name: 启动购物应用 + condition: + type: escalate + params: + ocr: + any: ["打开应用", "启动", "进入首页", "主页面"] + llm: + prompt: "该步是否成功启动了购物应用?" + expected_true: true + next: [search_entry, browse_category] # OR语义:可以选择搜索或浏览分类 + + # 路径A:搜索入口分支 + - id: search_entry + name: 进入搜索功能 + condition: + type: escalate + params: + ocr: + any: ["搜索", "search", "🔍"] + llm: + prompt: "该步是否点击了搜索功能?" + expected_true: true + next: [input_search_keyword] + + # 路径B:分类浏览分支 + - id: browse_category + name: 浏览商品分类 + condition: + type: escalate + params: + ocr: + any: ["分类", "类目", "商品分类", "category"] + llm: + prompt: "该步是否进入了商品分类页面?" + expected_true: true + next: [select_category] + + # 搜索路径继续 + - id: input_search_keyword + name: 输入搜索关键词 + condition: + type: escalate + params: + ocr: + any: ["输入", "搜索框", "关键词"] + llm: + prompt: "该步是否输入了任务task_description所需物品的搜索关键词?" + expected_true: true + next: [search_results] + + # 分类路径继续 + - id: select_category + name: 选择商品分类 + condition: + type: escalate + params: + ocr: + any: ["选择", "点击", "分类"] + llm: + prompt: "该步是否正确选择了商品分类?" + expected_true: true + next: [category_results] + + # 搜索结果页面 + - id: search_results + name: 查看搜索结果 + condition: + type: escalate + params: + ocr: + any: ["搜索结果", "共", "件商品", "结果"] + llm: + prompt: "该步是否显示了搜索结果?" + expected_true: true + next: [apply_filter, select_product] # OR语义:可以筛选或直接选择商品 + + # 分类结果页面 + - id: category_results + name: 查看分类商品列表 + condition: + type: escalate + params: + ocr: + any: ["商品列表", "分类商品", "商品"] + llm: + prompt: "该步是否显示了分类商品列表?" + expected_true: true + next: [apply_filter, select_product] # OR语义:同样可以筛选或直接选择 + + # 可选步骤:应用筛选条件 + - id: apply_filter + name: 应用筛选条件 + condition: + type: escalate + params: + ocr: + any: ["筛选", "排序", "销量", "价格", "评分", "filter"] + llm: + prompt: "该步是否按照任务task_description要求执行了相应的筛选或排序操作?" + expected_true: true + next: [select_product] + + # 选择商品(两条路径汇聚点) + - id: select_product + name: 选择目标商品 + condition: + type: escalate + params: + ocr: + any: ["选择", "点击", "商品", "详情"] + llm: + prompt: "该步是否选择了目标商品?" + expected_true: true + next: [add_to_cart, buy_now] # OR语义:可以加购物车或直接购买 + + # 路径C:加入购物车分支 + - id: add_to_cart + name: 加入购物车 + condition: + type: juxtaposition + params: + ocr: + any: ["加入购物车", "购物车", "cart"] + llm: + prompt: "该步是否将任务task_description所要求商品加入了购物车?" + expected_true: true + next: [go_to_cart] + + # 路径D:立即购买分支 + - id: buy_now + name: 立即购买 + condition: + type: juxtaposition + params: + ocr: + any: ["立即购买", "马上买", "buy now"] + llm: + prompt: "该步是否点击了立即购买?" + expected_true: true + next: [checkout] + + # 购物车路径继续 + - id: go_to_cart + name: 进入购物车 + condition: + type: escalate + params: + ocr: + any: ["购物车", "cart", "已加入"] + llm: + prompt: "该步是否进入了购物车页面?" + expected_true: true + next: [checkout] + + # 最终结算(两条路径最终汇聚) + - id: checkout + name: 结算/支付 + deps: [] # 无强制依赖,通过next路径到达即可 + condition: + type: escalate + params: + ocr: + any: ["结算", "支付", "确认订单", "checkout", "pay"] + llm: + prompt: "该步是否完成了结算或支付操作?" + expected_true: true + + # 可选:订单确认 + - id: order_confirmation + deps: [checkout] # AND语义:必须等待checkout完成 + name: 订单确认 + condition: + type: escalate + params: + ocr: + any: ["订单成功", "支付成功", "下单成功", "确认"] + llm: + prompt: "该步是否显示了订单确认信息?" + expected_true: true + +success: + any_of: [checkout, order_confirmation] # 完成结算或看到订单确认即为成功 diff --git a/MobiFlow/task_rules/feizhu/feizhu-type1.yaml b/MobiFlow/task_rules/feizhu/feizhu-type1.yaml new file mode 100644 index 0000000..fc65075 --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type1.yaml @@ -0,0 +1,50 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在飞猪应用中,通过搜索功能查询指定酒店的价格列表。 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "行程", "消息", "酒店", "机票", "火车票"] + llm: + prompt: 请判断当前界面是否为飞猪应用的首页或主界面? + expected_true: true + next: [enter_hotel_module] +- id: enter_hotel_module + name: 进入酒店搜索页 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "入住", "离店", "关键词", "搜索酒店"] + llm: + prompt: 请判断当前界面是否为酒店搜索功能的主页面,包含目的地、入住日期、离店日期和关键词搜索框等元素? + expected_true: true + next: [input_hotel_name] +- id: input_hotel_name + name: 输入酒店名称并搜索 + condition: + type: escalate + params: + ocr: + any: ["搜索", "请输入", "取消", "历史记录", "热门搜索"] + llm: + prompt: 请判断用户是否已经点击搜索框并准备输入或已经输入了酒店名称,或者已经点击了搜索按钮? + expected_true: true + next: [view_search_results] +- id: view_search_results + name: 查看酒店价格列表 + condition: + type: juxtaposition + params: + ocr: + any: ["¥", "起", "预订", "价格", "评分", "筛选", "订", "元"] + llm: + prompt: 请判断当前界面是否成功展示了酒店的搜索结果列表?列表中应包含多个酒店选项,并明确显示价格信息(如'¥'符号或'元')和'预订'等操作按钮。 + expected_true: true +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type2.yaml b/MobiFlow/task_rules/feizhu/feizhu-type2.yaml new file mode 100644 index 0000000..4684bf9 --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type2.yaml @@ -0,0 +1,61 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在飞猪应用中,根据用户指定的任意地标,查询并展示附近的酒店价格列表。 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "消息", "酒店", "机票", "火车票"] + llm: + prompt: 请判断当前界面是否为飞猪应用的首页或主界面? + expected_true: true + next: [navigate_to_hotel_search] +- id: navigate_to_hotel_search + name: 进入酒店搜索功能 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "客栈"] + llm: + prompt: 请判断用户是否已经点击并进入了酒店搜索功能的主界面? + expected_true: true + next: [input_destination] +- id: input_destination + name: 输入目的地或关键字 + condition: + type: escalate + params: + ocr: + any: ["目的地", "位置", "关键字", "想住哪儿", "城市/地标/酒店名"] + llm: + prompt: 请判断用户是否正在输入或已经输入了任务描述中要求的目的地(如地标、城市等)? + expected_true: true + next: [initiate_search] +- id: initiate_search + name: 设置日期并开始搜索 + condition: + type: escalate + params: + ocr: + any: ["入住", "离店", "选择日期", "查询", "搜索", "查找酒店"] + llm: + prompt: 请判断用户是否已经完成了目的地输入,并点击了'查询'或'搜索'按钮来查找酒店? + expected_true: true + next: [view_hotel_list] +- id: view_hotel_list + name: 查看酒店价格列表 + condition: + type: juxtaposition + params: + ocr: + any: ["价格", "筛选", "¥", "评分", "每晚", "综合排序", "酒店列表"] + llm: + prompt: 请判断当前界面是否成功展示了符合任务要求的酒店列表,列表中应包含多个酒店名称、价格(带有¥符号)和评分等信息? + expected_true: true +success: + any_of: [view_hotel_list] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type3.yaml b/MobiFlow/task_rules/feizhu/feizhu-type3.yaml new file mode 100644 index 0000000..1f46b9e --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type3.yaml @@ -0,0 +1,29 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在飞猪应用中查询指定城市和品牌的酒店及其价格 +nodes: +- id: start_search + name: 进入酒店搜索界面 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "关键词", "住/离店", "搜索"] + llm: + prompt: 当前界面是否为酒店搜索界面?用户可以在此界面输入城市、酒店名称或选择入住日期。 + expected_true: true + next: [view_search_results] +- id: view_search_results + name: 查看酒店搜索结果列表 + deps: [start_search] + condition: + type: juxtaposition + params: + ocr: + any: ["¥", "起", "价格", "评分", "筛选", "订", "每晚", "详情"] + llm: + prompt: 当前界面是否成功展示了符合用户查询意图(城市和酒店品牌)的酒店列表,并且列表中清晰地显示了酒店的价格信息(如'¥'符号或'xx元起')? + expected_true: true +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type4.yaml b/MobiFlow/task_rules/feizhu/feizhu-type4.yaml new file mode 100644 index 0000000..a012aee --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type4.yaml @@ -0,0 +1,54 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: travel_query +description: 在飞猪应用中,根据用户指定的城市和地标,查询附近的酒店及其价格。 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["酒店", "机票", "火车票", "我的", "首页", "去哪玩"] + llm: + prompt: 当前屏幕是否为飞猪应用的首页或主界面? + expected_true: true + next: [enter_hotel_module] +- id: enter_hotel_module + name: 进入酒店搜索模块 + deps: [launch_app] + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "入住", "离店", "关键词", "位置", "品牌"] + llm: + prompt: 用户是否已经进入了酒店搜索功能的界面,该界面包含目的地输入、日期选择等元素? + expected_true: true + next: [input_destination_and_dates] +- id: input_destination_and_dates + name: 输入目的地并选择日期 + deps: [enter_hotel_module] + condition: + type: escalate + params: + ocr: + any: ["选择目的地", "城市/地标", "入住", "离店", "选择日期", "日历", "确定", "完成"] + llm: + prompt: 用户是否正在输入目的地或在日历界面上选择入住和离店日期? + expected_true: true + next: [view_hotel_results] +- id: view_hotel_results + name: 查看酒店搜索结果列表 + deps: [input_destination_and_dates] + condition: + type: juxtaposition + params: + ocr: + all: ["价格", "筛选"] + any: ["¥", "起", "每晚", "评分", "综合排序", "查看详情", "订"] + llm: + prompt: 当前屏幕是否成功展示了符合任务要求的酒店列表,并且清晰地显示了酒店名称、价格(包含'¥'符号)等关键信息? + expected_true: true +success: + any_of: [view_hotel_results] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type5.yaml b/MobiFlow/task_rules/feizhu/feizhu-type5.yaml new file mode 100644 index 0000000..b6a43a2 --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type5.yaml @@ -0,0 +1,72 @@ +task_id: fliggy_hotel_price_query +app_id: com.taobao.trip +task_type: query +description: 在飞猪应用中根据指定城市、日期和酒店/地标查询酒店价格 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "行程", "消息", "酒店", "机票"] + llm: + prompt: 当前页面是否为飞猪应用的首页或主界面? + expected_true: true + next: [navigate_to_hotel] +- id: navigate_to_hotel + name: 进入酒店搜索 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿·客栈", "目的地/酒店/关键词", "搜酒店"] + llm: + prompt: 当前页面是否为酒店搜索的入口页面? + expected_true: true + next: [input_destination] +- id: input_destination + name: 输入目的地或酒店名 + condition: + type: escalate + params: + ocr: + any: ["目的地", "酒店", "位置", "关键词", "输入"] + llm: + prompt: 用户是否在当前页面输入了任务描述中指定的目的地、酒店或地标信息? + expected_true: true + next: [select_dates] +- id: select_dates + name: 选择入住和离店日期 + condition: + type: juxtaposition + params: + ocr: + all: ["入住", "离店", "日历", "确定"] + llm: + prompt: 用户是否在当前日历页面上,成功选择了任务描述中指定的入住和离店日期? + expected_true: true + next: [confirm_search] +- id: confirm_search + name: 点击查询按钮 + condition: + type: escalate + params: + ocr: + any: ["查询", "搜索", "查找酒店", "搜酒店", "完成"] + llm: + prompt: 用户是否点击了查询或搜索按钮以查找符合条件的酒店? + expected_true: true + next: [view_results] +- id: view_results + name: 查看酒店价格列表 + condition: + type: juxtaposition + params: + ocr: + any: ["价格", "筛选", "排序", "推荐", "¥", "起", "每晚", "综合推荐"] + llm: + prompt: 当前页面是否成功展示了符合任务描述中地点和日期要求的酒店列表及其价格信息? + expected_true: true +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type6.yaml b/MobiFlow/task_rules/feizhu/feizhu-type6.yaml new file mode 100644 index 0000000..8c8f7da --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type6.yaml @@ -0,0 +1,77 @@ +task_id: fliggy_hotel_search +app_id: com.taobao.trip +task_type: travel +description: 在飞猪应用中,根据指定的城市、地点、入住日期和晚数,查询酒店并查看结果列表。 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["酒店", "机票", "火车票", "旅行", "我的", "首页"] + llm: + prompt: 当前界面是否为飞猪应用首页或主功能界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索 + deps: [launch_app] + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "客栈", "住宿", "目的地"] + llm: + prompt: 当前操作是否进入了酒店搜索的入口界面,准备输入目的地和日期? + expected_true: true + next: [input_destination] +- id: input_destination + name: 输入目的地或关键词 + deps: [enter_hotel_search] + condition: + type: escalate + params: + ocr: + any: ["目的地", "位置", "关键词", "酒店名", "城市", "输入"] + llm: + prompt: 用户是否在当前界面输入或选择了任务描述中指定的城市和具体位置/酒店品牌(如'武汉大学'、'亚朵酒店')? + expected_true: true + next: [select_dates] +- id: select_dates + name: 选择入住和离店日期 + deps: [input_destination] + condition: + type: juxtaposition + params: + ocr: + any: ["入住", "离店", "日期", "日历", "共", "晚", "确定", "完成"] + llm: + prompt: 用户是否在当前界面通过日历等方式,选择了符合任务描述要求的入住日期和住宿晚数(例如:三天后入住,住1晚)? + expected_true: true + next: [click_search] +- id: click_search + name: 点击查询按钮 + deps: [select_dates] + condition: + type: escalate + params: + ocr: + any: ["查询", "搜索", "查找酒店", "搜酒店"] + llm: + prompt: 用户在设置好目的地和日期后,是否点击了“查询”或“搜索”按钮来查找酒店? + expected_true: true + next: [view_results] +- id: view_results + name: 查看酒店列表结果 + deps: [click_search] + condition: + type: juxtaposition + params: + ocr: + any: ["筛选", "排序", "价格", "评分", "¥", "起", "每晚", "酒店列表", "综合推荐"] + llm: + prompt: 当前界面是否成功展示了符合查询条件的酒店列表,并且能清晰地看到多个酒店的名称、价格、评分等核心信息? + expected_true: true +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/feizhu/feizhu-type7.yaml b/MobiFlow/task_rules/feizhu/feizhu-type7.yaml new file mode 100644 index 0000000..f32a04b --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type7.yaml @@ -0,0 +1,85 @@ +task_id: fliggy_select_hotel_room +app_id: com.taobao.trip +task_type: booking +description: 在飞猪应用中,根据特定条件(如价格、评价、距离、房型)筛选并选择酒店房间。 +nodes: +- id: launch_app + name: 启动飞猪应用 + condition: + type: escalate + params: + ocr: + any: ["飞猪", "首页", "我的", "酒店", "机票"] + llm: + prompt: 当前界面是否为飞猪应用的首页或主界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "入住", "离店", "搜索酒店"] + llm: + prompt: 当前界面是否为酒店搜索页面,可以输入目的地和日期? + expected_true: true + next: [view_hotel_list] +- id: view_hotel_list + name: 查看酒店列表 + condition: + type: escalate + params: + ocr: + any: ["家酒店", "筛选", "排序", "地图", "综合推荐"] + llm: + prompt: 当前界面是否展示了酒店搜索结果列表? + expected_true: true + next: [apply_filter] + - select_hotel +- id: apply_filter + name: 应用筛选或排序 + condition: + type: escalate + params: + ocr: + any: ["筛选", "排序", "价格", "距离", "评分", "好评优先", "价格优先", "距离优先", "确定"] + llm: + prompt: 用户是否根据任务描述(例如'价格最低'、'评价最好'或'距离最近')执行了相应的筛选或排序操作? + expected_true: true + next: [select_hotel] +- id: select_hotel + name: 选择酒店 + condition: + type: escalate + params: + ocr: + any: ["酒店详情", "房型", "设施", "评价", "预订"] + llm: + prompt: 用户是否从列表中选择了一家酒店并进入了其详情页面? + expected_true: true + next: [select_room] +- id: select_room + name: 选择目标房型 + condition: + type: juxtaposition + params: + ocr: + any: ["预订", "订", "立即预订", "订这间", "选择"] + llm: + prompt: 用户是否准确选择了任务描述中要求的房型(如'大床房'或'双床房'),并且该选择符合任务的排序/筛选要求(如'价格最低'、'评价最好'或'距离最近')? + expected_true: true + next: [confirm_booking] +- id: confirm_booking + name: 确认订单 + condition: + type: escalate + params: + ocr: + any: ["提交订单", "确认订单", "去支付", "订单详情"] + llm: + prompt: 当前界面是否为填写预订信息或确认订单的页面? + expected_true: true +success: + any_of: [select_room] + - confirm_booking diff --git a/MobiFlow/task_rules/feizhu/feizhu-type8.yaml b/MobiFlow/task_rules/feizhu/feizhu-type8.yaml new file mode 100644 index 0000000..d5767ea --- /dev/null +++ b/MobiFlow/task_rules/feizhu/feizhu-type8.yaml @@ -0,0 +1,83 @@ +task_id: hotel_booking_task +app_id: com.taobao.trip +task_type: booking +description: 验证用户在飞猪等旅行应用中,根据指定要求(地点、日期、房型)成功预订酒店的关键流程。 +nodes: +- id: launch_app + name: 启动应用并进入首页 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "酒店", "机票", "火车票", "去哪玩"] + llm: + prompt: 屏幕是否显示了旅行应用(如飞猪)的首页或主功能界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索模块 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "客栈", "Hotel"] + llm: + prompt: 用户是否已经点击并进入了酒店搜索功能的主界面? + expected_true: true + next: [input_search_criteria] +- id: input_search_criteria + name: 输入或确认搜索条件 + condition: + type: escalate + params: + ocr: + any: ["目的地", "我的位置", "入住", "离店", "关键词", "搜索", "查找酒店"] + llm: + prompt: 用户是否在当前页面输入或确认了任务描述中的核心搜索信息,如目的地、入住/离店日期或酒店名称/位置关键词? + expected_true: true + next: [view_hotel_list] +- id: view_hotel_list + name: 查看酒店列表 + condition: + type: escalate + params: + ocr: + any: ["综合排序", "筛选", "价格", "推荐", "评分", "列表", "地图"] + llm: + prompt: 屏幕是否展示了符合搜索条件的酒店列表,以供用户选择? + expected_true: true + next: [select_room_to_book] +- id: select_room_to_book + name: 选择房型并预订 + condition: + type: escalate + params: + ocr: + any: ["预订", "订", "立即预订", "订这间", "选择房型", "查看详情"] + llm: + prompt: 用户是否已经从酒店详情页或列表页中选择了符合任务要求的房型(如大床房、双床房)并点击了预订按钮? + expected_true: true + next: [fill_booking_details] +- id: fill_booking_details + name: 填写预订信息 + condition: + type: juxtaposition + params: + ocr: + any: ["订单填写", "入住人", "联系手机", "确认信息", "费用明细", "到店付"] + llm: + prompt: 屏幕是否跳转到了预订信息填写页面,要求用户输入住客姓名、联系方式等信息? + expected_true: true + next: [confirm_booking] +- id: confirm_booking + name: 提交订单或支付 + condition: + type: juxtaposition + params: + ocr: + any: ["提交订单", "去支付", "确认", "订单详情", "在线付", "信用住"] + llm: + prompt: 用户是否已到达最终的订单确认或支付页面? + expected_true: true +success: + any_of: [confirm_booking] diff --git a/MobiFlow/task_rules/gaode/gaode-type1.yaml b/MobiFlow/task_rules/gaode/gaode-type1.yaml new file mode 100644 index 0000000..ade37ff --- /dev/null +++ b/MobiFlow/task_rules/gaode/gaode-type1.yaml @@ -0,0 +1,73 @@ +task_id: map_navigation_generic +app_id: com.autonavi.minimap +task_type: navigation +description: 通用地图导航任务验证配置,覆盖从搜索目的地到开始导航的关键步骤。 +nodes: +# - id: launch_map +# name: 启动地图应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["我的位置", "搜索", "路线", "附近", "去哪"] +# llm: +# prompt: 该步是否成功启动了地图应用并进入了主界面? +# expected_true: true +# next: [input_destination] +# - id: input_destination +# name: 搜索目的地 +# condition: +# type: escalate +# params: +# ocr: +# any: ["搜索", "路线", "去这里", "详情", "公里"] +# llm: +# prompt: 该步是否已经输入或选择了任务描述中的目的地,并展示了相关的地点信息或路线规划入口? +# expected_true: true +# next: [plan_route] +# - id: plan_route +# name: 查看路线规划 +# condition: +# type: escalate +# params: +# ocr: +# any: ["路线", "方案", "时间", "公里", "打车", "公交", "驾车", "步行", "开始导航"] +# llm: +# prompt: 该步是否展示了从当前位置到目标地点的多种路线规划方案? +# expected_true: true +# next: [start_navigation] +# - id: start_navigation +# name: 开始导航 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["开始导航", "出发", "导航中", "全程", "剩余"] +# llm: +# prompt: 该步是否点击了'开始导航'或类似按钮,准备进入或已经进入实时导航界面? +# expected_true: true +# next: [navigation_active] +- id: result_detail + name: 目的地结果 + condition: + type: juxtaposition + params: + ocr: + all: ["周边", "分享", "打车", "导航", "路线"] + llm: + prompt: 是否已经正确查到到目的地结果,展示了相关的地点信息或路线规划入口? + expected_true: true + next: [navigation_active] + +- id: navigation_active + name: 导航进行中 + condition: + type: escalate + params: + ocr: + any: ["全览", "km", "继续导航"] + llm: + prompt: 该步是否已成功进入实时导航模式,界面上显示了行进路线、方向指引和距离等信息?或提供导航方式选择? + expected_true: true +success: + any_of: [start_navigation] diff --git a/MobiFlow/task_rules/gaode/gaode-type2.yaml b/MobiFlow/task_rules/gaode/gaode-type2.yaml new file mode 100644 index 0000000..d3113bc --- /dev/null +++ b/MobiFlow/task_rules/gaode/gaode-type2.yaml @@ -0,0 +1,42 @@ +task_id: ride_hailing_gaode +app_id: com.autonavi.minimap +task_type: ride_hailing +description: 验证在高德地图中完成打车到指定目的地的任务流程 +nodes: +- id: enter_ride_hailing_mode + name: 进入打车功能 + condition: + type: escalate + params: + ocr: + all: ["你要去哪", "预约", "代叫", "接送机"] + llm: + prompt: 当前界面是否为打车功能主界面,通常会显示地图、出发地和目的地输入框? + expected_true: true + next: [input_destination] + +- id: input_destination + name: 输正确目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["我的位置","收藏夹","地图选点"] + llm: + prompt: 用户是否已在目的地输入框中输入或确认了任务描述(task_description)中指定的地点? + expected_true: true + next: [confirm_ride_request] + +- id: confirm_ride_request + name: 等待确认车型并呼叫 + condition: + type: escalate + params: + ocr: + all: ["立即打车","预估"] + llm: + prompt: 界面是否已展示了路线、预估价格和可用的车型,并出现了'呼叫'或'确认'等按钮,让用户可以发起叫车请求? + expected_true: true + +success: + any_of: [confirm_ride_request] diff --git a/MobiFlow/task_rules/gaode/gaode-type3.yaml b/MobiFlow/task_rules/gaode/gaode-type3.yaml new file mode 100644 index 0000000..2de4493 --- /dev/null +++ b/MobiFlow/task_rules/gaode/gaode-type3.yaml @@ -0,0 +1,55 @@ +task_id: ride_hailing_amap +app_id: com.autonavi.minimap +task_type: ride_hailing +description: 使用高德地图完成打车任务,从指定起点到终点呼叫网约车。 +nodes: +- id: enter_ride_hailing_mode + name: 进入打车功能 + condition: + type: escalate + params: + ocr: + all: ["你要去哪", "预约", "代叫", "接送机"] + llm: + prompt: 当前界面是否为打车功能主界面,通常会显示地图、出发地和目的地输入框? + expected_true: true + next: [input_departure] + +- id: input_departure + name: 输正确出发地 + condition: + type: juxtaposition + params: + ocr: + # all: ["我的位置","收藏夹","地图选点"] + all: ["你要去哪", "预约", "代叫", "接送机"] + llm: + prompt: 用户是否已在出发地输入框中输入或确认了任务描述(task_description)中指定的起点位置? + expected_true: true + next: [input_destination] + +- id: input_destination + name: 输正确目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["我的位置","收藏夹","地图选点","途经点"] + llm: + prompt: 用户是否已在目的地输入框中输入或确认了任务描述(task_description)中指定的地点? + expected_true: true + next: [confirm_ride_request] + +- id: confirm_ride_request + name: 等待确认车型并呼叫 + condition: + type: escalate + params: + ocr: + all: ["立即打车","预估"] + llm: + prompt: 界面是否已展示了路线、预估价格和可用的车型,并出现了'呼叫'或'确认'等按钮,让用户可以发起叫车请求? + expected_true: true + +success: + any_of: [confirm_ride_request] diff --git a/MobiFlow/task_rules/taobao/type1-taobao-search.yaml b/MobiFlow/task_rules/taobao/type1-taobao-search.yaml new file mode 100644 index 0000000..7e1c61b --- /dev/null +++ b/MobiFlow/task_rules/taobao/type1-taobao-search.yaml @@ -0,0 +1,101 @@ +task_id: type1-taobao-search +app_id: com.taobao.taobao +task_type: type3 +description: 搜索<选择条件><商品描述> + +nodes: + # - id: open_app_home + # name: 打开App首页(存在搜索入口) + # condition: + # type: escalate + # params: + # ocr: + # all: ["视频", "消息", "搜索"] + # pattern: "(搜索|查询)" + + - id: activate_search + # deps: [open_app_home] + next: [input_keyword] + name: 激活搜索输入 + condition: + type: escalate + params: + ocr: + all: ["搜索", "猜你想搜", "历史搜索"] + llm: + prompt: "该步是否在激活搜索输入框?" + expected_true: true + + - id: input_keyword + # deps: [activate_search] + next: [apply_filter_condition,results_page] + name: 输入搜索关键词 + condition: + type: escalate + params: + llm: + prompt: "该步是否向搜索框输入了查询物品关键词?" + expected_true: true + + - id: apply_filter_condition + # deps: [input_keyword] + next: [results_page] + name: 应用过滤条件 + condition: + type: escalate + params: + # dynamic_match: + # extract_from: task_description + # condition_patterns: + # price_lowest: + # trigger_keywords: ["价格最低", "最便宜", "价格从低到高"] + # verify_keywords: ["价格", "低到高", "便宜", "最低"] + # llm_prompt: "该步骤是否执行了按价格从低到高排序的操作?" + # price_highest: + # trigger_keywords: ["价格最高", "最贵", "价格从高到低"] + # verify_keywords: ["价格", "高到低", "贵", "最高"] + # llm_prompt: "该步骤是否执行了按价格从高到低排序的操作?" + # sales_highest: + # trigger_keywords: ["销量最高", "销量最多", "销量"] + # verify_keywords: ["销量", "最高", "最多"] + # llm_prompt: "该步骤是否执行了按销量从高到低排序的操作?" + # sales_lowest: + # trigger_keywords: ["销量最低", "销量最少"] + # verify_keywords: ["销量", "最低", "最少"] + # llm_prompt: "该步骤是否执行了按销量从低到高排序的操作?" + # verification_fields: ["reasoning", "text"] + # fallback_llm: true + llm: + prompt: "该步是否按照任务要求执行了相应的筛选或排序操作(如按价格排序、按销量排序等)?请结合任务描述和当前操作判断。若任务task_description没有排序、筛选要求则检查均返回False" + # prompt: "该步是否按照任务要求执行了相应的筛选或排序操作(如按价格排序、按销量排序等)?请结合任务描述和当前操作判断。" + expected_true: true + + - id: results_page + # deps: [apply_filter_condition] + next: [item_detail] + name: 显示搜索结果列表 + condition: + type: escalate + params: + # 后续考虑增加图标匹配 + # ui: + # any: ["拍照"] + ocr: + all: ["销量", "天猫", "店铺"] + llm: + prompt: "该步是否正确呈现了满足任务task_description要求的完整的搜索结果列表?" + expected_true: true + + - id: item_detail + name: 显示商品详情 + condition: + type: escalate + params: + ocr: + all: ["店铺", "客服", "收藏"] + llm: + prompt: "该步是否正确呈现了满足任务task_description要求的商品详情页?" + expected_true: true + +success: + any_of: [item_detail] diff --git a/MobiFlow/task_rules/taobao/type2-taobao-search-open-add_cart.yaml b/MobiFlow/task_rules/taobao/type2-taobao-search-open-add_cart.yaml new file mode 100644 index 0000000..dabd7b4 --- /dev/null +++ b/MobiFlow/task_rules/taobao/type2-taobao-search-open-add_cart.yaml @@ -0,0 +1,73 @@ +task_id: type2-taobao-search-open-detail +app_id: com.taobao.taobao +task_type: type2 +description: 将<商品描述>加入购物车 + +nodes: + - id: activate_search + name: 激活搜索输入 + condition: + type: escalate + params: + ocr: + all: ["搜索", "猜你想搜", "历史搜索"] + llm: + prompt: "该步是否在激活淘宝的搜索输入框?" + expected_true: true + + - id: input_keyword + deps: [activate_search] + name: 输入搜索关键词 + condition: + type: escalate + params: + llm: + prompt: "该步是否向搜索框输入了查询商品的关键词?(限定任务task_description要求所需商品相关)" + expected_true: true + + - id: results_page + deps: [input_keyword] + name: 显示搜索结果列表 + condition: + type: escalate + params: + ocr: + all: ["销量", "天猫", "店铺"] + llm: + prompt: "该步是否呈现了搜索结果列表页面?" + expected_true: true + + - id: open_detail + deps: [results_page] + name: 进入某个商品详情 + condition: + type: escalate + params: + ocr: + all: ["店铺", "客服", "收藏"] + llm: + prompt: "是否点击某个满足task_description要求的搜索结果并进入了商品详情页?" + expected_true: true + + - id: confirm_add_to_cart + deps: [open_detail] + name: 确认加入购物车 + condition: + type: escalate + params: + # 仅ocr判断 + ocr: + all: ["购物车", "购买", "加购成功"] + + - id: add_to_cart + deps: [open_detail] + name: 加入购物车 + condition: + type: escalate + params: + llm: + prompt: "该步是否将任务task_description要求物品加入购物车?" + expected_true: true + +success: + any_of: [add_to_cart, confirm_add_to_cart] diff --git a/MobiFlow/task_rules/taobao/type3-taobao_add_cart-new.yaml b/MobiFlow/task_rules/taobao/type3-taobao_add_cart-new.yaml new file mode 100644 index 0000000..e291113 --- /dev/null +++ b/MobiFlow/task_rules/taobao/type3-taobao_add_cart-new.yaml @@ -0,0 +1,180 @@ +task_id: taobao_generic_search +app_id: com.taobao.taobao +task_type: type3 +description: 在淘宝App中将<选择条件>的<商品描述>加入购物车 + +nodes: + # - id: open_app_home + # name: 打开App首页(存在搜索入口) + # condition: + # type: escalate + # params: + # # ui: + # # key: package + # # equals: com.taobao.taobao + # ocr: + # all: ["视频", "消息", "搜索"] + # pattern: "(搜索|查询)" + + - id: activate_search + # deps: [open_app_home] + next: [input_keyword] + name: 激活搜索输入 + score: 10 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: click + # ui: + # any: ["拍照"] + ocr: + all: ["搜索", "猜你想搜", "历史搜索"] + llm: + prompt: "该步是否在激活搜索输入框?" + expected_true: true + + - id: input_keyword + # deps: [activate_search] + next: [results_page] + name: 输入搜索关键词 + score: 10 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: input + # ocr: + # any: ["取消"] + # 在每一次llm调用中,都会将当前的task_description传入 + llm: + prompt: "该节点是否向搜索框输入了task_description所需查询物品关键词?" + expected_true: true + + - id: results_page + # deps: [input_keyword] + next: [apply_filter_condition_first] + name: 显示搜索结果列表 + score: 10 + condition: + type: escalate + params: + # ui: + # any: ["拍照"] + ocr: + all: ["销量", "天猫", "店铺"] + llm: + prompt: "该步是否呈现了搜索结果列表?" + expected_true: true + +# 先进行选择,再进入商品界面 + + - id: apply_filter_condition_first + # deps: [results_page] + next: [item_detail_second] + name: 按照任务要求应用筛选条件 + score: 10 + condition: + type: escalate + params: + llm: + # prompt: "该步是否按照任务task_description期望正确执行了选择筛选、排序等操作(如按价格、销量、评价、店铺等)?请结合任务task_description描述和当前操作判断。" + # prompt: "该步是否按照任务要求执行了相应的筛选、排序等操作(如按价格排序、按销量排序、选择天猫旗舰店等)?请结合任务task_description描述和当前操作判断。" + prompt: "该节点是否按照任务task_description期望正确执行了选择筛选、排序等操作(如按价格、销量、评价、店铺等)?请结合任务task_description描述和当前状态、操作判断。" + expected_true: true + + + - id: item_detail_second + # deps: [apply_filter_condition] + next: [confirm_add_to_cart, add_to_cart] + name: 进入详情页 + score: 15 + condition: + type: escalate + params: + ocr: + all: ["收藏", "店铺", "客服"] + llm: + prompt: "是否已进入某个结果的详情页?" + expected_true: true + +# # 先进入商品页面,再进行筛选 +# - id: item_detail_first +# # deps: [apply_filter_condition] +# next: [apply_filter_condition_second] +# name: 进入详情页 +# score: 15 +# condition: +# type: escalate +# params: +# ocr: +# all: ["收藏", "店铺", "客服"] +# llm: +# prompt: "是否已进入某个结果的详情页?" +# expected_true: true + +# - id: apply_filter_condition_second +# # deps: [item_detail_first] +# next: [add_to_cart] +# name: 按照任务要求应用筛选条件 +# score: 10 +# condition: +# type: escalate +# params: +# llm: +# prompt: "该步是否按照任务task_description期望正确执行了选择筛选、排序等操作(如按价格、销量、评价、店铺等)?请结合任务task_description描述和当前操作判断。" +# # prompt: "该步是否按照任务要求执行了相应的筛选、排序等操作(如按价格排序、按销量排序、选择天猫旗舰店等)?请结合任务task_description描述和当前操作判断。" +# expected_true: true + +############ +# 可以考虑增加两个节点,判断不同的加入购物车方式 +# 1. 在商品详情页加入购物车 +# 2. 在搜索结果页加入购物车 +# 3. 模型判断和ocr识别判断 +########### +# - id: add_to_cart +# # deps: [item_detail] +# name: 加入购物车 +# score: 25 +# condition: +# type: escalate +# params: +# # action: +# # type: action_match +# # params: +# # type: click +# ocr: +# all: ["购物车", "购买", "加购成功"] +# llm: +# prompt: "该步是否将任务task_description期望的物品加入购物车?" +# expected_true: true + +# success: +# # any_of: [results_page, item_detail] +# any_of: [add_to_cart] + + - id: confirm_add_to_cart + # deps: [apply_filter_condition_second] + name: 确认加入购物车 + condition: + type: escalate + params: + # 仅ocr判断 + ocr: + all: ["购物车", "购买", "加购成功"] + + - id: add_to_cart + # deps: [apply_filter_condition_second] + name: 加入购物车 + condition: + type: escalate + params: + llm: + prompt: "该节点是否将任务task_description要求的物品加入购物车?" + expected_true: true +success: + any_of: [confirm_add_to_cart, add_to_cart] \ No newline at end of file diff --git a/MobiFlow/task_rules/taobao/type3-taobao_add_cart.yaml b/MobiFlow/task_rules/taobao/type3-taobao_add_cart.yaml new file mode 100644 index 0000000..02487af --- /dev/null +++ b/MobiFlow/task_rules/taobao/type3-taobao_add_cart.yaml @@ -0,0 +1,148 @@ +task_id: taobao_generic_search +app_id: com.taobao.taobao +task_type: type3 +description: 在淘宝App中将<选择条件>的<商品描述>加入购物车 + +nodes: + - id: activate_search + # deps: [open_app_home] + name: 激活搜索输入 + score: 10 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: click + # ui: + # any: ["拍照"] + ocr: + all: ["搜索", "猜你想搜", "历史搜索"] + llm: + prompt: "该步是否在激活搜索输入框?" + expected_true: true + + - id: input_keyword + deps: [activate_search] + name: 输入搜索关键词 + score: 10 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: input + # ocr: + # any: ["取消"] + # 在每一次llm调用中,都会将当前的task_description传入 + llm: + prompt: "该步是否向搜索框输入了查询物品关键词?" + expected_true: true + + - id: results_page + deps: [input_keyword] + name: 显示搜索结果列表 + score: 10 + condition: + type: escalate + params: + # ui: + # any: ["拍照"] + ocr: + all: ["销量", "天猫", "店铺"] + llm: + prompt: "该步是否呈现了搜索结果列表?" + expected_true: true + + - id: apply_filter_condition + deps: [results_page] + name: 按照任务要求应用筛选条件 + score: 10 + condition: + type: escalate + params: + # action: + # type: click + # dynamic_match: + # extract_from: task_description + # condition_patterns: + # price_lowest: + # trigger_keywords: ["价格最低", "最便宜", "价格从低到高"] + # verify_keywords: ["价格", "低到高", "便宜", "最低"] + # llm_prompt: "该步骤是否执行了按价格从低到高排序的操作?" + # price_highest: + # trigger_keywords: ["价格最高", "最贵", "价格从高到低"] + # verify_keywords: ["价格", "高到低", "贵", "最高"] + # llm_prompt: "该步骤是否执行了按价格从高到低排序的操作?" + # sales_highest: + # trigger_keywords: ["销量最高", "销量最多", "销量"] + # verify_keywords: ["销量", "最高", "最多"] + # llm_prompt: "该步骤是否执行了按销量从高到低排序的操作?" + # sales_lowest: + # trigger_keywords: ["销量最低", "销量最少"] + # verify_keywords: ["销量", "最低", "最少"] + # llm_prompt: "该步骤是否执行了按销量从低到高排序的操作?" + # verification_fields: ["reasoning", "text"] + # fallback_llm: true + # ocr: + # all: ["价格"] + # pattern: "(高|低)" + llm: + prompt: "该步是否按照任务task_description期望正确执行了选择筛选、排序等操作(如按价格、销量、评价、店铺等)?请结合任务task_description描述和当前操作判断。" + # prompt: "该步是否按照任务要求执行了相应的筛选、排序等操作(如按价格排序、按销量排序、选择天猫旗舰店等)?请结合任务task_description描述和当前操作判断。" + expected_true: true + + + # - id: select_item + # deps: [apply_filter_condition] + # name: 选择某个结果项 + # score: 15 + # condition: + # type: escalate + # params: + # # action: + # # type: action_match + # # params: + # # type: click + # # ocr: + # # any: ["天猫", "销量"] + # llm: + # prompt: "是否点击了一个符合任务task_description要求条件的搜索结果进入详情?" + # expected_true: true + + - id: item_detail + deps: [apply_filter_condition] + name: 进入详情页 + score: 15 + condition: + type: escalate + params: + ocr: + all: ["收藏", "店铺", "客服"] + llm: + prompt: "是否已进入某个结果的详情页?" + expected_true: true + + - id: confirm_add_to_cart + deps: [item_detail, apply_filter_condition] + name: 确认加入购物车 + condition: + type: escalate + params: + # 仅ocr判断 + ocr: + all: ["购物车", "购买", "加购成功"] + + - id: add_to_cart + deps: [item_detail, apply_filter_condition] + name: 加入购物车 + condition: + type: escalate + params: + llm: + prompt: "该步是否将任务task_description要求的物品加入购物车?" + expected_true: true +success: + any_of: [confirm_add_to_cart, add_to_cart] diff --git a/MobiFlow/task_rules/taobao/type4-taobao_add_cart.yaml b/MobiFlow/task_rules/taobao/type4-taobao_add_cart.yaml new file mode 100644 index 0000000..799385b --- /dev/null +++ b/MobiFlow/task_rules/taobao/type4-taobao_add_cart.yaml @@ -0,0 +1,123 @@ +task_id: taobao_generic_search +app_id: com.taobao.taobao +task_type: type4 +description: 在淘宝App中将<参数规格>的<商品描述>加入购物车 + +nodes: + + - id: activate_search + name: 激活搜索输入 + condition: + type: escalate + params: + ocr: + all: ["搜索", "猜你想搜", "历史搜索"] + llm: + prompt: "该节点是否在激活搜索输入框?" + expected_true: true + + - id: input_keyword + deps: [activate_search] + name: 输入搜索关键词 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: input + # ocr: + # any: ["取消"] + # 在每一次llm调用中,都会将当前的task_description传入 + llm: + prompt: "该节点是否向搜索框输入任务task_description要求的查询物品关键词?" + expected_true: true + + + - id: results_page + deps: [input_keyword] + name: 显示搜索结果列表 + condition: + type: escalate + params: + ocr: + all: ["销量", "天猫", "店铺"] + # all: ["销量", "天猫", "店铺", "上新"] + llm: + prompt: "该节点是否呈现了任务task_description要求的搜索结果列表?" + expected_true: true + + - id: select_item + deps: [results_page] + name: 选择某个结果项 + condition: + type: escalate + params: + # action: + # type: action_match + # params: + # type: click + # ocr: + # any: ["天猫", "销量"] + llm: + prompt: "是否选择了一个符合任务task_description要求的搜索结果详情?" + expected_true: true + + # 可能检查的内容和上一步重叠,导致使用了相同的frame + # - id: item_detail + # deps: [select_item] + # name: 进入详情页 + # condition: + # type: escalate + # params: + # ocr: + # all: ["收藏", "店铺", "客服"] + # llm: + # prompt: "是否已进入某个结果的详情页?" + # expected_true: true + + - id: select_parameters + deps: [select_item] + name: 选择商品参数 + condition: + type: escalate + params: + # ocr: + # any: ["颜色", "尺寸", "尺码", "规格", "型号"] + llm: + prompt: "该节点是否正确选择了任务task_description要求的商品的各项参数规格?(如颜色/尺码/型号等)**未正确选择**或者**该商品无任务要求的参数规格**均视为错误。" + expected_true: true + + # - id: add_to_cart + # deps: [select_parameters] + # name: 加入购物车 + # condition: + # type: escalate + # params: + # ocr: + # all: ["购物车", "购买", "加购成功"] + # llm: + # prompt: "该步是否将任务task_description要求的物品加入购物车?" + # expected_true: true + + - id: confirm_add_to_cart + deps: [select_parameters] + name: 确认加入购物车 + condition: + type: escalate + params: + # 仅ocr判断 + ocr: + all: ["购物车", "购买", "加购成功"] + + - id: add_to_cart + deps: [select_parameters] + name: 加入购物车 + condition: + type: escalate + params: + llm: + prompt: "该节点是否将任务task_description要求的物品加入购物车?" + expected_true: true +success: + any_of: [add_to_cart, confirm_add_to_cart] diff --git a/MobiFlow/task_rules/taobao/type5-taobao-select-specification.yaml b/MobiFlow/task_rules/taobao/type5-taobao-select-specification.yaml new file mode 100644 index 0000000..32732be --- /dev/null +++ b/MobiFlow/task_rules/taobao/type5-taobao-select-specification.yaml @@ -0,0 +1,105 @@ +task_id: type5-taobao-select-specification +app_id: com.taobao.taobao +task_type: type5 +description: 将<选择条件><参数规格>的<商品描述>加入购物车 + +nodes: + # 当前省略之前的主页检查、激活搜索框检查和输入关键词检查 + + + - id: results_page + name: 显示搜索结果列表 + condition: + type: escalate + params: + ocr: + all: ["销量", "天猫", "店铺"] + llm: + prompt: "该步是否呈现了完整、正确的搜索结果列表?" + expected_true: true + + - id: apply_filter_condition + deps: [results_page] + name: 按照任务要求应用筛选条件 + condition: + type: escalate + params: + # dynamic_match: + # extract_from: task_description + # condition_patterns: + # price_lowest: + # trigger_keywords: ["价格最低", "最便宜", "价格从低到高"] + # verify_keywords: ["价格", "低到高", "便宜", "最低"] + # llm_prompt: "该步骤是否执行了按价格从低到高排序的操作?" + # price_highest: + # trigger_keywords: ["价格最高", "最贵", "价格从高到低"] + # verify_keywords: ["价格", "高到低", "贵", "最高"] + # llm_prompt: "该步骤是否执行了按价格从高到低排序的操作?" + # sales_highest: + # trigger_keywords: ["销量最高", "销量最多", "销量"] + # verify_keywords: ["销量", "最高", "最多"] + # llm_prompt: "该步骤是否执行了按销量从高到低排序的操作?" + # sales_lowest: + # trigger_keywords: ["销量最低", "销量最少"] + # verify_keywords: ["销量", "最低", "最少"] + # llm_prompt: "该步骤是否执行了按销量从低到高排序的操作?" + # verification_fields: ["reasoning", "text"] + # fallback_llm: true + # ocr: + # all: ["价格"] + # pattern: "(高|低)" + llm: + prompt: "该步是否按照任务task_description要求正确执行了相应的筛选或排序操作(如按价格、销量、评价、店铺等)?请结合任务task_description描述和当前实际操作判断。" + expected_true: true + + - id: item_detail + deps: [apply_filter_condition] + name: 进入商品详情页 + condition: + type: escalate + params: + ocr: + all: ["收藏", "店铺", "客服"] + # all: ["收藏", "店铺", "客服", "购物车", "购买"] + + llm: + prompt: "当前是否已进入淘宝商品详情页?" + expected_true: true + + # 与select_spec存在冗余,可能使用相同帧。可针对默认参数已经等于期望值,或者没有指定值的情况,考虑合并 + # - id: open_spec_panel + # deps: [item_detail] + # name: 打开规格选择面板 + # condition: + # type: escalate + # params: + # ocr: + # any: ["规格", "颜色", "尺码", "型号"] + # llm: + # prompt: "该步是否打开了规格选择面板或展开了规格区域?" + # expected_true: true + + - id: select_spec + deps: [item_detail] + name: 选择规格项 + condition: + type: escalate + params: + llm: + prompt: "该步是否正确选择了任务task_description要求的商品的各项参数规格?(如颜色/尺码/型号等)**未正确选择**或者**该商品无任务要求的参数规格**均视为错误。" + expected_true: true + + - id: add_to_cart + deps: [select_spec] + name: 加入购物车 + condition: + type: escalate + params: + ocr: + all: ["购物车", "购买", "加购成功"] + llm: + prompt: "该步是否将任务task_description要求物品加入购物车?" + expected_true: true + +success: + any_of: [add_to_cart] diff --git a/MobiFlow/task_rules/weixin/weixin-type1.yaml b/MobiFlow/task_rules/weixin/weixin-type1.yaml new file mode 100644 index 0000000..add1686 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type1.yaml @@ -0,0 +1,46 @@ +task_id: wechat_send_message +app_id: com.tencent.mm +task_type: communication +description: 在微信中给指定联系人或群聊发送消息 +nodes: +- id: find_contact_entry + name: 查找联系人或群聊 + condition: + type: escalate + params: + icons: + all: ["icon_001_通讯录","icon_002_微信","icon_000_我"] + ocr: + all: ["微信", "通讯录","发现","我"] + llm: + prompt: 当前页面是否为微信主界面、通讯录或搜索页面,表明用户正准备查找联系人或群聊? + expected_true: true + next: [send_message_success] + +# - id: enter_chat_window +# name: 进入聊天窗口 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["发送", "按住 说话", "切换为键盘", "视频通话", "语音输入", "+"] +# llm: +# prompt: 当前页面是否为与任务指令中指定联系人或群聊的聊天窗口?请检查页面顶部的名称是否与任务目标一致,并且底部有文字输入框。 +# expected_true: true +# next: [send_message_success] + +- id: send_message_success + name: 成功发送消息 + condition: + type: juxtaposition + params: + icons: + any: ["icon_001_回车","icon_002_发送"] + ocr: + all: ["发送"] + llm: + prompt: 当前页面是否为与任务指令中指定联系人或群聊的聊天窗口?请检查页面顶部的对象名称是否与任务目标一致,文字输入框中内容是否与任务要求一致。 + expected_true: true + +success: + any_of: [send_message_success] diff --git a/MobiFlow/task_rules/weixin/weixin-type2.yaml b/MobiFlow/task_rules/weixin/weixin-type2.yaml new file mode 100644 index 0000000..589c2f6 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type2.yaml @@ -0,0 +1,87 @@ +task_id: wechat_make_call +app_id: com.tencent.mm +task_type: communication +description: 在微信中搜索并选择联系人,然后发起语音或视频通话。 +nodes: +# - id: find_contact +# name: 查找或选择联系人 +# condition: +# type: escalate +# params: +# ocr: +# any: ["搜索", "通讯录", "聊天", "发起群聊"] +# llm: +# prompt: 该步骤是否进入了可以查找或选择联系人的界面(如聊天列表、通讯录或搜索页)? +# expected_true: true +# next: [chat_page] +# - id: chat_page +# name: 进入聊天或详情页 +# condition: +# type: escalate +# params: +# ocr: +# any: ["发消息", "+", "音视频通话", "设置"] +# llm: +# prompt: 该步骤是否成功进入了与目标联系人的聊天或联系人详情界面? +# expected_true: true +# next: [initiate_call_menu] +# - id: initiate_call_menu +# name: 打开通话选项菜单 +# condition: +# type: escalate +# params: +# ocr: +# any: ["视频通话", "语音通话", "位置", "红包", "转账"] +# llm: +# prompt: 该步骤是否打开了包含'视频通话'和'语音通话'选项的功能菜单或页面? +# expected_true: true +# next: [start_voice_call] +# - start_video_call +# - id: start_voice_call +# name: 发起语音通话 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["语音通话", "等待对方接受", "邀请", "取消", "免提", "扬声器"] +# llm: +# prompt: 该步骤是否成功发起了语音通话,界面是否显示正在呼叫或等待对方接听的状态? +# expected_true: true +# - id: start_video_call +# name: 发起视频通话 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["视频通话", "等待对方接受", "邀请", "取消", "切换到语音", "转换摄像头"] +# llm: +# prompt: 该步骤是否成功发起了视频通话,界面是否显示正在呼叫或等待对方接听的状态? +# expected_true: true + + +- id: initiate_call_menu + name: 打开通话选项菜单 + condition: + type: juxtaposition + params: + ocr: + all: ["相册", "拍摄", "视频通话", "位置", "红包", "礼物"] + llm: + prompt: 该步骤是否在**任务要求的联系人**界面否打开了包含'视频通话'和'语音通话'选项的功能菜单或页面? + expected_true: true + next: [start_video_call] + +- id: start_video_call + name: 发起音视频通话 + condition: + type: escalate + params: + ocr: + any: ["视频通话", "语言通话"] + llm: + # prompt: 该步骤是否成功发起了视频通话,界面是否显示正在呼叫或等待对方接听的状态? + prompt: 该步骤是否在指定联系人聊天界面打开了视频通话或语音通话菜单? + expected_true: true + +success: + any_of: [start_video_call] diff --git a/MobiFlow/task_rules/weixin/weixin-type3.yaml b/MobiFlow/task_rules/weixin/weixin-type3.yaml new file mode 100644 index 0000000..6af1e42 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type3.yaml @@ -0,0 +1,44 @@ +task_id: wechat_view_moments +app_id: com.tencent.mm +task_type: social_browsing +description: 验证在微信中打开朋友圈信息流或特定联系人朋友圈的任务流程。 +nodes: +# - id: start_wechat +# name: 启动微信或进入主界面 +# condition: +# type: escalate +# params: +# ocr: +# any: ["微信", "通讯录", "发现", "我", "Chats", "Contacts", "Discover"] +# llm: +# prompt: 该步是否成功打开微信并进入主界面(如聊天、通讯录或发现页)? +# expected_true: true +# next: [view_general_moments] + +- id: view_contact_info + name: 进入朋友个人页 + condition: + type: escalate + params: + ocr: + any: ["朋友资料", "发消息", "音视频通话"] + llm: + prompt: 当前页面是否为微信的朋友主页,即显示该朋友的个人信息和朋友圈入口? + expected_true: true + next: [view_specific_moments] + + +- id: view_specific_moments + name: 查看特定联系人朋友圈 + condition: + type: escalate + params: + # ocr: + # all: ["朋友圈"] + # any: ["封面", "相册", "动态"] + llm: + prompt: 当前页面是否为任务指定联系人的个人朋友圈页面,即只显示该联系人自己发布的动态? + expected_true: true + +success: + any_of: [view_specific_moments] diff --git a/MobiFlow/task_rules/weixin/weixin-type4.yaml b/MobiFlow/task_rules/weixin/weixin-type4.yaml new file mode 100644 index 0000000..9df6b92 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type4.yaml @@ -0,0 +1,43 @@ +task_id: wechat_search_chat_history +app_id: com.tencent.mm +task_type: search +description: 在微信中查找指定联系人、群聊或关键词的聊天记录。 +nodes: +- id: launch_wechat + name: 启动微信并进入联系人界面 + condition: + type: escalate + params: + ocr: + all: ["聊天信息", "查找聊天记录"] + llm: + prompt: 该步是否成功进入了与任务指定联系人的聊天信息界面,展示有查找聊天记录的入口? + expected_true: true + next: [enter_search_page] + +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["取消", "搜索指定内容", "日期", "图片与视频", "文件"] + llm: + prompt: 该步是否成功进入了全局搜索页面? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 输入关键词并查看搜索结果 + condition: + type: escalate + params: + ocr: + all: ["全部","图片", "文件", "链接"] + llm: + prompt: 该步是否在搜索框中输入任务要求的关键词,并且下方展示了搜索结果列表? + expected_true: true + + +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/weixin/weixin-type5.yaml b/MobiFlow/task_rules/weixin/weixin-type5.yaml new file mode 100644 index 0000000..f10df87 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type5.yaml @@ -0,0 +1,43 @@ +task_id: wechat_search_chat_history +app_id: com.tencent.mm +task_type: search +description: 在微信中查找指定联系人、群聊中指定关键词的聊天记录。 +nodes: +- id: launch_wechat + name: 启动微信并进入联系人界面 + condition: + type: escalate + params: + ocr: + all: ["聊天信息", "查找聊天记录"] + llm: + prompt: 该步是否成功进入了与任务指定联系人的聊天信息界面,展示有查找聊天记录的入口? + expected_true: true + next: [enter_search_page] + +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["取消", "搜索指定内容", "日期", "图片与视频", "文件"] + llm: + prompt: 该步是否成功进入了全局搜索页面? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 输入关键词并查看搜索结果 + condition: + type: escalate + params: + # ocr: + # any: ["联系人", "群聊", "聊天记录", "最常使用", "Contacts", "Group Chats", "Chat History"] + llm: + prompt: 该步是否在搜索框中输入了关键词,并且下方展示了搜索结果列表? + expected_true: true + + +success: + any_of: [enter_search_page, view_search_results] diff --git a/MobiFlow/task_rules/weixin/weixin-type6.yaml b/MobiFlow/task_rules/weixin/weixin-type6.yaml new file mode 100644 index 0000000..719a994 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type6.yaml @@ -0,0 +1,45 @@ +task_id: wechat_send_group_message_with_at +app_id: com.tencent.mm +task_type: communication +description: 在微信群聊中发送消息,并可能@指定成员 +nodes: +- id: launch_specific + name: 启动@指定人的窗口 + score: 30 + condition: + type: escalate + params: + ocr: + all: ["选择提醒的人", "多选"] + any: ["所有人", "all"] + llm: + prompt: 该步是否成功进入了与任务指定群聊的聊天信息界面,并弹出了选择提醒的人窗口? + expected_true: true + next: [select_someone] + +- id: select_someone + name: 选择提醒的人 + score: 30 + condition: + type: juxtaposition + params: + ocr: + all: ["选择提醒的人", "多选"] + llm: + prompt: 该步是否成功找到需要提醒的群成员,准备或已经将他添加到提醒列表? + expected_true: true + next: [send_message] + +- id: send_message + name: 发送消息并验证 + score: 30 + condition: + type: juxtaposition + params: + ocr: + all: ["发送", "q", "w", "e"] + llm: + prompt: 请判断消息框中是否输入了任务指令(task_description)中要求的完整消息内容?请重点核对是否提到了正确的接收群组、正确的@对象以及完全一致的消息文本。 + expected_true: true +success: + any_of: [send_message] diff --git a/MobiFlow/task_rules/weixin/weixin-type7.yaml b/MobiFlow/task_rules/weixin/weixin-type7.yaml new file mode 100644 index 0000000..7914d38 --- /dev/null +++ b/MobiFlow/task_rules/weixin/weixin-type7.yaml @@ -0,0 +1,45 @@ +task_id: open_wechat_miniprogram +app_id: com.tencent.mm +task_type: miniprogram_launch +description: 验证在微信应用内成功打开指定的小程序。此配置通过检查微信的初始状态和目标小程序成功加载后的最终状态来确认任务完成。 +nodes: +- id: start_in_wechat + name: 任务起始于微信 + score: 30 + condition: + type: escalate + params: + ocr: + all: ["微信", "通讯录", "我"] + llm: + prompt: 当前屏幕是否为微信应用的主界面、搜索页或小程序列表页? + expected_true: true + next: [search] + +- id: search + name: 顶部小程序搜索框 + score: 30 + condition: + type: escalate + params: + ocr: + all: ["发现小程序", "搜索"] + llm: + prompt: 当前屏幕是否为微信应用搜索页或小程序列表页? + expected_true: true + next: [miniprogram_opened] + + +- id: miniprogram_opened + name: 目标小程序已打开 + score: 40 + condition: + type: juxtaposition + params: + ocr: + all:: ["搜索", "q", "w", "e"] + llm: + prompt: 请仔细核对任务描述`task_description`。当前节点是否已经成功展示或打开用户指定的小程序?列表中展示该小程序的名称或其标志性的主界面。 + expected_true: true +success: + any_of: [miniprogram_opened] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type1.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type1.yaml new file mode 100644 index 0000000..cd06ce0 --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type1.yaml @@ -0,0 +1,65 @@ +task_id: xiaohongshu_follow_blogger +app_id: com.xingin.xhs +task_type: social_interaction +description: 在小红书应用中,通过搜索找到指定博主并完成关注操作。 +nodes: +# - id: launch_app +# name: 启动小红书 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "推荐", "发现", "我", "消息", "购物"] +# llm: +# prompt: 当前界面是否为小红书应用的主界面或首页? +# expected_true: true +# next: [enter_search_page] +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [execute_search] + +- id: execute_search + name: 执行搜索并查看结果 + condition: + type: escalate + params: + ocr: + all: [ "用户", "商品", "地点"] + llm: + prompt: 界面是否展示了与任务描述中博主相关的搜索结果列表,并包含'用户'筛选标签? + expected_true: true + # next: [view_blogger_profile, confirm_follow_status] + next: [confirm_follow_status] + +# - id: view_blogger_profile +# name: 进入博主主页 +# condition: +# type: escalate +# params: +# ocr: +# any: ["关注", "粉丝", "获赞与收藏", "笔记", "瞬间", "主页"] +# llm: +# prompt: 当前界面是否为博主的个人主页,且包含一个可点击的'关注'按钮? +# expected_true: true +# next: [confirm_follow_status] + +- id: confirm_follow_status + name: 确认关注成功 + condition: + type: escalate + params: + ocr: + all: ["已关注", "粉丝"] + llm: + prompt: 界面是否明确显示用户已成功关注该博主(例如,按钮状态变为'已关注'或'私信')? + expected_true: true +success: + any_of: [confirm_follow_status] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type2.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type2.yaml new file mode 100644 index 0000000..77ef8f3 --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type2.yaml @@ -0,0 +1,54 @@ +task_id: xiaohongshu_find_blogger_homepage +app_id: com.xingin.xhs +task_type: search_and_navigate +description: 在小红书应用中,通过搜索找到指定博主并进入其个人主页。 +nodes: +# - id: launch_app +# name: 启动小红书 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "推荐", "关注", "发现", "购物"] +# llm: +# prompt: 当前界面是否为小红书的首页或主推荐流页面? +# expected_true: true +# next: [initiate_search] + +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [view_blogger_homepage] + +# - id: view_search_results +# name: 查看搜索结果 +# condition: +# type: escalate +# params: +# ocr: +# any: ["综合", "用户", "笔记", "商品", "筛选", "粉丝"] +# llm: +# prompt: 界面是否展示了关于任务描述中博主名称的搜索结果列表,其中应包含用户、笔记等分类? +# expected_true: true +# next: [view_blogger_homepage] + +- id: view_blogger_homepage + name: 进入博主个人主页 + condition: + type: escalate + params: + ocr: + all: ["关注", "粉丝", "获赞与收藏", "笔记", "私信"] + llm: + prompt: 当前界面是否为任务描述中指定博主的个人主页?请检查页面上是否有该博主的头像、昵称、关注/粉丝数、以及笔记列表等关键元素。 + expected_true: true + +success: + any_of: [view_blogger_homepage] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type3.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type3.yaml new file mode 100644 index 0000000..979774c --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type3.yaml @@ -0,0 +1,64 @@ +task_id: xiaohongshu_find_blogger_and_first_post +app_id: com.xingin.xhs +task_type: search_and_browse +description: 在小红书应用中,根据任务描述搜索指定博主,并打开其发布的第一篇笔记。 +nodes: +# - id: launch_app +# name: 启动小红书 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "推荐", "关注", "消息", "我", "商城"] +# llm: +# prompt: 当前界面是否为小红书应用的主界面或首页? +# expected_true: true +# next: [enter_search] +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看搜索结果 + condition: + type: escalate + params: + ocr: + all: [ "全部", "用户", "商品", "地点"] + llm: + prompt: 当前界面是否展示了包含用户、笔记等分类的搜索结果列表? + expected_true: true + next: [select_user_profile, view_first_post] + +- id: select_user_profile + name: 进入博主主页 + condition: + type: escalate + params: + ocr: + any: ["关注", "粉丝", "获赞与收藏", "笔记", "私信"] + llm: + prompt: 用户是否已成功从搜索结果中点击并进入了任务描述中指定博主的个人主页? + expected_true: true + next: [view_first_post] + +- id: view_first_post + name: 打开第一篇笔记 + condition: + type: escalate + params: + # ocr: + # any: ["评论", "点赞", "收藏", "分享", "关注", "作者"] + llm: + prompt: 用户是否已成功打开一篇笔记的详情页?请结合任务描述判断,这是否是目标博主发布的最近一篇笔记? + expected_true: true +success: + any_of: [view_first_post] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type4.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type4.yaml new file mode 100644 index 0000000..dbd1355 --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type4.yaml @@ -0,0 +1,65 @@ +task_id: xiaohongshu_search_content +app_id: com.xingin.xhs +task_type: search +description: 在小红书应用中,根据指定的主题搜索并查看相关内容或攻略。 +nodes: +# - id: launch_app +# name: 启动小红书应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "推荐", "关注", "商城", "我"] +# llm: +# prompt: 当前界面是否为小红书应用的首页或主界面? +# expected_true: true +# next: [enter_search_page] +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [view_search_results] + +# - id: input_keyword_and_search +# name: 输入关键词并执行搜索 +# condition: +# type: escalate +# params: +# ocr: +# any: ["搜索", "确认", "搜索历史", "猜你想搜"] +# llm: +# prompt: 用户是否在搜索框中输入了与任务描述(如'自驾游'、'租房技巧'等)相关的关键词,并点击了搜索按钮? +# expected_true: true +# next: [view_search_results] + +- id: view_search_results + name: 浏览搜索结果列表 + condition: + type: escalate + params: + ocr: + all: ["全部", "用户", "商品", "地点"] + llm: + prompt: 当前界面是否展示了与搜索关键词相关的图文或视频笔记列表? + expected_true: true + next: [view_content_details] + +- id: view_content_details + name: 查看内容详情 + condition: + type: escalate + params: + ocr: + any: ["关注", "说点什么"] + llm: + prompt: 用户是否已经从搜索结果列表中点击并进入了某一篇具体的笔记(图文或视频)的详情页面进行查看? + expected_true: true + +success: + any_of: [view_content_details, view_search_results] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type5.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type5.yaml new file mode 100644 index 0000000..bc17a47 --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type5.yaml @@ -0,0 +1,42 @@ +task_id: xiaohongshu_search_and_view_first_result +app_id: com.xingin.xhs +task_type: search_and_browse +description: 在小红书应用中,根据指定关键词进行搜索,并点击查看第一个搜索结果的详情内容。 +nodes: +- id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [execute_search] + +- id: execute_search + name: 查看搜索结果列表 + condition: + type: escalate + params: + ocr: + all: [ "用户", "商品", "地点", "问一问"] + llm: + prompt: 当前界面是否展示了与任务描述中关键词相关的图文或视频笔记列表? + expected_true: true + next: [view_first_result_detail] + +- id: view_first_result_detail + name: 查看第一个搜索结果详情 + condition: + type: escalate + params: + # ocr: + # any: ["关注", "评论", "点赞", "收藏", "分享", "作者", "相关推荐"] + llm: + prompt: 当前界面是否为一篇笔记详情页,展示正文内容、图片/视频、以及评论、点赞、收藏等互动按钮?这是否是用户点击搜索结果列表中的第一项后进入的页面? + expected_true: true + +success: + any_of: [view_first_result_detail] diff --git a/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type6.yaml b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type6.yaml new file mode 100644 index 0000000..aa89c85 --- /dev/null +++ b/MobiFlow/task_rules/xiaohongshu/xiaohongshu-type6.yaml @@ -0,0 +1,90 @@ +task_id: xiaohongshu_find_blogger_content +app_id: com.xingin.xhs +task_type: search_and_browse +description: 在小红书应用中,通过搜索或进入主页的方式,查找指定博主的特定内容并进行浏览。 + +nodes: + # - id: launch_app + # name: 启动小红书 + # condition: + # type: escalate + # params: + # ocr: + # any: ["首页", "推荐", "关注", "消息", "我", "商城"] + # llm: + # prompt: "当前界面是否为小红书应用的主界面(如推荐、首页)?" + # expected_true: true + # next: [enter_search_page] + + - id: enter_search_page + name: 进入搜索页面 + condition: + type: escalate + params: + ocr: + all: ["搜索", "历史记录", "猜你想搜"] + llm: + prompt: 用户是否已经进入了搜索输入界面? + expected_true: true + next: [enter_blogger_profile] + + # - id: input_search_keyword + # name: 输入搜索关键词并执行搜索 + # condition: + # type: escalate + # params: + # ocr: + # any: ["综合", "视频", "用户", "商品", "笔记", "搜索"] + # llm: + # prompt: "用户是否在搜索框中输入了任务描述中指定的博主名称或相关关键词,并看到了搜索结果页面?" + # expected_true: true + # next: [select_content_from_search_results, enter_blogger_profile] + + - id: enter_blogger_profile + name: 进入博主主页 + condition: + type: escalate + params: + ocr: + all: ["关注", "粉丝", "获赞与收藏", "笔记","私信"] + llm: + prompt: "用户是否从搜索结果中点击并成功进入了任务描述中指定的博主个人主页?" + expected_true: true + next: [use_profile_search] + # next: [browse_and_select_content_from_profile, use_profile_search] + + - id: use_profile_search + name: 在主页内搜索 + condition: + type: escalate + params: + ocr: + all: ["搜索","笔记"] + # 搜搜ta的商品/笔记 + llm: + prompt: "用户是否在博主主页内使用了搜索功能,并输入了任务描述中指定的内容关键词(如'新疆', '读书日记', '好物分享')?" + expected_true: true + next: [view_first_result_detail] + + # - id: select_content_from_search_results + # name: 从综合搜索结果中选择内容 + # condition: + # type: juxtaposition + # params: + # ocr: + # any: ["评论", "点赞", "收藏", "分享", "笔记详情", "相关推荐"] + # llm: + # prompt: "用户是否从全局搜索结果中,点击并打开了任务描述中 + - id: view_first_result_detail + name: 查看第一个搜索结果详情 + condition: + type: escalate + params: + # ocr: + # any: ["关注", "评论", "点赞", "收藏", "分享", "作者", "相关推荐"] + llm: + prompt: 当前界面是否为一篇笔记详情页,展示正文内容、图片/视频、以及评论、点赞、收藏等互动按钮?这是否是用户点击搜索结果列表中的第一项后进入的页面?(若搜索显示没有对应搜索结果则返回true) + expected_true: true + +success: + any_of: [view_first_result_detail] \ No newline at end of file diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type1.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type1.yaml new file mode 100644 index 0000000..67ea0d5 --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type1.yaml @@ -0,0 +1,52 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在携程应用中,通过搜索功能查询指定酒店的价格列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "我的", "行程", "消息", "酒店", "机票", "火车票"] +# llm: +# prompt: 请判断当前界面是否为携程应用的首页或主界面? +# expected_true: true +# next: [enter_hotel_module] +- id: enter_hotel_module + name: 进入酒店搜索页 + condition: + type: escalate + params: + ocr: + all: ["国内", "查询", "海外"] + llm: + prompt: 请判断当前界面是否为酒店搜索功能的主页面,包含目的地、入住日期、离店日期和关键词搜索框等元素? + expected_true: true + next: [search_result] + +- id: search_result + name: 查看酒店搜索结果 + condition: + type: escalate + params: + ocr: + any: ["位置", "距离", "价格", "筛选", "星级"] + llm: + prompt: 请判断当前是否正确搜索了和目标酒店相关的关键词,并展示正确搜索结果? + expected_true: true + # next: [view_search_results] + +# - id: view_search_results +# name: 查看酒店价格列表 +# condition: +# type: juxtaposition +# params: +# ocr: +# any: ["¥", "起", "预订", "价格", "评分", "筛选", "订", "元"] +# llm: +# prompt: 请判断当前界面是否成功展示了酒店的搜索结果列表?列表中应包含多个酒店选项,并明确显示价格信息(如'¥'符号或'元')和'预订'等操作按钮。 +# expected_true: true +success: + any_of: [search_result] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type2.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type2.yaml new file mode 100644 index 0000000..1d664f2 --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type2.yaml @@ -0,0 +1,53 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在携程应用中,根据用户指定的任意地标,查询并展示附近的酒店价格列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["首页", "我的", "行程", "消息", "酒店", "机票", "火车票"] +# llm: +# prompt: 请判断当前界面是否为携程应用的首页或主界面? +# expected_true: true +# next: [enter_hotel_module] +- id: enter_hotel_module + name: 进入酒店搜索页 + condition: + type: escalate + params: + ocr: + all: ["国内", "查询", "海外"] + llm: + prompt: 请判断当前界面是否为酒店搜索功能的主页面,包含目的地、入住日期、离店日期和关键词搜索框等元素? + expected_true: true + next: [search_result] + +- id: search_result + name: 查看酒店搜索结果 + condition: + type: escalate + params: + ocr: + any: ["位置", "距离", "价格", "筛选", "星级"] + llm: + prompt: 请判断当前是否正确搜索了和目标酒店相关的关键词,并展示正确搜索结果? + expected_true: true + next: [view_search_results] + +- id: view_search_results + name: 查看酒店详情 + condition: + type: juxtaposition + params: + ocr: + all: ["酒店", "设施", "政策", "问酒店"] + llm: + prompt: 请判断当前界面是否成功展示任务描述`{{task_description}}`目标的酒店?展示了正确的房型。 + expected_true: true + +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type3.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type3.yaml new file mode 100644 index 0000000..fc7c4ec --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type3.yaml @@ -0,0 +1,29 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: query +description: 在携程应用中查询指定城市和品牌的酒店及其价格 +nodes: +- id: start_search + name: 进入酒店搜索界面 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "关键词", "住/离店", "搜索"] + llm: + prompt: 当前界面是否为酒店搜索界面?用户可以在此界面输入城市、酒店名称或选择入住日期。 + expected_true: true + next: [view_search_results] +- id: view_search_results + name: 查看酒店搜索结果列表 + deps: [start_search] + condition: + type: juxtaposition + params: + ocr: + any: ["¥", "起", "价格", "评分", "筛选", "订", "每晚", "详情"] + llm: + prompt: 当前界面是否成功展示了符合用户查询意图(城市和酒店品牌)的酒店列表,并且列表中清晰地显示了酒店的价格信息(如'¥'符号或'xx元起')? + expected_true: true +success: + any_of: [view_search_results] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type4.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type4.yaml new file mode 100644 index 0000000..f3de99e --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type4.yaml @@ -0,0 +1,54 @@ +task_id: fliggy_query_hotel_price +app_id: com.taobao.trip +task_type: travel_query +description: 在携程应用中,根据用户指定的城市和地标,查询附近的酒店及其价格。 +nodes: +- id: launch_app + name: 启动携程应用 + condition: + type: escalate + params: + ocr: + any: ["酒店", "机票", "火车票", "我的", "首页", "去哪玩"] + llm: + prompt: 当前屏幕是否为携程应用的首页或主界面? + expected_true: true + next: [enter_hotel_module] +- id: enter_hotel_module + name: 进入酒店搜索模块 + deps: [launch_app] + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "入住", "离店", "关键词", "位置", "品牌"] + llm: + prompt: 用户是否已经进入了酒店搜索功能的界面,该界面包含目的地输入、日期选择等元素? + expected_true: true + next: [input_destination_and_dates] +- id: input_destination_and_dates + name: 输入目的地并选择日期 + deps: [enter_hotel_module] + condition: + type: escalate + params: + ocr: + any: ["选择目的地", "城市/地标", "入住", "离店", "选择日期", "日历", "确定", "完成"] + llm: + prompt: 用户是否正在输入目的地或在日历界面上选择入住和离店日期? + expected_true: true + next: [view_hotel_results] +- id: view_hotel_results + name: 查看酒店搜索结果列表 + deps: [input_destination_and_dates] + condition: + type: juxtaposition + params: + ocr: + all: ["价格", "筛选"] + any: ["¥", "起", "每晚", "评分", "综合排序", "查看详情", "订"] + llm: + prompt: 当前屏幕是否成功展示了符合任务要求的酒店列表,并且清晰地显示了酒店名称、价格(包含'¥'符号)等关键信息? + expected_true: true +success: + any_of: [view_hotel_results] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type5.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type5.yaml new file mode 100644 index 0000000..44f765b --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type5.yaml @@ -0,0 +1,72 @@ +task_id: fliggy_hotel_price_query +app_id: com.taobao.trip +task_type: query +description: 在携程应用中根据指定城市、日期和酒店/地标查询酒店价格 +nodes: +- id: launch_app + name: 启动携程应用 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "行程", "消息", "酒店", "机票"] + llm: + prompt: 当前页面是否为携程应用的首页或主界面? + expected_true: true + next: [navigate_to_hotel] +- id: navigate_to_hotel + name: 进入酒店搜索 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿·客栈", "目的地/酒店/关键词", "搜酒店"] + llm: + prompt: 当前页面是否为酒店搜索的入口页面? + expected_true: true + next: [input_destination] +- id: input_destination + name: 输入目的地或酒店名 + condition: + type: escalate + params: + ocr: + any: ["目的地", "酒店", "位置", "关键词", "输入"] + llm: + prompt: 用户是否在当前页面输入了任务描述中指定的目的地、酒店或地标信息? + expected_true: true + next: [select_dates] +- id: select_dates + name: 选择入住和离店日期 + condition: + type: juxtaposition + params: + ocr: + all: ["入住", "离店", "日历", "确定"] + llm: + prompt: 用户是否在当前日历页面上,成功选择了任务描述中指定的入住和离店日期? + expected_true: true + next: [confirm_search] +- id: confirm_search + name: 点击查询按钮 + condition: + type: escalate + params: + ocr: + any: ["查询", "搜索", "查找酒店", "搜酒店", "完成"] + llm: + prompt: 用户是否点击了查询或搜索按钮以查找符合条件的酒店? + expected_true: true + next: [view_results] +- id: view_results + name: 查看酒店价格列表 + condition: + type: juxtaposition + params: + ocr: + any: ["价格", "筛选", "排序", "推荐", "¥", "起", "每晚", "综合推荐"] + llm: + prompt: 当前页面是否成功展示了符合任务描述中地点和日期要求的酒店列表及其价格信息? + expected_true: true +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type6.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type6.yaml new file mode 100644 index 0000000..661259b --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type6.yaml @@ -0,0 +1,77 @@ +task_id: fliggy_hotel_search +app_id: com.taobao.trip +task_type: travel +description: 在携程应用中,根据指定的城市、地点、入住日期和晚数,查询酒店并查看结果列表。 +nodes: +- id: launch_app + name: 启动携程应用 + condition: + type: escalate + params: + ocr: + any: ["酒店", "机票", "火车票", "旅行", "我的", "首页"] + llm: + prompt: 当前界面是否为携程应用首页或主功能界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索 + deps: [launch_app] + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "客栈", "住宿", "目的地"] + llm: + prompt: 当前操作是否进入了酒店搜索的入口界面,准备输入目的地和日期? + expected_true: true + next: [input_destination] +- id: input_destination + name: 输入目的地或关键词 + deps: [enter_hotel_search] + condition: + type: escalate + params: + ocr: + any: ["目的地", "位置", "关键词", "酒店名", "城市", "输入"] + llm: + prompt: 用户是否在当前界面输入或选择了任务描述中指定的城市和具体位置/酒店品牌(如'武汉大学'、'亚朵酒店')? + expected_true: true + next: [select_dates] +- id: select_dates + name: 选择入住和离店日期 + deps: [input_destination] + condition: + type: juxtaposition + params: + ocr: + any: ["入住", "离店", "日期", "日历", "共", "晚", "确定", "完成"] + llm: + prompt: 用户是否在当前界面通过日历等方式,选择了符合任务描述要求的入住日期和住宿晚数(例如:三天后入住,住1晚)? + expected_true: true + next: [click_search] +- id: click_search + name: 点击查询按钮 + deps: [select_dates] + condition: + type: escalate + params: + ocr: + any: ["查询", "搜索", "查找酒店", "搜酒店"] + llm: + prompt: 用户在设置好目的地和日期后,是否点击了“查询”或“搜索”按钮来查找酒店? + expected_true: true + next: [view_results] +- id: view_results + name: 查看酒店列表结果 + deps: [click_search] + condition: + type: juxtaposition + params: + ocr: + any: ["筛选", "排序", "价格", "评分", "¥", "起", "每晚", "酒店列表", "综合推荐"] + llm: + prompt: 当前界面是否成功展示了符合查询条件的酒店列表,并且能清晰地看到多个酒店的名称、价格、评分等核心信息? + expected_true: true +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type7.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type7.yaml new file mode 100644 index 0000000..74159b9 --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type7.yaml @@ -0,0 +1,85 @@ +task_id: fliggy_select_hotel_room +app_id: com.taobao.trip +task_type: booking +description: 在携程应用中,根据特定条件(如价格、评价、距离、房型)筛选并选择酒店房间。 +nodes: +- id: launch_app + name: 启动携程应用 + condition: + type: escalate + params: + ocr: + any: ["携程", "首页", "我的", "酒店", "机票"] + llm: + prompt: 当前界面是否为携程应用的首页或主界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "目的地", "入住", "离店", "搜索酒店"] + llm: + prompt: 当前界面是否为酒店搜索页面,可以输入目的地和日期? + expected_true: true + next: [view_hotel_list] +- id: view_hotel_list + name: 查看酒店列表 + condition: + type: escalate + params: + ocr: + any: ["家酒店", "筛选", "排序", "地图", "综合推荐"] + llm: + prompt: 当前界面是否展示了酒店搜索结果列表? + expected_true: true + next: [apply_filter] + - select_hotel +- id: apply_filter + name: 应用筛选或排序 + condition: + type: escalate + params: + ocr: + any: ["筛选", "排序", "价格", "距离", "评分", "好评优先", "价格优先", "距离优先", "确定"] + llm: + prompt: 用户是否根据任务描述(例如'价格最低'、'评价最好'或'距离最近')执行了相应的筛选或排序操作? + expected_true: true + next: [select_hotel] +- id: select_hotel + name: 选择酒店 + condition: + type: escalate + params: + ocr: + any: ["酒店详情", "房型", "设施", "评价", "预订"] + llm: + prompt: 用户是否从列表中选择了一家酒店并进入了其详情页面? + expected_true: true + next: [select_room] +- id: select_room + name: 选择目标房型 + condition: + type: juxtaposition + params: + ocr: + any: ["预订", "订", "立即预订", "订这间", "选择"] + llm: + prompt: 用户是否准确选择了任务描述中要求的房型(如'大床房'或'双床房'),并且该选择符合任务的排序/筛选要求(如'价格最低'、'评价最好'或'距离最近')? + expected_true: true + next: [confirm_booking] +- id: confirm_booking + name: 确认订单 + condition: + type: escalate + params: + ocr: + any: ["提交订单", "确认订单", "去支付", "订单详情"] + llm: + prompt: 当前界面是否为填写预订信息或确认订单的页面? + expected_true: true +success: + any_of: [select_room] + - confirm_booking diff --git a/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type8.yaml b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type8.yaml new file mode 100644 index 0000000..ac70380 --- /dev/null +++ b/MobiFlow/task_rules/xiechen-jiudian/xiecheng_jd-type8.yaml @@ -0,0 +1,83 @@ +task_id: hotel_booking_task +app_id: com.taobao.trip +task_type: booking +description: 验证用户在携程等旅行应用中,根据指定要求(地点、日期、房型)成功预订酒店的关键流程。 +nodes: +- id: launch_app + name: 启动应用并进入首页 + condition: + type: escalate + params: + ocr: + any: ["首页", "我的", "酒店", "机票", "火车票", "去哪玩"] + llm: + prompt: 屏幕是否显示了旅行应用(如携程)的首页或主功能界面? + expected_true: true + next: [enter_hotel_search] +- id: enter_hotel_search + name: 进入酒店搜索模块 + condition: + type: escalate + params: + ocr: + any: ["酒店", "民宿", "客栈", "Hotel"] + llm: + prompt: 用户是否已经点击并进入了酒店搜索功能的主界面? + expected_true: true + next: [input_search_criteria] +- id: input_search_criteria + name: 输入或确认搜索条件 + condition: + type: escalate + params: + ocr: + any: ["目的地", "我的位置", "入住", "离店", "关键词", "搜索", "查找酒店"] + llm: + prompt: 用户是否在当前页面输入或确认了任务描述中的核心搜索信息,如目的地、入住/离店日期或酒店名称/位置关键词? + expected_true: true + next: [view_hotel_list] +- id: view_hotel_list + name: 查看酒店列表 + condition: + type: escalate + params: + ocr: + any: ["综合排序", "筛选", "价格", "推荐", "评分", "列表", "地图"] + llm: + prompt: 屏幕是否展示了符合搜索条件的酒店列表,以供用户选择? + expected_true: true + next: [select_room_to_book] +- id: select_room_to_book + name: 选择房型并预订 + condition: + type: escalate + params: + ocr: + any: ["预订", "订", "立即预订", "订这间", "选择房型", "查看详情"] + llm: + prompt: 用户是否已经从酒店详情页或列表页中选择了符合任务要求的房型(如大床房、双床房)并点击了预订按钮? + expected_true: true + next: [fill_booking_details] +- id: fill_booking_details + name: 填写预订信息 + condition: + type: juxtaposition + params: + ocr: + any: ["订单填写", "入住人", "联系手机", "确认信息", "费用明细", "到店付"] + llm: + prompt: 屏幕是否跳转到了预订信息填写页面,要求用户输入住客姓名、联系方式等信息? + expected_true: true + next: [confirm_booking] +- id: confirm_booking + name: 提交订单或支付 + condition: + type: juxtaposition + params: + ocr: + any: ["提交订单", "去支付", "确认", "订单详情", "在线付", "信用住"] + llm: + prompt: 用户是否已到达最终的订单确认或支付页面? + expected_true: true +success: + any_of: [confirm_booking] diff --git a/MobiFlow/task_rules/xiechen/xiechen-type1.yaml b/MobiFlow/task_rules/xiechen/xiechen-type1.yaml new file mode 100644 index 0000000..fcdb6a2 --- /dev/null +++ b/MobiFlow/task_rules/xiechen/xiechen-type1.yaml @@ -0,0 +1,68 @@ +task_id: ctrip_ticket_search +app_id: ctrip.android.view +task_type: search +description: 在携程应用中根据用户指令查询机票或火车票,并验证是否成功展示结果列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["机票", "酒店", "火车票", "旅游", "首页", "我的"] +# llm: +# prompt: 当前页面是否为携程应用的首页或主界面? +# expected_true: true +# next: [navigate_to_flight] + +- id: navigate_to_flight + name: 进入机票/火车票查询页面 + condition: + type: escalate + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否为任务要求的票(机票或火车票)查询界面? + # 请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? + expected_true: true + next: [view_results] + +# - id: navigate_to_train +# name: 进入火车票查询页面 +# condition: +# type: escalate +# params: +# ocr: +# any: ["火车票查询", "出发站", "到达站", "学生票", "只看高铁动车", "查询车票"] +# llm: +# prompt: 当前页面是否为火车票/高铁票查询界面?请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? +# expected_true: true +# next: [view_results] + +# 直接检查出发地、目的地、日期 +- id: check_departure_arrival + name: 检查出发地、目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否展示了任务`{task_description}`要求的正确的出发地、目的地?任一错误则判定为错误 + expected_true: true + next: [view_results] + +- id: view_results + name: 查看查询结果列表 + condition: + type: escalate + params: + ocr: + any: ["直飞优先", "时间排序", "价格排序", "经济舱", "商务", "一等", "二等", "出发最早","耗时最短"] + llm: + prompt: 当前页面是否成功展示了从任务描述(task_description)指定的出发地到目的地的机票或火车票的班次列表?列表中应包含价格、时间等关键信息。 + expected_true: true + +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen/xiechen-type2.yaml b/MobiFlow/task_rules/xiechen/xiechen-type2.yaml new file mode 100644 index 0000000..d73d592 --- /dev/null +++ b/MobiFlow/task_rules/xiechen/xiechen-type2.yaml @@ -0,0 +1,68 @@ +task_id: ctrip_ticket_search +app_id: ctrip.android.view +task_type: search +description: 在携程应用中根据用户指令查询指定日期的机票或火车票,并验证是否成功展示结果列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["机票", "酒店", "火车票", "旅游", "首页", "我的"] +# llm: +# prompt: 当前页面是否为携程应用的首页或主界面? +# expected_true: true +# next: [navigate_to_flight] + +- id: navigate_to_flight + name: 进入机票/火车票查询页面 + condition: + type: escalate + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否为任务要求的票(机票或火车票)查询界面? + # 请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? + expected_true: true + next: [view_results] + +# - id: navigate_to_train +# name: 进入火车票查询页面 +# condition: +# type: escalate +# params: +# ocr: +# any: ["火车票查询", "出发站", "到达站", "学生票", "只看高铁动车", "查询车票"] +# llm: +# prompt: 当前页面是否为火车票/高铁票查询界面?请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? +# expected_true: true +# next: [view_results] + +# 直接检查出发地、目的地、日期 +- id: check_departure_arrival + name: 检查出发地、目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否展示了任务`{task_description}`要求的正确的出发地、目的地和时间?任一错误则判定为错误 + expected_true: true + next: [view_results] + +- id: view_results + name: 查看查询结果列表 + condition: + type: escalate + params: + ocr: + any: ["直飞优先", "时间排序", "价格排序", "经济舱", "商务", "一等", "二等", "出发最早","耗时最短"] + llm: + prompt: 当前页面是否成功展示了从任务描述(task_description)指定的出发地到目的地、指定日期的机票或火车票的班次列表?列表中应包含价格、时间等关键信息。 + expected_true: true + +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen/xiechen-type3.yaml b/MobiFlow/task_rules/xiechen/xiechen-type3.yaml new file mode 100644 index 0000000..6acd82a --- /dev/null +++ b/MobiFlow/task_rules/xiechen/xiechen-type3.yaml @@ -0,0 +1,57 @@ +task_id: ctrip_ticket_search +app_id: ctrip.android.view +task_type: search +description: 在携程应用中根据用户指令查询指定日期的机票或火车票,并验证是否成功展示结果列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["机票", "酒店", "火车票", "旅游", "首页", "我的"] +# llm: +# prompt: 当前页面是否为携程应用的首页或主界面? +# expected_true: true +# next: [navigate_to_flight] + +- id: navigate_to_flight + name: 进入机票/火车票查询页面 + condition: + type: escalate + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否为任务要求的票机票/火车票查询界面? + # 请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? + expected_true: true + next: [check_departure_arrival] + + +# 直接检查出发地、目的地、日期 +- id: check_departure_arrival + name: 检查出发地、目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否展示了任务`{task_description}`要求的正确的出发地、目的地和时间?任一错误则判定为错误 + expected_true: true + next: [view_results] + +- id: view_results + name: 查看查询结果列表 + condition: + type: escalate + params: + ocr: + any: ["直飞优先", "时间排序", "价格排序", "经济舱", "商务", "一等", "二等", "出发最早","耗时最短"] + llm: + prompt: 当前页面是否成功展示了从任务描述(task_description)指定的**出发时间**范围机票或火车票的班次列表?如果图中没有满足时间条件的班次,判定为错误。 + expected_true: true + +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen/xiechen-type4.yaml b/MobiFlow/task_rules/xiechen/xiechen-type4.yaml new file mode 100644 index 0000000..734c0ac --- /dev/null +++ b/MobiFlow/task_rules/xiechen/xiechen-type4.yaml @@ -0,0 +1,57 @@ +task_id: ctrip_ticket_search +app_id: ctrip.android.view +task_type: search +description: 在携程应用中根据用户指令查询指定日期的机票或火车票,并验证是否成功展示结果列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["机票", "酒店", "火车票", "旅游", "首页", "我的"] +# llm: +# prompt: 当前页面是否为携程应用的首页或主界面? +# expected_true: true +# next: [navigate_to_flight] + +- id: navigate_to_flight + name: 进入机票/火车票查询页面 + condition: + type: escalate + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否为任务要求的票机票/火车票查询界面? + # 请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? + expected_true: true + next: [check_departure_arrival] + + +# 直接检查出发地、目的地、日期 +- id: check_departure_arrival + name: 检查出发地、目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否展示了任务`{task_description}`要求的正确的出发地、目的地和时间?任一错误则判定为错误 + expected_true: true + next: [view_results] + +- id: view_results + name: 查看查询结果列表 + condition: + type: escalate + params: + ocr: + any: ["直飞优先", "时间排序", "价格排序", "经济舱", "商务", "一等", "二等", "出发最早","耗时最短"] + llm: + prompt: 当前页面是否成功展示了任务描述(task_description)指定的**到达时间**范围的机票或火车票的班次列表?如果图中没有满足时间条件的班次,判定为错误。 + expected_true: true + +success: + any_of: [view_results] diff --git a/MobiFlow/task_rules/xiechen/xiechen-type5.yaml b/MobiFlow/task_rules/xiechen/xiechen-type5.yaml new file mode 100644 index 0000000..d73d592 --- /dev/null +++ b/MobiFlow/task_rules/xiechen/xiechen-type5.yaml @@ -0,0 +1,68 @@ +task_id: ctrip_ticket_search +app_id: ctrip.android.view +task_type: search +description: 在携程应用中根据用户指令查询指定日期的机票或火车票,并验证是否成功展示结果列表。 +nodes: +# - id: launch_app +# name: 启动携程应用 +# condition: +# type: escalate +# params: +# ocr: +# any: ["机票", "酒店", "火车票", "旅游", "首页", "我的"] +# llm: +# prompt: 当前页面是否为携程应用的首页或主界面? +# expected_true: true +# next: [navigate_to_flight] + +- id: navigate_to_flight + name: 进入机票/火车票查询页面 + condition: + type: escalate + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否为任务要求的票(机票或火车票)查询界面? + # 请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? + expected_true: true + next: [view_results] + +# - id: navigate_to_train +# name: 进入火车票查询页面 +# condition: +# type: escalate +# params: +# ocr: +# any: ["火车票查询", "出发站", "到达站", "学生票", "只看高铁动车", "查询车票"] +# llm: +# prompt: 当前页面是否为火车票/高铁票查询界面?请结合任务描述(task_description)判断,页面上是否已自动填充或展示了正确的出发地和目的地? +# expected_true: true +# next: [view_results] + +# 直接检查出发地、目的地、日期 +- id: check_departure_arrival + name: 检查出发地、目的地 + condition: + type: juxtaposition + params: + ocr: + all: ["单程", "往返", "查询"] + llm: + prompt: 当前页面是否展示了任务`{task_description}`要求的正确的出发地、目的地和时间?任一错误则判定为错误 + expected_true: true + next: [view_results] + +- id: view_results + name: 查看查询结果列表 + condition: + type: escalate + params: + ocr: + any: ["直飞优先", "时间排序", "价格排序", "经济舱", "商务", "一等", "二等", "出发最早","耗时最短"] + llm: + prompt: 当前页面是否成功展示了从任务描述(task_description)指定的出发地到目的地、指定日期的机票或火车票的班次列表?列表中应包含价格、时间等关键信息。 + expected_true: true + +success: + any_of: [view_results] diff --git a/MobiFlow/test_model_connectivity.py b/MobiFlow/test_model_connectivity.py new file mode 100644 index 0000000..6c2360a --- /dev/null +++ b/MobiFlow/test_model_connectivity.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +""" +模型连通性测试程序 +用于测试配置中的API、MODEL、BASE_URL的可用性和连通性 +""" + +import sys +import time +import requests +import openai +from typing import Dict, Any, Optional, Tuple +import logging + +# 导入配置 +try: + from llmconfig import API_KEY, BASE_URL, MODEL +except ImportError: + print("错误: 无法导入llmconfig配置文件") + sys.exit(1) + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +class ModelConnectivityTester: + """模型连通性测试器""" + + def __init__(self, api_key: str, base_url: str, model: str): + """ + 初始化测试器 + + Args: + api_key: API密钥 + base_url: API基础URL + model: 模型名称 + """ + self.api_key = api_key + self.base_url = base_url + self.model = model + self.client = None + + # 验证配置 + self._validate_config() + + # 创建OpenAI客户端 + if self.api_key and self.base_url: + try: + self.client = openai.OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + except Exception as e: + logger.error(f"创建OpenAI客户端失败: {e}") + + def _validate_config(self) -> None: + """验证配置完整性""" + logger.info("=== 配置验证 ===") + + if not self.api_key: + logger.error("❌ API_KEY 未配置") + else: + logger.info(f"✅ API_KEY: {self.api_key[:10]}...{self.api_key[-4:]}") + + if not self.base_url: + logger.error("❌ BASE_URL 未配置") + else: + logger.info(f"✅ BASE_URL: {self.base_url}") + + if not self.model: + logger.error("❌ MODEL 未配置") + else: + logger.info(f"✅ MODEL: {self.model}") + + if not all([self.api_key, self.base_url, self.model]): + logger.error("配置不完整,无法进行测试") + sys.exit(1) + + def test_basic_connectivity(self) -> Tuple[bool, str]: + """ + 测试基础连通性(网络连接) + + Returns: + (是否成功, 详细信息) + """ + logger.info("=== 基础连通性测试 ===") + + try: + # 移除/v1后缀进行基础连接测试 + test_url = self.base_url.rstrip('/v1').rstrip('/') + + response = requests.get( + test_url, + timeout=10, + headers={'User-Agent': 'ModelConnectivityTester/1.0'} + ) + + logger.info(f"✅ 网络连接正常 (状态码: {response.status_code})") + return True, f"连接成功,状态码: {response.status_code}" + + except requests.exceptions.ConnectionError as e: + error_msg = f"连接失败: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + except requests.exceptions.Timeout as e: + error_msg = f"连接超时: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + except Exception as e: + error_msg = f"网络测试异常: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + def test_api_endpoint(self) -> Tuple[bool, str]: + """ + 测试API端点可用性 + + Returns: + (是否成功, 详细信息) + """ + logger.info("=== API端点测试 ===") + + try: + # 测试模型列表端点 + models_url = f"{self.base_url.rstrip('/')}/models" + + response = requests.get( + models_url, + headers={ + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json' + }, + timeout=15 + ) + + if response.status_code == 200: + data = response.json() + if 'data' in data: + model_count = len(data['data']) + logger.info(f"✅ API端点可用,发现 {model_count} 个模型") + + # 检查目标模型是否可用 + available_models = [model['id'] for model in data['data']] + if self.model in available_models: + logger.info(f"✅ 目标模型 '{self.model}' 可用") + return True, f"API端点可用,目标模型可用" + else: + logger.warning(f"⚠️ 目标模型 '{self.model}' 不在可用模型列表中") + logger.info(f"可用模型: {', '.join(available_models[:5])}...") + return True, f"API端点可用,但目标模型可能不可用" + else: + logger.info("✅ API端点响应正常,但格式可能不同") + return True, "API端点可用" + else: + error_msg = f"API端点返回错误状态码: {response.status_code}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + except Exception as e: + error_msg = f"API端点测试失败: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + def test_model_inference(self) -> Tuple[bool, str]: + """ + 测试模型推理能力 + + Returns: + (是否成功, 详细信息) + """ + logger.info("=== 模型推理测试 ===") + + if not self.client: + error_msg = "OpenAI客户端未初始化" + logger.error(f"❌ {error_msg}") + return False, error_msg + + try: + # 发送简单的测试请求 + test_message = "请回复'测试成功'四个字" + + start_time = time.time() + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": test_message + } + ], + max_tokens=50, + temperature=0.1, + timeout=30 + ) + + response_time = time.time() - start_time + + if response and response.choices: + response_text = response.choices[0].message.content + logger.info(f"✅ 模型推理成功 (响应时间: {response_time:.2f}秒)") + logger.info(f"模型响应: {response_text}") + + return True, f"推理成功,响应时间: {response_time:.2f}秒" + else: + error_msg = "模型返回空响应" + logger.error(f"❌ {error_msg}") + return False, error_msg + + except openai.AuthenticationError as e: + error_msg = f"认证失败: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + except openai.APIError as e: + error_msg = f"API错误: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + except Exception as e: + error_msg = f"模型推理测试失败: {str(e)}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + def run_comprehensive_test(self) -> Dict[str, Any]: + """ + 运行综合测试 + + Returns: + 测试结果字典 + """ + logger.info("开始模型连通性综合测试...") + logger.info("=" * 50) + + results = { + 'config_valid': True, + 'basic_connectivity': False, + 'api_endpoint': False, + 'model_inference': False, + 'overall_status': 'FAILED', + 'details': {} + } + + # 1. 基础连通性测试 + success, detail = self.test_basic_connectivity() + results['basic_connectivity'] = success + results['details']['basic_connectivity'] = detail + + if not success: + logger.error("基础连通性测试失败,停止后续测试") + return results + + # 2. API端点测试 + success, detail = self.test_api_endpoint() + results['api_endpoint'] = success + results['details']['api_endpoint'] = detail + + if not success: + logger.warning("API端点测试失败,但继续进行模型推理测试") + + # 3. 模型推理测试 + success, detail = self.test_model_inference() + results['model_inference'] = success + results['details']['model_inference'] = detail + + # 4. 总体评估 + if results['model_inference']: + results['overall_status'] = 'SUCCESS' + elif results['api_endpoint']: + results['overall_status'] = 'PARTIAL' + else: + results['overall_status'] = 'FAILED' + + return results + + def print_summary(self, results: Dict[str, Any]) -> None: + """打印测试结果摘要""" + logger.info("=" * 50) + logger.info("=== 测试结果摘要 ===") + + status_map = { + 'SUCCESS': '✅ 全部通过', + 'PARTIAL': '⚠️ 部分通过', + 'FAILED': '❌ 测试失败' + } + + logger.info(f"总体状态: {status_map.get(results['overall_status'], '未知')}") + logger.info(f"配置验证: {'✅' if results['config_valid'] else '❌'}") + logger.info(f"基础连通性: {'✅' if results['basic_connectivity'] else '❌'}") + logger.info(f"API端点: {'✅' if results['api_endpoint'] else '❌'}") + logger.info(f"模型推理: {'✅' if results['model_inference'] else '❌'}") + + logger.info("\n详细信息:") + for test_name, detail in results['details'].items(): + logger.info(f" {test_name}: {detail}") + + logger.info("=" * 50) + + +def main(): + """主函数""" + print("模型连通性测试程序") + print("=" * 50) + + # 创建测试器 + tester = ModelConnectivityTester( + api_key=API_KEY, + base_url=BASE_URL, + model=MODEL + ) + + # 运行测试 + results = tester.run_comprehensive_test() + + # 打印摘要 + tester.print_summary(results) + + # 设置退出码 + if results['overall_status'] == 'SUCCESS': + sys.exit(0) + elif results['overall_status'] == 'PARTIAL': + sys.exit(1) + else: + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/MobiFlow/tools/Icon_detection/README.md b/MobiFlow/tools/Icon_detection/README.md new file mode 100644 index 0000000..5bb691e --- /dev/null +++ b/MobiFlow/tools/Icon_detection/README.md @@ -0,0 +1,222 @@ +# 图标检测工具 + +## 概述 + +这是一个基于OpenCV模板匹配的图标检测工具,用于在手机应用截图中检测UI图标的存在。该工具已完全集成到验证框架的条件检查系统中。 + +## 功能特性 + +- ✅ **多尺度模板匹配**:支持不同尺寸的图标检测 +- ✅ **相似度阈值控制**:可调节检测精度 +- ✅ **批量检测**:支持同时检测多个图标 +- ✅ **路径智能解析**:根据应用ID自动查找图标文件 +- ✅ **非极大值抑制**:去除重复检测结果 +- ✅ **条件系统集成**:与escalate和juxtaposition检查器无缝集成 +- ✅ **灵活配置**:支持any/all匹配模式 + +## 安装和配置 + +### 依赖项 + +```bash +pip install opencv-python numpy +``` + +### 图标资源 + +图标文件应放置在以下目录结构中: + +``` +task_configs/icons/ +├── weixin/ +│ ├── icon_001_通讯录.jpg +│ ├── icon_002_微信.jpg +│ └── icon_000_我.jpg +├── bilibili/ +│ └── ... +└── xiecheng/ + └── ... +``` + +## 在YAML配置中使用 + +### 基本配置 + +```yaml +# escalate模式:按优先级依次尝试,任意一个成功即返回True +condition: + type: escalate + params: + icons: + any: ["icon_001_通讯录", "icon_002_微信"] # 匹配任意一个图标 + ocr: + all: ["微信", "通讯录"] + llm: + prompt: "当前页面是否为微信主界面?" + expected_true: true + +# juxtaposition模式:要求所有检查器都成功 +condition: + type: juxtaposition + params: + icons: + all: ["icon_001_回车", "icon_002_发送"] # 必须匹配所有图标 + ocr: + any: ["发送"] +``` + +### 高级配置 + +```yaml +condition: + type: escalate + params: + icons: + any: ["icon_001_通讯录", "icon_002_微信"] + threshold: 0.85 # 自定义相似度阈值(可选) +``` + +## 匹配模式 + +### any模式 +- 列表中任意一个图标匹配成功即认为条件满足 +- 适用于多个可能的界面状态 + +### all模式 +- 要求列表中所有图标都必须匹配成功 +- 适用于确认特定界面元素的完整性 + +## API使用 + +### 直接使用图标检测服务 + +```python +from tools.Icon_detection import get_icon_detection_service, detect_icons_simple + +# 获取服务实例 +service = get_icon_detection_service() + +# 检测单个图标 +result = detect_icons_simple( + image_array, # numpy数组或文件路径 + ["icon_001_通讯录"], + app_id="com.tencent.mm" +) + +# 获取详细结果 +detailed_result = service.detect_icons( + image_array, + ["icon_001_通讯录", "icon_002_微信"], + app_id="com.tencent.mm", + match_mode='any' +) +``` + +### 在条件检查器中使用 + +```python +from avdag.conditions import get_checker + +icons_checker = get_checker("icons_match") + +frame = { + 'screenshot': image_array, # 或文件路径 + 'app_id': 'com.tencent.mm' +} + +params = { + "any": ["icon_001_通讯录", "icon_002_微信"] +} + +result = icons_checker.check(frame, params, options) +``` + +## 配置参数 + +### 全局配置 + +```python +from tools.Icon_detection import IconDetectionConfig, set_default_config + +config = IconDetectionConfig( + default_threshold=0.8, # 默认相似度阈值 + scale_range=(0.5, 2.0), # 缩放范围 + scale_step=0.1, # 缩放步长 + nms_threshold=0.3 # 非极大值抑制阈值 +) + +set_default_config(config) +``` + +### 条件参数 + +- `any`: 图标名称列表,匹配任意一个 +- `all`: 图标名称列表,必须匹配所有 +- `threshold`: 相似度阈值(可选),覆盖默认值 + +## 测试 + +运行测试脚本验证功能: + +```bash +# 基础功能测试 +python tools/Icon_detection/test_icon_detection.py + +# 集成测试 +python tools/Icon_detection/test_integration.py +``` + +## 工作原理 + +1. **模板加载**:从配置路径加载图标模板文件 +2. **多尺度匹配**:对模板进行多种尺寸缩放,与目标图像匹配 +3. **相似度计算**:使用OpenCV的TM_CCOEFF_NORMED方法计算相似度 +4. **结果筛选**:根据阈值过滤结果,应用非极大值抑制去重 +5. **模式判断**:根据any/all模式决定最终结果 + +## 扩展性 + +该工具设计为模块化架构,可以轻松扩展: + +- **替换检测算法**:可以替换为SIFT、ORB或深度学习检测器 +- **增加图标类型**:支持添加新的应用图标资源 +- **自定义路径解析**:可以自定义图标文件查找规则 +- **结果后处理**:可以添加自定义的结果过滤和排序逻辑 + +## 性能优化 + +- **图标缓存**:已加载的图标模板会被缓存,避免重复读取 +- **早期终止**:escalate模式下,一旦匹配成功立即返回 +- **尺寸预检查**:避免处理过大的缩放模板 +- **并行处理**:支持批量检测多个图标 + +## 故障排除 + +### 常见问题 + +1. **图标检测失败** + - 检查图标文件是否存在于正确路径 + - 调整相似度阈值(降低threshold值) + - 确认图像质量和分辨率 + +2. **路径解析错误** + - 验证app_id与目录名称的映射 + - 检查图标文件扩展名是否支持(png、jpg、jpeg、bmp) + +3. **性能问题** + - 减少缩放范围或增大缩放步长 + - 使用更高的相似度阈值 + - 清空图标模板缓存 + +### 调试日志 + +启用DEBUG级别日志查看详细执行信息: + +```python +import logging +logging.getLogger('avdag.condition').setLevel(logging.DEBUG) +``` + +## 示例应用 + +参考 `task_rules/weixin/weixin-type1.yaml` 查看完整的配置示例。 diff --git a/MobiFlow/tools/Icon_detection/__init__.py b/MobiFlow/tools/Icon_detection/__init__.py new file mode 100644 index 0000000..14cbaeb --- /dev/null +++ b/MobiFlow/tools/Icon_detection/__init__.py @@ -0,0 +1,33 @@ +""" +图标检测工具包 +提供基于OpenCV模板匹配的图标检测功能 +""" + +from .icon_detector import IconDetector, IconPathResolver +from .config import IconDetectionConfig, get_default_config, set_default_config +from .icon_detection import ( + IconDetectionService, + get_icon_detection_service, + detect_icons_simple, + detect_single_icon +) + +__version__ = "1.0.0" +__all__ = [ + # 核心检测器 + 'IconDetector', + 'IconPathResolver', + + # 配置管理 + 'IconDetectionConfig', + 'get_default_config', + 'set_default_config', + + # 服务接口 + 'IconDetectionService', + 'get_icon_detection_service', + + # 简化接口 + 'detect_icons_simple', + 'detect_single_icon', +] diff --git a/MobiFlow/tools/Icon_detection/config.py b/MobiFlow/tools/Icon_detection/config.py new file mode 100644 index 0000000..8f496a7 --- /dev/null +++ b/MobiFlow/tools/Icon_detection/config.py @@ -0,0 +1,125 @@ +""" +图标检测配置管理器 +""" + +import os +from typing import Dict, List, Optional, Any +from pathlib import Path +import logging + +logger = logging.getLogger(__name__) + + +class IconDetectionConfig: + """图标检测配置类""" + + def __init__(self, + icon_base_paths: Optional[List[str]] = None, + default_threshold: float = 0.8, + scale_range: tuple = (0.5, 2.0), + scale_step: float = 0.1, + nms_threshold: float = 0.3): + """ + 初始化图标检测配置 + + Args: + icon_base_paths: 图标基础搜索路径列表 + default_threshold: 默认相似度阈值 + scale_range: 缩放范围 + scale_step: 缩放步长 + nms_threshold: 非极大值抑制阈值 + """ + # 设置默认图标路径 + if icon_base_paths is None: + # 尝试从环境变量或项目结构推断 + project_root = self._find_project_root() + icon_base_paths = [ + os.path.join(project_root, 'task_configs', 'icons'), + ] + + self.icon_base_paths = [Path(p) for p in icon_base_paths if os.path.exists(p)] + self.default_threshold = default_threshold + self.scale_range = scale_range + self.scale_step = scale_step + self.nms_threshold = nms_threshold + + # 验证配置 + self._validate_config() + + def _find_project_root(self) -> str: + """查找项目根目录""" + current_dir = Path(__file__).parent + + # 向上查找,直到找到包含特定标识文件的目录 + markers = ['pyproject.toml', 'requirements.txt', '.git'] + + for _ in range(10): # 最多向上查找10级 + for marker in markers: + if (current_dir / marker).exists(): + return str(current_dir) + current_dir = current_dir.parent + + # 如果找不到,返回当前目录的上两级(假设在tools/Icon_detection中) + return str(Path(__file__).parent.parent.parent) + + def _validate_config(self): + """验证配置有效性""" + if not self.icon_base_paths: + logger.warning("未找到有效的图标路径") + + if not (0.0 <= self.default_threshold <= 1.0): + raise ValueError(f"默认阈值必须在0-1之间,当前值: {self.default_threshold}") + + if self.scale_range[0] <= 0 or self.scale_range[1] <= self.scale_range[0]: + raise ValueError(f"无效的缩放范围: {self.scale_range}") + + if self.scale_step <= 0: + raise ValueError(f"缩放步长必须大于0,当前值: {self.scale_step}") + + def get_icon_paths(self) -> List[str]: + """获取所有图标搜索路径""" + return [str(p) for p in self.icon_base_paths] + + def add_icon_path(self, path: str): + """添加新的图标搜索路径""" + path_obj = Path(path) + if path_obj.exists() and path_obj not in self.icon_base_paths: + self.icon_base_paths.append(path_obj) + logger.info(f"添加图标搜索路径: {path}") + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + 'icon_base_paths': [str(p) for p in self.icon_base_paths], + 'default_threshold': self.default_threshold, + 'scale_range': self.scale_range, + 'scale_step': self.scale_step, + 'nms_threshold': self.nms_threshold + } + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> 'IconDetectionConfig': + """从字典创建配置实例""" + return cls( + icon_base_paths=config_dict.get('icon_base_paths'), + default_threshold=config_dict.get('default_threshold', 0.8), + scale_range=tuple(config_dict.get('scale_range', (0.5, 2.0))), + scale_step=config_dict.get('scale_step', 0.1), + nms_threshold=config_dict.get('nms_threshold', 0.3) + ) + + +# 全局默认配置实例 +_default_config = None + +def get_default_config() -> IconDetectionConfig: + """获取默认配置实例""" + global _default_config + if _default_config is None: + _default_config = IconDetectionConfig() + return _default_config + +def set_default_config(config: IconDetectionConfig): + """设置默认配置""" + global _default_config + _default_config = config diff --git a/MobiFlow/tools/Icon_detection/icon_detection.py b/MobiFlow/tools/Icon_detection/icon_detection.py new file mode 100644 index 0000000..60bf9f2 --- /dev/null +++ b/MobiFlow/tools/Icon_detection/icon_detection.py @@ -0,0 +1,236 @@ +""" +图标检测主接口模块 +提供简单易用的图标检测接口 +""" + +import cv2 +import numpy as np +import logging +from typing import List, Dict, Optional, Union, Any +import os + +from .icon_detector import IconDetector, IconPathResolver +from .config import IconDetectionConfig, get_default_config + +logger = logging.getLogger(__name__) + + +class IconDetectionService: + """图标检测服务类,提供高级接口""" + + def __init__(self, config: Optional[IconDetectionConfig] = None): + """ + 初始化图标检测服务 + + Args: + config: 图标检测配置,为None时使用默认配置 + """ + self.config = config or get_default_config() + self.detector = IconDetector( + default_threshold=self.config.default_threshold, + scale_range=self.config.scale_range, + scale_step=self.config.scale_step + ) + self.path_resolver = IconPathResolver(self.config.get_icon_paths()) + + def detect_icons(self, + image: Union[np.ndarray, str], + icon_names: List[str], + app_id: Optional[str] = None, + threshold: Optional[float] = None, + match_mode: str = 'any') -> Dict[str, Any]: + """ + 检测图像中的图标 + + Args: + image: 目标图像(numpy数组或文件路径) + icon_names: 要检测的图标名称列表 + app_id: 应用ID,用于确定图标搜索路径 + threshold: 相似度阈值 + match_mode: 匹配模式,'any'表示匹配任意一个,'all'表示必须匹配所有 + + Returns: + 检测结果字典,包含成功状态、匹配的图标、详细结果等 + """ + logger.debug(f"开始图标检测,图标列表: {icon_names}, 匹配模式: {match_mode}") + + result = { + 'success': False, + 'matched_icons': [], + 'unmatched_icons': [], + 'details': {}, + 'total_matches': 0, + 'match_mode': match_mode + } + + # 预处理图像 + processed_image = self._preprocess_image(image) + if processed_image is None: + result['error'] = '无法处理输入图像' + return result + + # 逐个检测图标 + for icon_name in icon_names: + # 解析图标路径 + icon_path = self.path_resolver.resolve_icon_path(icon_name, app_id) + if icon_path is None: + result['unmatched_icons'].append(icon_name) + result['details'][icon_name] = { + 'found': False, + 'error': '图标文件未找到', + 'matches': [] + } + continue + + # 执行检测 + matches = self.detector.detect_icon( + processed_image, + icon_path, + threshold + ) + + # 记录结果 + is_found = len(matches) > 0 + result['details'][icon_name] = { + 'found': is_found, + 'icon_path': icon_path, + 'matches': matches, + 'match_count': len(matches) + } + + if is_found: + result['matched_icons'].append(icon_name) + result['total_matches'] += len(matches) + logger.debug(f"图标 {icon_name} 检测到 {len(matches)} 个匹配") + else: + result['unmatched_icons'].append(icon_name) + logger.debug(f"图标 {icon_name} 未检测到") + + # 根据匹配模式判断成功状态 + if match_mode == 'any': + result['success'] = len(result['matched_icons']) > 0 + elif match_mode == 'all': + result['success'] = len(result['unmatched_icons']) == 0 + else: + logger.warning(f"未知的匹配模式: {match_mode}") + result['success'] = False + + logger.debug(f"图标检测完成,成功: {result['success']}, " + f"匹配: {len(result['matched_icons'])}, " + f"未匹配: {len(result['unmatched_icons'])}") + + return result + + def _preprocess_image(self, image: Union[np.ndarray, str]) -> Optional[np.ndarray]: + """ + 预处理图像 + + Args: + image: 输入图像 + + Returns: + 处理后的灰度图像,失败返回None + """ + try: + if isinstance(image, str): + if not os.path.exists(image): + logger.error(f"图像文件不存在: {image}") + return None + img = cv2.imread(image) + if img is None: + logger.error(f"无法读取图像文件: {image}") + return None + else: + img = image.copy() + + # 转换为灰度图 + if len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + return img + + except Exception as e: + logger.error(f"图像预处理失败: {e}") + return None + + def get_available_icons(self, app_id: Optional[str] = None) -> List[str]: + """ + 获取可用的图标列表 + + Args: + app_id: 应用ID + + Returns: + 可用图标名称列表 + """ + return self.path_resolver.list_available_icons(app_id) + + def validate_icons(self, icon_names: List[str], app_id: Optional[str] = None) -> Dict[str, bool]: + """ + 验证图标是否存在 + + Args: + icon_names: 图标名称列表 + app_id: 应用ID + + Returns: + 图标名称到存在状态的映射 + """ + result = {} + for icon_name in icon_names: + icon_path = self.path_resolver.resolve_icon_path(icon_name, app_id) + result[icon_name] = icon_path is not None + return result + + +# 全局服务实例 +_default_service = None + +def get_icon_detection_service(config: Optional[IconDetectionConfig] = None) -> IconDetectionService: + """获取图标检测服务实例""" + global _default_service + if _default_service is None or config is not None: + _default_service = IconDetectionService(config) + return _default_service + + +def detect_icons_simple(image: Union[np.ndarray, str], + icon_names: List[str], + app_id: Optional[str] = None, + threshold: Optional[float] = None, + match_mode: str = 'any') -> bool: + """ + 简化的图标检测接口 + + Args: + image: 目标图像 + icon_names: 图标名称列表 + app_id: 应用ID + threshold: 相似度阈值 + match_mode: 匹配模式 ('any' 或 'all') + + Returns: + 检测是否成功 + """ + service = get_icon_detection_service() + result = service.detect_icons(image, icon_names, app_id, threshold, match_mode) + return result['success'] + + +def detect_single_icon(image: Union[np.ndarray, str], + icon_name: str, + app_id: Optional[str] = None, + threshold: Optional[float] = None) -> bool: + """ + 检测单个图标 + + Args: + image: 目标图像 + icon_name: 图标名称 + app_id: 应用ID + threshold: 相似度阈值 + + Returns: + 是否检测到图标 + """ + return detect_icons_simple(image, [icon_name], app_id, threshold, 'any') diff --git a/MobiFlow/tools/Icon_detection/icon_detector.py b/MobiFlow/tools/Icon_detection/icon_detector.py new file mode 100644 index 0000000..053de05 --- /dev/null +++ b/MobiFlow/tools/Icon_detection/icon_detector.py @@ -0,0 +1,355 @@ +""" +图标检测器模块 +使用OpenCV模板匹配实现图标检测,支持多尺度匹配和相似度阈值 +""" + +import cv2 +import numpy as np +import logging +from typing import List, Dict, Tuple, Optional, Union +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class IconDetector: + """ + 基于OpenCV模板匹配的图标检测器 + 支持多尺度匹配和可配置的相似度阈值 + """ + + def __init__(self, + default_threshold: float = 0.8, + scale_range: Tuple[float, float] = (0.5, 2.0), + scale_step: float = 0.1, + method: int = cv2.TM_CCOEFF_NORMED): + """ + 初始化图标检测器 + + Args: + default_threshold: 默认相似度阈值 + scale_range: 缩放范围 (min_scale, max_scale) + scale_step: 缩放步长 + method: OpenCV模板匹配方法 + """ + self.default_threshold = default_threshold + self.scale_range = scale_range + self.scale_step = scale_step + self.method = method + self._icon_cache = {} # 缓存加载的图标模板 + + def load_icon_template(self, icon_path: str) -> Optional[np.ndarray]: + """ + 加载图标模板 + + Args: + icon_path: 图标文件路径 + + Returns: + 图标模板的numpy数组,加载失败返回None + """ + if icon_path in self._icon_cache: + return self._icon_cache[icon_path] + + if not os.path.exists(icon_path): + logger.warning(f"图标文件不存在: {icon_path}") + return None + + try: + # 读取图标并转为灰度图 + template = cv2.imread(icon_path, cv2.IMREAD_GRAYSCALE) + if template is None: + logger.warning(f"无法读取图标文件: {icon_path}") + return None + + # 缓存模板 + self._icon_cache[icon_path] = template + logger.debug(f"成功加载图标模板: {icon_path}, 尺寸: {template.shape}") + return template + + except Exception as e: + logger.error(f"加载图标模板失败 {icon_path}: {e}") + return None + + def match_template_multiscale(self, + image: np.ndarray, + template: np.ndarray, + threshold: float) -> List[Dict]: + """ + 多尺度模板匹配 + + Args: + image: 目标图像(灰度图) + template: 模板图像(灰度图) + threshold: 相似度阈值 + + Returns: + 匹配结果列表,每个结果包含位置、尺度、相似度等信息 + """ + matches = [] + h, w = template.shape + + # 生成缩放比例列表 + scales = np.arange(self.scale_range[0], self.scale_range[1] + self.scale_step, self.scale_step) + + for scale in scales: + # 缩放模板 + scaled_w = int(w * scale) + scaled_h = int(h * scale) + + # 如果缩放后的模板大于图像,跳过 + if scaled_w > image.shape[1] or scaled_h > image.shape[0]: + continue + + scaled_template = cv2.resize(template, (scaled_w, scaled_h)) + + # 模板匹配 + result = cv2.matchTemplate(image, scaled_template, self.method) + + # 查找满足阈值的匹配点 + locations = np.where(result >= threshold) + + for pt in zip(*locations[::-1]): # 转换为(x, y)格式 + similarity = result[pt[1], pt[0]] + matches.append({ + 'position': pt, + 'scale': scale, + 'similarity': float(similarity), + 'bbox': (pt[0], pt[1], scaled_w, scaled_h), # (x, y, w, h) + 'template_size': (w, h) + }) + + # 按相似度排序 + matches.sort(key=lambda x: x['similarity'], reverse=True) + return matches + + def non_maximum_suppression(self, matches: List[Dict], overlap_threshold: float = 0.3) -> List[Dict]: + """ + 非极大值抑制,去除重叠的检测框 + + Args: + matches: 匹配结果列表 + overlap_threshold: 重叠阈值 + + Returns: + 去重后的匹配结果 + """ + if not matches: + return [] + + # 按相似度排序 + matches = sorted(matches, key=lambda x: x['similarity'], reverse=True) + + selected = [] + + for current in matches: + x1, y1, w1, h1 = current['bbox'] + + # 检查与已选择的框是否重叠 + is_overlap = False + for selected_match in selected: + x2, y2, w2, h2 = selected_match['bbox'] + + # 计算重叠区域 + overlap_x1 = max(x1, x2) + overlap_y1 = max(y1, y2) + overlap_x2 = min(x1 + w1, x2 + w2) + overlap_y2 = min(y1 + h1, y2 + h2) + + if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2: + overlap_area = (overlap_x2 - overlap_x1) * (overlap_y2 - overlap_y1) + area1 = w1 * h1 + area2 = w2 * h2 + union_area = area1 + area2 - overlap_area + + iou = overlap_area / union_area if union_area > 0 else 0 + + if iou > overlap_threshold: + is_overlap = True + break + + if not is_overlap: + selected.append(current) + + return selected + + def detect_icon(self, + image: Union[np.ndarray, str], + icon_path: str, + threshold: Optional[float] = None, + apply_nms: bool = True) -> List[Dict]: + """ + 在图像中检测指定图标 + + Args: + image: 目标图像(numpy数组或文件路径) + icon_path: 图标模板路径 + threshold: 相似度阈值,None时使用默认值 + apply_nms: 是否应用非极大值抑制 + + Returns: + 检测结果列表 + """ + # 处理输入图像 + if isinstance(image, str): + if not os.path.exists(image): + logger.error(f"图像文件不存在: {image}") + return [] + target_image = cv2.imread(image, cv2.IMREAD_GRAYSCALE) + if target_image is None: + logger.error(f"无法读取图像文件: {image}") + return [] + else: + target_image = image + if len(target_image.shape) == 3: + target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2GRAY) + + # 加载图标模板 + template = self.load_icon_template(icon_path) + if template is None: + return [] + + # 使用指定阈值或默认阈值 + use_threshold = threshold if threshold is not None else self.default_threshold + + # 执行多尺度匹配 + matches = self.match_template_multiscale(target_image, template, use_threshold) + + # 应用非极大值抑制 + if apply_nms and matches: + matches = self.non_maximum_suppression(matches) + + logger.debug(f"图标检测完成,找到 {len(matches)} 个匹配") + return matches + + def detect_icons_batch(self, + image: Union[np.ndarray, str], + icon_paths: List[str], + threshold: Optional[float] = None) -> Dict[str, List[Dict]]: + """ + 批量检测多个图标 + + Args: + image: 目标图像 + icon_paths: 图标路径列表 + threshold: 相似度阈值 + + Returns: + 每个图标的检测结果字典 + """ + results = {} + + for icon_path in icon_paths: + icon_name = os.path.basename(icon_path) + results[icon_name] = self.detect_icon(image, icon_path, threshold) + + return results + + def clear_cache(self): + """清空图标模板缓存""" + self._icon_cache.clear() + logger.debug("图标模板缓存已清空") + + +class IconPathResolver: + """图标路径解析器,负责根据配置查找图标文件""" + + def __init__(self, base_paths: List[str]): + """ + 初始化路径解析器 + + Args: + base_paths: 图标搜索基础路径列表 + """ + self.base_paths = [Path(p) for p in base_paths] + + def resolve_icon_path(self, icon_name: str, app_id: Optional[str] = None) -> Optional[str]: + """ + 解析图标路径 + + Args: + icon_name: 图标名称 + app_id: 应用ID,用于确定子目录 + + Returns: + 图标文件的完整路径,找不到返回None + """ + # 常见的图标文件扩展名 + extensions = ['.png', '.jpg', '.jpeg', '.bmp'] + + # 搜索路径优先级 + search_paths = [] + + # 1. 如果有app_id,优先在对应目录下搜索 + if app_id: + app_name = self._extract_app_name(app_id) + for base_path in self.base_paths: + search_paths.append(base_path / app_name) + + # 2. 在所有基础路径下搜索 + search_paths.extend(self.base_paths) + + # 在每个搜索路径下查找图标文件 + for search_path in search_paths: + if not search_path.exists(): + continue + + for ext in extensions: + # 尝试直接匹配 + icon_path = search_path / f"{icon_name}{ext}" + if icon_path.exists(): + return str(icon_path) + + # 尝试递归搜索 + for icon_file in search_path.rglob(f"{icon_name}{ext}"): + return str(icon_file) + + logger.warning(f"未找到图标文件: {icon_name} (app_id: {app_id})") + return None + + def _extract_app_name(self, app_id: str) -> str: + """从app_id提取应用名称""" + # 例如: com.tencent.mm -> weixin 或 mm + if 'tencent.mm' in app_id: + return 'weixin' + elif 'bilibili' in app_id: + return 'bilibili' + elif 'xiecheng' in app_id or 'ctrip' in app_id: + return 'xiecheng' + else: + return app_id + # 默认取最后一段 + # return app_id.split('.')[-1] + + def list_available_icons(self, app_id: Optional[str] = None) -> List[str]: + """ + 列出可用的图标 + + Args: + app_id: 应用ID,用于筛选特定应用的图标 + + Returns: + 可用图标名称列表 + """ + icons = set() + + search_paths = [] + if app_id: + app_name = self._extract_app_name(app_id) + for base_path in self.base_paths: + app_path = base_path / app_name + if app_path.exists(): + search_paths.append(app_path) + + if not search_paths: + search_paths = [p for p in self.base_paths if p.exists()] + + for search_path in search_paths: + for icon_file in search_path.rglob('*'): + if icon_file.is_file() and icon_file.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp']: + # 移除扩展名作为图标名称 + icon_name = icon_file.stem + icons.add(icon_name) + + return sorted(list(icons)) diff --git a/MobiFlow/tools/app_trajectory_analyzer/.gitignore b/MobiFlow/tools/app_trajectory_analyzer/.gitignore new file mode 100644 index 0000000..bd041a2 --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/.gitignore @@ -0,0 +1,2 @@ +src/model/ +src/**/__pycache__/ \ No newline at end of file diff --git a/MobiFlow/tools/app_trajectory_analyzer/download_model.py b/MobiFlow/tools/app_trajectory_analyzer/download_model.py new file mode 100644 index 0000000..ff0798b --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/download_model.py @@ -0,0 +1,10 @@ +import os +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + +from huggingface_hub import snapshot_download + +snapshot_download( + repo_id="google/owlvit-base-patch32", + local_dir="./owlvit-base-patch32", + local_dir_use_symlinks=False +) \ No newline at end of file diff --git a/MobiFlow/tools/app_trajectory_analyzer/requirements.txt b/MobiFlow/tools/app_trajectory_analyzer/requirements.txt new file mode 100644 index 0000000..acbff2f --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/requirements.txt @@ -0,0 +1,13 @@ +opencv-python>=4.8.0 +numpy>=1.24.0 +pillow>=10.0.0 +pytesseract>=0.3.10 +scikit-image>=0.22.0 +rapidfuzz>=3.9.1 +pyyaml>=6.0.0 +# Optional, for better OCR +paddleocr>=2.7.0.3 +# Optional, for open-vocabulary icon detection (OWL-ViT) +transformers>=4.41.0 +# torch is environment-specific; prefer installing separately (CPU): +# torch>=2.2.0 ; extra-index-url https://download.pytorch.org/whl/cpu diff --git a/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/__init__.py b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/__init__.py new file mode 100644 index 0000000..c5310f4 --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + "ocr_engine", + "vision", + "rules", +] diff --git a/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/ocr_engine.py b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/ocr_engine.py new file mode 100644 index 0000000..9fc81e1 --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/ocr_engine.py @@ -0,0 +1,462 @@ +from __future__ import annotations +import os +from dataclasses import dataclass +from typing import List, Tuple, Optional, Any + +from PIL import Image + +# 尝试导入日志系统(如果可用) +try: + import sys + from pathlib import Path + # 添加avdag路径以导入日志系统 + current_dir = Path(__file__).resolve() + avdag_path = current_dir.parent.parent.parent.parent / "avdag" + if avdag_path.exists(): + sys.path.insert(0, str(avdag_path.parent)) + from avdag.logger import get_ocr_logger + _has_logger = True + else: + _has_logger = False +except ImportError: + _has_logger = False + +def _get_logger(): + """获取日志器,如果不可用则返回None""" + if _has_logger: + try: + return get_ocr_logger() + except: + pass + return None + +def _log_info(msg: str): + """记录信息日志""" + logger = _get_logger() + if logger: + logger.info(msg) + else: + print(f"[OCR] {msg}") + +def _log_warning(msg: str): + """记录警告日志""" + logger = _get_logger() + if logger: + logger.warning(msg) + else: + print(f"[OCR] {msg}") + +def _log_error(msg: str): + """记录错误日志""" + logger = _get_logger() + if logger: + logger.error(msg) + else: + print(f"[OCR] {msg}") + +def _log_debug(msg: str): + """记录调试日志""" + logger = _get_logger() + if logger: + logger.debug(msg) + else: + print(f"[OCR] {msg}") + +try: + from paddleocr import PaddleOCR # type: ignore + import paddle + + _has_paddle = True +except Exception: # pragma: no cover + PaddleOCR = None + _has_paddle = False + +try: + import pytesseract # type: ignore + _has_tesseract = True + + # 检测Tesseract是否正确安装 + def _check_tesseract_installation(): + try: + # 尝试获取Tesseract版本信息 + version = pytesseract.get_tesseract_version() + _log_info(f"检测到Tesseract版本: {version}") + return True + except Exception as e: + _log_error(f"Tesseract未正确安装或配置: {e}") + # 尝试自动配置Tesseract路径(Windows) + if os.name == 'nt': # Windows + possible_paths = [ + r"C:\Program Files\Tesseract-OCR\tesseract.exe", + r"C:\Program Files (x86)\Tesseract-OCR\tesseract.exe", + r"D:\Program Files\Tesseract-OCR\tesseract.exe", + r"D:\Program Files (x86)\Tesseract-OCR\tesseract.exe" + ] + for path in possible_paths: + if os.path.exists(path): + pytesseract.pytesseract.tesseract_cmd = path + _log_info(f"设置Tesseract路径: {path}") + try: + version = pytesseract.get_tesseract_version() + _log_info(f"Tesseract配置成功,版本: {version}") + return True + except Exception: + continue + return False + + _has_tesseract = _check_tesseract_installation() + +except Exception: # pragma: no cover + pytesseract = None # type: ignore + _has_tesseract = False + +# 全局PaddleOCR实例缓存,避免重复初始化和下载模型 +_global_paddle_instance = None + + +@dataclass +class OCRWord: + """OCR识别的单个词语结果""" + text: str # 识别的文字内容 + bbox: Tuple[int, int, int, int] # 边界框坐标 (x1, y1, x2, y2) + conf: float # 置信度分数 + + +@dataclass +class OCRResult: + """OCR识别的完整结果""" + words: List[OCRWord] # 所有识别出的词语列表 + + def get_text(self) -> str: + """获取所有文字内容的拼接字符串""" + return " ".join([w.text for w in self.words]) + + def find(self, keyword: str, fuzzy: bool = True) -> bool: + """ + 在OCR结果中查找关键词 + + Args: + keyword: 要查找的关键词 + fuzzy: 是否使用模糊匹配 + + Returns: + 是否找到关键词 + """ + text = self.get_text() + if not fuzzy: + return keyword in text + try: + from rapidfuzz import fuzz # type: ignore + return fuzz.partial_ratio(keyword, text) >= 80 + except Exception: + return keyword in text + + +class OCREngine: + """OCR文字识别引擎,支持Tesseract和PaddleOCR""" + + def __init__(self, lang: str = "chi_sim+eng", use_paddle: Optional[bool] = None): + """ + 初始化OCR引擎 + + Args: + lang: Tesseract的语言设置,默认中英文混合 + use_paddle: 是否使用PaddleOCR,None表示自动选择 + """ + self.lang = lang + if use_paddle is None: + use_paddle = _has_paddle + self.use_paddle = use_paddle + # self.use_paddle = False # 强制关闭PaddleOCR,避免版本兼容问题 + self._paddle: Optional[Any] = None + + # 使用全局单例来避免重复初始化PaddleOCR + if self.use_paddle and _has_paddle: + self._paddle = self._get_paddle_instance() + + def _get_paddle_instance(self) -> Optional[Any]: + """获取全局PaddleOCR实例,如果不存在则创建""" + global _global_paddle_instance + if _global_paddle_instance is None: + try: + # 判断 Paddle 是否编译了 CUDA + use_gpu = paddle.device.is_compiled_with_cuda() + device = "gpu" if use_gpu else "cpu" + paddle.set_device(device) + + _log_info(f"正在初始化PaddleOCR实例(设备: {device.upper()},首次使用需要下载模型)...") + + # 注意:3.1.1版本不能同时设置 use_angle_cls 和 use_textline_orientation + _global_paddle_instance = PaddleOCR( + lang="ch", # 中文模型 + use_textline_orientation=True # 推荐启用角度分类 + ) + _log_info(f"PaddleOCR初始化成功(使用{device.upper()})") + + except Exception as e: + _log_error(f"PaddleOCR初始化失败: {e}") + try: + _log_info("尝试使用默认参数初始化PaddleOCR(CPU)...") + paddle.set_device("cpu") + _global_paddle_instance = PaddleOCR(lang="ch") + _log_info("PaddleOCR默认参数初始化成功(CPU)") + except Exception as e2: + _log_error(f"PaddleOCR默认参数初始化也失败: {e2}") + _global_paddle_instance = None + return _global_paddle_instance + + def _to_pil(self, img: Any) -> Image.Image: + """将输入图像转换为PIL图像对象""" + if isinstance(img, str): + return Image.open(img).convert("RGB") + # 可选的numpy数组支持 + try: + import numpy as np # type: ignore + if isinstance(img, np.ndarray): + from PIL import Image as _Image + return _Image.fromarray(img) + except Exception: + pass + if isinstance(img, Image.Image): + return img + raise TypeError("不支持的图像类型") + + def _resize_image_if_needed(self, img: Any, max_side: int = 4000) -> Any: + """ + 如果图像尺寸超过最大边长限制,则缩放图像 + + Args: + img: 输入图像(可以是路径、PIL图像或numpy数组) + max_side: 最大边长限制 + + Returns: + 处理后的图像(保持原始类型) + """ + # 如果是字符串路径,先转换为PIL图像进行尺寸检查 + if isinstance(img, str): + try: + pil_img = Image.open(img).convert("RGB") + w, h = pil_img.size + if max(w, h) <= max_side: + return img # 尺寸合适,返回原始路径 + + # 需要缩放,计算新尺寸 + scale = max_side / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + + resized_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + _log_debug(f"图像尺寸从 {w}x{h} 缩放到 {new_w}x{new_h}") + + # 转换为numpy数组返回(PaddleOCR支持) + import numpy as np + return np.array(resized_img) + except Exception as e: + _log_error(f"图像缩放失败: {e}") + return img + + # 如果是PIL图像 + if isinstance(img, Image.Image): + w, h = img.size + if max(w, h) <= max_side: + return img + + scale = max_side / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + _log_debug(f"图像尺寸从 {w}x{h} 缩放到 {new_w}x{new_h}") + return resized_img + + # 如果是numpy数组 + try: + import numpy as np + if isinstance(img, np.ndarray): + h, w = img.shape[:2] + if max(w, h) <= max_side: + return img + + scale = max_side / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + + # 转换为PIL进行缩放,然后转回numpy + pil_img = Image.fromarray(img) + resized_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + _log_debug(f"图像尺寸从 {w}x{h} 缩放到 {new_w}x{new_h}") + return np.array(resized_img) + except Exception: + pass + + return img + + def _enhance_image_for_tesseract(self, img: Image.Image) -> Image.Image: + """ + 为Tesseract优化图像质量 + + Args: + img: PIL图像 + + Returns: + 增强后的PIL图像 + """ + try: + # 转换为灰度图 + if img.mode != 'L': + img = img.convert('L') + + # 增加对比度 + from PIL import ImageEnhance + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(1.5) + + # 锐化 + from PIL import ImageFilter + img = img.filter(ImageFilter.SHARPEN) + + # 如果图像太小,放大一些(Tesseract对小字识别较差) + w, h = img.size + if min(w, h) < 100: + scale = 200 / min(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + _log_debug(f"为Tesseract放大小图像: {w}x{h} -> {new_w}x{new_h}") + + return img + except Exception as e: + _log_error(f"图像增强失败: {e}") + return img + + def run(self, img: Any) -> OCRResult: + """ + 运行OCR识别 + + Args: + img: 输入图像,可以是文件路径、PIL图像或numpy数组 + + Returns: + OCR识别结果 + """ + # 1) PaddleOCR路径 - 只支持str路径和numpy数组 + if self._paddle is not None: + try: + import numpy as np # type: ignore + _log_debug("尝试使用PaddleOCR识别") + # 预处理图像尺寸 + processed_img = self._resize_image_if_needed(img, max_side=4000) + + # 准备PaddleOCR的输入:str或numpy数组 + paddle_input = None + if isinstance(processed_img, str): + paddle_input = processed_img + else: + # 转换为numpy数组 + if isinstance(processed_img, np.ndarray): + paddle_input = processed_img + else: + pil = self._to_pil(processed_img) + paddle_input = np.array(pil) + + # 尝试新的predict API + try: + results = self._paddle.predict(paddle_input) + if not results or len(results) == 0: + _log_warning("PaddleOCR predict返回空结果") + return OCRResult(words=[]) + + result_data = results[0] + if not isinstance(result_data, dict): + _log_warning("PaddleOCR predict返回格式异常") + return OCRResult(words=[]) + + texts = result_data.get("rec_texts", []) + scores = result_data.get("rec_scores", []) + bboxes = result_data.get("det_polygons", []) + + words: List[OCRWord] = [] + for i, (text, score) in enumerate(zip(texts, scores)): + if i < len(bboxes): + box = bboxes[i] + x1 = int(min(p[0] for p in box)) + y1 = int(min(p[1] for p in box)) + x2 = int(max(p[0] for p in box)) + y2 = int(max(p[1] for p in box)) + else: + x1, y1, x2, y2 = 0, 0, 100, 20 + words.append(OCRWord(text=text, bbox=(x1, y1, x2, y2), conf=float(score))) + return OCRResult(words=words) + except (AttributeError, KeyError): + # 回退到旧的ocr API + res = self._paddle.ocr(paddle_input, cls=True) + words: List[OCRWord] = [] + if res and res[0]: + for line in res[0]: + box = line[0] + x1 = int(min(p[0] for p in box)) + y1 = int(min(p[1] for p in box)) + x2 = int(max(p[0] for p in box)) + y2 = int(max(p[1] for p in box)) + text = line[1][0] + conf = float(line[1][1]) if line[1][1] is not None else 0.0 + words.append(OCRWord(text=text, bbox=(x1, y1, x2, y2), conf=conf)) + return OCRResult(words=words) + except Exception as e: + _log_error(f"PaddleOCR识别失败: {e}") + _log_warning("该类型图片被处理后,不支持PaddleOCR") + pass + + # 2) Tesseract路径 - 需要PIL图像 + if _has_tesseract and pytesseract is not None: + try: + _log_debug(f"尝试使用Tesseract识别,语言设置: {self.lang}") + # 预处理图像尺寸 + processed_img = self._resize_image_if_needed(img, max_side=4000) + pil = self._to_pil(processed_img) + + # 增强图像质量以提高识别率 + enhanced_img = self._enhance_image_for_tesseract(pil) + + # 使用image_to_data获取详细信息 + data = pytesseract.image_to_data(enhanced_img, lang=self.lang, output_type=pytesseract.Output.DICT) + words: List[OCRWord] = [] + + if not data or not data.get("text"): + _log_warning("Tesseract未识别到任何文字") + return OCRResult(words=[]) + + n = len(data.get("text", [])) + recognized_count = 0 + for i in range(n): + txt = (data["text"][i] or "").strip() + if not txt: + continue + + try: + conf = float(data.get("conf", [0])[i]) + except Exception: + conf = 0.0 + + # 过滤置信度过低的结果 + if conf < 30: # Tesseract置信度阈值 + continue + + try: + x = int(data.get("left", [0])[i]) + y = int(data.get("top", [0])[i]) + w = int(data.get("width", [0])[i]) + h = int(data.get("height", [0])[i]) + except Exception: + x, y, w, h = 0, 0, 100, 20 + + words.append(OCRWord(text=txt, bbox=(x, y, x + w, y + h), conf=conf)) + recognized_count += 1 + + _log_debug(f"Tesseract识别成功,识别到 {recognized_count} 个文字片段") + return OCRResult(words=words) + except Exception as e: + _log_error(f"Tesseract识别失败: {e}") + pass + + # 3) 回退空结果 + return OCRResult(words=[]) diff --git a/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/rules.py b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/rules.py new file mode 100644 index 0000000..1ada2d8 --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/rules.py @@ -0,0 +1,95 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import List, Dict, Optional +import os, glob + +import yaml # type: ignore + + +@dataclass +class StepRule: + """步骤规则数据类""" + name: str # 步骤名称 + must_have_keywords: List[str] # 必须包含的关键词列表 + any_of_keywords: Optional[List[str]] = None # 任选其一的关键词列表 + forbidden_keywords: Optional[List[str]] = None # 禁止出现的关键词列表 + # 图标约束 + all_of_icons: Optional[List[str]] = None # 必须全部存在的图标列表 + any_of_icons: Optional[List[str]] = None # 任选其一的图标列表 + + +@dataclass +class TaskRule: + """任务规则数据类""" + task_name: str # 任务名称标识 + steps: List[StepRule] # 有序的步骤列表 + min_actions: int = 0 # 最小动作数量要求 + + +def load_tasks_from_dir(rules_dir: str) -> Dict[str, TaskRule]: + """ + 从目录中加载所有任务规则 + + Args: + rules_dir: 规则文件目录路径 + + Returns: + 任务规则字典,键为任务名称,值为任务规则对象 + """ + tasks: Dict[str, TaskRule] = {} + for path in glob.glob(os.path.join(rules_dir, "*.y*ml")): + try: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + if not data: + continue + + # 支持单个任务对象转换为列表格式 + if isinstance(data, dict) and "task_name" in data: + data = [data] + if not isinstance(data, list): + continue + + # 解析每个任务定义 + for td in data: + name = td.get("task_name") + steps_raw = td.get("steps", []) + min_actions = int(td.get("min_actions", 0)) + steps: List[StepRule] = [] + + for s in steps_raw: + steps.append(StepRule( + name=s.get("name"), + must_have_keywords=list(s.get("must_have_keywords", [])), + any_of_keywords=list(s.get("any_of_keywords", [])) if s.get("any_of_keywords") else None, + forbidden_keywords=list(s.get("forbidden_keywords", [])) if s.get("forbidden_keywords") else None, + all_of_icons=list(s.get("all_of_icons", [])) if s.get("all_of_icons") else None, + any_of_icons=list(s.get("any_of_icons", [])) if s.get("any_of_icons") else None, + )) + + task = TaskRule(task_name=name, steps=steps, min_actions=min_actions) + tasks[name] = task + except Exception: + # 忽略无法解析的文件 + continue + return tasks + + +def load_task_by_name(rules_dir: str, task_name: str) -> TaskRule: + """ + 根据任务名称加载特定任务规则 + + Args: + rules_dir: 规则文件目录路径 + task_name: 任务名称 + + Returns: + 任务规则对象 + + Raises: + ValueError: 任务未找到时抛出异常 + """ + tasks = load_tasks_from_dir(rules_dir) + if task_name not in tasks: + raise ValueError(f"任务 '{task_name}' 在规则目录中未找到: {rules_dir}") + return tasks[task_name] diff --git a/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/vision.py b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/vision.py new file mode 100644 index 0000000..5f7887f --- /dev/null +++ b/MobiFlow/tools/app_trajectory_analyzer/src/analyzer/vision.py @@ -0,0 +1,451 @@ +from __future__ import annotations +import os +from dataclasses import dataclass +from typing import List, Tuple, Optional, Dict, Any, Iterable, Set + +# 尝试导入日志系统 +try: + import sys + from pathlib import Path + current_dir = Path(__file__).resolve() + avdag_path = current_dir.parent.parent.parent.parent / "avdag" + if avdag_path.exists(): + sys.path.insert(0, str(avdag_path.parent)) + from avdag.logger import get_logger + _logger = get_logger("vision") + _has_logger = True + else: + _has_logger = False +except ImportError: + _has_logger = False + +def _log_debug(msg: str): + if _has_logger: + _logger.debug(msg) + else: + print(f"[Vision] {msg}") + +try: + import cv2 # type: ignore + _HAS_CV2 = True +except Exception: # pragma: no cover + cv2 = None # type: ignore + _HAS_CV2 = False +try: + import numpy as np # type: ignore + _HAS_NP = True +except Exception: # pragma: no cover + np = None # type: ignore + _HAS_NP = False +from PIL import Image, ImageChops, ImageStat + +from .ocr_engine import OCREngine, OCRResult + + +@dataclass +class Detection: + """检测结果数据类""" + name: str # 检测目标名称 + bbox: Tuple[int, int, int, int] # 边界框坐标 (x1, y1, x2, y2) + score: float # 检测置信度分数 + + +def imread(path: str): + """ + 读取图像文件 + + Args: + path: 图像文件路径 + + Returns: + 图像对象(OpenCV或PIL格式) + + Raises: + FileNotFoundError: 文件不存在 + """ + if _HAS_CV2: + img = cv2.imread(path, cv2.IMREAD_COLOR) + if img is None: + raise FileNotFoundError(path) + return img + # 回退到PIL + pil = Image.open(path).convert("RGB") + return pil + + +def to_pil(img: Any) -> Image.Image: + """ + 将图像转换为PIL格式 + + Args: + img: 输入图像(OpenCV、PIL或numpy数组) + + Returns: + PIL图像对象 + + Raises: + TypeError: 不支持的图像类型 + """ + if isinstance(img, Image.Image): + return img + if _HAS_CV2: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + # 如果已经是numpy数组(RGB),尝试直接转换 + try: + import numpy as _np # type: ignore + if isinstance(img, _np.ndarray): + return Image.fromarray(img) + except Exception: + pass + raise TypeError("to_pil不支持的图像类型") + + +def ocr_on_image(img_or_path: Any, engine: OCREngine) -> OCRResult: + """ + 对图像执行OCR识别 + + Args: + img_or_path: 图像或图像路径 + engine: OCR引擎实例 + + Returns: + OCR识别结果 + """ + # 直接传递给OCR引擎,让引擎内部处理格式转换 + return engine.run(img_or_path) + + +def find_keywords(ocr: OCRResult, keywords: Iterable[str], fuzzy: bool = True) -> Set[str]: + """ + 在OCR结果中查找关键词 + + Args: + ocr: OCR识别结果 + keywords: 要查找的关键词列表 + fuzzy: 是否使用模糊匹配 + + Returns: + 找到的关键词集合 + """ + found: Set[str] = set() + text = ocr.get_text() + for k in keywords: + if not k: + continue + if ocr.find(k, fuzzy=fuzzy): + found.add(k) + return found + + +def template_match(img: Any, template: Any, threshold: float = 0.85) -> List[Detection]: + """ + 模板匹配检测 + + Args: + img: 目标图像 + template: 模板图像 + threshold: 匹配阈值 + + Returns: + 检测结果列表 + """ + if not _HAS_CV2 or not _HAS_NP: + return [] + ih, iw = img.shape[:2] + th, tw = template.shape[:2] + if ih < th or iw < tw: + return [] + res = cv2.matchTemplate(img, template, cv2.TM_CCOEFF_NORMED) + ys, xs = np.where(res >= threshold) + dets: List[Detection] = [] + for (x, y) in zip(xs, ys): + dets.append(Detection(name="template", bbox=(int(x), int(y), int(x+tw), int(y+th)), score=float(res[y, x]))) + + # 非极大值抑制(简单版本) + dets_sorted = sorted(dets, key=lambda d: d.score, reverse=True) + kept: List[Detection] = [] + + def iou(a: Tuple[int,int,int,int], b: Tuple[int,int,int,int]) -> float: + """计算两个边界框的交并比""" + ax1, ay1, ax2, ay2 = a + bx1, by1, bx2, by2 = b + inter_x1, inter_y1 = max(ax1, bx1), max(ay1, by1) + inter_x2, inter_y2 = min(ax2, bx2), min(ay2, by2) + if inter_x2 <= inter_x1 or inter_y2 <= inter_y1: + return 0.0 + inter = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) + area_a = (ax2 - ax1) * (ay2 - ay1) + area_b = (bx2 - bx1) * (by2 - by1) + return inter / max(area_a + area_b - inter, 1e-6) + + for d in dets_sorted: + if all(iou(d.bbox, k.bbox) < 0.5 for k in kept): + kept.append(d) + return kept + + +def _load_icon_template(path: str): + """加载图标模板文件""" + if _HAS_CV2: + t = cv2.imread(path, cv2.IMREAD_COLOR) + return t + # 回退:作为PIL加载,在匹配时禁用模板匹配 + return Image.open(path).convert('RGB') + + +def build_icon_bank(icons_dir: str) -> Dict[str, Any]: + """ + 从目录中加载图标模板 + + Args: + icons_dir: 图标模板目录路径 + + Returns: + 图标模板字典,键为文件名(不含扩展名),值为模板图像 + + Note: + 接受 搜索.png, 购物车.jpg 等文件,键为文件名(不含扩展名) + """ + if not icons_dir or not os.path.isdir(icons_dir): + return {} + bank: Dict[str, Any] = {} + for fn in os.listdir(icons_dir): + name, ext = os.path.splitext(fn) + if ext.lower() not in {'.png', '.jpg', '.jpeg'}: + continue + try: + bank[name] = _load_icon_template(os.path.join(icons_dir, fn)) + except Exception: + continue + return bank + + +def detect_icons_in_image(img: Any, icon_bank: Dict[str, Any], threshold: float = 0.85) -> Set[str]: + """ + 在图像中检测图标 + + Args: + img: 目标图像 + icon_bank: 图标模板字典 + threshold: 检测阈值 + + Returns: + 检测到的图标名称集合 + """ + if not icon_bank: + return set() + present: Set[str] = set() + if not _HAS_CV2 or not _HAS_NP: + return present + base = img if not isinstance(img, Image.Image) else cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + for name, tmpl in icon_bank.items(): + _log_debug(f"检测图标: {name}") + if tmpl is None: + continue + # 检查PIL图像的有效性 + if isinstance(tmpl, Image.Image) and (tmpl.size[0] < 10 or tmpl.size[1] < 10): + continue # 跳过无效模板 + # 检查numpy数组的有效性 + if _HAS_NP and hasattr(tmpl, 'shape') and (tmpl.shape[0] < 10 or tmpl.shape[1] < 10): + continue # 跳过无效模板 + dets = template_match(base, tmpl, threshold=threshold) + if dets: + present.add(name) + return present + + +def red_badge_score(img: Any) -> float: + """ + 计算红色徽章分数的启发式算法(如购物车徽章) + + 专注于右上角区域检测红色/橙色圆点。 + 仅使用PIL操作,无需cv2/numpy。 + + Args: + img: 输入图像 + + Returns: + 归一化分数 [0,1] + """ + pil = to_pil(img) + w, h = pil.size + # 专注于右上角区域(状态栏+头部) + box = (int(w*0.7), int(0), w, int(h*0.2)) + region = pil.crop(box) + region = region.resize((200, 120)) + px = region.load() + red_count = 0 + total = region.size[0] * region.size[1] + for y in range(region.size[1]): + for x in range(region.size[0]): + r, g, b = px[x, y] + if r >= 180 and g <= 100 and b <= 100 and (r - max(g, b)) >= 50: + red_count += 1 + return red_count / max(total, 1) + + +def search_box_score(img: Any) -> float: + """ + 计算搜索框分数的启发式算法 + + 检测顶部区域的大型亮色水平条带(典型的搜索/文本输入框)。 + 仅使用PIL操作。 + + Args: + img: 输入图像 + + Returns: + 分数 [0,1] + """ + pil = to_pil(img).convert('L') + w, h = pil.size + # 专注于顶部中心区域,避免极端边缘 + left = int(w * 0.05) + right = int(w * 0.95) + top = int(h * 0.02) + bottom = int(h * 0.25) + region = pil.crop((left, top, right, bottom)) + region = region.resize((200, 120)) + px = region.load() + W, H = region.size + + # 预计算每行的亮度掩码 + bright_thresh = 220 + row_bright_frac = [] + for y in range(H): + bright = 0 + for x in range(W): + if px[x, y] >= bright_thresh: + bright += 1 + row_bright_frac.append(bright / max(W, 1)) + + # 找到高亮度覆盖率的最佳连续条带 + best = 0.0 + y = 0 + while y < H: + if row_bright_frac[y] < 0.55: # 当行相当亮时开始 + y += 1 + continue + y0 = y + while y < H and row_bright_frac[y] >= 0.5: + y += 1 + y1 = y # 排他性 + height = y1 - y0 + if 6 <= height <= 40: # 调整大小后的合理框高度 + avg_cov = sum(row_bright_frac[y0:y1]) / max(height, 1) + score = avg_cov * min(1.0, height / 25.0) + if score > best: + best = score + return float(best) + + +def ssim(a: Any, b: Any) -> float: + """ + 计算两幅图像的结构相似性指数 + + Args: + a, b: 两幅待比较的图像 + + Returns: + SSIM值 [0,1],值越高表示越相似 + """ + # 优先使用skimage+cv2,回退到基于PIL的近似 + if _HAS_CV2 and _HAS_NP: + try: + from skimage.metrics import structural_similarity as ssim_fn # type: ignore + a_arr = a if isinstance(a, type(getattr(a, 'shape', None))) else a + b_arr = b if isinstance(b, type(getattr(b, 'shape', None))) else b + if isinstance(a, Image.Image): + a_arr = cv2.cvtColor(np.array(a), cv2.COLOR_RGB2GRAY) + else: + a_arr = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY) + if isinstance(b, Image.Image): + b_arr = cv2.cvtColor(np.array(b), cv2.COLOR_RGB2GRAY) + else: + b_arr = cv2.cvtColor(b, cv2.COLOR_BGR2GRAY) + score, _ = ssim_fn(a_arr, b_arr, full=True) + return float(score) + except Exception: + pass + + # PIL回退:归一化平均绝对差异相似性 + a_pil = to_pil(a) + b_pil = to_pil(b) + a_pil = a_pil.convert('L').resize((300, 600)) + b_pil = b_pil.convert('L').resize((300, 600)) + diff = ImageChops.difference(a_pil, b_pil) + stat = ImageStat.Stat(diff) + mad = stat.mean[0] / 255.0 + return 1.0 - mad # 粗略近似 + + +def estimate_scroll_direction(prev: Any, curr: Any) -> Optional[str]: + """ + 估计滚动方向 + + Args: + prev: 前一帧图像 + curr: 当前帧图像 + + Returns: + 滚动方向:"UP"、"DOWN" 或 None(无明显滚动) + """ + if _HAS_CV2 and _HAS_NP: + try: + # 启发式:计算密集ORB匹配和中位数dy + orb = cv2.ORB_create(nfeatures=1000) + prev_gray = cv2.cvtColor(prev, cv2.COLOR_BGR2GRAY) if not isinstance(prev, Image.Image) else cv2.cvtColor(np.array(prev), cv2.COLOR_RGB2GRAY) + curr_gray = cv2.cvtColor(curr, cv2.COLOR_BGR2GRAY) if not isinstance(curr, Image.Image) else cv2.cvtColor(np.array(curr), cv2.COLOR_RGB2GRAY) + kp1, des1 = orb.detectAndCompute(prev_gray, None) + kp2, des2 = orb.detectAndCompute(curr_gray, None) + if des1 is None or des2 is None: + return None + bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) + matches = bf.match(des1, des2) + if not matches: + return None + dys = [] + for m in matches: + y1 = kp1[m.queryIdx].pt[1] + y2 = kp2[m.trainIdx].pt[1] + dys.append(y2 - y1) + if not dys: + return None + from statistics import median + med_dy = float(median(dys)) + if abs(med_dy) < 2.0: + return None + return "UP" if med_dy < 0 else "DOWN" + except Exception: + return None + + # 回退:比较上下三分之一区域的平均强度差异 + try: + a = to_pil(prev).convert('L').resize((300, 600)) + b = to_pil(curr).convert('L').resize((300, 600)) + top_diff = ImageChops.difference(a.crop((0,0,300,200)), b.crop((0,0,300,200))) + bot_diff = ImageChops.difference(a.crop((0,400,300,600)), b.crop((0,400,300,600))) + top_mean = ImageStat.Stat(top_diff).mean[0] + bot_mean = ImageStat.Stat(bot_diff).mean[0] + if abs(top_mean - bot_mean) < 1.0: + return None + return "UP" if top_mean > bot_mean else "DOWN" + except Exception: + return None + + +# 键盘识别的常见标记 +KEYBOARD_TOKENS = {"空格", "拼音", "英文", "123", "ABC", "符"} + +def has_keyboard(ocr: OCRResult) -> bool: + """ + 检测OCR结果中是否包含键盘相关文字 + + Args: + ocr: OCR识别结果 + + Returns: + 是否检测到键盘 + """ + text = ocr.get_text() + return any(tok in text for tok in KEYBOARD_TOKENS) diff --git a/MobiFlow/universal_test_runner.py b/MobiFlow/universal_test_runner.py new file mode 100644 index 0000000..db45f9d --- /dev/null +++ b/MobiFlow/universal_test_runner.py @@ -0,0 +1,723 @@ +#!/usr/bin/env python3 +""" +通用任务测试执行入口 +支持通过 task.json 配置文件灵活配置不同任务的测试 + +使用方法: +- python universal_test_runner.py task_configs/taobao.json # 测试所有类型 +- python universal_test_runner.py task_configs/taobao.json type3 # 测试指定类型 +- python universal_test_runner.py task_configs/taobao.json type3:150 # 测试指定类型的指定trace +- python universal_test_runner.py task_configs/taobao.json 150,151,152 # 测试指定的trace编号 +""" + +import os +import sys +import json +import time +from pathlib import Path +from datetime import datetime +from typing import Dict, List, Union, Any, Optional +from dataclasses import dataclass, asdict + +try: + import llmconfig +except ImportError: + print("警告: 无法导入 llmconfig,将使用默认配置") + class MockLLMConfig: + API_KEY = "your key" + BASE_URL = "your url" + MODEL = "gemini-2.5-pro-preview-06-05" + llmconfig = MockLLMConfig() + +from avdag.verifier import verify_task_folder, VerifierOptions, make_llm_options +from avdag.ocr_processor import create_standard_ocr_functions +from avdag.logger import configure_logging + +@dataclass +class TestResult: + """测试结果数据结构""" + trace_id: Union[str, int] + task_type: str + task_name: str + success: bool + score: float + matched_nodes: List[str] + reason: str + manual_review_needed: bool + execution_time: float + error_message: str = "" + +@dataclass +class TestSummary: + """测试汇总数据结构""" + task_name: str + total_tests: int + success_count: int + success_rate: float + total_execution_time: float + results_by_type: Dict[str, Dict[str, Any]] + all_results: List[TestResult] + +class UniversalTestRunner: + """通用测试执行器""" + + def __init__(self, config_file: str): + """初始化测试执行器 + + Args: + config_file: 任务配置文件路径 + """ + self.config_file = config_file + self.config = self._load_config() + # 使用当前工作目录作为基础目录,而不是配置文件所在目录 + self.base_dir = os.getcwd() + self.start_time = datetime.now() + + # 生成带时间戳的文件名 + timestamp = self.start_time.strftime("%Y%m%d_%H%M%S") + task_name = self.config['task_name'] + + # 配置日志 + log_config = self.config.get('logging', {}) + log_file = log_config.get('output_file', 'test-{task_name}-{timestamp}.log') + log_file = log_file.format(task_name=task_name, timestamp=timestamp) + + configure_logging( + level=log_config.get('level', 'DEBUG'), + use_colors=log_config.get('use_colors', True), + show_time=log_config.get('show_time', True), + show_module=log_config.get('show_module', True), + output_file=log_file + ) + + # 配置输出文件 + output_config = self.config.get('output', {}) + self.summary_file = output_config.get('summary_file', 'test-{task_name}-summary-{timestamp}.txt') + self.summary_file = self.summary_file.format(task_name=task_name, timestamp=timestamp) + + self.detailed_file = output_config.get('detailed_results_file', 'test-{task_name}-detailed-{timestamp}.json') + self.detailed_file = self.detailed_file.format(task_name=task_name, timestamp=timestamp) + + # 创建验证选项 + self.opts = self._create_verifier_options() + + print(f"=== 通用任务测试执行器 ===") + print(f"任务名称: {self.config['task_name']}") + print(f"任务描述: {self.config.get('description', 'N/A')}") + print(f"日志文件: {log_file}") + print(f"汇总文件: {self.summary_file}") + print(f"详细结果: {self.detailed_file}") + + def _load_config(self) -> Dict[str, Any]: + """加载配置文件""" + try: + with open(self.config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + raise RuntimeError(f"无法加载配置文件 {self.config_file}: {str(e)}") + + def _create_verifier_options(self) -> VerifierOptions: + """创建验证选项""" + test_opts = self.config.get('test_options', {}) + + # 获取LLM配置 + api_key = getattr(llmconfig, 'API_KEY', None) + base_url = getattr(llmconfig, 'BASE_URL', None) + model = getattr(llmconfig, 'MODEL', None) + + print("=== LLM 配置 ===") + print(f"API_KEY: {api_key}") + print(f"BASE_URL: {base_url}") + print(f"MODEL: {model}") + + if test_opts.get('enable_ocr', True): + # 创建标准OCR函数 + ocr_func, texts_func = create_standard_ocr_functions() + else: + ocr_func = None + + if test_opts.get('enable_llm', True) and api_key and base_url: + # 创建带LLM的选项 + opts = make_llm_options( + api_key=api_key, + base_url=base_url, + model=model, + force_llm=test_opts.get('force_llm', False) + ) + if ocr_func: + opts.ocr = ocr_func + print("[验证] 已启用OCR+LLM验证模式") + else: + # 仅OCR模式 + opts = VerifierOptions(ocr=ocr_func) + print("[验证] 已启用纯OCR验证模式") + if not api_key or not base_url: + print("[警告] LLM_API_KEY/LLM_BASE_URL 未设置,已退化到 OCR-only 验证") + + # 应用其他选项 + if test_opts.get('ocr_frame_exclusive'): + opts.ocr_frame_exclusive = True + if test_opts.get('llm_frame_exclusive'): + opts.llm_frame_exclusive = True + if test_opts.get('prevent_frame_backtrack'): + opts.prevent_frame_backtrack = True + + return opts + + def _get_rule_file_path(self, task_type: str) -> str: + """获取规则文件完整路径""" + rules_base = self.config['rules_base_dir'] + rule_file = self.config['task_types'][task_type]['rule_file'] + return os.path.join(self.base_dir, rules_base, rule_file) + + def _get_data_path(self, task_type: str, trace_id: Union[str, int]) -> str: + """获取数据路径""" + data_base = self.config['data_base_dir'] + + # 如果trace_id包含路径分隔符,说明是层级结构(如 "type2/11") + if isinstance(trace_id, str) and '/' in trace_id: + return os.path.join(self.base_dir, data_base, trace_id) + # 如果trace_id是字符串(如"type2"),直接使用作为文件夹名 + elif isinstance(trace_id, str): + return os.path.join(self.base_dir, data_base, trace_id) + else: + # 如果是数字,需要在task_type目录下查找 + return os.path.join(self.base_dir, data_base, task_type, str(trace_id)) + + def _auto_discover_traces(self, task_type: str) -> List[Union[str, int]]: + """自动发现指定类型的所有trace""" + data_base = self.config['data_base_dir'] + base_path = os.path.join(self.base_dir, data_base) + + discovered_traces = [] + + # 首先尝试查找 type* 目录格式 (如 type2, type3) + type_dir = os.path.join(base_path, task_type) + if os.path.exists(type_dir) and os.path.isdir(type_dir): + print(f"[自动发现] 找到 type 目录: {type_dir}") + + # 扫描 type 目录下的子目录 + sub_traces = [] + for item in os.listdir(type_dir): + item_path = os.path.join(type_dir, item) + if os.path.isdir(item_path): + try: + # 尝试转换为数字 + trace_num = int(item) + sub_traces.append(f"{task_type}/{trace_num}") + except ValueError: + # 如果不是数字,也包含字符串格式的目录 + sub_traces.append(f"{task_type}/{item}") + + if sub_traces: + # 对数字进行排序 + numeric_traces = [t for t in sub_traces if t.split('/')[-1].isdigit()] + string_traces = [t for t in sub_traces if not t.split('/')[-1].isdigit()] + + # 按最后一部分(trace编号)排序 + numeric_traces.sort(key=lambda x: int(x.split('/')[-1])) + string_traces.sort() + + discovered_traces = numeric_traces + string_traces + print(f"[自动发现] 在 {task_type} 目录下发现 {len(discovered_traces)} 个子目录: {[t.split('/')[-1] for t in discovered_traces]}") + return discovered_traces + else: + # 如果 type 目录下没有子目录,则将 type 目录本身作为 trace + discovered_traces.append(task_type) + print(f"[自动发现] {task_type} 目录下无子目录,使用目录本身作为 trace") + return discovered_traces + + # 然后尝试查找数字目录格式 (如 150, 151, 152...) + if os.path.exists(base_path): + for item in os.listdir(base_path): + item_path = os.path.join(base_path, item) + if os.path.isdir(item_path): + try: + # 尝试转换为数字 + trace_num = int(item) + discovered_traces.append(trace_num) + except ValueError: + # 如果不是数字,也包含字符串格式的目录 + if item.startswith('type') or item == task_type: + discovered_traces.append(item) + + # 排序结果 + numeric_traces = [t for t in discovered_traces if isinstance(t, int)] + string_traces = [t for t in discovered_traces if isinstance(t, str)] + numeric_traces.sort() + string_traces.sort() + + final_traces = string_traces + numeric_traces + + if final_traces: + print(f"[自动发现] 类型 {task_type} 发现 {len(final_traces)} 个 traces: {final_traces}") + else: + print(f"[自动发现] 类型 {task_type} 未发现任何可用的 traces") + + return final_traces + + def _get_traces_for_type(self, task_type: str) -> List[Union[str, int]]: + """获取指定类型的所有trace""" + task_config = self.config['task_types'][task_type] + data_traces = task_config.get('data_traces') + + # 如果没有配置 data_traces 或配置为空,则自动发现 + if not data_traces: + print(f"[配置] 类型 {task_type} 未配置 data_traces,启用自动发现模式") + return self._auto_discover_traces(task_type) + + # 如果配置了 data_traces,则优先使用配置的值 + print(f"[配置] 类型 {task_type} 使用配置的 data_traces: {data_traces}") + + if isinstance(data_traces, str): + # 如果是字符串,检查对应目录是否存在 + data_path = self._get_data_path(task_type, data_traces) + if os.path.exists(data_path): + # 检查是否是目录,如果是目录且有子目录,则扫描子目录 + if os.path.isdir(data_path): + sub_traces = [] + for item in os.listdir(data_path): + item_path = os.path.join(data_path, item) + if os.path.isdir(item_path): + try: + # 尝试转换为数字 + trace_num = int(item) + sub_traces.append(f"{data_traces}/{trace_num}") + except ValueError: + # 如果不是数字,也包含字符串格式的目录 + sub_traces.append(f"{data_traces}/{item}") + + if sub_traces: + # 对数字进行排序 + numeric_traces = [t for t in sub_traces if t.split('/')[-1].isdigit()] + string_traces = [t for t in sub_traces if not t.split('/')[-1].isdigit()] + + # 按最后一部分(trace编号)排序 + numeric_traces.sort(key=lambda x: int(x.split('/')[-1])) + string_traces.sort() + + discovered_traces = numeric_traces + string_traces + print(f"[配置] 类型 {task_type} 在目录 {data_traces} 下发现 {len(discovered_traces)} 个子目录: {[t.split('/')[-1] for t in discovered_traces]}") + return discovered_traces + else: + # 如果没有子目录,则使用目录本身 + return [data_traces] + else: + return [data_traces] + else: + print(f"警告: 配置的数据路径不存在: {data_path},尝试自动发现") + return self._auto_discover_traces(task_type) + elif isinstance(data_traces, list): + # 如果是列表,返回所有存在的trace + valid_traces = [] + for trace in data_traces: + data_path = self._get_data_path(task_type, trace) + if os.path.exists(data_path): + valid_traces.append(trace) + else: + print(f"警告: trace {trace} 数据路径不存在: {data_path}") + + # 如果配置的traces都不存在,则尝试自动发现 + if not valid_traces: + print(f"警告: 配置的所有 traces 都不存在,尝试自动发现") + return self._auto_discover_traces(task_type) + + return valid_traces + else: + print(f"警告: 不支持的数据traces格式: {type(data_traces)},尝试自动发现") + return self._auto_discover_traces(task_type) + + def test_single_trace(self, task_type: str, trace_id: Union[str, int]) -> TestResult: + """测试单个trace""" + task_config = self.config['task_types'][task_type] + task_name = task_config['name'] + + start_time = time.time() + + try: + # 获取文件路径 + rule_file = self._get_rule_file_path(task_type) + data_path = self._get_data_path(task_type, trace_id) + + # 检查文件是否存在 + if not os.path.exists(rule_file): + error_msg = f"规则文件不存在: {rule_file}" + print(f"❌ {trace_id}: {error_msg}") + return TestResult( + trace_id=trace_id, task_type=task_type, task_name=task_name, + success=False, score=0.0, matched_nodes=[], reason=error_msg, + manual_review_needed=False, execution_time=time.time() - start_time, + error_message=error_msg + ) + + if not os.path.exists(data_path): + error_msg = f"数据路径不存在: {data_path}" + print(f"❌ {trace_id}: {error_msg}") + return TestResult( + trace_id=trace_id, task_type=task_type, task_name=task_name, + success=False, score=0.0, matched_nodes=[], reason=error_msg, + manual_review_needed=False, execution_time=time.time() - start_time, + error_message=error_msg + ) + + print(f"\n--- 测试 {trace_id} [{task_name}] ---") + print(f"规则文件: {os.path.basename(rule_file)}") + print(f"数据路径: {data_path}") + + # 执行验证 + res = verify_task_folder(rule_file, data_path, self.opts) + + # 构建结果 + matched_nodes = [m.node_id for m in res.matched] if res.matched else [] + execution_time = time.time() - start_time + + result = TestResult( + trace_id=trace_id, + task_type=task_type, + task_name=task_name, + success=res.ok, + score=res.total_score, + matched_nodes=matched_nodes, + reason=res.reason or "", + manual_review_needed=res.manual_review_needed, + execution_time=execution_time + ) + + # 输出结果 + status = "✓ 成功" if res.ok else "✗ 失败" + print(f"验证结果: {status}") + print(f"匹配节点: {matched_nodes}") + print(f"任务得分: {res.total_score}分") + print(f"执行时间: {execution_time:.2f}秒") + + if res.reason: + print(f"详细原因: {res.reason}") + if res.manual_review_needed: + print("⚠️ 需要人工复核") + + return result + + except Exception as e: + error_msg = f"测试执行异常: {str(e)}" + print(f"❌ {trace_id}: {error_msg}") + + return TestResult( + trace_id=trace_id, task_type=task_type, task_name=task_name, + success=False, score=0.0, matched_nodes=[], reason="", + manual_review_needed=False, execution_time=time.time() - start_time, + error_message=error_msg + ) + + def test_by_type(self, task_type: str, specific_trace: Optional[Union[str, int]] = None) -> List[TestResult]: + """按类型测试""" + if task_type not in self.config['task_types']: + print(f"错误: 任务类型 '{task_type}' 未在配置中定义") + return [] + + task_config = self.config['task_types'][task_type] + task_name = task_config['name'] + + if specific_trace is not None: + # 测试指定的trace + traces = [specific_trace] + else: + # 获取该类型的所有trace + traces = self._get_traces_for_type(task_type) + + if not traces: + print(f"错误: 类型 '{task_type}' 没有可用的trace数据") + return [] + + print(f"\n{'='*60}") + print(f"测试任务类型 {task_type} - {task_name}") + print(f"trace数量: {len(traces)}") + print(f"{'='*60}") + + results = [] + for trace_id in traces: + result = self.test_single_trace(task_type, trace_id) + results.append(result) + + return results + + def test_all_types(self) -> TestSummary: + """测试所有类型""" + print(f"\n{'='*80}") + print(f"开始测试任务: {self.config['task_name']}") + print(f"任务描述: {self.config.get('description', 'N/A')}") + print(f"{'='*80}") + + all_results = [] + results_by_type = {} + + for task_type in self.config['task_types']: + type_results = self.test_by_type(task_type) + all_results.extend(type_results) + + # 汇总该类型的结果 + success_count = sum(1 for r in type_results if r.success) + total_count = len(type_results) + success_rate = (success_count / total_count * 100) if total_count > 0 else 0 + + results_by_type[task_type] = { + 'task_name': self.config['task_types'][task_type]['name'], + 'total_tests': total_count, + 'success_count': success_count, + 'success_rate': success_rate, + 'results': type_results + } + + print(f"\n--- 类型 {task_type} 汇总 ---") + for result in type_results: + status = "✓" if result.success else "✗" + print(f"trace {result.trace_id}: {status} | score: {result.score} | nodes: {result.matched_nodes} | reason: {result.reason}") + print(f"成功率: {success_count}/{total_count} ({success_rate:.1f}%)") + + # 生成总汇总 + total_tests = len(all_results) + total_success = sum(1 for r in all_results if r.success) + total_success_rate = (total_success / total_tests * 100) if total_tests > 0 else 0 + total_execution_time = time.time() - self.start_time.timestamp() + + summary = TestSummary( + task_name=self.config['task_name'], + total_tests=total_tests, + success_count=total_success, + success_rate=total_success_rate, + total_execution_time=total_execution_time, + results_by_type=results_by_type, + all_results=all_results + ) + + return summary + + def save_results(self, summary: TestSummary): + """保存测试结果""" + # 保存详细结果(JSON格式) + detailed_data = { + 'summary': asdict(summary), + 'timestamp': self.start_time.isoformat(), + 'config_file': self.config_file, + 'config': self.config + } + + with open(self.detailed_file, 'w', encoding='utf-8') as f: + json.dump(detailed_data, f, ensure_ascii=False, indent=2, default=str) + + # 保存汇总结果(文本格式) + with open(self.summary_file, 'w', encoding='utf-8') as f: + f.write(f"任务测试汇总报告\n") + f.write(f"{'='*60}\n") + f.write(f"任务名称: {summary.task_name}\n") + f.write(f"测试时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"配置文件: {self.config_file}\n") + f.write(f"总测试数: {summary.total_tests}\n") + f.write(f"总成功数: {summary.success_count}\n") + f.write(f"总成功率: {summary.success_rate:.1f}%\n") + f.write(f"总执行时间: {summary.total_execution_time:.2f}秒\n") + f.write(f"\n") + + f.write(f"分类型结果:\n") + f.write(f"{'-'*40}\n") + for task_type, type_data in summary.results_by_type.items(): + f.write(f"类型 {task_type} ({type_data['task_name']}):\n") + f.write(f" 测试数: {type_data['total_tests']}\n") + f.write(f" 成功数: {type_data['success_count']}\n") + f.write(f" 成功率: {type_data['success_rate']:.1f}%\n") + f.write(f"\n") + + f.write(f"详细结果:\n") + f.write(f"{'-'*40}\n") + for result in summary.all_results: + status = "✓" if result.success else "✗" + f.write(f"trace {result.trace_id} [{result.task_type}]: {status}\n") + f.write(f" 得分: {result.score}\n") + f.write(f" 匹配节点: {result.matched_nodes}\n") + f.write(f" 原因: {result.reason}\n") + f.write(f" 执行时间: {result.execution_time:.2f}秒\n") + if result.error_message: + f.write(f" 错误: {result.error_message}\n") + f.write(f"\n") + + print(f"\n=== 结果已保存 ===") + print(f"汇总文件: {self.summary_file}") + print(f"详细文件: {self.detailed_file}") + + def print_final_summary(self, summary: TestSummary): + """打印最终汇总""" + print(f"\n{'='*80}") + print(f"任务测试完成: {summary.task_name}") + print(f"{'='*80}") + print(f"总测试数: {summary.total_tests}") + print(f"总成功数: {summary.success_count}") + print(f"总成功率: {summary.success_rate:.1f}%") + print(f"总执行时间: {summary.total_execution_time:.2f}秒") + print(f"") + + print("分类型结果:") + for task_type, type_data in summary.results_by_type.items(): + print(f" 类型 {task_type} ({type_data['task_name']}): {type_data['success_count']}/{type_data['total_tests']} ({type_data['success_rate']:.1f}%)") + + print(f"\n详细结果文件: {self.detailed_file}") + print(f"汇总结果文件: {self.summary_file}") + +def show_usage(config_dir: str = "task_configs"): + """显示使用说明""" + print(f""" +通用任务测试执行器使用说明: + +1. 测试所有类型: + python universal_test_runner.py {config_dir}/taobao.json + +2. 测试指定类型: + python universal_test_runner.py {config_dir}/taobao.json type3 + python universal_test_runner.py {config_dir}/xiaohongshu.json type2 + +3. 测试指定类型的指定trace: + python universal_test_runner.py {config_dir}/taobao.json type3:150 + +4. 测试指定trace编号: + python universal_test_runner.py {config_dir}/taobao.json 150,151,152 + +可用的配置文件:""") + + # 查找可用的配置文件 + if os.path.exists(config_dir): + for file in os.listdir(config_dir): + if file.endswith('.json'): + print(f" - {config_dir}/{file}") + + print(f""" +配置文件说明: +- 每个配置文件定义一个任务的测试参数 +- 包含规则文件目录、数据目录、任务类型映射等 +- 可通过修改配置文件来调整测试范围和参数 +""") + +def main(): + """主函数""" + if len(sys.argv) < 2: + show_usage() + return + + config_file = sys.argv[1] + + if not os.path.exists(config_file): + print(f"错误: 配置文件不存在: {config_file}") + show_usage() + return + + # 创建测试运行器 + runner = UniversalTestRunner(config_file) + + if len(sys.argv) == 2: + # 测试所有类型 + summary = runner.test_all_types() + else: + # 解析参数 + arg = sys.argv[2] + + if ':' in arg: + # 格式: type3:150 或 3:150 + task_type, trace_id = arg.split(':', 1) + # 只有当任务类型不在配置中时,才尝试去掉 type 前缀 + if task_type not in runner.config['task_types'] and task_type.startswith('type'): + numeric_type = task_type.replace('type', '') + if numeric_type in runner.config['task_types']: + task_type = numeric_type + + try: + trace_id = int(trace_id) + except ValueError: + pass # 保持字符串格式 + results = runner.test_by_type(task_type, trace_id) + + # 创建简单汇总 + total_success = sum(1 for r in results if r.success) + summary = TestSummary( + task_name=runner.config['task_name'], + total_tests=len(results), + success_count=total_success, + success_rate=(total_success / len(results) * 100) if results else 0, + total_execution_time=time.time() - runner.start_time.timestamp(), + results_by_type={task_type: { + 'task_name': runner.config['task_types'].get(task_type, {}).get('name', f'任务{task_type}'), + 'total_tests': len(results), + 'success_count': total_success, + 'success_rate': (total_success / len(results) * 100) if results else 0, + 'results': results + }}, + all_results=results + ) + + elif arg.startswith('type') or arg.isdigit(): + # 测试指定类型,支持 type3 或 3 的格式 + task_type = arg + # 只有当任务类型不在配置中时,才尝试去掉 type 前缀 + if task_type not in runner.config['task_types'] and task_type.startswith('type'): + numeric_type = task_type.replace('type', '') + if numeric_type in runner.config['task_types']: + task_type = numeric_type + + results = runner.test_by_type(task_type) + + # 创建简单汇总 + total_success = sum(1 for r in results if r.success) + summary = TestSummary( + task_name=runner.config['task_name'], + total_tests=len(results), + success_count=total_success, + success_rate=(total_success / len(results) * 100) if results else 0, + total_execution_time=time.time() - runner.start_time.timestamp(), + results_by_type={task_type: { + 'task_name': runner.config['task_types'].get(task_type, {}).get('name', f'任务{task_type}'), + 'total_tests': len(results), + 'success_count': total_success, + 'success_rate': (total_success / len(results) * 100) if results else 0, + 'results': results + }}, + all_results=results + ) + + else: + # 按trace编号测试 + try: + trace_nums = [int(x.strip()) for x in arg.split(",")] + all_results = [] + + # 找到每个trace对应的类型 + for trace_num in trace_nums: + found = False + for task_type, task_config in runner.config['task_types'].items(): + data_traces = task_config['data_traces'] + if isinstance(data_traces, list) and trace_num in data_traces: + result = runner.test_single_trace(task_type, trace_num) + all_results.append(result) + found = True + break + + if not found: + print(f"警告: trace编号{trace_num}未在配置中找到") + + # 创建汇总 + total_success = sum(1 for r in all_results if r.success) + summary = TestSummary( + task_name=runner.config['task_name'], + total_tests=len(all_results), + success_count=total_success, + success_rate=(total_success / len(all_results) * 100) if all_results else 0, + total_execution_time=time.time() - runner.start_time.timestamp(), + results_by_type={}, + all_results=all_results + ) + + except ValueError: + print(f"错误: 无效的trace编号格式: {arg}") + return + + # 保存和打印结果 + runner.save_results(summary) + runner.print_final_summary(summary) + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md new file mode 100644 index 0000000..1fa8431 --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +
+ + MobiAgent + +
+ +

+MobiAgent: Towards Universally Customizable Mobile Agents +

+ +

+| Paper | Huggingface | App | +

+ +--- + + **English** | [中文](README_zh.md) + +MobiAgent is a powerful mobile agent system including: + +* **An agent model family**: MobiMind +* **An agent acceleration framework**: AgentRR +* **An agent benchmark**: MobiFlow + +System Architecture: + +
+

+ +

+
+ +## News + +- `[2025.8.30]`🔥 We've open-sourced the MobiAgent! + +## Evaluation Results + +
+

+ + + +

+
+ +
+

+ +

+
+ +## Demo + +
+
+ +## Project Structure + +- `agent_rr/` - Agent Record & Replay framework +- `collect/` - Data collection, annotation, processing and export tools +- `runner/` - Agent executor that connects to phone via ADB, executes tasks, and records execution traces +- `MobiFlow/` - Agent evaluation benchmark based on milestone DAG +- `MobiAgent/` - MobiAgent Android app +- `deployment/` - Service deployment for MobiAgent mobile application + +## Quick Start + +### Use with MobiAgent APP + +If you would like to try MobiAgent directly with our APP, please download it in [Download Link](https://github.com/IPADS-SAI/MobiAgent/releases/tag/v1.0) and enjoy yourself! + +### Use with Python Scripts + +If you would like to try MobiAgent with python scripts which leverage Android Debug Bridge (ADB) to control your phone, please follow these steps: + +#### Environment Setup + +```bash +conda create -n MobiMind python=3.10 +conda activate MobiMind + +pip install -r requirements.txt + +# Download OmniParser model weights +for f in icon_detect/{train_args.yaml,model.pt,model.yaml} ; do huggingface-cli download microsoft/OmniParser-v2.0 "$f" --local-dir weights; done + +# If you need GPU acceleration for OCR, install paddlepaddle-gpu according to your CUDA version +# For details, refer to https://www.paddlepaddle.org.cn/install/quick, for example CUDA 11.8: +python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ + +``` + +#### Mobile Device Setup + +- Download and install [ADBKeyboard](https://github.com/senzhk/ADBKeyBoard/blob/master/ADBKeyboard.apk) on your Android device +- Enable Developer Options on your Android device and allow USB debugging +- Connect your phone to the computer using a USB cable + +#### Model Deployment + +After downloading the `decider`, `grounder`, and `planner` models, use vLLM to deploy model inference services: + +```bash +vllm serve IPADS-SAI/MobiMind-Decider-7B --port +vllm serve IPADS-SAI/MobiMind-Grounder-3B --port +vllm serve Qwen/Qwen3-4B-Instruct --port +``` + + +#### Launch Agent Runner + +Write the list of tasks that you would like to test in `runner/mobiagent/task.json`, then launch agent runner: + +```bash +python -m runner.mobiagent.mobiagent --service_ip --decider_port --grounder_port --planner_port +``` + +Parameters: + +- `--service_ip`: Service IP (default: `localhost`) +- `--decider_port`: Decider service port (default: `8000`) +- `--grounder_port`: Grounder service port (default: `8001`) +- `--planner_port`: Planner service port (default: `8002`) + +The runner automatically controls the device and invoke agent models to complete the pre-defined tasks. + +## Detailed Sub-module Usage + +For detailed usage instructions, see the `README.md` files in each sub-module directory. diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000..ace4c2e --- /dev/null +++ b/README_zh.md @@ -0,0 +1,124 @@ +
+ + MobiAgent + +
+ +

+MobiAgent: Towards Universally Customizable Mobile Agents +

+ +

+| 论文 | Huggingface | App | +

+ +--- + +[English](README.md) | **中文** + +MobiAgent是一个强大的移动端智能体系统,包含: + +* **智能体模型家族:** MobiMind +* **智能体加速框架:** AgentRR +* **智能体评测基准:** MobiFlow + +系统架构: + +
+

+ +

+
+ +## 新闻 + +- `[2025.8.30]`🔥 我们开源了MobiAgent! + +## 评测结果 + +
+

+ + + +

+
+ +
+

+ +

+
+ +## 项目结构 + +- `agent_rr/` - Agent Record & Replay框架 +- `collect/` - 数据收集、标注、处理与导出工具 +- `runner/` - 智能体执行器,通过ADB连接手机、执行任务、并记录执行轨迹 +- `MobiFlow/` - 基于里程碑DAG的智能体评测基准 +- `MobiAgent/` - MobiAgent安卓App +- `deployment/` - MobiAgent移动端应用的服务部署方式 + +## 快速开始 + +### 通过 MobiAgent APP 使用 + +如果您想直接通过我们的 APP 体验 MobiAgent,请通过 [下载链接](https://github.com/IPADS-SAI/MobiAgent/releases/tag/v1.0) 进行下载,祝您使用愉快! + +### 使用 Python 脚本 + +如果您想通过 Python 脚本来使用 MobiAgent,并借助Android Debug Bridge (ADB) 来控制您的手机,请遵循以下步骤进行: + +#### 环境配置 + +```bash +conda create -n MobiMind python=3.10 +conda activate MobiMind + +pip install -r requirements.txt + +# 下载OmniParser模型权重 +for f in icon_detect/{train_args.yaml,model.pt,model.yaml} ; do huggingface-cli download microsoft/OmniParser-v2.0 "$f" --local-dir weights; done + +# 如果需要使用gpu加速ocr,需要根据cuda版本,手动安装paddlepaddle-gpu +# 详情参考 https://www.paddlepaddle.org.cn/install/quick,例如cuda 11.8版本: +python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ + +``` + +#### 手机配置 + +- 在Android设备上下载并安装 [ADBKeyboard](https://github.com/senzhk/ADBKeyBoard/blob/master/ADBKeyboard.apk) +- 在Android设备上,开启开发者选项,并允许USB调试 +- 使用USB数据线连接手机和电脑 + +#### 模型部署 + +下载好 `decider`、`grounder` 和 `planner` 三个模型后,使用 vLLM 部署模型推理服务: + +```bash +vllm serve IPADS-SAI/MobiMind-Decider-7B --port +vllm serve IPADS-SAI/MobiMind-Grounder-3B --port +vllm serve Qwen/Qwen3-4B-Instruct --port +``` + +#### 启动Agent执行器 + +在 `runner/mobiagent/task.json` 中写入想要测试的任务列表,然后启动Agent执行器 + +```bash +python -m runner.mobiagent.mobiagent --service_ip <服务IP> --decider_port <决策服务端口> --grounder_port <定位服务端口> --planner_port <规划服务端口> +``` + +**参数说明** + +- `--service_ip`:服务IP(默认:`localhost`) +- `--decider_port`:决策服务端口(默认:`8000`) +- `--grounder_port`:定位服务端口(默认:`8001`) +- `--planner_port`:规划服务端口(默认:`8002`) + +执行器启动后,将会自动控制手机并调用Agent模型,完成列表中指定的任务。 + +## 子模块详细使用方式 + +详细使用方式见各子模块目录下的 `README.md` 文件。 diff --git a/agent_rr/README.md b/agent_rr/README.md new file mode 100644 index 0000000..9d73f84 --- /dev/null +++ b/agent_rr/README.md @@ -0,0 +1,27 @@ +# AgentRR + +# Prepare Environment + +```bash +pip install -r requirements-agentrr.txt +``` + +# Train Latent Memory Model + +## Data Preparation + +Step 1: Prepare JSON task templates and store in `templates.json`. For example, see `train/train_data_example/wechat/templates.json`. + +Step 2: Create train/test dataset based on task templates using the following command: + +```bash +python -m train.prepare_data --task both --train_path --test_path +``` + +Step 3: Train Embedding and Reranker model with [ms-swift](https://github.com/modelscope/ms-swift), see official training example [SWIFT](https://github.com/QwenLM/Qwen3-Embedding/blob/main/docs/training/SWIFT.md). + +# Run Experiment + +```bash +python run_experiment.py --data_path --embedder_path --reranker_path +``` \ No newline at end of file diff --git a/agent_rr/action_cache/__init__.py b/agent_rr/action_cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_rr/action_cache/action.py b/agent_rr/action_cache/action.py new file mode 100644 index 0000000..7950598 --- /dev/null +++ b/agent_rr/action_cache/action.py @@ -0,0 +1,75 @@ +from typing import Dict, List +from PIL import Image +from skimage.metrics import structural_similarity as ssim +import numpy as np + +class UIElement: + def __init__(self, bbox: List[int], content: str = None, sub_img: Image.Image = None): + # bbox: [x1, y1, x2, y2], round(relative * 1000) format + self.bbox = bbox + # content: icon caption or ocr result + self.content = content + # sub_img: cropped image of the UI element + self.sub_img = sub_img + + def __eq__(self, other): + if self.sub_img is not None and other.sub_img is not None: + img1 = np.array(self.sub_img) + img2 = np.array(other.sub_img) + + if img1.shape != img2.shape: + img2_pil_resized = other.sub_img.resize(self.sub_img.size, Image.Resampling.LANCZOS) + img2 = np.array(img2_pil_resized) + + similarity = ssim(img1, img2, channel_axis=2, data_range=255) + return similarity > 0.9 + if self.content is not None and other.content is not None: + return self.content == other.content + return True + +class Action: + def __init__(self, name: str, param: Dict[str, str], extra: Dict[str, str] = None): + self.name = name + self.param = param + self.extra = extra + self.target_elem = None + + def extract_target_elem(self, screen, parser): + pass + + def __eq__(self, other): + return self.name == other.name and self.param == other.param + + def __str__(self): + return f"{self.name}({', '.join([f'{k}={v}' for k, v in self.param.items()])})" + +class GeneralAgentAction(Action): + def __init__(self, name: str, param: Dict[str, str], extra: Dict[str, str] = None): + super().__init__(name, param, extra) + + def extract_target_elem(self, screen, parser): + if self.target_elem is not None: + return + if screen is None: + return + if self.name not in ["click", "longclick"]: + return + bbox = self.param['bbox'] + target_element = self.param['target_element'] + sub_img = screen.crop((bbox[0], bbox[1], bbox[2], bbox[3])) + self.target_elem = UIElement(bbox, target_element, sub_img) + + def __eq__(self, other): + if not isinstance(other, GeneralAgentAction): + return False + if self.name != other.name: + return False + if self.name in ["click", "longclick"]: + box1 = self.param["bbox"] + box2 = other.param["bbox"] + center1 = ((box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2) + center2 = ((box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2) + center1_in_box2 = (box2[0] <= center1[0] <= box2[2] and box2[1] <= center1[1] <= box2[3]) + center2_in_box1 = (box1[0] <= center2[0] <= box1[2] and box1[1] <= center2[1] <= box1[3]) + return center1_in_box2 and center2_in_box1 + return self.param == other.param \ No newline at end of file diff --git a/agent_rr/action_cache/embedder.py b/agent_rr/action_cache/embedder.py new file mode 100644 index 0000000..26ffbaf --- /dev/null +++ b/agent_rr/action_cache/embedder.py @@ -0,0 +1,19 @@ +import torch +from sentence_transformers import SentenceTransformer + +class Qwen3Embedder: + def __init__(self, config): + path = config.get("path", "Qwen/Qwen3-Embedding-0.6B") + self.model = SentenceTransformer(path) + self.instruct_fmt = config.get("instruct_fmt", + # "Instruct: Given a phone-use task, retrieve similar tasks that shares at least **{n}** steps with the given task\nQuery:{query}") + "Instruct: Represent this phone-use task for level **{n}**\nQuery:{query}") + + @torch.no_grad() + def embed(self, tasks, steps=None): + if steps is None: + return self.model.encode(tasks, convert_to_tensor=True, normalize_embeddings=True) + if len(tasks) != len(steps): + raise ValueError("Tasks and steps must have the same length") + input_texts = [self.instruct_fmt.format(n=step, query=task) for task, step in zip(tasks, steps)] + return self.model.encode(input_texts, convert_to_tensor=True, normalize_embeddings=True) diff --git a/agent_rr/action_cache/reranker.py b/agent_rr/action_cache/reranker.py new file mode 100644 index 0000000..d4cf7f5 --- /dev/null +++ b/agent_rr/action_cache/reranker.py @@ -0,0 +1,51 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.utils import is_torch_npu_available, is_torch_cuda_available + +class Qwen3Reranker: + def __init__(self, config): + path = config.get("path", "Qwen/Qwen3-Reranker-0.6B") + self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') + device = "cpu" + if is_torch_cuda_available(): + device = "cuda:0" + elif is_torch_npu_available(): + device="npu:0" + self.model = AutoModelForCausalLM.from_pretrained(path, device_map=device).eval() + prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" + suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + self.instruct_fmt = config.get("instruct_fmt", + ": Given a phone-use task, retrieve similar tasks that shares at least **{n}** steps with the given task\n: {query} \n: {document}") + + def rerank(self, query_tasks, document_task, step): + input_texts = [self.instruct_fmt.format(n=step, query=query, document=document_task) for query in query_tasks] + inputs = self.process_inputs(input_texts) + logits = self.compute_logits(inputs) + return logits + + def process_inputs(self, input_texts): + max_length = 8192 + inputs = self.tokenizer( + input_texts, padding=False, truncation='longest_first', + return_attention_mask=False, max_length=max_length - len(self.prefix_tokens) - len(self.suffix_tokens) + ) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + @torch.no_grad() + def compute_logits(self, inputs): + batch_scores = self.model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores diff --git a/agent_rr/action_cache/tree.py b/agent_rr/action_cache/tree.py new file mode 100644 index 0000000..2da2cf3 --- /dev/null +++ b/agent_rr/action_cache/tree.py @@ -0,0 +1,556 @@ +from enum import Enum +import torch +import time +try: + from omniparser.omniparser import Omniparser +except ImportError as e: + print("Import omniparser failed, some features may not work") + +from .reranker import Qwen3Reranker +from .embedder import Qwen3Embedder +from .action import Action, UIElement + +EMBEDDER_THRESHOLD = 0.8 +RERANKER_MIN_CONF = 0.75 + +class MatchMode(Enum): + EXACT = 1 + FUZZY = 2 + +class Task: + def __init__(self, description): + self.description = description + + def __eq__(self, other): + return self.description == other.description + + def __str__(self): + return self.description + + def __repr__(self): + return f"Task(description={self.description})" + +class ActionTreeEdge: + def __init__(self, action=None, tasks=[], to=None): + self.action = action + self.to = to + self.tasks = tasks + + def add_task(self, task): + self.tasks.append(task) + + def remove_task(self, task_idx): + self.tasks.pop(task_idx) + + def __str__(self): + return f"{self.action} {self.tasks}" + +class ActionTreeEdgeFuzzy(ActionTreeEdge): + def __init__(self, action=None, tasks=[], to=None, task_embeddings=[], keywords=[]): + l = len(tasks) + if l == 0: + raise ValueError("Tasks list is empty") + if l != task_embeddings.shape[0]: + raise ValueError("Tasks list length must match task_embeddings length") + if l != len(keywords): + raise ValueError("Tasks list length must match keywords length") + super().__init__(action, tasks, to) + self.task_embeddings = task_embeddings + self.keywords = keywords + + def add_task(self, task, task_embedding, keyword=""): + self.tasks.append(task) + self.task_embeddings = torch.cat([self.task_embeddings, task_embedding], dim=0) + self.keywords.append(keyword) + + def remove_task(self, task_idx): + self.tasks.pop(task_idx) + self.task_embeddings = torch.cat([self.task_embeddings[:task_idx], self.task_embeddings[task_idx+1:]], dim=0) + self.keywords.pop(task_idx) + + def reset_keyword(self, keyword): + for i, kw in enumerate(self.keywords): + if kw == keyword: + self.keywords[i] = "" + + def __str__(self): + return f"{super().__str__()} {self.keywords}" + +class SuperNode: + def __init__(self, nodes): + self.nodes = nodes + + def add_node(self, node): + self.nodes.append(node) + +class ShortCutCheckResult(Enum): + MATCH_INTERMEDIATE = 1 + MATCH_SECOND_LAST = 2 + MATCH_LAST = 3 + NOT_MATCH = 4 + +class ShortCutTemplate: + def __init__(self, action_names, last_action): + self.action_names = action_names + self.last_action = last_action + + def check(self, action, step): + if step >= len(self.action_names): + return ShortCutCheckResult.NOT_MATCH + if step == len(self.action_names): + if action == self.last_action: + return ShortCutCheckResult.MATCH_LAST + else: + if action.name == self.action_names[step]: + if step == len(self.action_names) - 1: + return ShortCutCheckResult.MATCH_SECOND_LAST + else: + return ShortCutCheckResult.MATCH_INTERMEDIATE + return ShortCutCheckResult.NOT_MATCH + +class ShortCut: + def __init__(self, split_node, template, supernode): + self.split_node = split_node + self.template = template + self.supernode = supernode + + def check(self, action, step): + return self.template.check(action, step) + +class ActionTreeNode: + def __init__(self, parent=None): + self.edges = [] + self.parent = parent + self.parent_edge_idx = None + self.screenshot = None + # if a node is a possible split node, pin it + self.split_pin = False + if parent is not None: + self.depth = parent.depth + 1 + else: + self.depth = 0 + + def add_child(self, action, task): + for e in self.edges: + # merge happens here + if e.action == action: + e.add_task(task) + return e.to + new_node = ActionTreeNode(self) + new_edge = ActionTreeEdge(action, [task], new_node) + new_node.parent_edge_idx = len(self.edges) + self.edges.append(new_edge) + return new_node + + def get_cached_action(self, task): + ret = [] + for e in self.edges: + for t in e.tasks: + if t == task: + ret.append((e.action, e.to)) + return ret + + def _remove_task(self, task): + next_node = None + for edge_idx, e in enumerate(self.edges): + for task_idx, t in enumerate(e.tasks): + if t == task: + e.remove_task(task_idx) + next_node = e.to + break + if len(e.tasks) == 0: + self.edges.pop(edge_idx) + return None + elif next_node is not None: + return next_node + return None + + def remove_task_trace(self, task): + node = self + while node is not None: + node = node._remove_task(task) + + def get_incoming_edge(self): + return self.parent.edges[self.parent_edge_idx] + + def get_incoming_action(self): + return self.get_incoming_edge().action + + def remove_child(self, child): + if child.parent is not self: + raise ValueError("Not a child of this node") + self.edges.pop(child.parent_edge_idx) + + def try_find_shortcuts(self): + # assume self is the split node, find the possible merged supernodes + + # TODO: use values from config + min_supernode_capacity = 2 + min_shortcut_len, max_shortcut_len = 2, 3 + + def _can_merge_to_supernode(nodes): + # check incoming edges + if len(nodes) < min_supernode_capacity: + return False + action = None + for n in nodes: + if action is None: + action = n.get_incoming_action() + elif action != n.get_incoming_action(): + return False + return True + + def _have_same_parent(nodes): + parent = nodes[0].parent + for n in nodes[1:]: + if n.parent is not parent: + return False + return True + + def _dfs(nodes, trace, supernodes, templates): + cur_len = len(trace) + if cur_len > 1 and _have_same_parent(nodes): + return + if cur_len > max_shortcut_len: + return + if cur_len >= min_shortcut_len and _can_merge_to_supernode(nodes): + # print(nodes) + supernodes.append(SuperNode(nodes)) + action_names = trace[:-1] + last_action = nodes[0].get_incoming_action() + templates.append(ShortCutTemplate(action_names, last_action)) + # greedy match for minimizing shortcut length + return + next_layer_nodes = [] + next_actions = [] + for n in nodes: + for e in n.edges: + next_layer_nodes.append(e.to) + next_actions.append(e.action) + # group next_layer_nodes by action + action_group = {} + for n, a in zip(next_layer_nodes, next_actions): + if a.name not in action_group: + action_group[a.name] = [] + action_group[a.name].append(n) + for action_name, group in action_group.items(): + next_trace = trace + [action_name] + _dfs(group, next_trace, supernodes, templates) + + supernodes = [] + templates = [] + trace = [] + _dfs([self], trace, supernodes, templates) + shortcuts = [ShortCut(self, t, s) for t, s in zip(templates, supernodes)] + return shortcuts + +class ActionTreeNodeFuzzy(ActionTreeNode): + def __init__(self, parent=None): + super().__init__(parent) + + def add_child(self, action, task, task_embedding): + keyword = self._extract_keyword(task, action) + for e in self.edges: + # merge happens here + if e.action == action: + e.add_task(task, task_embedding, keyword) + return e.to + new_node = ActionTreeNodeFuzzy(self) + new_edge = ActionTreeEdgeFuzzy(action, [task], new_node, task_embedding, [keyword]) + new_node.parent_edge_idx = len(self.edges) + self.edges.append(new_edge) + return new_node + + def _extract_keyword(self, task, action): + return "" + + def get_cached_action(self, task, step_embedding): + ret = [] + for e in self.edges: + hit = util.semantic_search(step_embedding, e.task_embeddings, top_k=1, score_function=util.dot_score)[0] + # print(hit) + score = hit[0]['score'] + if score < EMBEDDER_THRESHOLD: + continue + corpus_id = hit[0]['corpus_id'] + keyword = e.keywords[corpus_id] + if keyword not in task.description: + continue + hit_task = e.tasks[corpus_id] + print(hit_task, score) + ret.append((e.action, e.to, keyword, hit_task)) + return ret + + def reset_keyword(self, keyword): + for e in self.edges: + e.reset_keyword(keyword) + + +class ActionTree: + def __init__(self, + env, + agent, + action_class=Action, + done=lambda a: a.name == 'END', + mode: MatchMode = MatchMode.EXACT, + embedder_config=None, + reranker_config=None, + enable_ui_detection=False, + omniparser_config=None): + self.env = env + self.agent = agent + self.done = done + self.mode = mode + self.action_class = action_class + self.enable_ui_detection = enable_ui_detection + self.generate_only = False + self.shortcuts = [] + self.num_tasks_last_check = 0 + if mode == MatchMode.EXACT: + self.embedder = None + self.root = ActionTreeNode() + elif mode == MatchMode.FUZZY: + if embedder_config is None: + raise ValueError("embedder_config is required for fuzzy matching") + self.embedder = Qwen3Embedder(embedder_config) + self.root = ActionTreeNodeFuzzy() + + if reranker_config is not None: + self.reranker = Qwen3Reranker(reranker_config) + else: + self.reranker = None + else: + raise ValueError(f"Unknown mode: {mode}") + + if omniparser_config is not None: + self.enable_ui_detection = True + self.omniparser = Omniparser(omniparser_config) + else: + self.omniparser = None + + self.reset_counter() + + def reset_counter(self): + self.env_counter = 0.0 + self.inference_counter = 0.0 + self.detection_counter = 0.0 + self.embedding_counter = 0.0 + + def print_counter(self): + print(f"env_counter: {self.env_counter}, inference_counter: {self.inference_counter}, detection_counter: {self.detection_counter}, embedding_counter: {self.embedding_counter}") + + def clear(self): + self.shortcuts = [] + self.num_tasks_last_check = 0 + self.root = ActionTreeNode() if self.mode == MatchMode.EXACT else ActionTreeNodeFuzzy() + + def target_elem_changed(self, cur_screen, action): + if action.target_elem is None: + return False + if cur_screen is None: + return False + target_elem = action.target_elem + bbox = target_elem.bbox + x1, x2 = map(lambda x: x / 1000 * cur_screen.width, (bbox[0], bbox[2])) + y1, y2 = map(lambda x: x / 1000 * cur_screen.height, (bbox[1], bbox[3])) + cropped_screen = cur_screen.crop((x1, y1, x2, y2)) + if self.omniparser is None: + new_elem = UIElement(bbox, target_elem.content, cropped_screen) + return new_elem != target_elem + else: + parsed_elems = self.omniparser.parse(cropped_screen) + + for elem in parsed_elems: + if elem["content"] == target_elem.content: + return False + return True + + def get_num_tasks(self): + return sum([len(e.tasks) for e in self.root.edges]) + + def generate_shortcuts(self): + # periodically check if there are new shortcuts + # use bfs + queue = [self.root] + self.shortcuts = [] + while queue: + node = queue.pop(0) + for e in node.edges: + queue.append(e.to) + if node is self.root: + continue + shortcuts = node.try_find_shortcuts() + # last_action cannot be done action + shortcuts = [sc for sc in shortcuts if not self.done(sc.template.last_action)] + self.shortcuts.extend(shortcuts) + node.split_pin = shortcuts != [] + + def execute(self, task_description): + node = self.root + history = [] + task = Task(task_description) + if self.mode == MatchMode.FUZZY: + start_time = time.time() + num_precomute = 16 + step_embeddings = self.embedder.embed([task_description] * num_precomute, steps=range(1, num_precomute + 1)) + end_time = time.time() + self.embedding_counter += end_time - start_time + recompute_times = 0 + + tracking_shortcut = False + shortcut_action = None + + while True: + # candidate (action, next_node) pairs + action_nodes = [] + depth = node.depth + + start_time = time.time() + + if self.mode == MatchMode.EXACT: + if shortcut_action is not None: + shortcut_next_node = node.add_child(shortcut_action, task) + action_nodes = [(shortcut_action, shortcut_next_node)] + keywords = [shortcut_next_node.get_incoming_edge().keywords[-1]] + shortcut_action = None + else: + action_nodes = node.get_cached_action(task) + else: + if depth >= (recompute_times + 1) * num_precomute: + recompute_times += 1 + step_embeddings = self.embedder.embed( + [task_description] * num_precomute, + steps=range(recompute_times * num_precomute + 1, (recompute_times + 1) * num_precomute + 1) + ) + + step_embedding = step_embeddings[depth - recompute_times * num_precomute].unsqueeze(0) + + if shortcut_action is not None: + shortcut_next_node = node.add_child(shortcut_action, task, step_embedding) + action_nodes = [(shortcut_action, shortcut_next_node)] + keywords = [shortcut_next_node.get_incoming_edge().keywords[-1]] + shortcut_action = None + else: + action_node_keyword_tasks = node.get_cached_action(task, step_embedding) + hit_tasks = [t.description for a, n, kw, t in action_node_keyword_tasks] + if len(action_node_keyword_tasks) == 0: + print(f"No similar task found.") + else: + print(f"Found similar task: {hit_tasks}") + if self.reranker is not None and len(hit_tasks) > 0: + scores = self.reranker.rerank(query_tasks=hit_tasks, document_task=task_description, step=depth + 1) + indices = [i for i, score in enumerate(scores) if score > RERANKER_MIN_CONF] + action_node_keyword_tasks = [action_node_keyword_tasks[i] for i in indices] + if len(indices) != len(hit_tasks): + print(f"Reranker filtered tasks: {[hit_tasks[i] for i in range(len(hit_tasks)) if i not in indices]}") + action_nodes = [(a, n) for a, n, kw, t in action_node_keyword_tasks] + keywords = [kw for a, n, kw, t in action_node_keyword_tasks] + end_time = time.time() + self.embedding_counter += end_time - start_time + + if node.split_pin and not self.generate_only and not tracking_shortcut: + # start tracking possible shortcut + print("Start tracking shortcut") + tracking_shortcut = True + possible_shortcuts = [sc for sc in self.shortcuts if sc.split_node is node] + cur_step = 0 + + # check if the action needs to be generated by model, or we can use cached action + needs_generation = len(action_nodes) == 0 or self.generate_only + + start_time = time.time() + agent_input = self.env.get_agent_input(history, task_description) + end_time = time.time() + self.env_counter += end_time - start_time + + screenshot = agent_input.get("image", None) + # if UI changed, we need to generate the action + if not needs_generation: + if self.enable_ui_detection: + start_time = time.time() + for i, (a, n) in enumerate(action_nodes): + if not self.target_elem_changed(screenshot, a): + action = a + next_node = n + if self.mode == MatchMode.FUZZY: + keyword = keywords[i] + break + # the else block is executed if the for loop is not broken + else: + print("warning: target element changed") + needs_generation = True + + end_time = time.time() + self.detection_counter += end_time - start_time + else: + action, next_node = action_nodes[0] + if self.mode == MatchMode.FUZZY: + keyword = keywords[0] + + if needs_generation: + print("Cache miss") + start_time = time.time() + agent_output = self.agent.generate(agent_input) + end_time = time.time() + self.inference_counter += end_time - start_time + action = self.action_class(**agent_output) + # extract target element and store it in action + if self.enable_ui_detection: + start_time = time.time() + action.extract_target_elem(screenshot, self.omniparser) + end_time = time.time() + self.detection_counter += end_time - start_time + if self.mode == MatchMode.EXACT: + next_node = node.add_child(action, task) + else: + next_node = node.add_child(action, task, step_embedding) + else: + print("Cache hit") + edge = next_node.get_incoming_edge() + # only add similar task to the edge + if self.mode == MatchMode.FUZZY and task not in edge.tasks: + edge.add_task(task, step_embedding, keyword) + + if tracking_shortcut: + new_possible_shortcuts = [] + for i, sc in enumerate(possible_shortcuts): + check_result = sc.check(action, cur_step) + if check_result == ShortCutCheckResult.MATCH_SECOND_LAST: + # can use cached action in next iteration + tracking_shortcut = False + # add a child for next_node in advance + # in next iteration, cache hit is guaranteed + if needs_generation: + last_action = sc.template.last_action + shortcut_action = last_action + break + if check_result == ShortCutCheckResult.MATCH_INTERMEDIATE: + new_possible_shortcuts.append(sc) + else: + cur_step += 1 + possible_shortcuts = new_possible_shortcuts + if len(possible_shortcuts) == 0: + tracking_shortcut = False + + # execute the action + if self.done(action): + break + history.append(action) + + start_time = time.time() + self.env.execute(action) + end_time = time.time() + self.env_counter += end_time - start_time + + node = next_node + + # periodically generate shortcuts + if not self.generate_only: + period = 1 + num_tasks = self.get_num_tasks() + if num_tasks - self.num_tasks_last_check >= period: + self.num_tasks_last_check = num_tasks + self.generate_shortcuts() + print(f"number of shortcuts: {len(self.shortcuts)}") + for sc in self.shortcuts: + print(f"split_node: {sc.split_node}, template: {sc.template.action_names}, last_action: {sc.template.last_action}, supernode size: {len(sc.supernode.nodes)}") diff --git a/agent_rr/agent/__init__.py b/agent_rr/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_rr/agent/agent.py b/agent_rr/agent/agent.py new file mode 100644 index 0000000..518e063 --- /dev/null +++ b/agent_rr/agent/agent.py @@ -0,0 +1,89 @@ +from enum import Enum +import json +from openai import OpenAI +import io, base64 + +class Agent: + def __init__(self): + pass + + def generate(self, agent_input): + pass + + +class ReplayLevel(Enum): + ALL = 1 + REASONING = 2 + +class RemoteMultiLevelGeneralAgent(Agent): + def __init__(self, decider_url, grounder_url): + super().__init__() + self.decider_client = OpenAI(api_key="0", base_url=decider_url) + self.grounder_client = OpenAI(api_key="0", base_url=grounder_url) + + def generate(self, agent_input): + if "replay_level" in agent_input: + replay_level = agent_input["replay_level"] + else: + replay_level = ReplayLevel.ALL + + image = agent_input["image"] + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') + query = agent_input["query"] + + action_dict = {} + if replay_level == ReplayLevel.ALL: + response = self.decider_client.chat.completions.create( + model="", + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, + {"type": "text", "text": query}, + ], + } + ], + temperature=0 + ) + decider_response = response.choices[0].message.content + decider_json = json.loads(decider_response) + reasoning = decider_json["reasoning"] + action = decider_json["action"] + param = decider_json["parameters"] + action_dict["name"] = action + action_dict["parameters"] = param + action_dict["extra"] = {"reasoning": reasoning, "decider_raw_output": decider_response} + if action in ["click", "longclick"]: + target_element = param["target_element"] + query = ''' +Based on the screenshot, user's intent and the description of the target UI element, provide the bounding box of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"bbox": [x1, y1, x2, y2]}}'''.format(reasoning=reasoning, description=target_element) + else: + return action_dict + + # do grounding + # case 1: a cache miss happened + # case 2: replaying cached reasoning + grounder_response = self.grounder_client.chat.completions.create( + model="", + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, + {"type": "text", "text": query}, + ], + } + ], + temperature=0 + ) + grounder_response = grounder_response.choices[0].message.content + grounder_json = json.loads(grounder_response) + action_dict["parameters"]["bbox"] = grounder_json["bbox"] + return action_dict diff --git a/agent_rr/agent/env.py b/agent_rr/agent/env.py new file mode 100644 index 0000000..5275c54 --- /dev/null +++ b/agent_rr/agent/env.py @@ -0,0 +1,130 @@ +from PIL import Image +import base64 +import io +import requests +import time + +from .agent import ReplayLevel + +class Environment: + def __init__(self): + pass + + def get_agent_input(self, history, task_description): + pass + + def get_agent_input_speculative(self, history, task_description, draft_action): + pass + + def execute(self, action): + pass + + +def request_screenshot(url): + body = {"action": "screenshot", "param": {}} + response = requests.post(url, json=body) + if response.status_code == 200: + encoded_image = response.json()['data']['image'] + image_data = base64.b64decode(encoded_image) + return Image.open(io.BytesIO(image_data)).convert("RGB") + else: + return None + +class MultiLevelGeneralEnvironment(Environment): + def __init__(self, agent, replay_level=ReplayLevel.ALL): + super().__init__() + self.agent = agent + self.replay_level = replay_level + self.decider_prompt_fmt = ''' +You are a phone-use AI agent. Now your task is "{task}". +Your action history is: +{history} +Please provide the next action based on the screenshot and your action history. You should do careful reasoning before providing the action. +Your action space includes: +- Name: click, Parameters: target_element (a high-level description of the UI element to click). +- Name: swipe, Parameters: direction (one of UP, DOWN, LEFT, RIGHT). +- Name: input, Parameters: text (the text to input). +- Name: wait, Parameters: (no parameters, will wait for 1 second). +- Name: done, Parameters: (no parameters). +Your output should be a JSON object with the following format: +{{"reasoning": "Your reasoning here", "action": "The next action (one of click, input, swipe, wait, done)", "parameters": {{"param1": "value1", ...}}}}''' + if replay_level == ReplayLevel.REASONING: + self.grounder_prompt_fmt = ''' +Based on the screenshot, user's intent and the description of the target UI element, provide the bounding box of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"bbox": [x1, y1, x2, y2]}}''' + + def get_screenshot(self): + pass + + def get_agent_input(self, history, task_description): + image = self.get_screenshot() + if len(history) == 0: + history_str = "(No history)" + else: + history_str = "\n".join(f"{idx}. {action.extra['decider_raw_output']}" for idx, action in enumerate(history, 1)) + query = self.decider_prompt_fmt.format(task=task_description, history=history_str) + return {"image": image, "query": query} + + def execute(self, action): + pass + +class RemoteMultiLevelGeneralEnvironment(MultiLevelGeneralEnvironment): + def __init__(self, agent, replay_level=ReplayLevel.ALL, url="http://localhost:8766/adb"): + super().__init__(agent, replay_level) + self.url = url + self.last_screenshot = None + self.factor = 0.5 + + def get_screenshot(self): + image = request_screenshot(self.url) + + if image is not None: + width, height = image.size + new_width = int(width * self.factor) + new_height = int(height * self.factor) + image = image.resize((new_width, new_height), Image.LANCZOS) + self.last_screenshot = image + return image + + def execute(self, action): + name = action.name + if name in ["click", "longclick"] and self.replay_level == ReplayLevel.REASONING: + query = self.grounder_prompt_fmt.format( + reasoning=action.extra['reasoning'], + description=action.param['target_element'] + ) + agent_input = { + "image": self.last_screenshot, + "query": query, + "replay_level": self.replay_level + } + agent_output = self.agent.generate(agent_input) + action = action.__class__(**agent_output) + if name in ["click", "longclick"]: + x1, y1, x2, y2 = action.param["bbox"] + x = int((x1 + x2) / 2 / self.factor) + y = int((y1 + y2) / 2 / self.factor) + param = {"x": x, "y": y} + elif name == "input": + param = {"text": action.param["text"]} + elif name == "scroll": + direction = action.param["direction"] + if direction == "DOWN": + y1, y2 = 300, 700 + elif direction == "UP": + y1, y2 = 700, 300 + x1, x2 = 500, 500 + x1, x2 = map(lambda x: x / 1000 * self.screenshot_width, [x1, x2]) + y1, y2 = map(lambda y: y / 1000 * self.screenshot_height, [y1, y2]) + param = {"x1": x1, "x2": x2, "y1": y1, "y2": y2} + name = "scroll" + else: + name = "" + + if name: + body = {"action": name, "param": param} + requests.post(self.url, json=body) + time.sleep(1) diff --git a/agent_rr/omniparser/__init__.py b/agent_rr/omniparser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_rr/omniparser/box_annotator.py b/agent_rr/omniparser/box_annotator.py new file mode 100644 index 0000000..82f7116 --- /dev/null +++ b/agent_rr/omniparser/box_annotator.py @@ -0,0 +1,262 @@ +from typing import List, Optional, Union, Tuple + +import cv2 +import numpy as np + +from supervision.detection.core import Detections +from supervision.draw.color import Color, ColorPalette + + +class BoxAnnotator: + """ + A class for drawing bounding boxes on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to draw the bounding box, + can be a single color or a color palette + thickness (int): The thickness of the bounding box lines, default is 2 + text_color (Color): The color of the text on the bounding box, default is white + text_scale (float): The scale of the text on the bounding box, default is 0.5 + text_thickness (int): The thickness of the text on the bounding box, + default is 1 + text_padding (int): The padding around the text on the bounding box, + default is 5 + + """ + + def __init__( + self, + color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, + thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo + text_color: Color = Color.BLACK, + text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + text_thickness: int = 2, #1, # 2 for demo + text_padding: int = 10, + avoid_overlap: bool = True, + ): + self.color: Union[Color, ColorPalette] = color + self.thickness: int = thickness + self.text_color: Color = text_color + self.text_scale: float = text_scale + self.text_thickness: int = text_thickness + self.text_padding: int = text_padding + self.avoid_overlap: bool = avoid_overlap + + def annotate( + self, + scene: np.ndarray, + detections: Detections, + labels: Optional[List[str]] = None, + skip_label: bool = False, + image_size: Optional[Tuple[int, int]] = None, + ) -> np.ndarray: + """ + Draws bounding boxes on the frame using the detections provided. + + Args: + scene (np.ndarray): The image on which the bounding boxes will be drawn + detections (Detections): The detections for which the + bounding boxes will be drawn + labels (Optional[List[str]]): An optional list of labels + corresponding to each detection. If `labels` are not provided, + corresponding `class_id` will be used as label. + skip_label (bool): Is set to `True`, skips bounding box label annotation. + Returns: + np.ndarray: The image with the bounding boxes drawn on it + + Example: + ```python + import supervision as sv + + classes = ['person', ...] + image = ... + detections = sv.Detections(...) + + box_annotator = sv.BoxAnnotator() + labels = [ + f"{classes[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _ in detections + ] + annotated_frame = box_annotator.annotate( + scene=image.copy(), + detections=detections, + labels=labels + ) + ``` + """ + font = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(detections)): + x1, y1, x2, y2 = detections.xyxy[i].astype(int) + class_id = ( + detections.class_id[i] if detections.class_id is not None else None + ) + idx = class_id if class_id is not None else i + color = ( + self.color.by_idx(idx) + if isinstance(self.color, ColorPalette) + else self.color + ) + cv2.rectangle( + img=scene, + pt1=(x1, y1), + pt2=(x2, y2), + color=color.as_bgr(), + thickness=self.thickness, + ) + if skip_label: + continue + + text = ( + f"{class_id}" + if (labels is None or len(detections) != len(labels)) + else labels[i] + ) + + text_width, text_height = cv2.getTextSize( + text=text, + fontFace=font, + fontScale=self.text_scale, + thickness=self.text_thickness, + )[0] + + if not self.avoid_overlap: + text_x = x1 + self.text_padding + text_y = y1 - self.text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * self.text_padding - text_height + + text_background_x2 = x1 + 2 * self.text_padding + text_width + text_background_y2 = y1 + # text_x = x1 - self.text_padding - text_width + # text_y = y1 + self.text_padding + text_height + # text_background_x1 = x1 - 2 * self.text_padding - text_width + # text_background_y1 = y1 + # text_background_x2 = x1 + # text_background_y2 = y1 + 2 * self.text_padding + text_height + else: + text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size) + + cv2.rectangle( + img=scene, + pt1=(text_background_x1, text_background_y1), + pt2=(text_background_x2, text_background_y2), + color=color.as_bgr(), + thickness=cv2.FILLED, + ) + # import pdb; pdb.set_trace() + box_color = color.as_rgb() + luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] + text_color = (0,0,0) if luminance > 160 else (255,255,255) + cv2.putText( + img=scene, + text=text, + org=(text_x, text_y), + fontFace=font, + fontScale=self.text_scale, + # color=self.text_color.as_rgb(), + color=text_color, + thickness=self.text_thickness, + lineType=cv2.LINE_AA, + ) + return scene + + +def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + +def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + +def IoU(box1, box2, return_max=True): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + if return_max: + return max(intersection / union, ratio1, ratio2) + else: + return intersection / union + + +def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size): + """ check overlap of text and background detection box, and get_optimal_label_pos, + pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right + Threshold: default to 0.3 + """ + + def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size): + is_overlap = False + for i in range(len(detections)): + detection = detections.xyxy[i].astype(int) + if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3: + is_overlap = True + break + # check if the text is out of the image + if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]: + is_overlap = True + return is_overlap + + # if pos == 'top left': + text_x = x1 + text_padding + text_y = y1 - text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x1 + 2 * text_padding + text_width + text_background_y2 = y1 + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'outer left': + text_x = x1 - text_padding - text_width + text_y = y1 + text_padding + text_height + + text_background_x1 = x1 - 2 * text_padding - text_width + text_background_y1 = y1 + + text_background_x2 = x1 + text_background_y2 = y1 + 2 * text_padding + text_height + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + + # elif pos == 'outer right': + text_x = x2 + text_padding + text_y = y1 + text_padding + text_height + + text_background_x1 = x2 + text_background_y1 = y1 + + text_background_x2 = x2 + 2 * text_padding + text_width + text_background_y2 = y1 + 2 * text_padding + text_height + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'top right': + text_x = x2 - text_padding - text_width + text_y = y1 - text_padding + + text_background_x1 = x2 - 2 * text_padding - text_width + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x2 + text_background_y2 = y1 + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 diff --git a/agent_rr/omniparser/omniparser.py b/agent_rr/omniparser/omniparser.py new file mode 100644 index 0000000..f185646 --- /dev/null +++ b/agent_rr/omniparser/omniparser.py @@ -0,0 +1,18 @@ +from .utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box +import torch +from PIL import Image +from typing import Dict +class Omniparser(object): + def __init__(self, config: Dict): + self.config = config + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + self.som_model = get_yolo_model(model_path=config['som_model_path']) + self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device) + print('Omniparser initialized') + + def parse(self, image): + text, ocr_bbox = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.9,'paragraph': False}, use_paddleocr=True) + parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD=self.config['box_threshold'], ocr_bbox=ocr_bbox, caption_model_processor=self.caption_model_processor, ocr_text=text, use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) + + return parsed_content_list diff --git a/agent_rr/omniparser/utils.py b/agent_rr/omniparser/utils.py new file mode 100644 index 0000000..40f4927 --- /dev/null +++ b/agent_rr/omniparser/utils.py @@ -0,0 +1,505 @@ +# from ultralytics import YOLO +import io +import base64 +import time +from PIL import Image + +import cv2 +import numpy as np +# from matplotlib import pyplot as plt +from paddleocr import PaddleOCR +paddle_ocr = PaddleOCR( + lang='ch', # other lang also available + use_angle_cls=False, + use_gpu=False, # using cuda will conflict with pytorch in the same process + show_log=False, + max_batch_size=1024, + use_dilation=True, # improves accuracy + det_db_score_mode='slow', # improves accuracy + rec_batch_num=1024) +import time +import base64 + +import torch +from typing import Tuple, List, Union +from torchvision.ops import box_convert +from torchvision.transforms import ToPILImage +# import supervision as sv +import torchvision.transforms as T +from .box_annotator import BoxAnnotator + + +def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): + if not device: + device = "cuda" if torch.cuda.is_available() else "cpu" + if model_name == "blip2": + from transformers import Blip2Processor, Blip2ForConditionalGeneration + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + if device == 'cpu': + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float32 + ) + else: + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float16 + ).to(device) + elif model_name == "florence2": + from transformers import AutoProcessor, AutoModelForCausalLM + processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) + if device == 'cpu': + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device) + return {'model': model.to(device), 'processor': processor} + + +def get_yolo_model(model_path): + from ultralytics import YOLO + # Load the model. + model = YOLO(model_path) + return model + + +@torch.inference_mode() +def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128): + # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model + to_pil = ToPILImage() + if starting_idx: + non_ocr_boxes = filtered_boxes[starting_idx:] + else: + non_ocr_boxes = filtered_boxes + croped_pil_image = [] + for i, coord in enumerate(non_ocr_boxes): + try: + xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) + ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) + cropped_image = image_source[ymin:ymax, xmin:xmax, :] + cropped_image = cv2.resize(cropped_image, (64, 64)) + croped_pil_image.append(to_pil(cropped_image)) + except: + continue + + model, processor = caption_model_processor['model'], caption_model_processor['processor'] + if not prompt: + if 'florence' in model.config.name_or_path: + prompt = "" + else: + prompt = "The image shows" + + generated_texts = [] + device = model.device + for i in range(0, len(croped_pil_image), batch_size): + start = time.time() + batch = croped_pil_image[i:i+batch_size] + t1 = time.time() + if model.device.type == 'cuda': + inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16) + else: + inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device) + if 'florence' in model.config.name_or_path: + generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False) + else: + generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True, + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_text = [gen.strip() for gen in generated_text] + generated_texts.extend(generated_text) + + return generated_texts + + + +def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor): + to_pil = ToPILImage() + if ocr_bbox: + non_ocr_boxes = filtered_boxes[len(ocr_bbox):] + else: + non_ocr_boxes = filtered_boxes + croped_pil_image = [] + for i, coord in enumerate(non_ocr_boxes): + xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) + ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) + cropped_image = image_source[ymin:ymax, xmin:xmax, :] + croped_pil_image.append(to_pil(cropped_image)) + + model, processor = caption_model_processor['model'], caption_model_processor['processor'] + device = model.device + messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}] + prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + batch_size = 5 # Number of samples per batch + generated_texts = [] + + for i in range(0, len(croped_pil_image), batch_size): + images = croped_pil_image[i:i+batch_size] + image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images] + inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []} + texts = [prompt] * len(images) + for i, txt in enumerate(texts): + input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt") + inputs['input_ids'].append(input['input_ids']) + inputs['attention_mask'].append(input['attention_mask']) + inputs['pixel_values'].append(input['pixel_values']) + inputs['image_sizes'].append(input['image_sizes']) + max_len = max([x.shape[1] for x in inputs['input_ids']]) + for i, v in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1) + inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1) + inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()} + + generation_args = { + "max_new_tokens": 25, + "temperature": 0.01, + "do_sample": False, + } + generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) + # # remove input tokens + generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:] + response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + response = [res.strip('\n').strip() for res in response] + generated_texts.extend(response) + + return generated_texts + +def remove_overlap(boxes, iou_threshold, ocr_bbox=None): + assert ocr_bbox is None or isinstance(ocr_bbox, List) + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + + def IoU(box1, box2): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + 1e-6 + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + return max(intersection / union, ratio1, ratio2) + + def is_inside(box1, box2): + # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] + intersection = intersection_area(box1, box2) + ratio1 = intersection / box_area(box1) + return ratio1 > 0.95 + + boxes = boxes.tolist() + filtered_boxes = [] + if ocr_bbox: + filtered_boxes.extend(ocr_bbox) + # print('ocr_bbox!!!', ocr_bbox) + for i, box1 in enumerate(boxes): + # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j): + is_valid_box = True + for j, box2 in enumerate(boxes): + # keep the smaller box + if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): + is_valid_box = False + break + if is_valid_box: + # add the following 2 lines to include ocr bbox + if ocr_bbox: + # only add the box if it does not overlap with any ocr bbox + if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)): + filtered_boxes.append(box1) + else: + filtered_boxes.append(box1) + return torch.tensor(filtered_boxes) + + +def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): + ''' + ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...] + boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...] + + ''' + assert ocr_bbox is None or isinstance(ocr_bbox, List) + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + + def IoU(box1, box2): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + 1e-6 + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + return max(intersection / union, ratio1, ratio2) + + def is_inside(box1, box2): + # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] + intersection = intersection_area(box1, box2) + ratio1 = intersection / box_area(box1) + return ratio1 > 0.80 + + # boxes = boxes.tolist() + filtered_boxes = [] + if ocr_bbox: + filtered_boxes.extend(ocr_bbox) + # print('ocr_bbox!!!', ocr_bbox) + for i, box1_elem in enumerate(boxes): + box1 = box1_elem['bbox'] + is_valid_box = True + for j, box2_elem in enumerate(boxes): + # keep the smaller box + box2 = box2_elem['bbox'] + if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): + is_valid_box = False + break + if is_valid_box: + if ocr_bbox: + # keep yolo boxes + prioritize ocr label + box_added = False + ocr_labels = '' + for box3_elem in ocr_bbox: + if not box_added: + box3 = box3_elem['bbox'] + if is_inside(box3, box1): # ocr inside icon + # box_added = True + # delete the box3_elem from ocr_bbox + try: + # gather all ocr labels + ocr_labels += box3_elem['content'] + ' ' + filtered_boxes.remove(box3_elem) + except: + continue + # break + elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box + box_added = True + break + else: + continue + if not box_added: + if ocr_labels: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'}) + else: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'}) + else: + filtered_boxes.append(box1) + return filtered_boxes # torch.tensor(filtered_boxes) + + +def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image_source = Image.open(image_path).convert("RGB") + image = np.asarray(image_source) + image_transformed, _ = transform(image_source, None) + return image, image_transformed + + +def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float, + text_padding=5, text_thickness=2, thickness=3) -> np.ndarray: + """ + This function annotates an image with bounding boxes and labels. + + Parameters: + image_source (np.ndarray): The source image to be annotated. + boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale + logits (torch.Tensor): A tensor containing confidence scores for each bounding box. + phrases (List[str]): A list of labels for each bounding box. + text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + + Returns: + np.ndarray: The annotated image. + """ + h, w, _ = image_source.shape + boxes = boxes * torch.Tensor([w, h, w, h]) + xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy() + detections = sv.Detections(xyxy=xyxy) + + labels = [f"{phrase}" for phrase in range(boxes.shape[0])] + + box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + annotated_frame = image_source.copy() + annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h)) + + label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)} + return annotated_frame, label_coordinates + + +def predict(model, image, caption, box_threshold, text_threshold): + """ Use huggingface model to replace the original model + """ + model, processor = model['model'], model['processor'] + device = model.device + + inputs = processor(images=image, text=caption, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + + results = processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=box_threshold, # 0.4, + text_threshold=text_threshold, # 0.3, + target_sizes=[image.size[::-1]] + )[0] + boxes, logits, phrases = results["boxes"], results["scores"], results["labels"] + return boxes, logits, phrases + + +def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7): + """ Use huggingface model to replace the original model + """ + # model = model['model'] + if scale_img: + result = model.predict( + source=image, + conf=box_threshold, + imgsz=imgsz, + iou=iou_threshold, # default 0.7 + ) + else: + result = model.predict( + source=image, + conf=box_threshold, + iou=iou_threshold, # default 0.7 + ) + boxes = result[0].boxes.xyxy#.tolist() # in pixel space + conf = result[0].boxes.conf + phrases = [str(i) for i in range(len(boxes))] + + return boxes, conf, phrases + +def int_box_area(box, w, h): + x1, y1, x2, y2 = box + int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)] + area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1]) + return area + +def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, ocr_bbox=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128): + """Process either an image path or Image object + + Args: + image_source: Either a file path (str) or PIL Image object + ... + """ + if isinstance(image_source, str): + image_source = Image.open(image_source) + image_source = image_source.convert("RGB") # for CLIP + w, h = image_source.size + if not imgsz: + imgsz = (h, w) + # print('image size:', w, h) + xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) + xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device) + image_source = np.asarray(image_source) + + # annotate the image with labels + if ocr_bbox: + ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h]) + ocr_bbox=ocr_bbox.tolist() + else: + print('no ocr bbox!!!') + ocr_bbox = None + + ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0] + xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0] + # xyxy_elem = [] + filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem) + # filtered_boxes = xyxy_elem + + # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None + filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None) + # get the index of the first 'content': None + starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1) + filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem]) + + # get parsed icon local semantics + if use_local_semantics: + caption_model = caption_model_processor['model'] + if 'phi3_v' in caption_model.config.model_type: + parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor) + else: + parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size) + ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] + icon_start = len(ocr_text) + parsed_content_icon_ls = [] + # fill the filtered_boxes_elem None content with parsed_content_icon in order + for i, box in enumerate(filtered_boxes_elem): + if box['content'] is None: + box['content'] = parsed_content_icon.pop(0) + for i, txt in enumerate(parsed_content_icon): + parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}") + else: + ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] + + return filtered_boxes_elem + + +def get_xywh(input): + x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1] + x, y, w, h = int(x), int(y), int(w), int(h) + return x, y, w, h + +def get_xyxy(input): + x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1] + x, y, xp, yp = int(x), int(y), int(xp), int(yp) + return x, y, xp, yp + +def get_xywh_yolo(input): + x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1] + x, y, w, h = int(x), int(y), int(w), int(h) + return x, y, w, h + +def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', easyocr_args=None, use_paddleocr=False): + if isinstance(image_source, str): + image_source = Image.open(image_source) + if image_source.mode == 'RGBA': + # Convert RGBA to RGB to avoid alpha channel issues + image_source = image_source.convert('RGB') + image_np = np.array(image_source) + w, h = image_source.size + if use_paddleocr: + if easyocr_args is None: + text_threshold = 0.5 + else: + text_threshold = easyocr_args['text_threshold'] + result = paddle_ocr.ocr(image_np, cls=False)[0] + if result is None: + return [], [] + coord = [item[0] for item in result if item[1][1] > text_threshold] + text = [item[1][0] for item in result if item[1][1] > text_threshold] + # else: # EasyOCR + # if easyocr_args is None: + # easyocr_args = {} + # result = reader.readtext(image_np, **easyocr_args) + # coord = [item[0] for item in result] + # text = [item[1] for item in result] + if display_img: + opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) + bb = [] + for item in coord: + x, y, a, b = get_xywh(item) + bb.append((x, y, a, b)) + cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2) + # matplotlib expects RGB + plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)) + else: + if output_bb_format == 'xywh': + bb = [get_xywh(item) for item in coord] + elif output_bb_format == 'xyxy': + bb = [get_xyxy(item) for item in coord] + return text, bb \ No newline at end of file diff --git a/agent_rr/requirements-agentrr.txt b/agent_rr/requirements-agentrr.txt new file mode 100644 index 0000000..ad9e625 --- /dev/null +++ b/agent_rr/requirements-agentrr.txt @@ -0,0 +1,10 @@ +scikit-image +torch +sentence-transformers +transformers +scikit-learn +numpy +torchvision +supervision +paddleocr +paddlepaddle \ No newline at end of file diff --git a/agent_rr/run_experiment.py b/agent_rr/run_experiment.py new file mode 100644 index 0000000..e273cd2 --- /dev/null +++ b/agent_rr/run_experiment.py @@ -0,0 +1,156 @@ +from train.task_template import get_app_task_trajectories +from agent.agent import Agent +from agent.env import Environment +import os +from action_cache.action import Action +from action_cache.tree import ActionTree, Task, MatchMode + +class MybenchTasks: + def __init__(self, data_path): + self.app_task_trajectories = {} + for root, _, files in os.walk(data_path): + if 'templates.json' in files: + domain_app_task_trajectories = get_app_task_trajectories(root) + for app, tasks in domain_app_task_trajectories.items(): + if app not in self.app_task_trajectories: + self.app_task_trajectories[app] = [] + self.app_task_trajectories[app].extend(tasks) + for app in self.app_task_trajectories.keys(): + old_task_trajectories = self.app_task_trajectories[app] + new_task_trajectories = [] + for task, trajectory in old_task_trajectories: + new_trajectory = [] + for action in trajectory: + fields = action.split(" ") + new_trajectory.append(Action(name=fields[0], param={str(i) : field for i, field in enumerate(fields[1:])}, extra={})) + new_task_trajectories.append((task, new_trajectory)) + self.app_task_trajectories[app] = new_task_trajectories + + def get_app_task_trajectories(self): + return self.app_task_trajectories + + +class MybenchAgent(Agent): + def __init__(self, tasks: MybenchTasks): + super().__init__() + self.reset_cnt() + self.reset_cur_task() + self.tasks = tasks + self.task_trajectory = {} + self.task_step = {} + app_task_trajectories = self.tasks.get_app_task_trajectories() + for app, task_trajectories in app_task_trajectories.items(): + for task, trajectory in task_trajectories: + self.task_trajectory[task] = trajectory + self.task_step[task] = -1 + + def reset_cnt(self): + self.generate_cnt = 0 + + def reset_cur_task(self, account=False): + if account: + self.generate_cnt += self.cur_generate_cnt + self.cur_generate_cnt = 0 + + def print_cnt(self): + print(f"generate_cnt: {self.generate_cnt}") + + def generate(self, agent_input): + self.cur_generate_cnt += 1 + task = agent_input["task"] + trajectory = self.task_trajectory[task] + cur_step = self.task_step[task] + if cur_step >= len(trajectory): + return {"name":"done", "param":{}, "extra":{}} + action = trajectory[cur_step] + return {"name": action.name, "param": action.param, "extra": action.extra} + +class MybenchEnvironment(Environment): + def __init__(self, agent: MybenchAgent): + super().__init__() + self.agent = agent + self.cur_task = "" + self.reset_cnt() + self.reset_cur_task() + + def reset_cnt(self): + self.execute_cnt = 0 + self.total_task_cnt = 0 + self.correct_task_cnt = 0 + + def reset_cur_task(self): + self.cur_execute_cnt = 0 + self.cur_success = True + + def print_cnt(self): + print(f"execute_cnt: {self.execute_cnt}, total_task_cnt: {self.total_task_cnt}, correct_task_cnt: {self.correct_task_cnt}") + + def get_agent_input(self, history, task_description): + self.agent.task_step[task_description] += 1 + if self.cur_task != task_description: + self.reset_cur_task() + self.total_task_cnt += 1 + self.cur_task = task_description + return {"task": task_description, "history": history} + + def execute(self, action): + print(f"env executing: {action}") + self.cur_execute_cnt += 1 + step = self.agent.task_step[self.cur_task] + if step >= len(self.agent.task_trajectory[self.cur_task]): + print(f"incorrect: expected done action, actual action is {action}") + self.cur_success = False + return + ground_truth = self.agent.task_trajectory[self.cur_task][step] + if action != ground_truth: + print(f"incorrect: {action} != {ground_truth} in step {step}") + self.cur_success = False + + def check_done(self): + step = self.agent.task_step[self.cur_task] + self.cur_execute_cnt += 1 + if not self.cur_success: + self.agent.reset_cur_task(account=False) + return + if step == len(self.agent.task_trajectory[self.cur_task]): + self.execute_cnt += self.cur_execute_cnt + self.correct_task_cnt += 1 + else: + self.cur_success = False + print("incorrect: done mismatch") + self.agent.reset_cur_task(account=self.cur_success) + +def main(args): + agent = MybenchAgent(MybenchTasks(args.data_path)) + env = MybenchEnvironment(agent) + tree = ActionTree(env, agent, Action, done=lambda a: a.name == 'done', + mode=MatchMode.FUZZY, + embedder_config={ + "path": args.embedder_path + }, + reranker_config={ + "path": args.reranker_path + }) + + app_task_trajectories = agent.tasks.get_app_task_trajectories() + for app, task_trajectories in app_task_trajectories.items(): + print(f"Current app: {app}") + tree.clear() + for task, _ in task_trajectories: + print(f"Current task: {task}") + tree.execute(task) + env.check_done() + if not env.cur_success: + tree.root.remove_task_trace(Task(task)) + env.print_cnt() + agent.print_cnt() + input("Press enter to continue") + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--embedder_path', type=str, required=True) + parser.add_argument('--reranker_path', type=str, required=True) + parser.add_argument('--data_path', type=str, required=True) + args = parser.parse_args() + main(args) diff --git a/agent_rr/train/__init__.py b/agent_rr/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_rr/train/prepare_data.py b/agent_rr/train/prepare_data.py new file mode 100644 index 0000000..ded68cc --- /dev/null +++ b/agent_rr/train/prepare_data.py @@ -0,0 +1,248 @@ +import json, os, itertools +import random +from sklearn.model_selection import train_test_split +import re + +from .task_template import get_app_task_trajectories + +NUM_NEGATIVE = 10 +MAX_REPEAT_TIMES = 10 +EMBEDDING_QUERY_FORMAT = "Instruct: Given a phone-use task, retrieve similar tasks that shares at least **{n}** steps with the given task\nQuery:{query}" +EMBEDDING_INSTRUCT_FORMAT = "Instruct: Represent this phone-use task for level **{n}**\nQuery:{query}" + +def get_lshare(trajectory1, trajectory2): + k = 0 + l_share = 0 + while k < len(trajectory1) and k < len(trajectory2): + if trajectory1[k] == trajectory2[k]: + l_share += 1 + else: + break + k += 1 + return l_share + +def single_split_embedding(path): + task_trajectories = [] + app_tasks = {} + for root, _, files in os.walk(path): + if 'templates.json' in files: + app_task_trajectories = get_app_task_trajectories(root) + app_tasks.update({app: list(set(task for task, _ in app_task_trajectories[app])) for app in app_task_trajectories}) + task_trajectories.extend(list(set((task, tuple(trajectory)) for app in app_task_trajectories for task, trajectory in app_task_trajectories[app]))) + entries = [] + task_positives = {} + task_negatives = {} + task_app = {} + for i, j in itertools.combinations(range(len(task_trajectories)), 2): + task1, trajectory1 = task_trajectories[i] + task2, trajectory2 = task_trajectories[j] + app1 = trajectory1[0].split(' ')[1].replace("<", "").replace(">", "") + app2 = trajectory2[0].split(' ')[1].replace("<", "").replace(">", "") + task_app[task1] = app1 + task_app[task2] = app2 + is_same_app = app1 == app2 + l_share = get_lshare(trajectory1, trajectory2) + if task1 not in task_positives: + task_positives[task1] = {} + task_negatives[task1] = {} + if task2 not in task_positives: + task_positives[task2] = {} + task_negatives[task2] = {} + for n in range(1, len(trajectory1) + 1): + if n not in task_positives[task1]: + task_positives[task1][n] = set([task1]) + if n not in task_negatives[task1]: + task_negatives[task1][n] = set() + if l_share >= n: + task_positives[task1][n].add(task2) + elif n == 1 or is_same_app: + task_negatives[task1][n].add(task2) + for n in range(1, len(trajectory2) + 1): + if n not in task_positives[task2]: + task_positives[task2][n] = set([task2]) + if n not in task_negatives[task2]: + task_negatives[task2][n] = set() + if l_share >= n: + task_positives[task2][n].add(task1) + elif n == 1 or is_same_app: + task_negatives[task2][n].add(task1) + for task in task_positives.keys(): + positive = task_positives[task] + negative = task_negatives[task] + for n in positive.keys(): + if len(positive[n]) == 1: + positive[n] |= set([f"请{task}", f"请你{task}", f"请帮我{task}", f"帮我{task}", f"请你帮我{task}"]) + if len(negative[n]) == 0 and n > 1: + # sample some tasks from other apps + other_apps = [app for app in app_tasks if app != task_app[task]] + for app in other_apps: + sample_num = (10 * NUM_NEGATIVE + len(negative[n]) - 1) // len(other_apps) + sample_num = min(sample_num, len(app_tasks[app])) + sampled_tasks = random.sample(app_tasks[app], sample_num) + negative[n].update(set(sampled_tasks)) + # print(f"Task: {task}, n: {n}, positives: {len(positive[n])}, negatives: {len(negative[n])}") + repeat_times = (len(negative[n]) + len(positive[n]) - 1) // len(positive[n]) + repeat_times = (repeat_times + NUM_NEGATIVE - 1) // NUM_NEGATIVE + repeat_times = max(1, repeat_times) + start = 0 + rejected = list(negative[n]) + random.shuffle(rejected) + old_num_entries = len(entries) + for positive_task in positive[n]: + # query = EMBEDDING_QUERY_FORMAT.format(n=n, query=task) + query = EMBEDDING_INSTRUCT_FORMAT.format(n=n, query=task) + # response = positive_task + response = EMBEDDING_INSTRUCT_FORMAT.format(n=n, query=positive_task) + for _ in range(repeat_times): + if start >= len(rejected): + start = 0 + end = start + NUM_NEGATIVE + end = min(end, len(rejected)) + entries.append({ + "query": query, + "response": response, + # "rejected_response": [rejected[start:end]] + "rejected_response": [EMBEDDING_INSTRUCT_FORMAT.format(n=n, query=t) for t in rejected[start:end]], + }) + start = end + # print(f"Increment: {len(entries) - old_num_entries}") + return entries + + +def balance_embedding(entries): + level_entries = {} + for entry in entries: + match = re.search(r'\*\*(\d+)\*\*', entry['query']) + if match: + n = int(match.group(1)) + if n not in level_entries: + level_entries[n] = [] + level_entries[n].append(entry) + max_len = max(len(level_entries[n]) for n in level_entries) + for n in level_entries: + if len(level_entries[n]) < max_len // 2: + multiplier = max_len // 2 // len(level_entries[n]) + level_entries[n] = level_entries[n] * multiplier + ret = [] + for n in level_entries: + ret.extend(level_entries[n]) + return ret + +def embedding_main(train_path, test_path): + entries_train = single_split_embedding(train_path) + if test_path is None: + entries_train, entries_test = train_test_split(entries_train, test_size=0.1, random_state=42) + else: + entries_test = single_split_embedding(test_path) + entries_train = balance_embedding(entries_train) + # random.shuffle(entries_train) + # random.shuffle(entries_test) + with open('embedding_mybench_data.jsonl', 'w', encoding='utf-8') as f: + for entry in entries_train: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + with open('embedding_mybench_data_test.jsonl', 'w', encoding='utf-8') as f: + for entry in entries_test: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + +RERANKER_SYSTEM = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"." +POSITIVE_TOKEN = "yes" +NEGATIVE_TOKEN = "no" +RERANKER_INPUT_FORMAT = ": Given a phone-use task, retrieve similar tasks that shares at least **{n}** steps with the given task\n: {query} \n: {document}" +RERANKER_OUTPUT_FORMAT = "\n\n\n\n{token}" + +def single_app_reranker(task_trajectory_pairs): + entries = [] + max_len = max(len(trajectory) for _, trajectory in task_trajectory_pairs) + for i, j in itertools.combinations(range(len(task_trajectory_pairs)), 2): + task1, trajectory1 = task_trajectory_pairs[i] + task2, trajectory2 = task_trajectory_pairs[j] + l_share = get_lshare(trajectory1, trajectory2) + for n in range(1, max_len + 1): + token = POSITIVE_TOKEN if l_share >= n else NEGATIVE_TOKEN + input_text1 = RERANKER_INPUT_FORMAT.format(n=n, query=task1, document=task2) + input_text2 = RERANKER_INPUT_FORMAT.format(n=n, query=task2, document=task1) + output_text = RERANKER_OUTPUT_FORMAT.format(token=token) + entries.extend([ + { + "system": RERANKER_SYSTEM, + "input": input_text, + "output": output_text + } + for input_text in [input_text1, input_text2] + ]) + return entries + +def single_domain_reranker(domain_dir): + app_task_trajectories = get_app_task_trajectories(domain_dir) + entries = [] + for app in app_task_trajectories: + entries.extend(single_app_reranker(app_task_trajectories[app])) + return entries + +def cross_app_step1_reranker(path): + app_tasks = {} + task_app = {} + for root, _, files in os.walk(path): + if 'templates.json' in files: + app_task_trajectories = get_app_task_trajectories(root) + app_tasks.update({app: list(set(task for task, _ in app_task_trajectories[app])) for app in app_task_trajectories}) + task_app.update({task: app for app in app_task_trajectories for task, _ in app_task_trajectories[app]}) + entries = [] + for task in task_app: + app = task_app[task] + other_apps = [a for a in app_tasks if a != app] + for other_app in other_apps: + sample_num = (30 + len(app_tasks[app]) - 1) // len(other_apps) + sample_num = min(sample_num, len(app_tasks[other_app])) + sampled_tasks = random.sample(app_tasks[other_app], sample_num) + for sampled_task in sampled_tasks: + input_text = RERANKER_INPUT_FORMAT.format(n=1, query=task, document=sampled_task) + output_text = RERANKER_OUTPUT_FORMAT.format(token=NEGATIVE_TOKEN) + entries.append({ + "system": RERANKER_SYSTEM, + "input": input_text, + "output": output_text + }) + return entries + +def reranker_main(train_path, test_path): + entries_train = [] + entries_test = [] + for root, _, files in os.walk(train_path): + if 'templates.json' in files: + entries_train.extend(single_domain_reranker(root)) + entries_train.extend(cross_app_step1_reranker(train_path)) + if test_path is None: + entries_train, entries_test = train_test_split(entries_train, test_size=0.1, random_state=42) + else: + for root, _, files in os.walk(test_path): + if 'templates.json' in files: + entries_test.extend(single_domain_reranker(root)) + entries_test.extend(cross_app_step1_reranker(test_path)) + + with open('reranker_mybench_data.jsonl', 'w', encoding='utf-8') as f: + for entry in entries_train: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + with open('reranker_mybench_data_test.jsonl', 'w', encoding='utf-8') as f: + for entry in entries_test: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, choices=['embedding', 'reranker', 'both'], required=True, + help="Specify the task to generate data for.") + parser.add_argument('--train_path', type=str, default='train/train_data', + help="Path to the training data directory.") + parser.add_argument('--test_path', type=str, default='train/test_data', + help="Path to the test data directory.") + args = parser.parse_args() + if args.task == 'embedding': + embedding_main(args.train_path, args.test_path) + elif args.task == 'reranker': + reranker_main(args.train_path, args.test_path) + elif args.task == 'both': + embedding_main(args.train_path, args.test_path) + reranker_main(args.train_path, args.test_path) + else: + print("Invalid task specified. Use 'embedding' or 'reranker' or 'both'.") diff --git a/agent_rr/train/task_template.py b/agent_rr/train/task_template.py new file mode 100644 index 0000000..2828d3e --- /dev/null +++ b/agent_rr/train/task_template.py @@ -0,0 +1,88 @@ +import os, itertools, json + +def get_task_templates(raw_template): + results = [] + left_pos = 0 + while True: + left = raw_template.find('(', left_pos) + if left == -1: + break + right = raw_template.find(')', left + 1) + if right == -1: + break + + content = raw_template[left + 1: right] + contents = content.replace("NULL", "").split('|') + + results.append({ + 'left': left, + 'right': right, + 'contents': contents + }) + + left_pos = right + 1 + combinations = itertools.product(*[result['contents'] for result in results]) + task_templates = [] + for combination in combinations: + segments = [] + last_left = 0 + for i, content in enumerate(combination): + left = results[i]['left'] + right = results[i]['right'] + segments.append(raw_template[last_left:left]) + segments.append(content) + last_left = right + 1 + segments.append(raw_template[last_left:]) + task_templates.append(''.join(segments)) + return task_templates + +def get_trajectory(trajectory_template, fmt): + trajectory = [] + for act in trajectory_template: + for k in fmt.keys(): + if f"{{{k}}}" in act: + act = act.replace(f"{{{k}}}", fmt[k]) + if f"<{k}>" in act: + act = act.replace(f"<{k}>", f"<{fmt[k]}>") + trajectory.append(act) + return trajectory + +def get_app_task_trajectories(domain_dir): + with open(os.path.join(domain_dir, "templates.json"), encoding='utf-8') as f: + templates = json.load(f) + print(f"Domain: {domain_dir}") + task_trajectories = {} + for template in templates: + raw_task_template = template["task"] + trajectory_template = template["trajectory"] + task_templates = get_task_templates(raw_task_template) + candidates = template["candidates"] + dependency = template.get("dependency", "no") + keys = list(candidates.keys()) + + if dependency == "one-to-one": + combinations = zip(*[candidates[k] for k in keys]) + elif dependency == "no": + combinations = itertools.product(*[candidates[k] for k in keys]) + else: + print(f"Unknonw dependency type: {dependency}") + continue + + for combination in combinations: + fmt = {} + for i, k in enumerate(keys): + fmt[k] = combination[i] + trajectory = get_trajectory(trajectory_template, fmt) + for task_template in task_templates: + task = task_template.format(**fmt) + task_trajectories[task] = trajectory + # print(task_trajectories) + app_task_trajectories = {} + for task, trajectory in task_trajectories.items(): + app = trajectory[0].split(' ')[1] + app = app.replace("<", "").replace(">", "") + if app not in app_task_trajectories: + app_task_trajectories[app] = [] + app_task_trajectories[app].append((task, trajectory)) + + return app_task_trajectories \ No newline at end of file diff --git a/agent_rr/train/train_data_example/wechat/templates.json b/agent_rr/train/train_data_example/wechat/templates.json new file mode 100644 index 0000000..4f6f17d --- /dev/null +++ b/agent_rr/train/train_data_example/wechat/templates.json @@ -0,0 +1,72 @@ +[ + { + "task": "发一条微信朋友圈,内容为{content}", + "trajectory": [ + "click 微信", + "click 发现", + "click 朋友圈", + "longclick 相机图标", + "input {content}", + "click 发表" + ], + "candidates": { + "content": ["和朋友们聚会很开心", "今日早餐打卡", "努力工作的一天", "周末快乐", "工作加油"] + } + }, + { + "task": "给{user}发一条微信,内容为{content}", + "trajectory": [ + "click 微信", + "click ", + "click 输入框", + "input {content}", + "click 发送" + ], + "candidates": { + "user": ["alice", "bob", "charlie"], + "content": ["你好", "下午几点出来", "请查收邮件"] + }, + "dependency": "no" + }, + { + "task": "搜索并打开微信小程序{app}", + "trajectory": [ + "click 微信", + "click 放大镜图标", + "click 小程序选项", + "input {app}", + "click 搜索", + "click " + ], + "candidates": { + "app": ["美团外卖", "肯德基"] + } + }, + { + "task": "打开微信{service}服务", + "trajectory": [ + "click 微信", + "click 个人中心", + "click 服务", + "click " + ], + "candidates": { + "service": ["手机充值", "生活缴费", "酒店民宿"] + } + }, + { + "task": "搜索并关注微信公众号{account}", + "trajectory": [ + "click 微信", + "click 放大镜图标", + "click 公众号选项", + "input {account}", + "click 搜索", + "click ", + "click 关注公众号" + ], + "candidates": { + "account": ["机器之心", "人民日报", "物理竞赛", "豆瓣"] + } + } +] \ No newline at end of file diff --git a/assets/arch.png b/assets/arch.png new file mode 100644 index 0000000..031b9c4 Binary files /dev/null and b/assets/arch.png differ diff --git a/assets/arch_zh.png b/assets/arch_zh.png new file mode 100644 index 0000000..94333a2 Binary files /dev/null and b/assets/arch_zh.png differ diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..7db3e45 Binary files /dev/null and b/assets/logo.png differ diff --git a/assets/result1.png b/assets/result1.png new file mode 100644 index 0000000..7c0cf8d Binary files /dev/null and b/assets/result1.png differ diff --git a/assets/result2.png b/assets/result2.png new file mode 100644 index 0000000..d34c154 Binary files /dev/null and b/assets/result2.png differ diff --git a/assets/result3.png b/assets/result3.png new file mode 100644 index 0000000..0aa7829 Binary files /dev/null and b/assets/result3.png differ diff --git a/assets/result_agentrr.png b/assets/result_agentrr.png new file mode 100644 index 0000000..d8309c6 Binary files /dev/null and b/assets/result_agentrr.png differ diff --git a/collect/README.md b/collect/README.md new file mode 100644 index 0000000..c890e6e --- /dev/null +++ b/collect/README.md @@ -0,0 +1,252 @@ +# 数据收集标注工具 + +## 数据收集 + +### 数据格式 + +通过人工/自动收集工具,收集每个action前的手机截图,并记录每个action的信息,并汇总到一个actions.json文件中。action格式如下: +``` +{{ + "app_name": str + "task_description": ["The description of the task list."], + "action_count": "The count of the actions.", + "actions": [ + {{ + "type": "The type of the action", + "parameters": "etc.", + }}, + {{ + "type": "click", + "position_x": "x-coordinate of click", + "position_y": "y-coordinate of click action", + "bounds": "the bound of the clicked element", + }}, + {{ + "type": "swipe", + "press_position_x": "x-coordinate of press", + "press_position_y": "y-coordinate of press", + "release_position_x": "x-coordinate of release", + "release_position_y": "y-coordinate of release", + "direction": "The direction of the user's swipe gesture. UP: swipe finger upward to scroll content up and reveal content below. DOWN: swipe finger downward to scroll content down and reveal content above. LEFT: swipe finger leftward to scroll content left. RIGHT: swipe finger rightward to scroll content right." + }}, + {{ + "type": "input", + "text": "The text to input", + }}, + {{ + "type": "done" + }}, + {{ + "type": "wait" + }}, + ] +}} +``` + +### 手动数据收集 + +**启动服务器** +```bash +python -m collect.manual.server +``` +启动成功后,访问 http://localhost:9000 进入Web操作界面。 + +**操作步骤** + +1. **开始收集**:在Web界面点击 **开始收集** 按钮 + +2. **配置应用信息**:在弹出的 **应用信息配置** 窗口中填写: +- **应用名称**:如 "饿了么"、"微信"、"淘宝" 等 +- **任务类型**:如 “tpye1”、 “tpye2” 等,具体参考收集任务文档 + +3. **输入任务描述** + - 在 **任务描述** 窗口中详细描述当前要执行的具体任务 + - 确保描述清晰明确,便于后续数据分析和模型训练 + +4. **执行操作** + - 在Web界面的手机截图上进行以下操作: + - **点击操作**:直接点击截图上的目标位置 + - **滑动操作**:按住鼠标左键拖拽到目标位置后松开(注意保持在屏幕范围内) + - **文本输入**:点击 **文本输入** 按钮,在弹出框中输入文本内容 + +5. **保存数据** + - 完成一个任务序列后,根据需要选择: + - **下一条数据**:继续收集同类型任务的更多数据样本 + - **结束收集**:完成当前收集会话并保存所有数据 + - **删除任务**:丢弃当前数据(用于处理错误操作或无效数据) + +**数据存储格式** + +收集的数据自动保存到 `collect/manual/data/` 目录,按以下层级结构组织: + +``` +data/ +├── <应用名称>/ +│ ├── <任务类型>/ +│ │ ├── 1/ +│ │ │ ├── 1.jpg # 第1个操作前的截图 +│ │ │ ├── 2.jpg # 第2个操作前的截图 +│ │ │ ├── ... +│ │ │ └── actions.json # 操作记录和任务信息 +│ │ ├── 2/ +│ │ │ └── ... # 第2条数据 +│ │ └── ... +│ └── <其他任务类型>/ +└── <其他应用名称>/ +``` + +每个数据样本包含: +- **截图序列**:记录每个操作步骤前的界面状态 +- **actions.json**:包含完整的操作序列、任务描述和应用信息 + +### 自动数据收集 +先在 `collect/auto/task.json` 写入需要完成的任务列表,格式为字符串数组: +```json +[ + "在淘宝搜索iPhone手机", + "在微信给张三发消息说你好", + "在b站关注up主李四" +] +``` + +运行自动数据收集程序: +```bash +python -m collect.auto.server --model <模型名称> --api_key --base_url [--max_steps <最大步数>] +``` + +**必需参数:** +- `--model`:LLM模型名称 +- `--api_key`:API密钥 +- `--base_url`:API基础URL + +**可选参数:** +- `--max_steps`:每个任务的最大执行步数,默认为 15 + +**工作流程:** +1. 程序读取 `task.json` 中的任务列表 +2. 对每个任务: + - AI智能体根据任务描述自动选择并启动相应的应用 + - 自动执行操作序列(点击、滑动、输入等) + - 每步操作前自动截图并记录操作信息 + - 达到最大步数或任务完成时停止 +3. 自动保存数据到指定目录 + +**存储数据格式:** +- 原始日志数据存储在 `collect/auto/data_log/` +- 转换后的标准格式数据存储在 `collect/auto/data/` +- 数据结构与手动收集保持一致,包含截图序列和 `actions.json` 文件 + +## 数据标注 + +数据标注模块将原始的操作数据转换为带有视觉标注的数据,为通用AI模型提供更丰富的上下文信息,使得其能够提供更加准确的reasoning。 + +### 视觉标注格式 + +**操作标注** +- 用户每个时间步的操作以 **红色字体** 标注在对应截图的顶部 +- 辅助信息同时在截图中进行可视化标注: + - **点击操作**:在操作位置标注 **红色圆圈** + - **滑动操作**:用 **红色箭头** 标示从起始位置到结束位置的方向 + +**数据生成** +系统将标注后的截图序列和任务描述发送给大模型,生成 `react.json` 文件,包含推理过程和操作决策: + +```json +[ + { + "reasoning": "选择此操作类型的推理过程和原因", + "function": { + "name": "click", + "parameters": { + "target_element": "点击目标的高级语义描述" + } + } + }, + { + "reasoning": "滑动操作的推理过程", + "function": { + "name": "swipe", + "parameters": { + "direction": "UP, DOWN, LEFT, RIGHT" + } + } + }, + { + "reasoning": "文本输入的推理过程", + "function": { + "name": "input", + "parameters": { + "text": "要输入的文本内容" + } + } + }, + { + "reasoning": "任务完成的判断依据", + "function": { + "name": "done", + "parameters": {} + } + }, + { + "reasoning": "等待操作的原因说明", + "function": { + "name": "wait", + "parameters": {} + } + } +] +``` + +### 自动标注执行 + +**启动命令** +```bash +python -m collect.annotate --data_path <数据路径> --model <模型名称> --api_key --base_url +``` + +**参数说明** +- `--data_path`:原始轨迹数据存储路径(可选,默认为当前目录下的 `data` 目录) +- `--model`:大语言模型名称(必需) +- `--api_key`:模型服务API密钥(必需) +- `--base_url`:模型服务基础URL(必需) + +**处理流程** +1. 读取原始数据目录中的截图序列和 `actions.json` 文件 +2. 根据操作信息在截图上添加视觉标注 +3. 将标注后的数据发送给大模型进行推理分析 +4. 生成包含推理过程的 `react.json` 文件 +5. 保存完整的标注数据集,用于后续模型训练 + +**数据存储格式** + +收集的数据自动保存到对应目录,最小的子目录有如下结构: +``` +dir/ +├── 1.jpg # 第1个操作前的截图 +├── 2.jpg # 第2个操作前的截图 +├── ... +└── actions.json # 操作记录和任务信息 +└── react.json # 标注数据 +``` + +## 数据构建 + +数据构建模块将标注后的数据转换为适合模型训练的格式,支持监督微调(SFT)数据集的生成。 + +### 启动命令 + +```bash +python -m collect.construct_sft --data_path <原始数据路径> --ss_data_path <单步数据路径> --unexpected_img_path <意外图片路径> --out_path <输出路径> [--factor <缩放因子>] [--train_ratio <训练比例>] +``` + +### 参数说明 + +**必需参数** +- `--data_path`:原始轨迹数据存储路径(默认:`data`) +- `--ss_data_path`:单步数据存储路径(默认:`ss_data`) +- `--unexpected_img_path`:意外图片数据路径(默认:`unexpected_img`) +- `--out_path`:训练数据集输出路径(默认:`output`) + +**可选参数** +- `--factor`:图片缩放因子,用于减小图片尺寸(默认:`0.5`) +- `--train_ratio`:训练集与验证集的划分比例(默认:`0.9`) \ No newline at end of file diff --git a/collect/annotate.py b/collect/annotate.py new file mode 100644 index 0000000..1027d6e --- /dev/null +++ b/collect/annotate.py @@ -0,0 +1,403 @@ +from PIL import Image, ImageDraw, ImageFont +import textwrap +import cv2 +import numpy as np +import os + +import argparse +from argparse import Namespace + +from langchain_openai import ChatOpenAI +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +import base64, re +import json + +from utils.parse_omni import extract_all_bounds, find_clicked_element + +from utils.load_md_prompt import load_prompt + +model = None + +direction_mapping = { + "向上滑动": "UP", + "向下滑动": "DOWN", + "向左滑动": "LEFT", + "向右滑动": "RIGHT", + "从下往上滑动": "UP", + "从上往下滑动": "DOWN", + "从右往左滑动": "LEFT", + "从左往右滑动": "RIGHT", + "向上滚动": "UP", + "向下滚动": "DOWN", + "向左滚动": "LEFT", + "向右滚动": "RIGHT", + "从下往上滚动": "UP", + "从上往下滚动": "DOWN", + "从右往左滚动": "LEFT", + "从左往右滚动": "RIGHT", +} + +# 前者是actions.json 后者是react.json 对应的解析内容 +def compare_actions(actions, reacts): + if (len(actions) != len(reacts)): + raise Exception(f"[Action and React length mismatch] actions: {len(actions)}, reacts: {len(reacts)}") + + for i, (action, react) in enumerate(zip(actions, reacts)): + # 比较动作类型(忽略大小写) + action_type = action.get("type", "").lower() + react_type = react.get("function").get("name", "").lower() + + if action_type != react_type: + raise Exception(f"[type mismatch] Action {i+1}: action type {action_type},react type {react_type}") + + reasoning = react["reasoning"] + + # 展示放弃如reasoning中有滚动滑动,强制让类型变成swipe + # for desc, expected_direction in direction_mapping.items(): + # if desc in reasoning:y + # if react_type != "swipe": + # raise Exception(f"[Swipe action is expected] action {i+1} action: {action}, react: {react}, reasoning: {reasoning}") + # break + + if(action_type == "swipe"): + action_direction = action["direction"].upper() if "direction" in action else None + + if(react_type == "swipe"): + # parameters 内的字段可能不是 direction,而是taget啥的 + if "parameters" not in react["function"] or "direction" not in react["function"]["parameters"]: + raise Exception(f"[Swipe action missing parameters] React {i+1}: {react}") + + react_direction = react["function"]["parameters"]["direction"] + + if(action_direction != react_direction): + raise Exception(f"[direction mismatch] Action {i+1}: action_direction: {action_direction}, react_direction: {react_direction}") + + flag = False + for desc, expected_direction in direction_mapping.items(): + if desc in reasoning: + if react_direction == expected_direction: + flag = True + break + else: + raise Exception(f"[Swipe reasoning direction mismatch] Action {i+1}: action_direction: {action_direction}, react: {react}") + if not flag: + raise Exception(f"[Swipe reasoning hasn't direction description] Action {i+1}: action_direction: {action_direction}, react: {react}") + +change_task_description_prompt = load_prompt("change_task_description.md") +def change_task_description(app_name, original_task): + count = 6 + max_attempts = 3 + for attempt in range(max_attempts): + try: + prompt = ChatPromptTemplate.from_messages([ + ("system", "{sys_prompt}"), + ("user", "{user_message}") + ]) + chain = prompt | model + response = chain.invoke({ + "sys_prompt": change_task_description_prompt.replace("{app_name}", app_name if app_name else "").replace("{original_task}", original_task).replace("{count}", str(count)), + "user_message": f"请将任务'{original_task}'改写成{count}个版本" + (f"(前3条不带应用名称,后3条带应用名称)" if app_name else "(都不带应用名称)") + }) + + # 提取JSON内容 + pattern = re.compile(r"```json\n(.*?)\n```", re.DOTALL) + match = pattern.search(response.content) + + if match is None: + # 如果没有找到json代码块,尝试直接解析整个响应 + try: + data = json.loads(response.content.strip()) + if isinstance(data, list) and len(data) == count: + return data + except: + pass + print(f"[Generate Task] Attempt {attempt + 1} failed, no valid JSON found in response.") + continue + + json_str = match.group(1) + data = json.loads(json_str) + + if not isinstance(data, list): + raise Exception("Response is not a list") + + if len(data) != count: + raise Exception(f"Expected {count} tasks, got {len(data)}") + + # 验证所有元素都是字符串 + for i, task in enumerate(data): + if not isinstance(task, str) or not task.strip(): + raise Exception(f"Task {i+1} is not a valid string") + + return data + + except Exception as e: + print(f"[Generate Task] Attempt {attempt + 1} failed, error: {str(e)}") + if attempt == max_attempts - 1: + raise Exception(f"Failed to generate tasks after {max_attempts} attempts") + continue + + +def add_action_index(actions): + """为 actions 列表中的每个元素添加 action_index 字段""" + for i, action in enumerate(actions): + if isinstance(action, dict): + action['action_index'] = i + 1 + return actions + +def add_bounds_to_action(root, actions): + for i, action in enumerate(actions): + flag = False + + if action["type"] == "click": + if not "bounds" in action: + print(f"[Add Bounds] {root} Action {i + 1} no bounds, adding bounds") + flag = True + elif action["bounds"] is None: + print(f"[Add Bounds] {root} Action {i + 1} bounds is None, adding bounds") + flag = True + + # bounds = action.get("bounds", None) + # if bounds is not None: + # x1, y1, x2, y2 = bounds + # # if x1 < 50 and x2 > 950: + # if x1 < 100 and x2 > 950: + # print(f"[Add Bounds] {root} Action {i + 1} bounds is Special") + # flag = True + + # flag = action["type"] == "click" + + if flag: + img_path = os.path.join(root, f"{i + 1}.jpg") + if not os.path.exists(img_path): + raise Exception(f"[Add Bounds] Image not found: {img_path}") + + # actions[i]["bounds"] = None + # bounds_list = extract_all_bounds(img_path) + # actions[i]["bounds"] = find_clicked_element(bounds_list, action["position_x"], action["position_y"]) + + bounds_list = extract_all_bounds(img_path) + if "bounds" in action and action["bounds"]: + bounds_list.append(action["bounds"]) + actions[i]["bounds"] = find_clicked_element(bounds_list, action["position_x"], action["position_y"]) + + return actions + +def visual_prompt(root, actions): + print(f"[Visual Prompt] {root} begin") + for file_name in os.listdir(root): + # 检查文件是否以 '_highlighted.jpg' 结尾 + if file_name.endswith('_highlighted.jpg') or file_name.endswith('_bounds.jpg'): + file_path = os.path.join(root, file_name) + os.remove(file_path) + + jpg_files = [f for f in os.listdir(root) if f.endswith('.jpg')] + + if actions[-1]["type"] == "done": + if(len(jpg_files)!= len(actions)): + raise Exception(f"[Visual Prompt] {root} has {len(jpg_files)} images, but {len(actions)} actions with done") + else: + if(len(jpg_files)!= len(actions) + 1): + raise Exception(f"[Visual Prompt] {root} has {len(jpg_files)} images, but {len(actions)} actions without done") + + for i, action in enumerate(actions): + img_path = os.path.join(root, f"{i + 1}.jpg") + save_path = os.path.join(root, f"{i + 1}_highlighted.jpg") + + if not os.path.exists(img_path): + raise Exception(f"[Visual Prompt] Image not found: {img_path}") + + img = Image.open(img_path) + draw = ImageDraw.Draw(img) + font = ImageFont.truetype("msyh.ttf", 40) + if action["type"] == "click": + text = f"CLICK [{int(action['position_x'])}, {int(action['position_y'])}]" + elif action["type"] == "input": + text = f"INPUT {action['text']}" + elif action["type"] == "swipe": + text = f"SWIPE [{int(action['press_position_x'])}, {int(action['press_position_y'])}] to [{int(action['release_position_x'])}, {int(action['release_position_y'])}]" + elif action["type"] == "done": + text = f"DONE" + elif action["type"] == "long_press": + text = f"LONG PRESS [{int(action['position_x'])}, {int(action['position_y'])}]" + elif action["type"] == "open_app": + text = f"OPEN APP {action['app_name']}" + else: + raise Exception(f"[Visual Prompt] Unknown action type: {action['type']}") + + text_width, text_height = draw.textbbox((0, 0), text, font=font)[2:] + draw.text((img.width / 2 - text_width / 2, 0), text, fill="red", font=font) + img.save(save_path) + + # 拉框 + bounds_path = os.path.join(root, f"{i + 1}_bounds.jpg") + img_bounds = Image.open(save_path) + draw_bounds = ImageDraw.Draw(img_bounds) + if action["type"] == "click" or action["type"] == "long_press": + if "bounds" in action and action["bounds"]: + draw_bounds.rectangle(action["bounds"], outline='red', width=5) + img_bounds.save(bounds_path) + # 画点 + with open(save_path, 'rb') as f: + image_data = f.read() + nparr = np.frombuffer(image_data, np.uint8) + cv2image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if action["type"] == "click": + x = int(action['position_x']) + y = int(action['position_y']) + cv2.circle(cv2image, (x, y), 50, (0, 0, 255), 10) + elif action["type"] == "swipe": + x1 = int(action['press_position_x']) + y1 = int(action['press_position_y']) + x2 = int(action['release_position_x']) + y2 = int(action['release_position_y']) + cv2.arrowedLine(cv2image, (x1, y1), (x2, y2), (0, 0, 255), 5) + success, encoded_img = cv2.imencode('.jpg', cv2image) + if success: + with open(save_path, 'wb') as f: + f.write(encoded_img.tobytes()) + print(f"[Visual Prompt] done") + +def auto_annotate(root, chain, task_description, actions): + print(f"[Reasoning] root: \"{root}\" task: \"{task_description}\"") + + files = os.listdir(root) + image_data = [] + highlighted = [file for file in files if file.endswith("_highlighted.jpg")] + highlighted.sort(key=lambda f: int(f.replace("_highlighted.jpg", ""))) + for file in highlighted: + img_path = os.path.join(root, file) + with open(img_path, "rb") as f: + image_data.append(base64.b64encode(f.read()).decode("utf-8")) + + max_attempts = 3 + for attempt in range(0, max_attempts): + response = chain.invoke({ + "goal": task_description, + "screenshot_count": len(image_data), + "messages": [ + ( + "user", + [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}} for image in image_data] + ) + ] + }) + + pattern = re.compile(r"```json\n(.*)\n```", re.DOTALL) + match = pattern.search(response.content) + if match is None: + print(f"[Reasoning] Attempt {attempt + 1} failed, no JSON found in response.") + continue + + try: + json_str = match.group(1) + data = json.loads(json_str) + reasoning_count = len(data) + if(len(image_data) != reasoning_count): + raise Exception(f"[Invalid reasoning count]") + react_json = os.path.join(root, "temp.json") + with open(react_json, "w", encoding="UTF-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + compare_actions(actions, data) + + except Exception as e: + print(f"[Reasoning] Attempt {attempt + 1} failed, error: {str(e)}") + continue + break + + json_str = match.group(1) + data = json.loads(json_str) + reasoning_count = len(data) + if(len(image_data) != reasoning_count): + raise Exception(f"[Invalid reasoning count]") + + # # 为 react.json 中的数据添加 action_index + # for i, item in enumerate(data): + # if isinstance(item, dict): + # item['action_index'] = i + 1 + + compare_actions(actions, data) + + react_json = os.path.join(root, "react.json") + with open(react_json, "w", encoding="UTF-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + print(f"[Reasoning] finished, saved to {react_json}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Auto annotation of GUI data') + parser.add_argument('--data_path', type=str, default='data', help='root directory containing the data (default: data)') + parser.add_argument('--model', type=str, required=True, help='name of the annotation model') + parser.add_argument('--api_key', type=str, required=True, help='API key of the annotation model') + parser.add_argument('--base_url', type=str, required=True, help='base URL of the annotation model') + + args = parser.parse_args() + + model = ChatOpenAI( + model=args.model, + api_key=args.api_key, + base_url=args.base_url, + ) + + from utils.load_md_prompt import load_prompt + sys_prompt = load_prompt("annotation_en_general.md") + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", sys_prompt), + MessagesPlaceholder(variable_name='messages') + ] + ) + + chain = prompt | model + + for root, dirs, files in os.walk(args.data_path): + # 对子目录按数字顺序排序 + dirs.sort(key=lambda x: int(x) if x.isdigit() else float('inf')) + + try: + actions_json = os.path.join(root, "actions.json") + if not os.path.exists(actions_json): + raise Exception("No actions.json") + + react_json = os.path.join(root, "react.json") + if os.path.exists(react_json): + continue + parse_error = os.path.join(root, "parse.error") + if os.path.exists(parse_error): + continue + + with open(actions_json, 'r', encoding='utf-8') as file: + data = json.load(file) + if "task_description" not in data: + raise Exception("No task_description in actions.json") + task_description = data.get("task_description") + actions = data.get("actions") + + # 不要随意开启这个,ocr有风险 + # actions = add_bounds_to_action(root, actions) + # data["actions"] = actions + # with open(actions_json, 'w', encoding='utf-8') as file: + # json.dump(data, file, ensure_ascii=False, indent=4) + + visual_prompt(root, actions) + auto_annotate(root, chain, task_description, actions) + + app_name = data.get("app_name") + if(isinstance(task_description, str)): + new_tasks = change_task_description(app_name, task_description) + all_tasks = [task_description] + new_tasks + data["task_description"] = all_tasks + + with open(actions_json, 'w', encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False, indent=4) + + print(f"[Increase Task] finished, saved to {actions_json}") + + + except Exception as e: + with open(f"{root}/parse.error", 'w', encoding='utf-8', errors='ignore') as file: + file.write(f"{str(e)}\n") + with open(f"{args.data_path}/list.error", 'a', encoding='utf-8', errors='ignore') as file: + file.write(f"root: \"{root}\" {str(e)}\n") + diff --git a/collect/augment_config.json b/collect/augment_config.json new file mode 100644 index 0000000..a0ad42d --- /dev/null +++ b/collect/augment_config.json @@ -0,0 +1,15 @@ +[ + { + "dir": [ + "function", + "name" + ], + "pattern": "swipe", + "multiplier": { + "reason": 5, + "reason_no_history": 5, + "grounder": 1, + "other": 1 + } + } +] \ No newline at end of file diff --git a/collect/auto/draw_bounds.py b/collect/auto/draw_bounds.py new file mode 100644 index 0000000..ace74d6 --- /dev/null +++ b/collect/auto/draw_bounds.py @@ -0,0 +1,109 @@ +from PIL import Image, ImageDraw, ImageFont +import os + +# from parse_xml import extract_all_bounds +from utils.parse_omni import extract_all_bounds + +def check_text_overlap(text_rect1, text_rect2): + """检查两个文本矩形是否重叠""" + x1, y1, x2, y2 = text_rect1 + x3, y3, x4, y4 = text_rect2 + + # 如果一个矩形在另一个的左侧、右侧、上方或下方,则不重叠 + if x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1: + return False + return True + +def assign_bounds_to_layers(folder_path, screenshot_path, bounds_list): + """使用贪心算法将bounds分配到不同的图层,避免文本重叠""" + image = Image.open(screenshot_path) + draw = ImageDraw.Draw(image) + font = ImageFont.truetype("arial.ttf", 40) + + layers = [] # 每个元素是一个包含(index, bounds, text_rect)的列表 + + for index, bounds in enumerate(bounds_list): + left, top, right, bottom = bounds + + text = str(index) + bbox = draw.textbbox((0, 0), text, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + # text_x = right - text_width + text_x = left + text_y = top + + text_rect = (text_x, text_y, text_x + text_width + 5, text_y + text_height + 15) + + # 寻找可以容纳当前bounds的图层 + placed = False + for layer in layers: + can_place = True + for _, existing_bounds, existing_text_rect in layer: + # if check_text_overlap(text_rect, existing_text_rect): + if check_text_overlap(bounds, existing_bounds) or check_text_overlap(text_rect, existing_bounds) or check_text_overlap(bounds, existing_text_rect) or check_text_overlap(text_rect, existing_text_rect): + can_place = False + break + + if can_place: + layer.append((index, bounds, text_rect)) + placed = True + break + + # 如果没有找到合适的图层,创建新图层 + if not placed: + layers.append([(index, bounds, text_rect)]) + + for index, layer in enumerate(layers, 1): + output_path = os.path.join(folder_path, f"layer_{index}.jpg") + draw_bounds_on_screenshot(screenshot_path, layer, output_path) + + return len(layers) + +def draw_bounds_on_screenshot(screenshot_path, layer, output_path): + """在截图上绘制所有bounds并保存""" + try: + image = Image.open(screenshot_path) + draw = ImageDraw.Draw(image) + font = ImageFont.truetype("arial.ttf", 40) + + # 用红色绘制所有bounds并标记索引 + for index, bounds, text_rect in layer: + left, top, right, bottom = bounds + draw.rectangle([left, top, right, bottom], outline='red', width=5) + + text = str(index) + text_x, text_y, _, _ = text_rect + + draw.rectangle(text_rect, fill='red', outline='red', width=1) + draw.text((text_x, text_y), text, fill='white', font=font) + + image.save(output_path) + # print(f"已保存标注结果到: {output_path}") + return True + + except Exception as e: + print(f"绘制bounds时出错: {str(e)}") + return False + +def process_folder(folder_path, need_clickable=False): + """处理单个文件夹""" + hierarchy_path = os.path.join(folder_path, 'hierarchy.xml') + screenshot_path = os.path.join(folder_path, 'screenshot.jpg') + + try: + # 读取hierarchy.xml + # with open(hierarchy_path, 'r', encoding='utf-8') as f: + # hierarchy_xml = f.read() + + # 提取所有bounds + # bounds_list = extract_all_bounds(hierarchy_xml, need_clickable) + + bounds_list = extract_all_bounds(screenshot_path) + # print(f"在 {folder_path} 中找到 {len(bounds_list)} 个bounds") + + return assign_bounds_to_layers(folder_path, screenshot_path, bounds_list), bounds_list + + except Exception as e: + print(f"处理文件夹 {folder_path} 时出错: {str(e)}") + return 0 diff --git a/collect/auto/server.py b/collect/auto/server.py new file mode 100644 index 0000000..afb5982 --- /dev/null +++ b/collect/auto/server.py @@ -0,0 +1,496 @@ +import uiautomator2 as u2 +import time +import os +import shutil +import base64 +from PIL import Image +import io +import json +import re +import logging +import sys +import json +from datetime import datetime +from openai import OpenAI +import argparse + +from collect.auto.draw_bounds import process_folder + +device = None # 设备连接对象 +hierarchy = None # 层次结构数据 +data_index = 1 # 数据索引 + +operation_history = [] # 操作历史记录 +logger = None # 日志记录器 + +# 全局配置变量,将由命令行参数设置 +model = None +api_key = None +base_url = None +max_steps = 15 +client = None + +# action_dir 是存储的目录 +def get_current_hierarchy_and_screenshot(action_dir, sleep_time = 0): + global hierarchy + time.sleep(sleep_time) + + if os.path.exists(action_dir): + shutil.rmtree(action_dir) + os.makedirs(action_dir) + + # if not os.path.exists(action_dir): + # os.makedirs(action_dir) + + screenshot_path = os.path.join(action_dir, "screenshot.jpg") + hierarchy_path = os.path.join(action_dir, "hierarchy.xml") + + device.screenshot(screenshot_path) + hierarchy = device.dump_hierarchy() + with open(hierarchy_path, "w", encoding="utf-8") as f: + f.write(hierarchy) + + logger.info(f"操作完成,已重新截图和获取层次结构") + +# 将路径 img_path 截图保存为base64编码的字符串 +def get_screenshot(img_path, factor=0.4): + img = Image.open(img_path) + img = img.resize((int(img.width * factor), int(img.height * factor)), Image.Resampling.LANCZOS) + buffered = io.BytesIO() + img.save(buffered, format="JPEG") + screenshot = base64.b64encode(buffered.getvalue()).decode("utf-8") + return screenshot + +def handle_click(x, y): + """处理点击操作""" + device.click(x, y) + operation_record = { + "type": "click", + "timestamp": __import__('datetime').datetime.now().isoformat(), + "position": {"x": x, "y": y}, + # "clicked_elements": clicked_elements + } + operation_history.append(operation_record) + +def handle_input(text): + current_ime = device.current_ime() + device.shell(['settings', 'put', 'secure', 'default_input_method', 'com.android.adbkeyboard/.AdbIME']) + time.sleep(1) + charsb64 = base64.b64encode(text.encode('utf-8')).decode('utf-8') + device.shell(['am', 'broadcast', '-a', 'ADB_INPUT_B64', '--es', 'msg', charsb64]) + time.sleep(1) + device.shell(['settings', 'put', 'secure', 'default_input_method', current_ime]) + # time.sleep(1) + # device.shell(['am', 'broadcast', '-a', 'ADB_INPUT_METHOD_HIDE']) + + operation_record = { + "type": "input", + "timestamp": __import__('datetime').datetime.now().isoformat(), + "text": text, + } + operation_history.append(operation_record) + # get_current_hierarchy_and_screenshot(1.5) + +def handle_swipe(direction): + # device.swipe(action.startX, action.startY, action.endX, action.endY, duration=0.1) + device.swipe_ext(direction=direction, duration=0.1) + operation_record = { + "type": "swipe", + "timestamp": __import__('datetime').datetime.now().isoformat(), + "direction": direction + } + operation_history.append(operation_record) + # get_current_hierarchy_and_screenshot(1.5) + + + +from utils.load_md_prompt import load_prompt +app_selection_prompt_template = load_prompt("planner.md") +decider_prompt_template = load_prompt("auto_decider.md") + +def get_app_package_name(task_description): + """根据任务描述获取需要启动的app包名""" + app_selection_prompt = app_selection_prompt_template.format(task_description=task_description) + while True: + response_str = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": app_selection_prompt + } + ] + ).choices[0].message.content + + logger.info(f"应用选择响应: \n{response_str}") + + # 解析JSON响应 + pattern = re.compile(r"```json\n(.*)\n```", re.DOTALL) + match = pattern.search(response_str) + if match: + break + + response = json.loads(match.group(1)) + package_name = response.get("package_name") + reasoning = response.get("reasoning") + + logger.info(f"选择应用原因: {reasoning}") + logger.info(f"选择的包名: {package_name}") + + return package_name + +def do_task(task_description, data_dir): + global logger + + logger.info(f"开始执行任务: {task_description}") + + # 根据任务描述获取需要启动的应用包名 + package_name = get_app_package_name(task_description) + logger.info(f"选择启动应用: {package_name}") + + device.app_start(package_name, stop=True) + time.sleep(3) + action_history = [] + reasoning_history = [] + screenshots = [] + while True: + logger.info('=' * 50) + logger.info('=' * 50) + action_count = len(action_history) # 已有的操作数量 + action_index = action_count + 1 # 接下来的操作索引 + action_dir = os.path.join(data_dir, str(action_index)) + get_current_hierarchy_and_screenshot(action_dir) + + if(action_count > max_steps): + logger.info(f"任务步骤超过上限({max_steps}),停止执行") + return + + if action_count == 0: + history = "(No history)" + else: + # history = "\n".join(f"{idx}. {h}" for idx, h in enumerate(action_history, 1)) + history = "\n".join(f"{idx}. {h}" for idx, h in enumerate(reasoning_history, 1)) + + # 截图拉框 + # layer_count, bounds_list = process_folder(action_dir, need_clickable=True) + layer_count, bounds_list = process_folder(action_dir) + logger.info(f"已处理 {action_dir},共绘制 {layer_count} 个图层") + + # decider_prompt + decider_prompt = decider_prompt_template.format( + task_description = task_description, + history = history, + layer_count = layer_count + ) + message_content = [ + {"type": "text", "text": decider_prompt} + ] + + # 屏幕截图 + screenshot_path = os.path.join(action_dir, "screenshot.jpg") + screenshot = get_screenshot(screenshot_path, factor=1.0) + message_content.append({ + "type": "text", + "text": f"\n屏幕截图:" + }) + message_content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{screenshot}"} + }) + + # 遍历所有标注图层 + for idx in range(1, layer_count + 1): + screenshot_path = os.path.join(action_dir, f"layer_{idx}.jpg") + screenshot = get_screenshot(screenshot_path) + message_content.append({ + "type": "text", + "text": f"\n第{idx}张标注图层:" + }) + message_content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{screenshot}"} + }) + + decider_response_str = client.chat.completions.create( + model= model, + messages=[ + { + "role": "user", + "content": message_content + } + ] + ).choices[0].message.content + + logger.info(f"response: \n{decider_response_str}") + pattern = re.compile(r"```json\n(.*)\n```", re.DOTALL) + match = pattern.search(decider_response_str) + if not match: + logger.error("错误:未找到有效的JSON响应") + continue + decider_response = json.loads(match.group(1)) + + reasoning = decider_response.get("reasoning") + action = decider_response.get("action") + parameters = decider_response.get("parameters") + + if action == "done": + logger.info("任务完成!") + action = { + "reasoning": reasoning, + "function": { + "name": "done", + "parameters": {} + } + } + logger.info(f"完成操作: {action}") + action_history.append(action) + reasoning_history.append(reasoning) + break + + elif action == "click": + target_element = parameters.get("target_element") + index = parameters.get("index") + if index is None or index < 0 or index >= len(bounds_list): + logger.error(f"错误:index {index} 超出范围,有效范围为 0 到 {len(bounds_list)-1}") + continue + bounds = bounds_list[index] + # index, bounds = decide_click_element(data_dir, action_count + 1, task_description, reasoning, target_element) + logger.info(f"选择点击元素: {target_element} (index: {index}, bounds: {bounds})") + x = (bounds[0] + bounds[2]) / 2 + y = (bounds[1] + bounds[3]) / 2 + handle_click(x, y) + action = { + "reasoning": reasoning, + "function": { + "name": "click", + "parameters": { + "position_x": x, + "position_y": y, + "bounding_box": bounds, + "target_element": target_element, + } + } + } + logger.info(f"点击操作: {action}") + action_history.append(action) + reasoning_history.append(reasoning) + + elif action == "input": + text = parameters.get("text") + handle_input(text) + action = { + "reasoning": reasoning, + "function": { + "name": "input", + "parameters": { + "text": text + } + } + } + logger.info(f"输入操作: {action}") + action_history.append(action) + reasoning_history.append(reasoning) + + elif action == "swipe": + direction = parameters.get("direction").lower() + handle_swipe(direction) + action = { + "reasoning": reasoning, + "function": { + "name": "swipe", + "parameters": { + "direction": direction + } + } + } + logger.info(f"滑动操作: {action}") + action_history.append(action) + reasoning_history.append(reasoning) + else: + raise ValueError(f"Unknown action: {action}") + + time.sleep(2.5) + + data = { + "task_description": task_description, + "action_count": len(action_history), + "actions": action_history, + } + data_file = os.path.join(data_dir, "task_data.json") + with open(data_file, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + logger.info(f"任务执行完成,共执行 {len(action_history)} 个操作") + logger.info(f"任务数据已保存到: {data_file}") + logger.info("日志记录完成") + +def setup_logger(data_dir): + """设置日志记录器,同时输出到控制台和文件""" + global logger + + # 创建日志目录 + log_file = os.path.join(data_dir, "execution.log") + + # 创建logger,使用特定名称避免冲突 + logger_name = f'auto_collect_{id(data_dir)}' + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # 清除已有的处理器 + logger.handlers.clear() + + # 防止日志传播到根logger + logger.propagate = False + + # 创建文件处理器 + file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8') + file_handler.setLevel(logging.INFO) + + # 创建控制台处理器 + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # 创建格式器 + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # 添加处理器到logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + +def change_auto_data(data_log_path, index): + parse_error = os.path.join(data_log_path, "parse.error") + if os.path.exists(parse_error): + return + task_data = os.path.join(data_log_path, "task_data.json") + if not os.path.exists(task_data): + return + + with open(task_data, 'r', encoding='utf-8') as file: + task_data = json.load(file) + + app_name = task_data.get("app_name") + task_type = None + task_description = task_data.get("task_description") + actions = task_data.get("actions") + + new_actions = [] + for action in actions: + action_type = action["function"]["name"].lower() + if action_type == "click": + new_action = { + "type": action_type, + "position_x": int(action["function"]["parameters"]["position_x"]), + "position_y": int(action["function"]["parameters"]["position_y"]), + "bounds": action["function"]["parameters"]["bounding_box"] + } + new_actions.append(new_action) + elif action_type == "swipe": + new_action = { + "type": action_type, + "press_position_x": None, + "press_position_y": None, + "release_position_x": None, + "release_position_y": None, + "direction": action["function"]["parameters"]["direction"] + } + new_actions.append(new_action) + elif action_type == "input": + new_action = { + "type": action_type, + "text": action["function"]["parameters"]["text"] + } + new_actions.append(new_action) + elif action_type == "done": + new_action = { + "type": "done" + } + new_actions.append(new_action) + else: + raise ValueError(f"Unknown action type: {action_type}") + + data = { + "app_name": app_name, + "task_type": task_type, + "task_description": task_description, + "action_count": len(new_actions), + "actions": new_actions + } + + dest_path_dir = os.path.join(os.path.dirname(__file__), 'data') + if not os.path.exists(dest_path_dir): + os.makedirs(dest_path_dir) + existing_dirs = [d for d in os.listdir(dest_path_dir) if os.path.isdir(os.path.join(dest_path_dir, d)) and d.isdigit()] + if existing_dirs: + max_index = max(int(d) for d in existing_dirs) + 1 + else: + max_index = 1 + dest_path = os.path.join(dest_path_dir, str(max_index)) + os.makedirs(dest_path) + + # 复制并重命名图片文件 + for index in range(1, len(new_actions) + 2): # +2 因为通常有一张额外的截图 + screenshot_src = os.path.join(data_log_path, str(index), "screenshot.jpg") + if os.path.exists(screenshot_src): + screenshot_dest = os.path.join(dest_path, f"{index}.jpg") + shutil.copy2(screenshot_src, screenshot_dest) + print(f"复制图片: {screenshot_src} -> {screenshot_dest}") + + with open(os.path.join(dest_path, "actions.json"), 'w', encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Auto collection of GUI data') + parser.add_argument('--model', type=str, required=True, help='name of the LLM model') + parser.add_argument('--api_key', type=str, required=True, help='API key for the LLM model') + parser.add_argument('--base_url', type=str, required=True, help='base URL for the LLM model API') + parser.add_argument('--max_steps', type=int, default=15, help='maximum steps per task (default: 15)') + + args = parser.parse_args() + + # 设置全局配置 + model = args.model + api_key = args.api_key + base_url = args.base_url + max_steps = args.max_steps + + # 初始化OpenAI客户端 + client = OpenAI( + api_key=api_key, + base_url=base_url + ) + + device = u2.connect() + # 创建数据目录 + session_base_dir = os.path.dirname(__file__) + data_base_dir = os.path.join(session_base_dir, 'data_log') + if not os.path.exists(data_base_dir): + os.makedirs(data_base_dir) + + # 读取任务列表 + task_json_path = os.path.join(os.path.dirname(__file__), "task.json") + with open(task_json_path, "r", encoding="utf-8") as f: + task_list = json.load(f) + + for task_description in task_list: + # 遍历现有数据目录,找到最大的索引 + existing_dirs = [d for d in os.listdir(data_base_dir) if os.path.isdir(os.path.join(data_base_dir, d)) and d.isdigit()] + if existing_dirs: + data_index = max(int(d) for d in existing_dirs) + 1 + else: + data_index = 1 + data_log_dir = os.path.join(data_base_dir, str(data_index)) + os.makedirs(data_log_dir) + + # 设置日志记录器 + logger = setup_logger(data_log_dir) + logger.info("程序启动") + logger.info(f"数据索引: {data_index}") + logger.info(f"数据目录: {data_log_dir}") + + do_task(task_description, data_log_dir) + change_auto_data(data_log_dir, data_index) + diff --git a/collect/auto/task.json b/collect/auto/task.json new file mode 100644 index 0000000..e065509 --- /dev/null +++ b/collect/auto/task.json @@ -0,0 +1,3 @@ +[ + "在B站中,搜索“明日方舟”" +] \ No newline at end of file diff --git a/collect/construct_sft.py b/collect/construct_sft.py new file mode 100644 index 0000000..21b2add --- /dev/null +++ b/collect/construct_sft.py @@ -0,0 +1,708 @@ +import os, json +from skimage.metrics import structural_similarity as ssim +import cv2 +from dataclasses import dataclass, asdict +from typing import List +from PIL import Image +import random +import argparse +from tqdm import tqdm + +import re +from functools import reduce + +from utils.load_md_prompt import load_prompt + +def calculate_index_weight(index, total_length): + # 分段权重计算 + if index <= 5: + base_weight = 1 + elif index <= 8: + base_weight = 1 + index // 4 + else: + base_weight = 1 + index // 3 + return base_weight + +def load_augmentation_rules(config_path="augment_config.json"): + """读取数据扩充配置文件,返回规则列表""" + if not os.path.exists(config_path): + print(f"警告:配置文件 '{config_path}' 不存在,使用默认规则(无扩充)。") + return [] + try: + with open(config_path, 'r', encoding='utf-8') as f: + rules = json.load(f) + for rule in rules: + if not isinstance(rule.get("dir"), list): + raise ValueError(f"无效规则:{rule},dir 必须是列表") + if not isinstance(rule.get("pattern"), str): + raise ValueError(f"无效规则:{rule},pattern 必须是字符串") + if not isinstance(rule.get("multiplier"), dict): + raise ValueError(f"无效规则:{rule},multiplier 必须是字典") + rule["compiled_pattern"] = re.compile(rule["pattern"]) + return rules + except Exception as e: + print(f"读取配置文件失败:{e},使用默认规则(无扩充)。") + return [] + +def augment_data(action, rules): + # 检查每个规则 + for rule in rules: + try: + field_value = reduce(lambda d, k: d[k], rule["dir"], action) + except (KeyError, TypeError): + continue + if not isinstance(field_value, str): + continue + if rule["compiled_pattern"].search(field_value): + return rule["multiplier"] + return {"other": 1} + +@dataclass +class AlpacaImageEntry: + instruction: str + output: str + images: List[str] + input: str = "" + +executor_prompt = load_prompt("grounder_coordinates.md") +executor_prompt_bbox = load_prompt("grounder_bbox.md") + +decider_prompt = load_prompt("decider.md") +decider_prompt_no_history = load_prompt("decider_nohistory.md") + +main_page_classification_prompt = ''' + +Is this screenshot the main page of the current app? Your answer can only be "yes" or "no". +''' + +def construct_main_page_classification_ds(data_path, out_path, factor=0.5, train_ratio=0.9): + if not os.path.exists(out_path): + raise RuntimeError(f"Output path {out_path} does not exist. Make sure out_path is the same as construct_ds") + entries_train = [] + entries_val = [] + + main_pages = [] + other_pages = [] + for root, dirs, files in os.walk(data_path): + if len(files) == 0: + continue + if "react.json" not in files or "actions.json" not in files or "parse.error" in files: + continue + if "1.jpg" not in files: + continue + idx = 1 + while f"{idx}.jpg" in files: + idx += 1 + largest_idx = idx - 1 + for i in range(1, largest_idx + 1): + img_path = os.path.join(root, f"{i}.jpg") + relative_path = os.path.relpath(img_path, data_path) + safe_filename = relative_path.replace(os.sep, "_").replace(":", "_") + safe_filename = f"main_{safe_filename}" + out_relpath = os.path.join(out_path, safe_filename) + out_abspath = os.path.abspath(out_relpath) + if not os.path.exists(out_abspath): + raise RuntimeError(f"Image {out_abspath} does not exist. Make sure out_path is the same as construct_ds") + if i == 1: + main_pages.append(out_abspath) + else: + other_pages.append(out_abspath) + other_pages = random.sample(other_pages, len(other_pages) // 2) + for pages in [main_pages, other_pages]: + output = "yes" if pages is main_pages else "no" + entries = [] + for abspath in pages: + entry = AlpacaImageEntry( + instruction=main_page_classification_prompt, + output=output, + images=[abspath] + ) + entries.append(entry) + random.shuffle(entries) + split_idx = int(len(entries) * train_ratio) + entries_train.extend(entries[:split_idx]) + entries_val.extend(entries[split_idx:]) + + print(f"main_page_classification entries_train: {len(entries_train)}") + print(f"main_page_classification entries_val: {len(entries_val)}") + + with open(os.path.join(out_path, "main_page_train.json"), "w", encoding="utf-8") as f: + json.dump([asdict(entry) for entry in entries_train], f, ensure_ascii=False) + with open(os.path.join(out_path, "main_page_val.json"), "w", encoding="utf-8") as f: + json.dump([asdict(entry) for entry in entries_val], f, ensure_ascii=False) + +def construct_ss_data(single_step_data_path, out_path, factor=0.5, train_ratio=0.9): + if not os.path.exists(single_step_data_path): + return + + augment_config_path = os.path.join(os.path.dirname(__file__), 'augment_config.json') + rules = load_augmentation_rules(augment_config_path) + + # 初始化所有返回变量 + decider_ss_entry_train = [] + decider_ss_entry_val = [] + grounder_ss_entry_train = [] + grounder_ss_entry_val = [] + + decider_ss_path = os.path.join(single_step_data_path, "decider") + if os.path.exists(decider_ss_path): + for root, dirs, files in tqdm(os.walk(decider_ss_path), desc="constructing single step decider dataset"): + if len(files) == 0: + continue + if "react.json" not in files: + continue + if "tasks.json" not in files: + continue + + react_path = os.path.join(root, "react.json") + with open(react_path, "r", encoding="UTF-8") as f: + react_data = json.load(f) + + tasks_path = os.path.join(root, "tasks.json") + with open(tasks_path, "r", encoding="UTF-8") as f: + tasks = json.load(f) + + for i, react in enumerate(react_data, 1): + is_train = random.random() < train_ratio + + augment_rule = augment_data(react, rules) + + img_path = os.path.join(root, f"{i}.jpg") + pil_img = Image.open(img_path) + width, height = pil_img.size + new_width = int(width * factor) + new_height = int(height * factor) + resized_img = pil_img.resize((new_width, new_height), Image.LANCZOS) + + relative_path = os.path.relpath(img_path, single_step_data_path) + safe_filename = relative_path.replace(os.sep, "_").replace(":", "_") + safe_filename = f"ss_{safe_filename}" + out_relpath = os.path.join(out_path, safe_filename) + resized_img.save(out_relpath) + out_abspath = os.path.abspath(out_relpath) + + reasoning = react["reasoning"] + action = react["function"]["name"] + param = react["function"]["parameters"] + + random_tasks = random.sample(tasks, 1) + + for task in random_tasks: + instruction = decider_prompt_no_history.format(task=task) + output_dict = dict(reasoning=reasoning, action=action, parameters=param) + output = json.dumps(output_dict, ensure_ascii=False) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + if is_train: + num = augment_rule.get("reason_no_history", augment_rule.get("other", 1)) + decider_ss_entry_train.extend([entry] * num) + else: + decider_ss_entry_val.append(entry) + + grounder_ss_path = os.path.join(single_step_data_path, "grounder") + if os.path.exists(grounder_ss_path): + for root, dirs, files in tqdm(os.walk(grounder_ss_path), desc="constructing single step grounder dataset"): + if len(files) == 0: + continue + if "react.json" not in files: + continue + + react_path = os.path.join(root, "react.json") + with open(react_path, "r", encoding="UTF-8") as f: + react_data = json.load(f) + + for i, react in enumerate(react_data, 1): + is_train = random.random() < train_ratio + + img_path = os.path.join(root, f"{i}.jpg") + pil_img = Image.open(img_path) + width, height = pil_img.size + new_width = int(width * factor) + new_height = int(height * factor) + resized_img = pil_img.resize((new_width, new_height), Image.LANCZOS) + + relative_path = os.path.relpath(img_path, single_step_data_path) + safe_filename = relative_path.replace(os.sep, "_").replace(":", "_") + safe_filename = f"ss_{safe_filename}" + out_relpath = os.path.join(out_path, safe_filename) + resized_img.save(out_relpath) + out_abspath = os.path.abspath(out_relpath) + + reasoning = react["reasoning"] + action = react["function"]["name"] + param = react["function"]["parameters"] + + # grounder训练集 + if action == "click": + bbox = react["bbox"] + x1, y1 ,x2 ,y2 = bbox + x = (x1 + x2) // 2 + y = (y1 + y2) // 2 + coords = [int(x * factor), int(y * factor)] + + instruction = executor_prompt.format(reasoning=reasoning, description=param["target_element"]) + output = json.dumps(dict(coordinates=coords)) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + if is_train: + grounder_ss_entry_train.extend([entry]) + else: + grounder_ss_entry_val.append(entry) + + bbox = [int(x * factor) for x in bbox] + output = json.dumps(dict(bbox=bbox)) + instruction = executor_prompt_bbox.format(reasoning=reasoning, description=param["target_element"]) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + if is_train: + num = augment_rule.get("grounder", augment_rule.get("other", 1)) + grounder_ss_entry_train.extend([entry] * num) + else: + grounder_ss_entry_val.append(entry) + + return decider_ss_entry_train, decider_ss_entry_val, grounder_ss_entry_train, grounder_ss_entry_val + +def construct_ds(data_path, single_step_data_path, unexpected_img_path, out_path, factor=0.5, train_ratio=0.9): + os.makedirs(out_path, exist_ok=True) + + # 训练集 + reason_entries_train = [] + shift_entries_train = [] + terminate_entries_train = [] + reason_no_history_entries_train = [] + grounder_entries_train = [] + + # 验证集 + reason_entries_val = [] + shift_entries_val = [] + terminate_entries_val = [] + reason_no_history_entries_val = [] + grounder_entries_val = [] + + augment_config_path = os.path.join(os.path.dirname(__file__), 'augment_config.json') + rules = load_augmentation_rules(augment_config_path) + + #TODO: unexpected_img_path 不存在情况 + unexpected_img_dir = os.path.abspath(unexpected_img_path) + unexpected_img_paths = os.listdir(unexpected_img_dir) + unexpected_img_paths = [os.path.join(unexpected_img_dir, img) for img in unexpected_img_paths] + + unexpected_img_safe_abspaths = [] + for unexpected_img_path in unexpected_img_paths: + pil_img = Image.open(unexpected_img_path) + width, height = pil_img.size + new_width = int(width * factor) + new_height = int(height * factor) + resized_img = pil_img.resize((new_width, new_height), Image.LANCZOS) + + relative_path = os.path.relpath(unexpected_img_path, unexpected_img_dir) + safe_filename = relative_path.replace(os.sep, "_").replace(":", "_") + safe_filename = f"unexpected_{safe_filename}" + out_relpath = os.path.join(out_path, safe_filename) + resized_img.save(out_relpath) + out_abspath = os.path.abspath(out_relpath) + unexpected_img_safe_abspaths.append(out_abspath) + + for root, dirs, files in tqdm(os.walk(data_path), desc="constructing dataset"): + if len(files) == 0: + continue + if "actions.json" not in files or "react.json" not in files or "parse.error" in files: + continue + + actions_json = os.path.join(root, "actions.json") + with open(actions_json, 'r', encoding='utf-8') as file: + data = json.load(file) + task_description = data.get("task_description") + actions = data.get("actions") + react_json = os.path.join(root, "react.json") + with open(react_json, "r", encoding="UTF-8") as f: + react_data = json.load(f) + + # 多模式适配 将没有done的react补充done,目前全部修正带done + index = 1 + while f"{index}.jpg" in files: + index += 1 + num_img = index - 1 + if num_img == len(react_data) + 1: + done_reasoning = "我已经完成了目标任务,任务已结束。" + react_data.append( + { + "reasoning": done_reasoning, + "function": { + "name": "done", + "parameters": {} + } + } + ) + elif num_img != len(react_data): + print(f"Warning: Number of images ({num_img}) does not match number of ReAct entries ({len(react_data)}) in {root}. Skipping this directory.") + continue + + history = [] + for i, react in enumerate(react_data, 1): + is_train = random.random() < train_ratio + + augment_rule = augment_data(react, rules) + + # Resize image并保存在同一目录下 + img_path = os.path.join(root, f"{i}.jpg") + pil_img = Image.open(img_path) + width, height = pil_img.size + new_width = int(width * factor) + new_height = int(height * factor) + resized_img = pil_img.resize((new_width, new_height), Image.LANCZOS) + + relative_path = os.path.relpath(img_path, data_path) + safe_filename = relative_path.replace(os.sep, "_").replace(":", "_") + safe_filename = f"main_{safe_filename}" + out_relpath = os.path.join(out_path, safe_filename) + resized_img.save(out_relpath) + out_abspath = os.path.abspath(out_relpath) + + # 获取相关参数 + reasoning = react["reasoning"] + action_type = react["function"]["name"] + param = react["function"]["parameters"] + + output_dict = dict(reasoning=reasoning, action=action_type, parameters=param) + output = json.dumps(output_dict, ensure_ascii=False) + + # partial_histories是当前action的前几个action + # 对input类和done类型特殊处理 + if action_type == "input" or action_type == "done": + min_history_length = min(3, len(history)) + partial_histories = [history[i:] for i in range(len(history) + 1 - min_history_length)] + else: + partial_histories = [history[i:] for i in range(len(history) + 1)] + + partial_history_entries = [] + + for partial_history in partial_histories: + if len(partial_history) == 0: + partial_history_str = "(No history)" + else: + partial_history_str = "\n".join(f"{idx}. {h}" for idx, h in enumerate(partial_history, 1)) + + if(isinstance(task_description, list)): + weight = calculate_index_weight(i, len(actions)) + weight = min(weight, len(task_description)) + random_tasks = random.sample(task_description, weight) + for task in random_tasks: + instruction = decider_prompt.format(task=task, history=partial_history_str) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + partial_history_entries.append(entry) + else: + instruction = decider_prompt.format(task=task_description, history=partial_history_str) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + partial_history_entries.append(entry) + + history.append(output) + + shifted_history_entry = [] + terminate_history_entry = [] + + synthesize_terminate = action_type != "wait" and action_type != "done" and action_type != "swipe" + # synthesize terminate samples + if synthesize_terminate: + terminate_list1 = [ + "当前页面未按预期加载", + "进入了错误的页面", + "打开了不合预期的页面", + "当前打开了错误页面", + "当前页面不合预期" + ] + terminate_list2 = [ + "需要用户介入", + "需要用户接管", + "任务无法继续执行" + ] + terminate_list3 = [ + "任务提前结束", + "中止任务执行" + ] + + terminate_reasoning = ",".join(map(random.choice, [terminate_list1, terminate_list2, terminate_list3])) + terminate_output_dict = dict(reasoning=terminate_reasoning, action="done", parameters={}) + terminate_output = json.dumps(terminate_output_dict, ensure_ascii=False) + + history_str = "\n".join(f"{idx}. {h}" for idx, h in enumerate(history, 1)) + if(isinstance(task_description, list)): + weight = 1 + random_tasks = random.sample(task_description, weight) + for task in random_tasks: + instruction = decider_prompt.format(task=task, history=history_str) + unexpected_img_abspath = random.choice(unexpected_img_safe_abspaths) + entry = AlpacaImageEntry( + instruction=instruction, + output=terminate_output, + images=[unexpected_img_abspath] + ) + terminate_history_entry.append(entry) + else: + instruction = decider_prompt.format(task=task_description, history=history_str) + unexpected_img_abspath = random.choice(unexpected_img_safe_abspaths) + entry = AlpacaImageEntry( + instruction=instruction, + output=terminate_output, + images=[unexpected_img_abspath] + ) + terminate_history_entry.append(entry) + + synthesize_retry = synthesize_terminate + if synthesize_retry: + # i+1.jpg must exist since action_type is not done + cv2_img = cv2.imread(os.path.join(root, f"{i}.jpg"), cv2.IMREAD_GRAYSCALE) + next_cv2_img = cv2.imread(os.path.join(root, f"{i + 1}.jpg"), cv2.IMREAD_GRAYSCALE) + if cv2_img.shape != next_cv2_img.shape: + next_cv2_img = cv2.resize(next_cv2_img, (cv2_img.shape[1], cv2_img.shape[0])) + ssim_value = ssim(cv2_img, next_cv2_img) + synthesize_retry = ssim_value < 0.9 + + # synthesize retry samples + if synthesize_retry: + retry_list1 = [ + "应用未响应", + "上一个操作没有成功", + "操作未响应", + "上一动作未正常执行" + ] + retry_list2 = [ + "需要重新执行上一个动作", + "需要再执行一次上一个操作", + "我需要进行重试", + ] + + retry_reasoning = ",".join(map(random.choice, [retry_list1, retry_list2])) + retry_output_dict = dict(reasoning=retry_reasoning, action=action_type, parameters=param) + retry_output = json.dumps(retry_output_dict, ensure_ascii=False) + + history_str = "\n".join(f"{idx}. {h}" for idx, h in enumerate(history, 1)) + if(isinstance(task_description, list)): + weight = 1 + random_tasks = random.sample(task_description, weight) + for task in random_tasks: + instruction = decider_prompt.format(task=task, history=history_str) + entry = AlpacaImageEntry( + instruction=instruction, + output=retry_output, + images=[out_abspath] + ) + shifted_history_entry.append(entry) + else: + instruction = decider_prompt.format(task=task_description, history=history_str) + entry = AlpacaImageEntry( + instruction=instruction, + output=retry_output, + images=[out_abspath] + ) + shifted_history_entry.append(entry) + + # 有历史action训练集 + full_history_entry = partial_history_entries[0] + partial_history_entries = partial_history_entries[1:] + partial_history_entries = random.sample(partial_history_entries, min(2, len(partial_history_entries))) + + # 按比例分配到训练集和验证集(在增强前分配) + if is_train: + num = augment_rule.get("reason", augment_rule.get("other", 1)) + reason_entries_train.extend((partial_history_entries + [full_history_entry]) * num) + shift_entries_train.extend(shifted_history_entry * num) + terminate_entries_train.extend(terminate_history_entry * num) + else: + reason_entries_val.extend(partial_history_entries + [full_history_entry]) + shift_entries_val.extend(shifted_history_entry) + terminate_entries_val.extend(terminate_history_entry) + + # 无历史action训练集 (input类型不生成no history数据) + if action_type != "done" and action_type != "input": + no_history_entries = [] + if(isinstance(task_description, list)): + weight = calculate_index_weight(i, len(actions)) + weight = min(weight, len(task_description)) + random_tasks = random.sample(task_description, weight) + for task in random_tasks: + instruction = decider_prompt_no_history.format(task=task) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + no_history_entries.append(entry) + else: + instruction = decider_prompt_no_history.format(task=task_description) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + no_history_entries.append(entry) + + # 按比例分配到训练集和验证集(在增强前分配) + if is_train: + num = augment_rule.get("reason_no_history", augment_rule.get("other", 1)) + reason_no_history_entries_train.extend(no_history_entries * num) + else: + reason_no_history_entries_val.extend(no_history_entries) + + # grounder训练集 + if action_type == "click": + action = actions[i - 1] + coords = [int(action["position_x"]* factor), int(action["position_y"]* factor)] + bbox = action.get("bounds", None) + instruction = executor_prompt.format(reasoning=reasoning, description=param["target_element"]) + output = json.dumps(dict(coordinates=coords)) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + + # 按比例分配到训练集和验证集(在增强前分配) + if is_train: + num = augment_rule.get("grounder", augment_rule.get("other", 1)) + grounder_entries_train.extend([entry] * num) + else: + grounder_entries_val.append(entry) + + if bbox: + bbox = [int(x * factor) for x in bbox] + output = json.dumps(dict(bbox=bbox)) + instruction = executor_prompt_bbox.format(reasoning=reasoning, description=param["target_element"]) + entry = AlpacaImageEntry( + instruction=instruction, + output=output, + images=[out_abspath] + ) + + # 按比例分配到训练集和验证集(在增强前分配) + if is_train: + num = augment_rule.get("grounder", augment_rule.get("other", 1)) + grounder_entries_train.extend([entry] * num) + else: + grounder_entries_val.append(entry) + + decider_ss_entry_train, decider_ss_entry_val, grounder_ss_entry_train, grounder_ss_entry_val = construct_ss_data(single_step_data_path, out_path, factor, train_ratio) + + # 合并训练集数据 + shift_entries_train = random.sample(shift_entries_train, len(shift_entries_train) // 8) + shift_entries_val = random.sample(shift_entries_val, len(shift_entries_val) // 8) + terminate_entries_train = random.sample(terminate_entries_train, len(terminate_entries_train) // 5) + terminate_entries_val = random.sample(terminate_entries_val, len(terminate_entries_val) // 5) + + print(f"reason_entries_train: {len(reason_entries_train)}") + print(f"reason_entries_no_history_train: {len(reason_no_history_entries_train)}") + print(f"shift_entries_train: {len(shift_entries_train)}") + print(f"terminate_entries_train: {len(terminate_entries_train)}") + print(f"grounder_entries_train: {len(grounder_entries_train)}") + print(f"decider_ss_entry_train: {len(decider_ss_entry_train)}") + print(f"grounder_ss_entry_train: {len(grounder_ss_entry_train)}") + print(f"\n") + + data = { + "reason_entries_train": len(reason_entries_train), + "reason_entries_no_history_train": len(reason_no_history_entries_train), + "shift_entries_train": len(shift_entries_train), + "terminate_entries_train": len(terminate_entries_train), + "grounder_entries_train": len(grounder_entries_train), + "decider_ss_entry_train": len(decider_ss_entry_train), + "grounder_ss_entry_train": len(grounder_ss_entry_train) + } + + decider_entries_train = [asdict(entry) for entry in reason_entries_train] + decider_entries_train.extend([asdict(entry) for entry in reason_no_history_entries_train]) + decider_entries_train.extend([asdict(entry) for entry in shift_entries_train]) + decider_entries_train.extend([asdict(entry) for entry in terminate_entries_train]) + decider_entries_train.extend([asdict(entry) for entry in decider_ss_entry_train]) + # random.shuffle(decider_entries_train) + + grounder_entries_train = [asdict(entry) for entry in grounder_entries_train] + grounder_entries_train.extend([asdict(entry) for entry in grounder_ss_entry_train]) + # random.shuffle(grounder_entries_train) + + # 合并验证集数据 + print(f"reason_entries_val: {len(reason_entries_val)}") + print(f"reason_entries_no_history_val: {len(reason_no_history_entries_val)}") + print(f"shift_entries_val: {len(shift_entries_val)}") + print(f"terminate_entries_val: {len(terminate_entries_val)}") + print(f"grounder_entries_val: {len(grounder_entries_val)}") + print(f"decider_ss_entry_val: {len(decider_ss_entry_val)}") + print(f"grounder_ss_entry_val: {len(grounder_ss_entry_val)}") + + # 添加验证集统计信息到data字典 + data.update({ + "reason_entries_val": len(reason_entries_val), + "reason_entries_no_history_val": len(reason_no_history_entries_val), + "shift_entries_val": len(shift_entries_val), + "terminate_entries_val": len(terminate_entries_val), + "grounder_entries_val": len(grounder_entries_val), + "decider_ss_entry_val": len(decider_ss_entry_val), + "grounder_ss_entry_val": len(grounder_ss_entry_val) + }) + + decider_entries_val = [asdict(entry) for entry in reason_entries_val] + decider_entries_val.extend([asdict(entry) for entry in reason_no_history_entries_val]) + decider_entries_val.extend([asdict(entry) for entry in shift_entries_val]) + decider_entries_val.extend([asdict(entry) for entry in terminate_entries_val]) + decider_entries_val.extend([asdict(entry) for entry in decider_ss_entry_val]) + # random.shuffle(decider_entries_val) + + grounder_entries_val_dict = [asdict(entry) for entry in grounder_entries_val] + grounder_entries_val_dict.extend([asdict(entry) for entry in grounder_ss_entry_val]) + # random.shuffle(grounder_entries_val_dict) + + # 保存训练集 + with open(os.path.join(out_path, f"general_decider_train.json"), "w", encoding="UTF-8") as f: + json.dump(decider_entries_train, f, ensure_ascii=False) + with open(os.path.join(out_path, f"general_grounder_train.json"), "w", encoding="UTF-8") as f: + json.dump(grounder_entries_train, f, ensure_ascii=False) + + # 保存验证集 + with open(os.path.join(out_path, f"general_decider_val.json"), "w", encoding="UTF-8") as f: + json.dump(decider_entries_val, f, ensure_ascii=False) + with open(os.path.join(out_path, f"general_grounder_val.json"), "w", encoding="UTF-8") as f: + json.dump(grounder_entries_val_dict, f, ensure_ascii=False) + + with open(os.path.join(out_path, f"metadata.json"), "w", encoding="UTF-8") as f: + json.dump(data, f, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training dataset construction with Alpaca format") + parser.add_argument("--data_path", type=str, default="data", help="root path of raw data (default: data)") + parser.add_argument("--ss_data_path", type=str, default="ss_data", help="root path of single-step data (default: ss_data)") + parser.add_argument("--unexpected_img_path", type=str, default="unexpected_img", help="root path of unexpected image data (default: unexpected_data)") + parser.add_argument("--out_path", type=str, default="output", help="output path of train dataset (default: output)") + parser.add_argument("--factor", type=float, default=0.5, help="resize factor for images (default: 0.5)") + parser.add_argument("--train_ratio", type=float, default=0.9, help="ratio of training data (default: 0.9)") + args = parser.parse_args() + construct_ds( + data_path=args.data_path, + single_step_data_path=args.ss_data_path, + unexpected_img_path=args.unexpected_img_path, + out_path=args.out_path, + factor=args.factor, + train_ratio=args.train_ratio, + ) + construct_main_page_classification_ds( + data_path=args.data_path, + out_path=args.out_path, + factor=args.factor, + train_ratio=args.train_ratio + ) \ No newline at end of file diff --git a/collect/manual/server.py b/collect/manual/server.py new file mode 100644 index 0000000..1b0d334 --- /dev/null +++ b/collect/manual/server.py @@ -0,0 +1,406 @@ +from fastapi import FastAPI, HTTPException +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +import os +import time +import json +import base64 +import shutil +import uvicorn +import uiautomator2 as u2 +import sys +import os + +from utils.parse_xml import find_clicked_element + +# 数据模型 +class ClickAction(BaseModel): + x: int + y: int + +class SwipeAction(BaseModel): + startX: int + startY: int + endX: int + endY: int + direction: str # 'up', 'down', 'left', 'right' + +class InputAction(BaseModel): + text: str + +class TaskDescription(BaseModel): + description: str + app_name: str + task_type: str + +screenshot_path = "screenshot.jpg" + +currentDataIndex = 0 +action_history = [] +current_task_description = "" # 当前任务描述 +current_app_name = "" # 当前应用名称 +current_task_type = "" # 当前任务类型 + +device = None # 设备连接对象 +hierarchy = None # 层次结构数据 + +app = FastAPI() + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中应该设置具体的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 挂载静态文件服务 +static_dir = os.path.join(os.path.dirname(__file__), "static") +app.mount("/static", StaticFiles(directory=static_dir), name="static") + +def save_screenshot(): + action_count = len(action_history) + + # 创建数据目录 + session_base_dir = os.path.dirname(__file__) + data_base_dir = os.path.join(session_base_dir, 'data') + app_dir = os.path.join(data_base_dir, current_app_name) + task_type_dir = os.path.join(app_dir, current_task_type) + data_dir = os.path.join(task_type_dir, str(currentDataIndex)) + + # 复制当前截图到数据目录 + if os.path.exists(screenshot_path): + screenshot_save_path = os.path.join(data_dir, f'{action_count + 1}.jpg') + shutil.copy2(screenshot_path, screenshot_save_path) + +def get_current_hierarchy_and_screenshot(sleep_time = 0): + global hierarchy + time.sleep(sleep_time) + hierarchy = device.dump_hierarchy() + + # with open("hierarchy.xml", "w", encoding="utf-8") as f: + # f.write(hierarchy) + + device.screenshot(screenshot_path) + print(f"操作完成,已重新截图和获取层次结构。总操作数: {len(action_history)}") + +@app.get("/", response_class=HTMLResponse) +async def read_root(): + """返回前端页面""" + html_path = os.path.join(os.path.dirname(__file__), "static", "index.html") + with open(html_path, "r", encoding="utf-8") as f: + html_content = f.read() + return HTMLResponse(content=html_content) + +@app.get("/screenshot") +async def get_screenshot(): + """获取最新截图文件和层次结构信息""" + try: + get_current_hierarchy_and_screenshot() + with open(screenshot_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode('utf-8') + + return { + "status": "success", + "image_data": f"data:image/jpeg;base64,{image_data}", + "hierarchy": hierarchy, + "timestamp": int(time.time() * 1000) + } + + except Exception as e: + raise HTTPException(status_code=500, detail=f"获取截图失败: {str(e)}") + +@app.post("/click") +async def handle_click(action: ClickAction): + """处理点击操作""" + try: + # 确保坐标为整数(舍入) + x = round(action.x) + y = round(action.y) + + element_bounds = find_clicked_element(hierarchy, x, y) + if element_bounds: + element_bounds = [round(coord) for coord in element_bounds] + + get_current_hierarchy_and_screenshot() + save_screenshot() + device.click(x, y) + action_record = { + "type": "click", + "position_x": x, + "position_y": y, + "bounds": element_bounds, + } + print(action_record) + action_history.append(action_record) + # get_current_hierarchy_and_screenshot(1.5) + + return { + "status": "success", + "message": f"点击操作完成: ({x}, {y})", + "action": "click", + "coordinates": {"x": x, "y": y}, + "clicked_bounds": element_bounds, + "action_count": len(action_history) + } + + except Exception as e: + print(f"点击操作失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"点击操作失败: {str(e)}") + +@app.post("/swipe") +async def handle_swipe(action: SwipeAction): + """处理滑动操作""" + try: + # 确保坐标为整数(舍入) + startX = round(action.startX) + startY = round(action.startY) + endX = round(action.endX) + endY = round(action.endY) + + get_current_hierarchy_and_screenshot() + save_screenshot() + device.swipe(startX, startY, endX, endY, duration=0.1) + action_record = { + "type": "swipe", + "press_position_x": startX, + "press_position_y": startY, + "release_position_x": endX, + "release_position_y": endY, + "direction": action.direction, + } + print(action_record) + action_history.append(action_record) + # get_current_hierarchy_and_screenshot(1.5) + + return { + "status": "success", + "message": f"滑动操作完成: ({startX}, {startY}) → ({endX}, {endY}) [{action.direction}]", + "action": "swipe", + "start": {"x": startX, "y": startY}, + "end": {"x": endX, "y": endY}, + "direction": action.direction, + "action_count": len(action_history) + } + + except Exception as e: + print(f"滑动操作失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"滑动操作失败: {str(e)}") + +@app.post("/input") +async def handle_input(action: InputAction): + try: + get_current_hierarchy_and_screenshot() + save_screenshot() + current_ime = device.current_ime() + device.shell(['settings', 'put', 'secure', 'default_input_method', 'com.android.adbkeyboard/.AdbIME']) + time.sleep(0.5) + charsb64 = base64.b64encode(action.text.encode('utf-8')).decode('utf-8') + device.shell(['am', 'broadcast', '-a', 'ADB_INPUT_B64', '--es', 'msg', charsb64]) + time.sleep(0.5) + device.shell(['settings', 'put', 'secure', 'default_input_method', current_ime]) + action_record = { + "type": "input", + "text": action.text, + } + print(action_record) + action_history.append(action_record) + # get_current_hierarchy_and_screenshot(1.5) + + return { + "status": "success", + "message": f"输入操作完成", + "action": "input", + "text": action.text, + "action_count": len(action_history) + } + + except Exception as e: + print(f"输入操作失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"输入操作失败: {str(e)}") + +@app.get("/action_history") +async def get_action_history(): + """获取操作历史记录""" + return { + "status": "success", + "total_actions": len(action_history), + "actions": action_history + } + +@app.post("/save_data") +async def save_current_data(): + """保存当前数据并清空历史记录""" + global currentDataIndex + global action_history + + try: + get_current_hierarchy_and_screenshot() + save_screenshot() + action_record = { + "type": "done" + } + action_history.append(action_record) + action_count = len(action_history) + + app_dir = os.path.join(os.path.dirname(__file__), 'data', current_app_name) + task_type_dir = os.path.join(app_dir, current_task_type) + data_dir = os.path.join(task_type_dir, str(currentDataIndex)) + json_file_path = os.path.join(data_dir, 'actions.json') + + save_data = { + "app_name": current_app_name, + "task_type": current_task_type, + "task_description": current_task_description, + "action_count": action_count, + "actions": action_history + } + with open(json_file_path, 'w', encoding='utf-8') as f: + json.dump(save_data, f, ensure_ascii=False, indent=4) + + action_history.clear() + + # [Info] + print(f"第 {currentDataIndex} 条数据已保存") + print(f"应用:{current_app_name} | 任务类型:{current_task_type}") + print(f"包含 {action_count} 个操作记录") + print("操作历史记录已清空") + + return { + "status": "success", + "message": f"第 {currentDataIndex} 条数据已保存", + "data_index": currentDataIndex, + "saved_actions": action_count + } + except Exception as e: + print(f"保存数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存数据失败: {str(e)}") + +@app.post("/delete_data") +async def delete_current_data(): + """保存当前数据并清空历史记录""" + global currentDataIndex + + try: + app_dir = os.path.join(os.path.dirname(__file__), 'data', current_app_name) + task_type_dir = os.path.join(app_dir, current_task_type) + data_dir = os.path.join(task_type_dir, str(currentDataIndex)) + + # 删除数据目录 + if os.path.exists(data_dir): + shutil.rmtree(data_dir) + + action_history.clear() + + return { + "status": "success", + "message": f"第 {currentDataIndex} 条数据已删除", + "data_index": currentDataIndex + } + except Exception as e: + print(f"保存数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存数据失败: {str(e)}") + + +app_packages ={ + "微信": "com.tencent.mm", + "QQ": "com.tencent.mobileqq", + "微博": "com.sina.weibo", + + "饿了么": "me.ele", + "美团": "com.sankuai.meituan", + + "bilibili": "tv.danmaku.bili", + "爱奇艺": "com.qiyi.video", + "腾讯视频": "com.tencent.qqlive", + "优酷": "com.youku.phone", + + "淘宝": "com.taobao.taobao", + "京东": "com.jingdong.app.mall", + + "携程": "ctrip.android.view", + "同城": "com.tongcheng.android", + "飞猪": "com.taobao.trip", + "去哪儿": "com.Qunar", + "华住会": "com.htinns", + + "知乎": "com.zhihu.android", + "小红书": "com.xingin.xhs", + + "QQ音乐": "com.tencent.qqmusic", + "网易云音乐": "com.netease.cloudmusic", + "酷狗音乐": "com.kugou.android", + + "高德地图": "com.autonavi.minimap" +} + +@app.post("/set_task_description") +async def set_task_description(task: TaskDescription): + """设置任务描述""" + global currentDataIndex + global current_task_description + global current_app_name + global current_task_type + try: + current_app_name = task.app_name + current_task_type = task.task_type + current_task_description = task.description + + # 创建新的目录结构:data/<应用名称>/<任务类型>/<数据索引>/ + session_base_dir = os.path.dirname(__file__) + if not os.path.exists(session_base_dir): + os.makedirs(session_base_dir) + + data_base_dir = os.path.join(session_base_dir, 'data') + if not os.path.exists(data_base_dir): + os.makedirs(data_base_dir) + + app_dir = os.path.join(data_base_dir, current_app_name) + if not os.path.exists(app_dir): + os.makedirs(app_dir) + + task_type_dir = os.path.join(app_dir, current_task_type) + if not os.path.exists(task_type_dir): + os.makedirs(task_type_dir) + + # 遍历现有数据目录,找到最大的索引 + existing_dirs = [d for d in os.listdir(task_type_dir) if os.path.isdir(os.path.join(task_type_dir, d)) and d.isdigit()] + if existing_dirs: + currentDataIndex = max(int(d) for d in existing_dirs) + 1 + else: + currentDataIndex = 1 + data_dir = os.path.join(task_type_dir, str(currentDataIndex)) + os.makedirs(data_dir) + + print(f"\n{'='*50}") + print(f"📋 新任务开始") + print(f"应用名称: {current_app_name}") + print(f"任务类型: {current_task_type}") + print(f"任务描述: {current_task_description}") + print(f"数据目录: data/{current_app_name}/{current_task_type}/{currentDataIndex}/") + print(f"{'='*50}\n") + + package_name = app_packages.get(current_app_name) + if not package_name: + raise ValueError(f"App '{app}' is not registered with a package name.") + device.app_start(package_name, stop=True) + + return { + "status": "success", + "message": "任务描述已设置", + "description": current_task_description, + "app_name": current_app_name, + "task_type": current_task_type + } + except Exception as e: + print(f"设置任务描述失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"设置任务描述失败: {str(e)}") + +if __name__ == "__main__": + device = u2.connect() + print("启动服务器...") + print("访问 http://localhost:9000 查看前端页面") + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/collect/manual/static/css/style.css b/collect/manual/static/css/style.css new file mode 100644 index 0000000..6dafd51 --- /dev/null +++ b/collect/manual/static/css/style.css @@ -0,0 +1,622 @@ +/* 基础样式 */ +body { + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + margin: 0; + padding: 20px; + background-color: #f5f5f5; + min-height: 100vh; +} + +.container { + background: white; + border-radius: 10px; + box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); + display: flex; + min-height: calc(100vh - 40px); + overflow: hidden; +} + +/* 左侧截图区域 */ +.screenshot-section { + flex: 1; + padding: 20px; + background: #fafafa; + border-right: 1px solid #eee; + display: flex; + flex-direction: column; + min-width: 400px; +} + +/* 右侧控制区域 */ +.control-section { + flex: 0 0 450px; + padding: 30px; + display: flex; + flex-direction: column; + text-align: center; +} + +h1 { + color: #333; + margin-bottom: 30px; + text-align: center; +} + +/* 数据收集控制按钮 */ +.data-collection-controls { + margin: 30px 0; + display: flex; + gap: 20px; + justify-content: center; + flex-wrap: wrap; +} + +.start-btn { + background: linear-gradient(135deg, #27ae60 0%, #2ecc71 100%); +} + +.start-btn:hover:not(:disabled) { + box-shadow: 0 5px 15px rgba(39, 174, 96, 0.4); +} + +.end-btn { + background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%); +} + +.end-btn:hover:not(:disabled) { + box-shadow: 0 5px 15px rgba(231, 76, 60, 0.4); +} + +.next-btn { + background: linear-gradient(135deg, #3498db 0%, #2980b9 100%); +} + +.next-btn:hover:not(:disabled) { + box-shadow: 0 5px 15px rgba(52, 152, 219, 0.4); +} + +.delete-btn { + background: linear-gradient(135deg, #f39c12 0%, #e67e22 100%); +} + +.delete-btn:hover:not(:disabled) { + box-shadow: 0 5px 15px rgba(243, 156, 18, 0.4); +} + +.collection-status { + background: #f8f9fa; + border: 1px solid #dee2e6; + border-radius: 8px; + padding: 15px; + margin: 20px 0; + text-align: center; + font-weight: 500; +} + +.collection-status.collecting { + background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%); + border-color: #28a745; + color: #155724; +} + +/* 按钮样式 */ +.screenshot-btn { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + border: none; + padding: 15px 30px; + font-size: 18px; + border-radius: 25px; + cursor: pointer; + transition: transform 0.2s, box-shadow 0.2s; + margin-bottom: 30px; +} + +.screenshot-btn:hover { + transform: translateY(-2px); + box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); +} + +.screenshot-btn:disabled { + background: #ccc; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +.control-buttons { + margin: 20px 0; + display: flex; + gap: 15px; + justify-content: center; +} + +.control-btn { + background: linear-gradient(135deg, #27ae60 0%, #2ecc71 100%); + color: white; + border: none; + padding: 10px 20px; + font-size: 14px; + border-radius: 20px; + cursor: pointer; + transition: transform 0.2s, box-shadow 0.2s; +} + +.control-btn:hover { + transform: translateY(-1px); + box-shadow: 0 3px 10px rgba(39, 174, 96, 0.4); +} + +.control-btn:disabled { + background: #95a5a6; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +/* 截图容器样式 */ +.screenshot-container { + flex: 1; + border: 2px dashed #ddd; + border-radius: 10px; + padding: 20px; + display: flex; + align-items: center; + justify-content: center; + position: relative; + background: white; + margin: 0; +} + +.screenshot-img { + max-width: 100%; + max-height: calc(100vh - 120px); + width: auto; + height: auto; + border-radius: 10px; + box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); + cursor: crosshair; + user-select: none; +} + +/* 状态消息样式 */ +.loading { + color: #666; + font-style: italic; +} + +.error { + color: #e74c3c; + font-weight: bold; +} + +.success { + color: #27ae60; + font-weight: bold; +} + +/* 删除旧的滚动指示器样式 */ +.scroll-indicator { + display: none; +} + +/* 元素信息显示区域 */ +.element-info { + background: #f8f9fa; + border: 1px solid #dee2e6; + border-radius: 8px; + padding: 15px; + margin-top: 15px; + max-height: 300px; + overflow-y: auto; +} + +.element-info h3 { + margin: 0 0 10px 0; + color: #495057; + font-size: 16px; +} + +.element-info .element-item { + background: white; + border: 1px solid #e9ecef; + border-radius: 4px; + padding: 10px; + margin-bottom: 8px; +} + +.element-info .element-item:last-child { + margin-bottom: 0; +} + +.element-info .element-property { + margin: 3px 0; + font-size: 12px; +} + +.element-info .element-property strong { + color: #495057; +} + +.element-info .element-bounds { + color: #6c757d; + font-family: monospace; +} + +.element-info .element-text { + color: #28a745; + font-weight: 500; +} + +/* 响应式设计 */ +@media (max-width: 1024px) { + .container { + flex-direction: column; + min-height: auto; + } + + .screenshot-section { + min-width: auto; + border-right: none; + border-bottom: 1px solid #eee; + } + + .control-section { + flex: none; + } + + .screenshot-container { + min-height: 300px; + } + + .screenshot-img { + max-height: 400px; + } +} + +@media (max-width: 768px) { + body { + padding: 10px; + } + + .screenshot-section, + .control-section { + padding: 15px; + } + + .data-collection-controls { + flex-direction: column; + gap: 10px; + } + + .control-btn { + width: 100%; + margin: 5px 0; + } +} + +/* 操作提示 */ +.action-hint { + margin-top: 15px; + padding: 10px; + background-color: #f8f9fa; + border-left: 4px solid #3498db; + border-radius: 4px; + font-size: 14px; + color: #666; + text-align: left; +} + +.action-hint ul { + margin: 5px 0; + padding-left: 20px; +} + +.action-hint li { + margin: 3px 0; +} + +/* 交互状态指示 */ +.screenshot-img.interacting { + opacity: 0.8; + cursor: wait; +} + +/* 加载指示器 */ +.loading-spinner { + display: inline-block; + width: 20px; + height: 20px; + border: 2px solid #f3f3f3; + border-top: 2px solid #3498db; + border-radius: 50%; + animation: spin 1s linear infinite; + margin-right: 10px; +} + +@keyframes spin { + 0% { + transform: rotate(0deg); + } + + 100% { + transform: rotate(360deg); + } +} + +/* 历史记录弹窗样式 */ +.history-modal { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; +} + +.history-content { + background: white; + border-radius: 10px; + padding: 20px; + max-width: 80%; + max-height: 80%; + overflow-y: auto; + box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3); +} + +.history-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 20px; + border-bottom: 2px solid #eee; + padding-bottom: 10px; +} + +.close-btn { + background: #e74c3c; + color: white; + border: none; + border-radius: 50%; + width: 30px; + height: 30px; + cursor: pointer; + font-size: 16px; +} + +.action-item { + padding: 10px; + margin: 10px 0; + border: 1px solid #ddd; + border-radius: 5px; + background: #f9f9f9; +} + +.action-timestamp { + color: #666; + font-size: 12px; + margin-bottom: 5px; +} + +.action-details { + font-weight: bold; + color: #333; +} + +/* 任务描述弹窗样式 */ +.task-modal { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + justify-content: center; + z-index: 2000; +} + +.task-modal-content { + background: white; + border-radius: 10px; + padding: 0; + width: 90%; + max-width: 500px; + box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3); + animation: taskModalFadeIn 0.3s ease-out; +} + +.task-modal-header { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + padding: 20px; + border-radius: 10px 10px 0 0; + text-align: center; + position: relative; + display: flex; + justify-content: space-between; + align-items: center; +} + +.task-modal-header h3 { + margin: 0; + font-size: 18px; + flex: 1; +} + +.close-btn { + background: none; + border: none; + color: white; + font-size: 18px; + cursor: pointer; + padding: 5px; + border-radius: 50%; + width: 30px; + height: 30px; + display: flex; + align-items: center; + justify-content: center; + transition: background-color 0.2s; +} + +.close-btn:hover { + background-color: rgba(255, 255, 255, 0.2); +} + +.task-modal-body { + padding: 20px; +} + +.task-modal-body label { + display: block; + margin-bottom: 10px; + font-weight: 500; + color: #333; +} + +.task-modal-body input, +.task-modal-body select { + width: 100%; + padding: 12px; + border: 2px solid #ddd; + border-radius: 8px; + font-size: 14px; + font-family: inherit; + margin-bottom: 15px; + box-sizing: border-box; + background-color: white; + cursor: pointer; +} + +.task-modal-body select { + appearance: none; + background-image: url("data:image/svg+xml;charset=UTF-8,%3csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3e%3cpolyline points='6,9 12,15 18,9'%3e%3c/polyline%3e%3c/svg%3e"); + background-repeat: no-repeat; + background-position: right 12px center; + background-size: 16px; + padding-right: 40px; +} + +.task-modal-body input { + cursor: text; +} + +.task-modal-body input:focus, +.task-modal-body select:focus { + outline: none; + border-color: #667eea; + box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1); +} + +.task-modal-body textarea { + width: 100%; + padding: 12px; + border: 2px solid #ddd; + border-radius: 8px; + font-size: 14px; + font-family: inherit; + resize: vertical; + min-height: 100px; + box-sizing: border-box; +} + +.task-modal-body textarea:focus { + outline: none; + border-color: #667eea; + box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1); +} + +.task-hint { + margin-top: 10px; + padding: 8px 12px; + background: #f8f9fa; + border-left: 3px solid #667eea; + border-radius: 4px; + font-size: 12px; + color: #6c757d; + line-height: 1.4; +} + +.task-modal-footer { + padding: 20px; + display: flex; + justify-content: center; + gap: 10px; + border-top: 1px solid #eee; +} + +.task-btn { + padding: 10px 20px; + border: none; + border-radius: 6px; + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: all 0.2s ease; +} + +.cancel-btn { + background: #f8f9fa; + color: #6c757d; + border: 1px solid #dee2e6; +} + +.cancel-btn:hover { + background: #e9ecef; + border-color: #adb5bd; +} + +.confirm-btn { + background: linear-gradient(135deg, #27ae60 0%, #2ecc71 100%); + color: white; +} + +.confirm-btn:hover { + transform: translateY(-1px); + box-shadow: 0 4px 12px rgba(39, 174, 96, 0.3); +} + +.confirm-btn:disabled { + background: #95a5a6; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +@keyframes taskModalFadeIn { + from { + opacity: 0; + transform: scale(0.9) translateY(-20px); + } + + to { + opacity: 1; + transform: scale(1) translateY(0); + } +} + +/* 自动刷新状态显示样式 */ +.auto-refresh-status { + background: #d4edda; + border: 1px solid #c3e6cb; + border-radius: 8px; + padding: 10px; + margin: 15px 0; + text-align: center; + color: #155724; +} + +.refresh-indicator { + font-size: 14px; + font-weight: 500; + display: flex; + align-items: center; + justify-content: center; + gap: 8px; +} \ No newline at end of file diff --git a/collect/manual/static/index.html b/collect/manual/static/index.html new file mode 100644 index 0000000..f6288bf --- /dev/null +++ b/collect/manual/static/index.html @@ -0,0 +1,173 @@ + + + + + + + 安卓设备数据收集工具 + + + + +
+ +
+
+
请先开始数据收集
+
+
+ + +
+

📱 安卓设备数据收集工具

+ +
+ + + + +
+ +
+
点击"开始收集"按钮开始数据收集
+
+ +
+ + + +
+ + + + +
+ + + + + + +
+
+ + + + + + + + + + + + + + \ No newline at end of file diff --git a/collect/manual/static/js/script.js b/collect/manual/static/js/script.js new file mode 100644 index 0000000..2348c04 --- /dev/null +++ b/collect/manual/static/js/script.js @@ -0,0 +1,1288 @@ +// 全局变量 +let screenshotImg = null; +let isInteracting = false; +let isDragging = false; +let dragStartX = 0; +let dragStartY = 0; +let dragStartTime = 0; +let isCollecting = false; +let currentTaskDescription = ''; // 当前任务描述 +let currentAppName = ''; // 当前应用名称 +let currentTaskType = ''; // 当前任务类型 +let currentElements = []; // 当前页面的UI元素信息 +let hoveredElement = null; // 当前悬停的元素 +let elementOverlay = null; // 元素高亮覆盖层 + +let autoRefreshEnabled = false; // 是否启用自动刷新 + +// 鼠标位置追踪 +let lastMousePosition = { x: 0, y: 0 }; // 记录最后的鼠标位置 + +async function startDataCollection() { + // 显示应用信息输入弹窗 + showAppInfoModal(); +} + +async function endDataCollection() { + const startBtn = document.getElementById('startBtn'); + const endBtn = document.getElementById('endBtn'); + const nextBtn = document.getElementById('nextBtn'); + const deleteBtn = document.getElementById('deleteBtn'); + const inputBtn = document.getElementById('inputBtn'); + const historyBtn = document.getElementById('historyBtn'); + const autoRefreshBtn = document.getElementById('autoRefreshBtn'); + const collectionInfo = document.getElementById('collectionInfo'); + + try { + // 停止自动刷新 + if (autoRefreshEnabled) { + stopAutoRefresh(); + } + + await saveCurrentData(); + + // 更新UI状态 + startBtn.disabled = false; + endBtn.disabled = true; + nextBtn.disabled = true; + deleteBtn.disabled = true; + inputBtn.disabled = true; + historyBtn.disabled = true; + autoRefreshBtn.disabled = true; + isCollecting = false; + + // 隐藏自动刷新状态 + const statusPanel = document.getElementById('autoRefreshStatus'); + statusPanel.style.display = 'none'; + autoRefreshBtn.textContent = '⏰ 自动刷新'; + + // 更新状态显示 + const statusDiv = document.querySelector('.collection-status'); + statusDiv.classList.remove('collecting'); + collectionInfo.innerHTML = `✅ 数据收集已结束`; + + // 隐藏操作提示 + const hint = document.getElementById('actionHint'); + if (hint) { + hint.style.display = 'none'; + } + + updateStatus(`数据收集已结束,自动刷新已关闭`, 'success'); + + } catch (error) { + updateStatus(`结束收集失败: ${error.message}`, 'error'); + } +} + +async function nextDataCollection() { + try { + // 保存当前数据 + await saveCurrentData(); + + // 显示应用信息输入弹窗,为下一条数据输入新的应用信息和任务描述 + showTaskDescriptionModal(true); + + } catch (error) { + updateStatus(`切换到下一条数据失败: ${error.message}`, 'error'); + } +} + +async function deleteDataCollection() { + try { + // 删除当前数据 + await deleteCurrentData(); + + // 显示任务描述输入弹窗,为下一条数据输入新的任务描述 + showTaskDescriptionModal(true); // 传入true表示这是删除后的下一条数据 + + } catch (error) { + updateStatus(`删除数据失败: ${error.message}`, 'error'); + } +} + +async function takeScreenshot() { + const status = document.getElementById('status'); + const container = document.getElementById('screenshotContainer'); + + // 显示加载状态 + status.innerHTML = '
正在获取截图,请稍候...
'; + container.innerHTML = '
截图中...
'; + + try { + // 直接调用获取截图的API,该API会自动更新截图 + const response = await fetch('/screenshot'); + + if (response.ok) { + const result = await response.json(); + status.innerHTML = '
截图成功!可以点击或滑动进行操作
'; + + // 显示截图并添加事件监听 + container.innerHTML = ` + 设备截图 + `; + + // 获取截图元素引用 + screenshotImg = document.getElementById('screenshotImage'); + + // 直接设置截图数据 + if (result.image_data) { + screenshotImg.src = result.image_data; + + // 存储层次结构信息供后续使用 + window.currentHierarchy = result.hierarchy; + + // 解析并保存所有UI元素信息 + if (result.hierarchy) { + currentElements = parseUIElements(result.hierarchy); + const clickableElements = currentElements.filter(el => el.clickable); + console.log(`UI元素信息已加载: ${currentElements.length} 个元素 (其中 ${clickableElements.length} 个可点击)`); + } + } + + // 显示操作提示 + const hint = document.getElementById('actionHint'); + if (hint) { + hint.style.display = 'block'; + } + + // 为截图添加交互功能 + setupScreenshotInteraction(); + } else { + const error = await response.json(); + throw new Error(error.detail || '截图失败'); + } + } catch (error) { + status.innerHTML = `
错误: ${error.message}
`; + container.innerHTML = '
截图失败,请重试
'; + } +} + +function setupScreenshotInteraction() { + screenshotImg = document.getElementById('screenshotImage'); + if (!screenshotImg) { + console.error('找不到截图元素'); + return; + } + + console.log('设置截图交互功能...'); + + // 确保清除之前的状态 + clearElementHighlight(); + hoveredElement = null; + + // 添加鼠标事件处理 + screenshotImg.addEventListener('mousedown', handleMouseDown); + screenshotImg.addEventListener('mousemove', handleMouseMove); + screenshotImg.addEventListener('mouseup', handleMouseUp); + screenshotImg.addEventListener('mouseleave', handleMouseUp); // 鼠标离开时也要结束拖拽 + + // 添加元素高亮的鼠标移动处理 + screenshotImg.addEventListener('mousemove', handleScreenshotMouseMove); + screenshotImg.addEventListener('mouseleave', () => { + clearElementHighlight(); + lastMousePosition = { x: -1, y: -1 }; // 重置鼠标位置 + }); + + // 禁用图片的默认拖拽行为 + screenshotImg.addEventListener('dragstart', (e) => e.preventDefault()); + + // 添加触摸事件支持 + screenshotImg.addEventListener('touchstart', handleTouchStart); + screenshotImg.addEventListener('touchmove', handleTouchMove); + screenshotImg.addEventListener('touchend', handleTouchEnd); + + // 禁用右键菜单 + screenshotImg.addEventListener('contextmenu', (e) => e.preventDefault()); + + console.log('截图交互功能设置完成'); +} + +function handleMouseDown(event) { + if (isInteracting) return; + + isDragging = true; + dragStartX = event.clientX; + dragStartY = event.clientY; + dragStartTime = Date.now(); + + // 获取相对于图片的坐标 + const rect = screenshotImg.getBoundingClientRect(); + const relativeX = event.clientX - rect.left; + const relativeY = event.clientY - rect.top; + + // 计算在原始图片上的坐标 + const scaleX = screenshotImg.naturalWidth / screenshotImg.width; + const scaleY = screenshotImg.naturalHeight / screenshotImg.height; + + dragStartX = Math.round(relativeX * scaleX); + dragStartY = Math.round(relativeY * scaleY); + + screenshotImg.style.cursor = 'grabbing'; + event.preventDefault(); +} + +function handleMouseMove(event) { + if (!isDragging) return; + + // 更新光标样式以显示正在拖拽 + screenshotImg.style.cursor = 'grabbing'; + event.preventDefault(); +} + +function handleMouseUp(event) { + if (!isDragging) return; + + isDragging = false; + screenshotImg.style.cursor = 'crosshair'; + + const dragEndTime = Date.now(); + const dragDuration = dragEndTime - dragStartTime; + + // 获取相对于图片的坐标 + const rect = screenshotImg.getBoundingClientRect(); + const relativeX = event.clientX - rect.left; + const relativeY = event.clientY - rect.top; + + // 计算在原始图片上的坐标 + const scaleX = screenshotImg.naturalWidth / screenshotImg.width; + const scaleY = screenshotImg.naturalHeight / screenshotImg.height; + + const dragEndX = Math.round(relativeX * scaleX); + const dragEndY = Math.round(relativeY * scaleY); + + // 计算移动距离 + const deltaX = dragEndX - dragStartX; + const deltaY = dragEndY - dragStartY; + const distance = Math.sqrt(deltaX * deltaX + deltaY * deltaY); + + // 如果移动距离很小或时间很短,认为是点击 + if (distance < 10 || dragDuration < 150) { + handleClickAction(dragStartX, dragStartY); + } else { + // 否则认为是滑动,判断方向 + handleSwipeAction(dragStartX, dragStartY, dragEndX, dragEndY, deltaX, deltaY); + } + + event.preventDefault(); +} + +function handleTouchStart(event) { + if (isInteracting) return; + + const touch = event.touches[0]; + isDragging = true; + + const rect = screenshotImg.getBoundingClientRect(); + const relativeX = touch.clientX - rect.left; + const relativeY = touch.clientY - rect.top; + + const scaleX = screenshotImg.naturalWidth / screenshotImg.width; + const scaleY = screenshotImg.naturalHeight / screenshotImg.height; + + dragStartX = Math.round(relativeX * scaleX); + dragStartY = Math.round(relativeY * scaleY); + dragStartTime = Date.now(); + + event.preventDefault(); +} + +function handleTouchMove(event) { + if (!isDragging) return; + event.preventDefault(); +} + +function handleTouchEnd(event) { + if (!isDragging) return; + + isDragging = false; + const dragEndTime = Date.now(); + const dragDuration = dragEndTime - dragStartTime; + + const touch = event.changedTouches[0]; + const rect = screenshotImg.getBoundingClientRect(); + const relativeX = touch.clientX - rect.left; + const relativeY = touch.clientY - rect.top; + + const scaleX = screenshotImg.naturalWidth / screenshotImg.width; + const scaleY = screenshotImg.naturalHeight / screenshotImg.height; + + const dragEndX = Math.round(relativeX * scaleX); + const dragEndY = Math.round(relativeY * scaleY); + + const deltaX = dragEndX - dragStartX; + const deltaY = dragEndY - dragStartY; + const distance = Math.sqrt(deltaX * deltaX + deltaY * deltaY); + + if (distance < 10 || dragDuration < 150) { + handleClickAction(dragStartX, dragStartY); + } else { + handleSwipeAction(dragStartX, dragStartY, dragEndX, dragEndY, deltaX, deltaY); + } + + event.preventDefault(); +} + +async function handleClickAction(x, y) { + isInteracting = true; + + try { + // 如果正在自动刷新,暂时停止以避免冲突 + const wasAutoRefreshing = autoRefreshEnabled; + if (wasAutoRefreshing) { + console.log('点击操作开始,暂停自动刷新'); + stopAutoRefresh(); + } + + const response = await fetch('/click', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ x, y }) + }); + + if (response.ok) { + const result = await response.json(); + updateStatus(`点击操作完成: (${x}, ${y}) | 总操作数: ${result.action_count || 0}`, 'success'); + + // 显示被点击的元素信息 + if (result.clicked_elements && result.clicked_elements.length > 0) { + displayElementInfo(result.clicked_elements); + } + + // 操作完成后刷新截图和UI元素信息 + setTimeout(async () => { + await refreshScreenshot(); + console.log('点击操作后已刷新UI元素信息'); + + // 如果之前开启了自动刷新,重新开启 + if (wasAutoRefreshing && isCollecting) { + setTimeout(() => { + console.log('重新开启自动刷新'); + startAutoRefresh(); + const btn = document.getElementById('autoRefreshBtn'); + const statusPanel = document.getElementById('autoRefreshStatus'); + btn.textContent = '⏹️ 停止刷新'; + statusPanel.style.display = 'block'; + }, 500); // 延迟500ms再开启自动刷新,给操作完成留出时间 + } + }, 200); + } else { + const error = await response.json(); + updateStatus(`点击操作失败: ${error.detail}`, 'error'); + } + } catch (error) { + updateStatus(`点击操作错误: ${error.message}`, 'error'); + } finally { + isInteracting = false; + } +} + +async function handleSwipeAction(startX, startY, endX, endY, deltaX, deltaY) { + isInteracting = true; + + // 判断滑动方向 + let direction; + if (Math.abs(deltaX) > Math.abs(deltaY)) + direction = deltaX > 0 ? 'right' : 'left'; + else + direction = deltaY > 0 ? 'down' : 'up'; + + try { + // 如果正在自动刷新,暂时停止以避免冲突 + const wasAutoRefreshing = autoRefreshEnabled; + if (wasAutoRefreshing) { + console.log('滑动操作开始,暂停自动刷新'); + stopAutoRefresh(); + } + + const response = await fetch('/swipe', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + startX, + startY, + endX, + endY, + direction + }) + }); + + if (response.ok) { + const result = await response.json(); + updateStatus(`滑动操作完成: (${startX}, ${startY}) → (${endX}, ${endY}) [${direction}] | 总操作数: ${result.action_count || 0}`, 'success'); + + setTimeout(async () => { + await refreshScreenshot(); + console.log('滑动操作后已刷新UI元素信息'); + + // 如果之前开启了自动刷新,重新开启 + if (wasAutoRefreshing && isCollecting) { + setTimeout(() => { + console.log('重新开启自动刷新'); + startAutoRefresh(); + const btn = document.getElementById('autoRefreshBtn'); + const statusPanel = document.getElementById('autoRefreshStatus'); + btn.textContent = '⏹️ 停止刷新'; + statusPanel.style.display = 'block'; + }, 500); + } + }, 200); + } else { + const error = await response.json(); + updateStatus(`滑动操作失败: ${error.detail}`, 'error'); + } + } catch (error) { + updateStatus(`滑动操作错误: ${error.message}`, 'error'); + } finally { + isInteracting = false; + } +} + +function updateStatus(message, type) { + const status = document.getElementById('status'); + status.innerHTML = `
${message}
`; +} + +async function refreshScreenshot() { + try { + console.log('开始刷新截图和UI元素信息...'); + + const response = await fetch('/screenshot'); + const data = await response.json(); + + if (screenshotImg && data.image_data) { + screenshotImg.src = data.image_data; + + // 存储层次结构信息供后续使用 + window.currentHierarchy = data.hierarchy; + + // 解析并保存所有UI元素信息 + if (data.hierarchy) { + const oldElementsCount = currentElements.length; + currentElements = parseUIElements(data.hierarchy); + + // 统计可点击元素数量 + const clickableElements = currentElements.filter(el => el.clickable); + console.log(`UI元素信息已更新: ${oldElementsCount} -> ${currentElements.length} 个元素 (其中 ${clickableElements.length} 个可点击)`); + + // 清除当前高亮,确保下次鼠标移动时重新计算 + clearElementHighlight(); + hoveredElement = null; + + // 如果鼠标在截图区域内,重新检测鼠标位置的元素 + checkMousePositionAfterRefresh(); + } else { + console.warn('未获取到层次结构数据'); + currentElements = []; + } + + console.log('截图和UI元素信息刷新完成'); + return true; + } else { + console.error('截图数据不完整'); + return false; + } + + } catch (error) { + console.error('刷新截图时出错:', error); + return false; + } +} + +async function showActionHistory() { + try { + const response = await fetch('/action_history'); + const data = await response.json(); + + if (response.ok) { + displayHistoryModal(data.actions, data.total_actions); + } else { + updateStatus('获取操作历史失败', 'error'); + } + } catch (error) { + updateStatus(`获取操作历史错误: ${error.message}`, 'error'); + } +} + +function displayHistoryModal(actions, totalCount) { + // 创建弹窗 + const modal = document.createElement('div'); + modal.className = 'history-modal'; + + const content = document.createElement('div'); + content.className = 'history-content'; + + // 创建标题栏 + const header = document.createElement('div'); + header.className = 'history-header'; + header.innerHTML = ` +

操作历史记录 (总计: ${totalCount})

+ + `; + + content.appendChild(header); + + // 创建操作列表 + if (actions.length === 0) { + content.innerHTML += '

暂无操作记录

'; + } else { + actions.reverse().forEach((action, index) => { + const item = document.createElement('div'); + item.className = 'action-item'; + + const timestamp = new Date(action.timestamp).toLocaleString(); + let details = ''; + + if (action.type === 'click') { + details = `点击操作 - 位置: (${action.position.x}, ${action.position.y})`; + } else if (action.type === 'swipe') { + details = `滑动操作 - 从 (${action.press_position.x}, ${action.press_position.y}) 到 (${action.release_position.x}, ${action.release_position.y}) [${action.direction}]`; + } + + item.innerHTML = ` +
${timestamp}
+
${details}
+ `; + + content.appendChild(item); + }); + } + + modal.appendChild(content); + document.body.appendChild(modal); + + // 点击背景关闭弹窗 + modal.addEventListener('click', (e) => { + if (e.target === modal) { + closeHistoryModal(); + } + }); + + window.currentHistoryModal = modal; +} + +function closeHistoryModal() { + if (window.currentHistoryModal) { + document.body.removeChild(window.currentHistoryModal); + window.currentHistoryModal = null; + } +} + +async function saveCurrentData() { + try { + updateStatus(`正在保存数据...`, 'loading'); + + const response = await fetch('/save_data', { + method: 'POST' + }); + + if (response.ok) { + const result = await response.json(); + updateStatus(`第 ${result.data_index} 条数据已保存 (${result.saved_actions} 个操作)`, 'success'); + return result; + } else { + const error = await response.json(); + throw new Error(error.detail || '保存数据失败'); + } + } catch (error) { + updateStatus(`保存数据失败: ${error.message}`, 'error'); + throw error; + } +} + +async function deleteCurrentData() { + try { + updateStatus(`正在删除数据...`, 'loading'); + + const response = await fetch('/delete_data', { + method: 'POST' + }); + + if (response.ok) { + const result = await response.json(); + updateStatus(`第 ${result.data_index} 条数据已删除`, 'success'); + return result; + } else { + const error = await response.json(); + throw new Error(error.detail || '删除数据失败'); + } + } catch (error) { + updateStatus(`删除数据失败: ${error.message}`, 'error'); + throw error; + } +} + + +function showTaskDescriptionModal(isNextData = false) { + const modal = document.getElementById('taskDescriptionModal'); + const taskInput = document.getElementById('taskDescription'); + const confirmBtn = document.getElementById('confirmTaskBtn'); + const header = modal.querySelector('.task-modal-header h3'); + + // 根据场景修改标题 + if (isNextData) { + header.textContent = '📝 下一条数据 - 任务描述'; + } else { + header.textContent = '📝 任务描述'; + } + + // 清空输入框 + taskInput.value = ''; + taskInput.focus(); + + // 显示弹窗 + modal.style.display = 'flex'; + + // 只绑定确认按钮事件 + confirmBtn.onclick = async () => { + const description = taskInput.value.trim(); + if (description === '') { + alert('请输入任务描述才能开始任务!'); + taskInput.focus(); + return; + } + + currentTaskDescription = description; + hideTaskDescriptionModal(); + + if (isNextData) { + await continueWithNextDataCollection(); + } else { + await startDataCollectionWithDescription(); + } + }; +} + +function hideTaskDescriptionModal() { + const modal = document.getElementById('taskDescriptionModal'); + modal.style.display = 'none'; +} + +async function startDataCollectionWithDescription() { + const startBtn = document.getElementById('startBtn'); + const endBtn = document.getElementById('endBtn'); + const nextBtn = document.getElementById('nextBtn'); + const deleteBtn = document.getElementById('deleteBtn'); + const inputBtn = document.getElementById('inputBtn'); + const historyBtn = document.getElementById('historyBtn'); + const autoRefreshBtn = document.getElementById('autoRefreshBtn'); + const collectionInfo = document.getElementById('collectionInfo'); + const status = document.getElementById('status'); + const container = document.getElementById('screenshotContainer'); + + try { + // 重置UI状态 + resetUIState(); + + // 发送任务描述到后端 + await sendTaskDescription(currentTaskDescription); + + // 更新UI状态 + startBtn.disabled = true; + endBtn.disabled = false; + nextBtn.disabled = false; + deleteBtn.disabled = false; + inputBtn.disabled = false; + historyBtn.disabled = false; + autoRefreshBtn.disabled = false; + isCollecting = true; + + // 更新状态显示 + const statusDiv = document.querySelector('.collection-status'); + statusDiv.classList.add('collecting'); + collectionInfo.innerHTML = `应用:${currentAppName} | 类型:${currentTaskType}
任务:${currentTaskDescription}`; + status.innerHTML = '
正在获取初始截图...
'; + container.innerHTML = '
截图中...
'; + + // 自动获取截图 + await takeScreenshot(); + + // 自动开启自动刷新功能 + if (!autoRefreshEnabled) { + startAutoRefresh(); + autoRefreshBtn.textContent = '⏹️ 停止刷新'; + const statusPanel = document.getElementById('autoRefreshStatus'); + statusPanel.style.display = 'block'; + updateStatus('数据收集已开始,自动刷新已开启', 'success'); + } + + // 显示操作提示 + const hint = document.getElementById('actionHint'); + if (hint) { + hint.style.display = 'block'; + } + + } catch (error) { + updateStatus(`开始收集失败: ${error.message}`, 'error'); + // 恢复按钮状态 + startBtn.disabled = false; + endBtn.disabled = true; + nextBtn.disabled = true; + deleteBtn.disabled = true; + autoRefreshBtn.disabled = true; + isCollecting = false; + } +} + +async function sendTaskDescription(description) { + try { + const response = await fetch('/set_task_description', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + description: description, + app_name: currentAppName, + task_type: currentTaskType + }) + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '发送任务描述失败'); + } + } catch (error) { + console.error('发送任务描述失败:', error); + throw error; + } +} + +// 重置UI状态函数 +function resetUIState() { + // 停止自动刷新 + if (autoRefreshEnabled) { + stopAutoRefresh(); + } + + // 清除元素高亮 + clearElementHighlight(); + + // 重置全局变量 + hoveredElement = null; + currentElements = []; + + // 如果存在元素覆盖层,移除它 + if (elementOverlay) { + elementOverlay.remove(); + elementOverlay = null; + } + + // 清除之前的鼠标事件监听器(如果有的话) + if (screenshotImg) { + // 克隆节点来移除所有事件监听器 + const newImg = screenshotImg.cloneNode(true); + screenshotImg.parentNode.replaceChild(newImg, screenshotImg); + screenshotImg = newImg; + } + + console.log('UI状态已重置'); +} + +async function continueWithNextDataCollection() { + const collectionInfo = document.getElementById('collectionInfo'); + + try { + // 重置UI状态 + resetUIState(); + + // 发送新的任务描述到后端 + await sendTaskDescription(currentTaskDescription); + + // 更新状态显示 + collectionInfo.innerHTML = `应用:${currentAppName} | 类型:${currentTaskType}
任务:${currentTaskDescription}`; + + // 自动获取新截图 + await takeScreenshot(); + + // 自动开启自动刷新功能 + if (!autoRefreshEnabled) { + startAutoRefresh(); + const autoRefreshBtn = document.getElementById('autoRefreshBtn'); + autoRefreshBtn.textContent = '⏹️ 停止刷新'; + const statusPanel = document.getElementById('autoRefreshStatus'); + statusPanel.style.display = 'block'; + } + + updateStatus(`已切换下一条数据,自动刷新已开启`, 'success'); + + } catch (error) { + updateStatus(`切换到下一条数据失败: ${error.message}`, 'error'); + } +} + +// 文本输入功能 +function showInputModal() { + if (!isCollecting) { + updateStatus('请先开始数据收集', 'error'); + return; + } + + const modal = document.getElementById('inputModal'); + const inputText = document.getElementById('inputText'); + + modal.style.display = 'flex'; + inputText.value = ''; + inputText.focus(); + + // 添加键盘快捷键支持 + inputText.onkeydown = function (event) { + if (event.key === 'Escape') { + hideInputModal(); + } + }; +} + +function hideInputModal() { + const modal = document.getElementById('inputModal'); + modal.style.display = 'none'; +} + +async function sendInputText() { + const inputText = document.getElementById('inputText'); + const text = inputText.value.trim(); + + if (!text) { + updateStatus('请输入文本内容', 'error'); + return; + } + + if (!isCollecting) { + updateStatus('请先开始数据收集', 'error'); + hideInputModal(); + return; + } + + try { + updateStatus('正在发送文本...', 'info'); + + // 如果正在自动刷新,暂时停止以避免冲突 + const wasAutoRefreshing = autoRefreshEnabled; + if (wasAutoRefreshing) { + console.log('文本输入操作开始,暂停自动刷新'); + stopAutoRefresh(); + } + + hideInputModal(); + const response = await fetch('/input', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + text: text + }) + }); + if (response.ok) { + const result = await response.json(); + updateStatus(`文本输入完成: "${text}"`, 'success'); + + // 操作完成后刷新截图和UI元素信息 + setTimeout(async () => { + await refreshScreenshot(); + console.log('输入操作后已刷新UI元素信息'); + + // 如果之前开启了自动刷新,重新开启 + if (wasAutoRefreshing && isCollecting) { + setTimeout(() => { + console.log('重新开启自动刷新'); + startAutoRefresh(); + const btn = document.getElementById('autoRefreshBtn'); + const statusPanel = document.getElementById('autoRefreshStatus'); + btn.textContent = '⏹️ 停止刷新'; + statusPanel.style.display = 'block'; + }, 500); + } + }, 200); + } else { + const error = await response.json(); + updateStatus(`输入操作失败: ${error.detail}`, 'error'); + } + + } catch (error) { + console.error('文本输入失败:', error); + updateStatus(`文本输入失败: ${error.message}`, 'error'); + } +} + +// 显示元素信息 +function displayElementInfo(elements) { + const elementInfo = document.getElementById('elementInfo'); + const elementDetails = document.getElementById('elementDetails'); + + if (!elements || elements.length === 0) { + elementInfo.style.display = 'none'; + return; + } + + let html = ''; + elements.forEach((element, index) => { + html += ` +
+
元素 ${index + 1}:
+
位置: ${element.bounds}
+
类型: ${element.class}
+ ${element['resource-id'] ? `
ID: ${element['resource-id']}
` : ''} + ${element.text ? `
文本: ${element.text}
` : ''} + ${element['content-desc'] ? `
描述: ${element['content-desc']}
` : ''} +
可点击: ${element.clickable ? '是' : '否'}
+
应用包名: ${element.package}
+
+ `; + }); + + elementDetails.innerHTML = html; + elementInfo.style.display = 'block'; +} + +// 解析UI层次结构,提取所有元素的位置信息 +function parseUIElements(hierarchyXml) { + if (!hierarchyXml) return []; + + const parser = new DOMParser(); + const xmlDoc = parser.parseFromString(hierarchyXml, 'text/xml'); + const nodes = xmlDoc.querySelectorAll('node'); + const elements = []; + + nodes.forEach(node => { + const bounds = node.getAttribute('bounds'); + if (bounds) { + // 解析bounds属性,格式如: [left,top][right,bottom] + const boundsMatch = bounds.match(/\[(\d+),(\d+)\]\[(\d+),(\d+)\]/); + if (boundsMatch) { + const left = parseInt(boundsMatch[1]); + const top = parseInt(boundsMatch[2]); + const right = parseInt(boundsMatch[3]); + const bottom = parseInt(boundsMatch[4]); + + elements.push({ + bounds: bounds, + left: left, + top: top, + right: right, + bottom: bottom, + width: right - left, + height: bottom - top, + class: node.getAttribute('class') || '', + 'resource-id': node.getAttribute('resource-id') || '', + text: node.getAttribute('text') || '', + 'content-desc': node.getAttribute('content-desc') || '', + clickable: node.getAttribute('clickable') === 'true', + package: node.getAttribute('package') || '' + }); + } + } + }); + + return elements; +} + +// 创建元素高亮覆盖层 +function createElementOverlay() { + if (elementOverlay) return elementOverlay; + + if (!screenshotImg || !screenshotImg.parentElement) { + console.error('截图元素或其父容器不存在'); + return null; + } + + const overlay = document.createElement('div'); + overlay.id = 'elementOverlay'; + overlay.style.position = 'absolute'; + overlay.style.top = '0'; + overlay.style.left = '0'; + overlay.style.width = '100%'; + overlay.style.height = '100%'; + overlay.style.pointerEvents = 'none'; + overlay.style.zIndex = '10'; + + const container = screenshotImg.parentElement; + container.style.position = 'relative'; + container.appendChild(overlay); + + elementOverlay = overlay; + + // 监听窗口大小变化,重新绘制边框 + window.addEventListener('resize', () => { + if (hoveredElement) { + drawElementBorder(hoveredElement); + } + }); + + console.log('元素高亮覆盖层已创建'); + return overlay; +} + +// 在指定位置绘制元素边框 +function drawElementBorder(element) { + if (!screenshotImg || !element) { + console.warn('绘制元素边框失败:缺少截图或元素信息'); + return; + } + + const overlay = createElementOverlay(); + if (!overlay) { + console.error('创建覆盖层失败,无法绘制元素边框'); + return; + } + + // 获取图片在容器中的实际位置 + const imgRect = screenshotImg.getBoundingClientRect(); + const containerRect = screenshotImg.parentElement.getBoundingClientRect(); + + // 计算图片相对于容器的偏移 + const imgOffsetX = imgRect.left - containerRect.left; + const imgOffsetY = imgRect.top - containerRect.top; + + // 计算缩放比例 + const scaleX = screenshotImg.width / screenshotImg.naturalWidth; + const scaleY = screenshotImg.height / screenshotImg.naturalHeight; + + // 计算在显示图片上的位置(相对于图片左上角) + const displayLeft = element.left * scaleX; + const displayTop = element.top * scaleY; + const displayWidth = element.width * scaleX; + const displayHeight = element.height * scaleY; + + // 创建边框元素,位置相对于容器,但要加上图片的偏移 + const border = document.createElement('div'); + border.style.position = 'absolute'; + border.style.left = (imgOffsetX + displayLeft) + 'px'; + border.style.top = (imgOffsetY + displayTop) + 'px'; + border.style.width = displayWidth + 'px'; + border.style.height = displayHeight + 'px'; + border.style.border = '2px solid #ff6b6b'; + border.style.backgroundColor = 'rgba(255, 107, 107, 0.1)'; + border.style.boxSizing = 'border-box'; + + // 清除之前的边框 + overlay.innerHTML = ''; + overlay.appendChild(border); +} + +// 清除元素高亮 +function clearElementHighlight() { + if (elementOverlay) { + elementOverlay.innerHTML = ''; + } + hoveredElement = null; +} + +// 根据鼠标位置查找对应的UI元素(只显示可点击的元素) +function findElementAtPosition(x, y) { + if (!currentElements.length) return null; + + // 计算在原始图片上的坐标 + const scaleX = screenshotImg.naturalWidth / screenshotImg.width; + const scaleY = screenshotImg.naturalHeight / screenshotImg.height; + + const originalX = x * scaleX; + const originalY = y * scaleY; + + // 找到包含该点的可点击元素(只显示clickable为true的元素) + const matchingElements = currentElements.filter(element => + element.clickable && // 只显示可点击的元素 + originalX >= element.left && + originalX <= element.right && + originalY >= element.top && + originalY <= element.bottom + ); + + if (matchingElements.length === 0) return null; + + // 返回面积最小的可点击元素 + return matchingElements.reduce((smallest, current) => { + const smallestArea = smallest.width * smallest.height; + const currentArea = current.width * current.height; + return currentArea < smallestArea ? current : smallest; + }); +} + +// 鼠标移动处理函数 +function handleScreenshotMouseMove(event) { + if (!screenshotImg) { + console.log('没有截图元素'); + return; + } + + if (!currentElements.length) { + console.log('没有UI元素信息,currentElements长度:', currentElements.length); + return; + } + + const rect = screenshotImg.getBoundingClientRect(); + const relativeX = event.clientX - rect.left; + const relativeY = event.clientY - rect.top; + + // 更新鼠标位置记录 + lastMousePosition = { x: relativeX, y: relativeY }; + + // 确保鼠标在图片范围内 + if (relativeX < 0 || relativeX > screenshotImg.width || + relativeY < 0 || relativeY > screenshotImg.height) { + if (hoveredElement) { + clearElementHighlight(); + } + return; + } + + const element = findElementAtPosition(relativeX, relativeY); + + if (element !== hoveredElement) { + hoveredElement = element; + + if (element) { + drawElementBorder(element); + console.log('高亮可点击元素:', element.class, element.clickable ? '✓可点击' : '✗不可点击'); + } else { + clearElementHighlight(); + } + } +} + +// 刷新后检测鼠标位置的元素 +function checkMousePositionAfterRefresh() { + if (!screenshotImg || !currentElements.length) { + return; + } + + // 如果有记录的鼠标位置且在有效范围内 + if (lastMousePosition.x >= 0 && lastMousePosition.y >= 0) { + const rect = screenshotImg.getBoundingClientRect(); + + // 确保鼠标位置在图片范围内 + if (lastMousePosition.x >= 0 && lastMousePosition.x <= screenshotImg.width && + lastMousePosition.y >= 0 && lastMousePosition.y <= screenshotImg.height) { + + const element = findElementAtPosition(lastMousePosition.x, lastMousePosition.y); + + if (element !== hoveredElement) { + hoveredElement = element; + + if (element) { + drawElementBorder(element); + console.log('刷新后重新高亮元素:', element.class, element.clickable ? '✓可点击' : '✗不可点击'); + } else { + clearElementHighlight(); + } + } + } + } +} + +// 自动刷新功能 - 简化版本,固定0.7秒间隔 +function toggleAutoRefresh() { + if (!isCollecting) { + updateStatus('请先开始数据收集', 'error'); + return; + } + + const btn = document.getElementById('autoRefreshBtn'); + const statusPanel = document.getElementById('autoRefreshStatus'); + + if (autoRefreshEnabled) { + // 当前已开启,点击关闭 + stopAutoRefresh(); + btn.textContent = '⏰ 自动刷新'; + statusPanel.style.display = 'none'; + updateStatus('自动刷新已关闭', 'success'); + } else { + // 当前已关闭,点击开启 + startAutoRefresh(); + btn.textContent = '⏹️ 停止刷新'; + statusPanel.style.display = 'block'; + updateStatus('自动刷新已开启,连续刷新模式', 'success'); + } +} + +// 连续自动刷新功能 - 请求完成后立即发下一个请求 +async function startAutoRefresh() { + if (autoRefreshEnabled) return; + autoRefreshEnabled = true; + + while (autoRefreshEnabled && isCollecting) { + // 检查是否应该刷新:正在收集数据、没有正在交互 + if (!isInteracting) { + try { + console.log('连续自动刷新截图...'); + const success = await refreshScreenshot(); + if (success) { + console.log('连续自动刷新完成'); + } else { + console.log('连续自动刷新跳过或失败'); + } + } catch (error) { + console.error('连续自动刷新失败:', error); + // 出错时等待一小段时间再继续,避免连续错误 + await new Promise(resolve => setTimeout(resolve, 500)); + } + } else { + // 如果不能刷新,等待一小段时间再检查 + if (!isCollecting) console.log('连续刷新等待:未在收集数据'); + if (isInteracting) console.log('连续刷新等待:正在交互'); + + await new Promise(resolve => setTimeout(resolve, 100)); // 等待100ms后重新检查 + } + } + console.log('连续自动刷新已停止'); +} + +function stopAutoRefresh() { + if (!autoRefreshEnabled) return; + autoRefreshEnabled = false; +} + +// 应用信息输入功能 +function showAppInfoModal() { + const modal = document.getElementById('appInfoModal'); + const appNameInput = document.getElementById('appName'); + const taskTypeInput = document.getElementById('taskType'); + const confirmBtn = document.getElementById('confirmAppInfoBtn'); + + // 清空输入框 + appNameInput.value = ''; + taskTypeInput.value = ''; + appNameInput.focus(); + + // 显示弹窗 + modal.style.display = 'flex'; + + // 绑定确认按钮事件 + confirmBtn.onclick = async () => { + const appName = appNameInput.value.trim(); + const taskType = taskTypeInput.value.trim(); + + if (appName === '') { + alert('请选择应用名称!'); + appNameInput.focus(); + return; + } + + if (taskType === '') { + alert('请输入任务类型!'); + taskTypeInput.focus(); + return; + } + + // 保存应用信息 + currentAppName = appName; + currentTaskType = taskType; + + // 隐藏应用信息弹窗 + hideAppInfoModal(); + + // 显示任务描述弹窗 + showTaskDescriptionModal(); + }; +} + +function hideAppInfoModal() { + const modal = document.getElementById('appInfoModal'); + modal.style.display = 'none'; +} \ No newline at end of file diff --git a/deployment/README.md b/deployment/README.md new file mode 100644 index 0000000..be74d0c --- /dev/null +++ b/deployment/README.md @@ -0,0 +1,22 @@ +# MobiAgent Server + +## Deploy MobiMind Models with vLLM + +```bash +vllm serve IPADS-SAI/MobiMind-Decider-7B --port +vllm serve IPADS-SAI/MobiMind-Grounder-3B --port +vllm serve Qwen/Qwen3-4B-Instruct --port +``` + +## Run Server + +```bash +python -m mobiagent_server.server \ + --service_ip \ + --port \ + --decider_port \ + --grounder_port \ + --planner_port \ +``` + +Then you can set MobiAgent Server IP and port in the MobiAgent App, and start exploration! \ No newline at end of file diff --git a/deployment/server.py b/deployment/server.py new file mode 100644 index 0000000..b4de8fd --- /dev/null +++ b/deployment/server.py @@ -0,0 +1,308 @@ +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel +from typing import Dict, List, Any +import json +import traceback +from openai import OpenAI +import copy + +app = FastAPI() + +decider_client = None +grounder_client = None +planner_client = None + +terminate_checklist = [ + "当前页面未按预期加载", + "进入了错误的页面", + "打开了不合预期的页面", + "当前打开了错误页面", + "当前页面不合预期", + "需要用户介入", + "需要用户接管", +] + +supported_apps = { + "微信": "com.tencent.mm", + "QQ": "com.tencent.mobileqq", + "微博": "com.sina.weibo", + "饿了么": "me.ele", + "美团": "com.sankuai.meituan", + "bilibili": "tv.danmaku.bili", + "B站": "tv.danmaku.bili", + "爱奇艺": "com.qiyi.video", + "腾讯视频": "com.tencent.qqlive", + "淘宝": "com.taobao.taobao", + "京东": "com.jingdong.app.mall", + "携程": "ctrip.android.view", + "去哪儿": "com.Qunar", + "知乎": "com.zhihu.android", + "小红书": "com.xingin.xhs", + "QQ音乐": "com.tencent.qqmusic", + "网易云": "com.netease.cloudmusic", + "高德": "com.autonavi.minimap" +} + +def should_terminate(reasoning: str): + for phrase in terminate_checklist: + if phrase in reasoning: + return True + return False + +def try_find_app(task_description: str): + longest_match = "" + for app in supported_apps: + if app.lower() in task_description.lower() and len(app) > len(longest_match): + longest_match = app + if longest_match != "": + return longest_match, supported_apps[longest_match] + else: + return None, None + +PLANNER_PROMPT = ''' +## 角色定义 +你是一个任务描述优化专家和智能手机应用选择助手。你需要根据用户的任务描述,选择一个最合适的应用启动。 + +## 任务描述 +用户想要完成的任务是:"{task_description}" + +## 可用应用列表 +以下是可用的应用及其包名: +- 微信: com.tencent.mm +- QQ: com.tencent.mobileqq +- 新浪微博: com.sina.weibo +- 饿了么: me.ele +- 美团: com.sankuai.meituan +- bilibili: tv.danmaku.bili +- 爱奇艺: com.qiyi.video +- 腾讯视频: com.tencent.qqlive +- 淘宝: com.taobao.taobao +- 京东: com.jingdong.app.mall +- 携程: ctrip.android.view +- 去哪儿: com.Qunar +- 知乎: com.zhihu.android +- 小红书: com.xingin.xhs +- QQ音乐: com.tencent.qqmusic +- 网易云音乐: com.netease.cloudmusic +- 高德地图:com.autonavi.minimap + +## 默认应用列表 +以下是各个应用类别的默认应用: + +通讯应用: +- 微信: com.tencent.mm + +外卖应用: +- 饿了么: me.ele + +视频应用: +- bilibili: tv.danmaku.bili + +酒店/旅行应用: +- 携程: ctrip.android.view + +社区应用: +- 小红书: com.xingin.xhs + +音乐应用: +- 网易云音乐: com.netease.cloudmusic + +地图/打车应用: +- 高德地图:com.autonavi.minimap + + +## 输出格式 +请严格按照以下JSON格式输出: +```json +{{ + "reasoning": "分析任务内容,说明为什么选择这个应用最合适", + "app_name": "选择的应用名称", + "package_name": "选择的应用包名", +}} +``` + +## 重要规则 +1. 只能从上述可用应用列表中选择 +2. 如果应用列表中不存在能够完成用户任务的应用,或者用户显式指定了不在列表中的应用,"app_name"和"package_name"请返回空字符串,也就是"" +3. 必须选择最符合任务需求的应用 +4. 包名必须完全匹配列表中的包名,不能修改 +5. 若用户没有显式指定应用名称,你只能根据任务类型,从**默认应用列表**中挑选,**不能挑选非默认应用** +'''.strip() + +DECIDER_PROMPT = ''' +You are a phone-use AI agent. Now your task is "{task}". +Your action history is: +{history} +Please provide the next action based on the screenshot and your action history. You should do careful reasoning before providing the action. +Your action space includes: +- Name: click, Parameters: target_element (a high-level description of the UI element to click). +- Name: swipe, Parameters: direction (one of UP, DOWN, LEFT, RIGHT). +- Name: input, Parameters: text (the text to input). +- Name: wait, Parameters: (no parameters, will wait for 1 second). +- Name: done, Parameters: (no parameters). +Your output should be a JSON object with the following format: +{{"reasoning": "Your reasoning here", "action": "The next action (one of click, input, swipe, wait, done)", "parameters": {{"param1": "value1", ...}}}}''' + +GROUNDER_PROMPT = ''' +Based on the screenshot, user's intent and the description of the target UI element, provide the bounding box of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"bbox": [x1, y1, x2, y2]}}''' + +# Define the response body model using Pydantic +class ResponseBody(BaseModel): + reasoning: str + action: str + parameters: Dict[str, Any] + +# Define the request body model using Pydantic +class RequestBody(BaseModel): + task: str + image: str + history: List[str] + +def get_model_output(model_client, prompt, image_b64=None): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ], + } + ] + if image_b64 is not None: + messages[0]["content"].insert(0, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}) + + response = model_client.chat.completions.create( + model="", + messages=messages, + temperature=0, + ) + return response.choices[0].message.content + +def validate_history(history: List[str]): + filtered = [] + allowed_keys = { + "click": {"target_element"}, + "input": {"text"}, + "swipe": {"direction"}, + "done": {} + } + for h in history: + old = json.loads(h) + new = copy.deepcopy(old) + action = old["action"] + if action not in allowed_keys: + continue + for k in old["parameters"]: + if k not in allowed_keys[action]: + new["parameters"].pop(k) + filtered.append(new) + + return [json.dumps(act, ensure_ascii=False) for act in filtered] + +@app.post("/v1", response_model=ResponseBody) +async def v1(request_body: RequestBody): + try: + if request_body.task.strip() == "": + return ResponseBody( + reasoning="任务不能为空,任务终止", + action="terminate", + # action="done", + parameters={} + ) + history = request_body.history + if len(history) == 0: + app_name, package_name = try_find_app(request_body.task) + if app_name is None: + planner_prompt = PLANNER_PROMPT.format(task_description=request_body.task) + planner_output = get_model_output(planner_client, planner_prompt) + print(planner_output) + planner_output = planner_output.replace("```json", "").replace("```", "") + planner_output_json = json.loads(planner_output) + app_name = planner_output_json["app_name"] + package_name = planner_output_json["package_name"] + if package_name not in supported_apps.values(): + app_name, package_name = None, "" + if app_name is None or app_name == "" or package_name == "": + reasoning = f"无法识别用户任务\"{request_body.task}\"需要打开的应用,任务终止" + return ResponseBody( + reasoning=reasoning, + action="terminate", + # action="done", + parameters={} + ) + else: + reasoning = f"为了完成用户任务\"{request_body.task}\", 我需要打开应用\"{app_name}\"" + return ResponseBody( + reasoning=reasoning, + action="open_app", + parameters={ + "package_name": package_name, + } + ) + + # print("raw history: ", history) + history = validate_history(history) + # print("cleaned history: ", history) + if len(history) == 0: + history_str = "(No history)" + else: + history_str = "\n".join(f"{idx}. {act}" for idx, act in enumerate(history, start=1)) + + img_b64 = request_body.image + decider_prompt = DECIDER_PROMPT.format(task=request_body.task, history=history_str) + decider_output = get_model_output(decider_client, decider_prompt, img_b64) + print(decider_output) + decider_output_json = json.loads(decider_output) + reasoning = decider_output_json["reasoning"] + if should_terminate(reasoning): + return ResponseBody( + reasoning=reasoning, + action="terminate", + parameters={} + ) + action = decider_output_json["action"] + parameters = decider_output_json["parameters"] + if action == "click": + grounder_prompt = GROUNDER_PROMPT.format(reasoning=reasoning, description=parameters["target_element"]) + grounder_output = get_model_output(grounder_client, grounder_prompt, img_b64) + print(grounder_output) + grounder_output_json = json.loads(grounder_output) + bbox = grounder_output_json["bbox"] + parameters["x"] = (bbox[0] + bbox[2]) // 2 + parameters["y"] = (bbox[1] + bbox[3]) // 2 + response = ResponseBody( + reasoning=reasoning, + action=action, + parameters=parameters + ) + return response + + except Exception as e: + traceback.print_exc() + # Handle potential errors + raise HTTPException( + status_code=500, + detail=f"An error occurred: {str(e)}" + ) + +# Optional: Add a root endpoint for health checks +@app.get("/") +async def root(): + return {} + +if __name__ == "__main__": + import uvicorn, argparse + parser = argparse.ArgumentParser() + parser.add_argument("--service_ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=22334) + parser.add_argument("--planner_port", type=int, default=18003) + parser.add_argument("--decider_port", type=int, default=18001) + parser.add_argument("--grounder_port", type=int, default=18002) + args = parser.parse_args() + decider_client = OpenAI(api_key="0", base_url=f"http://{args.service_ip}:{args.decider_port}/v1") + grounder_client = OpenAI(api_key="0", base_url=f"http://{args.service_ip}:{args.grounder_port}/v1") + planner_client = OpenAI(api_key="0", base_url=f"http://{args.service_ip}:{args.planner_port}/v1") + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/msyh.ttf b/msyh.ttf new file mode 100644 index 0000000..e46ff3c Binary files /dev/null and b/msyh.ttf differ diff --git a/prompts/annotation_en_general.md b/prompts/annotation_en_general.md new file mode 100644 index 0000000..111f0e4 --- /dev/null +++ b/prompts/annotation_en_general.md @@ -0,0 +1,142 @@ +# System Prompt for React Agent Simulation + +## Background + +You are an AI assistant for understanding human-annotated mobile app trajectories and simulating a ReAct agent to reproduce the trajectories on real mobile devices. +Your task is to simulate an AI agent with ReAct (Reasoning + Acting) workflow, and reconstruct the reasoning process and the function call to **reproduce** each action, which is the ground truth in the corresponding time step. Your reconstructed high-level semantics must be **consistent** with the ground truth, do not include your own thinking. + +## Input + +The user will provide you a mobile app usage trajectory. A trajectory contains a sequence of pictures, each of which is a screenshot of the mobile app at a certain time step. The user's action at each time step is annotated at the top of the matched screenshot in **red font**. + +Auxiliary information is also annotated in the screenshots: +1. For CLICK actions, the exact position of the action is annotated with a **red circle** in the screenshot. +2. For SWIPE actions, there is a **red arrow** pointing from the starting position to the ending position in the screenshot. + +### User Action Space + +1. **CLICK [x,y]**: The user clicked on the screen at the position [x,y]. The origin [0,0] is at the top-left corner of the screen, x is the horizontal coordinate, and y is the vertical coordinate. Both x and y are relative coordinates, ranging from 0 to 1000. For example, [500,500] is the center of the screen, and [1000,1000] is the bottom-right corner of the screen. +2. **INPUT ``**: The user typed the text `` using the keyboard. The text can contain characters in any language. The action only happens when the user has already clicked on a search bar or a text input field, and the keyboard is activated. +3. **SWIPE [x1,y1] to [x2,y2]**: The user swiped from the position [x1,y1] to the position [x2,y2]. The meaning of x1, y1, x2, and y2 is the same as in the CLICK action. +4. **DONE**: The user has successfully completed the assigned task. This action indicates that all required objectives have been accomplished and no further interaction is needed. +5. **LONG PRESS [x,y]**: The user performed a long press on the screen at the position [x,y]. This action is typically used to trigger context menus, drag operations, or special functions. The coordinate system is the same as in the CLICK action. +6. **OPEN APP ``**: The user opened an application. The `` is the name of the application that was launched or opened by the user. +## Output + +Each screenshot contains auxiliary information about the action, and you must analyze each screenshot and provide **the matched reasoning for the action**, which must match the user's action. Each screenshot must have a matched reasoning, **neither too much nor too little**. +Your final output should be a list of JSON objects, each matching to an action in the trajectory. Keep the action order consistent with the input trajectory. + +### Output Action Space + +The functions that the ReAct agent can call are as follows: + +```json +[ + {{ + "name": "click", + "description": "Click on the screen at the target UI element", + "parameters": {{ + "properties": {{ + "target_element": {{ + "type": "string", + "description": "The description of the target UI element, which should contain enough information to locate the element without ambiguity. Possible information includes the element type, the content, the relative position, the color, the parent element, the order as a list item, etc." + }} + }}, + "required": ["target_element"] + }} + }}, + {{ + "name": "input", + "description": "Input the text into the activated text input field", + "parameters": {{ + "properties": {{ + "text": {{ + "type": "string", + "description": "The text to input" + }} + }}, + "required": ["text"] + }} + }}, + {{ + "name": "swipe", + "description": "Swipe on the screen", + "parameters": {{ + "properties": {{ + "direction": {{ + "type": "string", + "enum": ["UP", "DOWN", "LEFT", "RIGHT"], + "description": "The direction of the user's swipe gesture. UP: swipe finger upward to swipe content up and reveal content below (press position is below release position). DOWN: swipe finger downward to swipe content down and reveal content above (press position is above release position). LEFT: swipe finger leftward to swipe content left (press position is to the right of release position). RIGHT: swipe finger rightward to swipe content right (press position is to the left of release position)." + }} + }}, + "required": ["direction"] + }} + }}, + {{ + "name": "done", + "description": "Indicate that the assigned task has been successfully completed", + "parameters": {{}} + }}, + {{ + "name": "long_press", + "description": "Perform a long press (long click) on the screen at the target UI element", + "parameters": {{ + "properties": {{ + "target_element": {{ + "type": "string", + "description": "The description of the target UI element to long press, which should contain enough information to locate the element without ambiguity. Possible information includes the element type, the content, the relative position, the color, the parent element, the order as a list item, etc." + }} + }}, + "required": ["target_element"] + }} + }}, + {{ + "name": "open_app", + "description": "Open an application", + "parameters": {{ + "properties": {{ + "app_name": {{ + "type": "string", + "description": "The name of the application to open" + }} + }}, + "required": ["app_name"] + }} + }} +] +``` + +### Output Format + +Specifically, for each action, your output is in the following JSON format: + +```json +{{ + "reasoning": "The reasoning process before taking this action. You should consider the user's task description, the previous actions, the current screen content, and what to do next.", + "function": {{ + "name": "The function name to call", + "parameters": {{ + "The parameters of the function call" + }} + }} +}} +``` + +The reasoning process and function parameters should be in in Chinese. + +## Rules + +1. For each screenshot, after executing the matched action, it will change to the state of the next screenshot. When generating reasoning, you can compare the current screenshot with the next one (i.e., the state after executing the action). +2. The length of your output JSON list **must strictly equal to {screenshot_count}**, which is the length of screenshot sequence provided by user. +3. Each item in your output JSON list must adhere to the information provided in the screenshot with identical index, i.e., the `function` field must **match with the name, parameter and auxiliary information of the action annotated in the screenshot**, the `reasoning` field must be the **exact reason why the user executes this action**. +4. When performing text input, sometimes the input field **is not activated** (i.e., there is no keyboard present on the screen). You need to **click** on it first to activate it. +5. When performing text input, sometimes the input field contains **default or previous content**, and you must first **clear this content** (by clicking delete/clear button or selecting all and typing over) before entering the new content. +6. When performing multi-step selections (such as date ranges, time slots, or cascading options), recognize that this typically requires multiple sequential actions to complete the full selection process. +7. There may exist ineffective actions, such as misclicks that don't trigger the intended response. You need to recognize and reason about these actions as well. The user may also need to correct previously entered incorrect information. +8. The **DONE** action has special constraints: it can **only appear as the final action** in the trajectory sequence. There must be **exactly one DONE action** per trajectory, and it must be the **last item** in your output JSON list. DONE will **never appear in the middle** of a sequence - only at the very end when all task objectives have been accomplished. + +## Current Task + +Now, the task description is: {goal} + +This task description contains important information about the user's objective and any relevant details needed to understand the context. I will provide you with {screenshot_count} screenshots. Please analyze the actions matched to these screenshots based on the task information and provide the corresponding reasoning for each action. \ No newline at end of file diff --git a/prompts/auto_decider.md b/prompts/auto_decider.md new file mode 100644 index 0000000..4e3c3fb --- /dev/null +++ b/prompts/auto_decider.md @@ -0,0 +1,146 @@ +## 角色定义 +你是一个手机操作AI助手,需要帮助用户完成以下任务:"{task_description}" + +## 输入说明 +我会提供给你: +1. **操作历史**:之前的所有操作记录 +2. **屏幕截图**:当前手机屏幕的完整截图 +3. **{layer_count}张标注截图**:基于屏幕截图生成的可点击元素标注图层。为避免元素重叠,将所有可点击元素分配到不同图层中显示,所有图层包含的元素合集即为全部可点击元素 + - 可点击元素用红色方框标出 + - 每个元素的编号(index)显示在红色方框的左上角内侧,红底白字数字 + +## 操作历史 +{history} + +## 任务要求 +请仔细分析当前屏幕状态和操作历史,然后决定下一步最合适的操作。 + +## 可用操作 +1. **点击操作 (click)** + - 参数:index (整数,对应标注截图中要点击的UI元素编号) + - 参数:target_element (字符串,描述要点击的UI元素) + - **重要**:必须仔细观察标注截图,找到与target_element描述最匹配的红色方框,使用该方框左上角内侧红底白字显示的数字作为index + - **防止误选**:在reasoning中必须明确说明为什么选择这个红色方框而不是其相邻的红色方框 + +2. **滑动操作 (swipe)** + - 参数:direction (字符串,必须是 UP、DOWN、LEFT、RIGHT 之一) + - **重要**:滑动方向说明:UP表示向上滑动手指来向上滚动内容并显示下方内容;DOWN表示向下滑动手指来向下滚动内容并显示上方内容;LEFT表示向左滑动手指来向左滚动内容;RIGHT表示向右滑动手指来向右滚动内容。 + +3. **文本输入 (input)** + - 参数:text (字符串,要输入的文本内容) + +4. **完成任务 (done)** + - 无参数,表示任务已完成 + +## 输出格式 +请严格按照以下JSON格式输出: +```json +{{ + "reasoning": "详细说明你的分析思路和选择这个操作的原因。如果是点击操作,必须包含以下完整过程:\n1)目标元素的详细描述:包括元素内容、颜色、形状、大小等视觉特征\n2)目标元素的精确位置:在屏幕中的具体位置,使用周围元素作为参考\n3)标注图查找过程:说明在第几张标注图中找到了匹配的红色方框\n4)红色方框验证:确认该红色方框的位置、包含内容、边界是否与目标元素完全匹配\n5)index读取确认:明确说明选中红色方框左上角的数字是多少\n6)最终确认:再次确认这个选择是正确的,没有选错相邻元素\n如果是输入操作,必须先明确说明:1)当前屏幕是否显示软键盘 2)目标输入框是否已激活(有光标或高亮) 3)如果前两项有任何一项为否,则必须选择click操作先激活输入框,而不是input操作。", + "action": "操作名称(click/swipe/input/done)", + "parameters": {{ + "参数名": "参数值" + }} +}} +``` + +## index选择的关键步骤(点击操作必读) +**步骤1: 精确描述目标元素的视觉特征** +- 详细描述目标元素的内容、颜色、形状等视觉特征 +- 精确描述元素在屏幕中的位置(如:屏幕上方1/3处、左侧边缘、右下角等) +- 描述元素周围的其他UI元素作为参考点 +- 例如:"需要点击带有'搜索'文字的白色输入框,位于屏幕最上方,在应用标题下方" + +**步骤2: 系统性查找红色方框** +- **必须按顺序逐张查看每一张标注图**,不能跳过任何一张 +- 对于每张图,先整体观察所有红色方框的分布 +- 重点查找与步骤1描述的位置和特征完全匹配的红色方框 +- **关键要求:红色方框必须完全包围目标元素,边界贴合** + +**步骤3: 多重验证确保选择正确(最重要步骤)** +- **位置验证**:确认红色方框的位置与步骤1描述的位置完全一致 +- **内容验证**:仔细观察红色方框内部包含的内容是否就是目标元素 +- **边界验证**:红色方框的边界应该紧贴目标元素,不应该包含过多空白区域 +- **排除干扰**:如果有多个相似的红色方框,必须选择位置最精确匹配的那个 +- **避免相邻选择**:绝对不能选择目标元素旁边或附近的红色方框 + +**步骤4: 读取index数字(执行前的最后确认)** +- 再次确认选中的红色方框确实包围了正确的目标元素 +- 查看该红色方框**左上角内侧**的红底白字数字 +- **严格要求:必须是左上角,数字必须清晰可见** +- 该数字就是要使用的index值 + +**步骤5: 最终验证** +- 在reasoning中明确说明:"我选择的红色方框位于[具体位置],框内包含[具体内容],左上角数字为[X]" +- 如果对选择有任何不确定,必须重新从步骤1开始 + +## 文本输入的关键步骤(输入操作必读) +**重要前提:绝对禁止在未激活输入框时直接使用input操作!** + +**步骤1: 强制检查软键盘状态(必须执行,不可跳过)** +- **必须检查**:仔细观察当前屏幕最底部是否显示了软键盘(虚拟键盘界面) +- **判断标准**:如果屏幕底部没有显示包含字母、数字键的软键盘界面,说明没有任何输入框被激活 +- **关键规则**:**只有当软键盘完全显示在屏幕底部时,才允许进行input操作** +- **reasoning必须写明**:"检查软键盘状态:[已显示/未显示]" + +**步骤2: 输入框激活操作(如果第1步检查失败则必须执行)** +- **严格禁止**:如果没有软键盘或输入框未激活,绝对不能使用input操作 +- **必须操作**:必须先使用click操作点击目标输入框来激活它 +- **reasoning必须写明**:"软键盘未显示/输入框未激活,必须先点击激活输入框" + +**步骤3: 处理现有内容** +- 如果输入框中有默认文本,可以先尝试清除或直接覆盖 +- 根据具体情况选择处理方式 + +**步骤4: 执行文本输入(仅在前置条件满足时)** +- 确认软键盘已显示且输入框已激活后,才能使用input操作 +- 输入后检查输入框内的文本是否正确,确保没有输入错误、遗漏或多余字符 +- 在reasoning中必须明确说明"已确认软键盘显示且输入框已激活" + +**步骤5: 输入后的软键盘处理(重要)** +- **输入完成后必须检查**:观察软键盘上的按键类型 +- **隐藏软键盘的判断标准**: + - 如果软键盘上有"搜索"、"确定"、"完成"、"发送"等提交按钮,应该点击这些按钮 + - 如果软键盘上只有"下一项"、"换行"等非提交按钮,且软键盘遮挡了重要的界面元素,应该点击软键盘右上角的"向下箭头"按钮来隐藏软键盘 +- **reasoning必须说明**:"检查软键盘按键类型:[提交类型/导航类型]。软键盘是否遮挡重要元素:[是/否]。决定[点击提交按钮/隐藏软键盘/保持现状]" + +**严格禁止的input操作模式** +1. 没有软键盘但直接input +2. 未在reasoning中说明检查过程但直接input +3. 看到输入框就直接input(必须先检查激活状态) +4. 输入完成后不考虑软键盘遮挡问题 + +**唯一正确的input操作模式** +- reasoning包含:"检查软键盘状态:已显示。检查输入框状态:已激活。已确认可以进行文本输入。" +- 只有包含以上完整检查过程的reasoning才允许使用input操作 +- **输入后处理**:输入完成后,必须检查软键盘按键类型和是否遮挡重要元素,决定是否需要隐藏软键盘 + +## 重要规则 +1. **位置匹配优先**:先确定元素在原图中的准确位置,再找标注图中对应位置的红色方框 +2. **数字读取准确**:index必须是红色方框左上角内侧红底白字显示的实际数字 +3. **避免误选相邻元素**:这是最容易出错的地方!必须确保选择的红色方框完全包围目标元素,而不是相邻的类似元素 +4. **强制性相邻元素排除检查**:在选择任何index前,必须明确说明为什么没有选择周围的其他红色方框 +5. **软键盘遮挡处理**:输入完成后,如果软键盘遮挡了重要元素且没有提交按钮,应该点击右上角向下箭头隐藏软键盘 +6. **多步骤操作**:对于复杂选择(如日期范围、时间段、级联选项),需要多个连续操作 +7. **日期选择特别注意**: + - 在日期选择界面时,必须先确认当前显示的月份是否正确 + - 不能仅仅看到相同的日期数字就直接选择,必须确保月份匹配任务要求 + - 如果月份不对,需要先切换到正确的月份,然后再选择日期 +8. **任务完成判断**:只有在确实完成了指定任务时才使用done操作 +9. **操作连贯性**:每个操作都应该基于当前屏幕状态和任务目标进行合理选择 +10. **页面错误处理**:如果遇到进入错误页面或加载失败,可以尝试返回上一级界面(通过手势自屏幕最左侧向右滑动或使用点击返回按钮) + +## index选择示例 +**错误示例1**: +- reasoning: "需要点击搜索按钮" +- 问题:没有描述元素的具体位置和视觉特征 + +**错误示例2**: +- reasoning: "需要点击搜索框,位于屏幕上方。在标注图中找到了搜索框,选择数字8。" +- 问题:描述过于简单,没有验证过程,容易选错相邻元素 + +**正确示例**: +- reasoning: "1)目标元素详细描述:需要点击带有'搜索'提示文字的白色输入框,该输入框呈长方形,有浅灰色边框。2)精确位置描述:该搜索框位于屏幕最上方,在状态栏下方约50像素处,占据屏幕宽度的80%左右,位置居中。3)标注图查找:在第2张标注图中,我找到了位于屏幕上方中央位置的红色方框。4)红色方框验证:该红色方框完全包围了搜索输入框,边界与输入框的边缘完全贴合,框内确实包含带有'搜索'文字的白色输入框。5)index读取:该红色方框的左上角内侧清晰显示数字'15'。6)最终确认:确认该方框没有包含其他无关元素,也不是相邻的其他UI元素,正是我要点击的搜索框。" +- parameters: {{"index": 15, "target_element": "搜索输入框"}} + +**记住:每次点击操作都必须在reasoning中包含完整的6步验证过程,确保精确匹配而不是选择相邻元素!每次输入操作后都要考虑软键盘遮挡问题!** \ No newline at end of file diff --git a/prompts/change_task_description.md b/prompts/change_task_description.md new file mode 100644 index 0000000..daad8ea --- /dev/null +++ b/prompts/change_task_description.md @@ -0,0 +1,27 @@ +你是一个任务描述改写专家。给定一个应用名称和原始任务描述,你需要生成{count}个语义完全相同但表达方式不同的任务描述。 + +重要要求: +1. **语义必须完全相同** - 任务的目标、操作、内容都不能改变 +2. **生成规则**: + - 总共生成{count}条描述 + - 前3条:不包含应用名称,使用通用表达 + - 后3条:必须包含应用名称(如"在{app_name}中..."、"用{app_name}..."、"打开{app_name}..."等) +3. **表达自然** - 改写后的描述应该自然、符合中文日常表达习惯 +4. **不改变具体内容** - 搜索关键词、目标对象、操作步骤等具体内容不能改变 + +请严格按照以下JSON格式返回结果(前3条不带应用名称,后3条带应用名称): +```json +[ + "不带应用名称的任务描述1", + "不带应用名称的任务描述2", + "不带应用名称的任务描述3", + "带应用名称的任务描述1", + "带应用名称的任务描述2", + "带应用名称的任务描述3" +] +``` + +应用名称:{app_name} +原始任务描述:{original_task} + +请生成{count}个语义相同但表达不同的任务描述 \ No newline at end of file diff --git a/prompts/decider.md b/prompts/decider.md new file mode 100644 index 0000000..0a6a29f --- /dev/null +++ b/prompts/decider.md @@ -0,0 +1,13 @@ + +You are a phone-use AI agent. Now your task is "{task}". +Your action history is: +{history} +Please provide the next action based on the screenshot and your action history. You should do careful reasoning before providing the action. +Your action space includes: +- Name: click, Parameters: target_element (a high-level description of the UI element to click). +- Name: swipe, Parameters: direction (one of UP, DOWN, LEFT, RIGHT). +- Name: input, Parameters: text (the text to input). +- Name: wait, Parameters: (no parameters, will wait for 1 second). +- Name: done, Parameters: (no parameters). +Your output should be a JSON object with the following format: +{{"reasoning": "Your reasoning here", "action": "The next action (one of click, input, swipe, wait, done)", "parameters": {{"param1": "value1", ...}}}} \ No newline at end of file diff --git a/prompts/decider_nohistory.md b/prompts/decider_nohistory.md new file mode 100644 index 0000000..464666b --- /dev/null +++ b/prompts/decider_nohistory.md @@ -0,0 +1,12 @@ + +You are a phone-use AI agent. Now your task is "{task}". Please provide the next action based on the screenshot. +You should do careful reasoning before providing the action. +Please provide the next action based on the screenshot and your action history. You should do careful reasoning before providing the action. +Your action space includes: +- Name: click, Parameters: target_element (a high-level description of the UI element to click). +- Name: swipe, Parameters: direction (one of UP, DOWN, LEFT, RIGHT). +- Name: input, Parameters: text (the text to input). +- Name: wait, Parameters: (no parameters, will wait for 1 second). +- Name: done, Parameters: (no parameters). +Your output should be a JSON object with the following format: +{{"reasoning": "Your reasoning here", "action": "The next action (one of click, input, swipe, wait, done)", "parameters": {{"param1": "value1", ...}}}} \ No newline at end of file diff --git a/prompts/grounder_bbox.md b/prompts/grounder_bbox.md new file mode 100644 index 0000000..990eef0 --- /dev/null +++ b/prompts/grounder_bbox.md @@ -0,0 +1,6 @@ + +Based on the screenshot, user's intent and the description of the target UI element, provide the bounding box of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"bbox": [x1, y1, x2, y2]}} \ No newline at end of file diff --git a/prompts/grounder_coordinates.md b/prompts/grounder_coordinates.md new file mode 100644 index 0000000..c48aa4e --- /dev/null +++ b/prompts/grounder_coordinates.md @@ -0,0 +1,6 @@ + +Based on the screenshot, user's intent and the description of the target UI element, provide the coordinates of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"coordinates": [x, y]}} \ No newline at end of file diff --git a/prompts/planner.md b/prompts/planner.md new file mode 100644 index 0000000..adabd5a --- /dev/null +++ b/prompts/planner.md @@ -0,0 +1,58 @@ +## 角色定义 +你是一个任务描述优化专家和智能手机应用选择助手。你需要根据用户的任务描述选择最合适的应用,并同时生成一个更准确、更贴合用户日常使用习惯、语义必须完全相同的任务描述。 + +## 任务描述 +用户想要完成的任务是:"{task_description}" + +## 可用应用列表 +以下是可用的应用及其包名: +- 微信: com.tencent.mm +- QQ: com.tencent.mobileqq +- 新浪微博: com.sina.weibo +- 饿了么: me.ele +- 美团: com.sankuai.meituan +- bilibili: tv.danmaku.bili +- 爱奇艺: com.qiyi.video +- 腾讯视频: com.tencent.qqlive +- 优酷: com.youku.phone +- 淘宝: com.taobao.taobao +- 京东: com.jingdong.app.mall +- 携程: ctrip.android.view +- 同城: com.tongcheng.android +- 飞猪: com.taobao.trip +- 去哪儿: com.Qunar +- 华住会: com.htinns +- 知乎: com.zhihu.android +- 小红书: com.xingin.xhs +- QQ音乐: com.tencent.qqmusic +- 网易云音乐: com.netease.cloudmusic +- 酷狗音乐: com.kugou.android +- 抖音: com.ss.android.ugc.aweme +- 高德地图: com.autonavi.minimap + +## 任务要求 +1. 分析任务描述,选择最合适的应用来完成该任务 +2. 生成一个更准确、更贴合用户日常使用习惯、语义必须完全相同的任务描述 + +## 任务描述优化要求 +1. **语义必须完全相同** - 任务的目标、操作、内容都不能改变 +2. **表达自然** - 改写后的描述应该自然、符合中文日常表达习惯 +3. **不改变具体内容** - 搜索关键词、目标对象、操作步骤等具体内容不能改变 + +## 输出格式 +请严格按照以下JSON格式输出: +```json +{{ + "reasoning": "分析任务内容,说明为什么选择这个应用最合适", + "app_name": "选择的应用名称", + "package_name": "选择的应用包名", + "task_description": "优化后的任务描述" +}} +``` + +## 重要规则 +1. 只能从上述可用应用列表中选择 +2. 必须选择最符合任务需求的应用 +3. 如果任务涉及多个可能的应用,选择最主要和最常用的那个 +4. 包名必须完全匹配列表中的包名,不能修改 +5. 优化后的任务描述应该更准确、更贴合用户日常使用习惯、语义必须完全相同的任务描述 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..50d7dc9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,39 @@ +# torch +torch +torchvision +torchaudio + +# PaddlePaddle and OCR +paddlepaddle +paddleocr==2.10.0 +pytesseract + +# Computer Vision and ML +ultralytics +transformers==4.47.0 +supervision + +# Image Processing +Pillow +opencv-python + +# Scientific Computing +numpy +scipy + +# LangChain +langchain-openai +langchain-core + +# API and Mobile +openai +uiautomator2 + +# Web Framework +fastapi +uvicorn + +# UI +matplotlib +requests +ui-tars diff --git a/runner/README.md b/runner/README.md new file mode 100644 index 0000000..3ac29f5 --- /dev/null +++ b/runner/README.md @@ -0,0 +1,74 @@ +# Agent Runner + +## MobiAgent Runner + +**支持功能** +1. 论坛文章视频类(小红书,b站,知乎等) +- 关注xx,进入主页 +- 搜索,打开,播放 +- 在用户主页搜索,打开,播放 +- 点赞,收藏,评论,转发 + +2. 社交软件类(微信QQ等) +- 发消息,打电话,打视频,查找聊天内容 +- @某人+发消息 +- 打开小程序,打开朋友圈(打开朋友圈评论我们这个框架肯定可以) + +3. 购物类(淘宝,京东等) +- 搜索,按照价格销量等排序搜索,打开搜索结果 +- 加入购物车和下单,选择对应规格加入购物车和下单 +- 关注店铺 + +4. 外卖类(饿了么,美团) +- 点外卖,包括选择规格和数量 + +5. 旅游类(飞猪,去哪儿,携程,同城,华住会) +- 查询酒店价格(地点,地标附近,指定酒店,日期) +- 预定酒店(地点,地标附近,指定酒店,日期,房间类型) +- 购买火车票飞机票(和设定始发地和目的地,以及日期时间段) + +6. 地图类(高德) +- 导航,打车(始发地,目的地可以更改) + +7. 听歌类(网易云,QQ音乐) +- 搜索歌曲,歌手,乐队 +- 搜索并播放 + +### 模型部署 +下载好 `decider`、`grounder` 和 `planner` 三个模型后,使用 vLLM 部署模型推理服务: + +**默认端口部署** +```bash +vllm serve IPADS-SAI/MobiMind-Decider-7B --port +vllm serve IPADS-SAI/MobiMind-Grounder-3B --port +vllm serve Qwen/Qwen3-4B-Instruct --port +``` + +**注意事项** +- 确保部署的服务端口与后续启动 MobiMind-Agent 时指定的端口参数一致 +- 如果使用非默认端口,需要在启动 Agent 时通过 `--decider_port`、`--grounder_port`、`--planner_port` 参数指定对应端口 + +### 设置任务 +在 `runner/mobiagent/task.json` 中写入要测试的任务列表 + +### 项目启动 + +**基本启动**(使用默认配置) +```bash +python -m runner.mobiagent.mobiagent +``` + +**自定义配置启动** +```bash +python -m runner.mobiagent.mobiagent --service_ip <服务IP> --decider_port <决策服务端口> --grounder_port <定位服务端口> --planner_port <规划服务端口> +``` + +**参数说明** +- `--service_ip`:服务IP(默认:`localhost`) +- `--decider_port`:决策服务端口(默认:`8000`) +- `--grounder_port`:定位服务端口(默认:`8001`) +- `--planner_port`:规划服务端口(默认:`8002`) + +## UI-TARS Runner + + \ No newline at end of file diff --git a/runner/UI-TARS-agent/.gitignore b/runner/UI-TARS-agent/.gitignore new file mode 100644 index 0000000..d5e4f15 --- /dev/null +++ b/runner/UI-TARS-agent/.gitignore @@ -0,0 +1 @@ +**/__pycache__ \ No newline at end of file diff --git a/runner/UI-TARS-agent/README.md b/runner/UI-TARS-agent/README.md new file mode 100644 index 0000000..1e6b430 --- /dev/null +++ b/runner/UI-TARS-agent/README.md @@ -0,0 +1,157 @@ +# UI-TARS 自动化执行框架 + +基于UI-TARS-7B-SFT模型的移动应用智能自动化框架,支持视觉理解和自然语言指令执行,能够收集到支持当前项目MobiAgent格式的trace数据。 + +## 核心文件 + +- `automation_framework_simple.py` - 核心自动化框架 +- `automation_examples.py` - 任务示例和交互界面 +- `requirements.txt` - Python依赖包 + +## 快速开始 + +### 1. 安装依赖 + +安装项目根目录下的相关依赖即可 + +```bash +pip install -r requirements.txt +``` + +### 2. 启动模型服务 +```bash +# 使用vLLM启动UI-TARS模型服务 +python -m vllm.entrypoints.openai.api_server \ + --model UI-TARS-7B-SFT \ + --served-model-name UI-TARS-7B-SFT \ + --host 0.0.0.0 \ + --port 8000 +``` + +### 3. 连接Android设备 +- 开启USB调试 +- 安装ADB键盘 (用于文本输入)(在手机上安装项目根目录中的 ADBKeyboard.apk 文件) +- 确保设备可通过ADB连接 + +### 4. 运行示例 + +下面是最小步骤,帮助你在本地运行 UI-TARS完成手机任务的核心脚本。 + +1. 安装依赖: +```bash +pip install -r requirements.txt +``` + +1. 运行单任务示例: +```bash +python quick_start.py +``` +quick_start 会使用仓库内默认配置并引导你通过交互或指定参数运行单个测试任务。 + +1. 批量执行任务(示例): +```bash +python batch_task_executor.py --config auto_task.json +``` +默认配置文件 `auto_task.json` 放在仓库根目录,用于演示批量任务的 JSON 格式。 + +## 目录与主要文件说明(补充) +- `quick_start.py`:单任务执行的快速入口,适合手工调试和快速验证。 +- `batch_task_executor.py`:批量任务执行器,支持从 JSON 列表加载多个任务并顺序执行。 +- `ui_tars_automation/`:核心库,包含设备/坐标处理、数据管理、日志封装和执行框架。 + - `data_manager.py`:负责保存截图、XML、动作记录等执行数据。 + - `config.py`:包含执行相关的可调整配置项(例如保存路径、最大步数等)。 + - `framework.py`:任务执行流程入口与调度逻辑。 + +## 常见问题与提示(补充) +- 日志:脚本会输出到标准输出并保存到当前目录下的日志文件,若需要修改日志级别或格式,可编辑 `ui_tars_automation/logger.py`。 +- ADB 输入:如需通过键盘输入模拟,请先将 `ADBKeyboard.apk` 安装到目标设备。 +- 数据输出路径:可在 `ui_tars_automation/config.py` 中调整 `data_base_dir`。 + +## 预定义任务 + +框架包含以下预定义任务: + +1. **淘宝购物**: 打开淘宝应用,搜索'手机壳',查看搜索结果 +2. **微信聊天**: 打开微信,找到好友列表,查看最近的聊天记录 +3. **系统设置**: 打开系统设置,找到WiFi设置,查看当前连接的WiFi信息 +4. **网页浏览**: 打开浏览器,搜索'UI-TARS'相关信息 +5. **短视频**: 打开抖音或B站,浏览视频内容 +6. **地图导航**: 打开地图应用,搜索附近的餐厅 +7. **音乐播放**: 打开音乐应用,搜索并播放一首歌曲 + +## 自定义任务 + +支持三种方式执行任务: + +1. **预定义任务**: 从任务列表选择执行 +2. **交互模式**: 输入自然语言描述执行自定义任务 +3. **动态添加**: 将自定义任务添加到任务列表 + +## 支持的操作 + +- `click(point)` - 点击指定坐标 +- `long_press(point)` - 长按指定坐标 +- `type(content)` - 输入文本 +- `scroll(point, direction)` - 滚动(上/下/左/右) +- `drag(start_point, end_point)` - 拖拽操作 +- `press_home()` - 按Home键 +- `press_back()` - 按返回键 +- `finished(content)` - 任务完成 + +## 配置参数 + +```python +config = ExecutionConfig( + model_base_url="http://192.168.12.152:8000/v1", # 模型服务地址 + model_name="UI-TARS-7B-SFT", # 模型名称 + max_steps=30, # 最大执行步数 + step_delay=2.0, # 步骤间延迟(秒) + language="Chinese" # 思考语言 +) +``` + +## 使用示例 + +### 编程方式使用 +```python +from automation_framework_simple import UITarsSimpleFramework, ExecutionConfig + +# 创建配置 +config = ExecutionConfig( + model_base_url="http://192.168.12.152:8000/v1", + max_steps=20, + step_delay=2.0 +) + +# 创建框架实例 +framework = UITarsSimpleFramework(config) + +# 执行任务 +success = framework.execute_task("打开微信,查看朋友圈") + +# 获取执行摘要 +summary = framework.get_execution_summary() +print(f"任务{'成功' if success else '失败'}, 共{summary['total_steps']}步") +``` + +### 命令行使用 +```bash +python automation_examples.py +``` + +## 输出文件 + +- `automation.log` - 详细执行日志 +- `execution_log_*.json` - 任务执行摘要 +- `automation_screenshots/` - 每步执行的截图 + +## 注意事项 + +1. 确保模型服务正常运行 +2. 设备网络连接稳定 +3. 合理设置最大步数和延迟时间 +4. 任务描述要清晰具体 + +## 许可证 + +Apache 2.0 License diff --git a/runner/UI-TARS-agent/USAGE_GUIDE.md b/runner/UI-TARS-agent/USAGE_GUIDE.md new file mode 100644 index 0000000..c88fd3e --- /dev/null +++ b/runner/UI-TARS-agent/USAGE_GUIDE.md @@ -0,0 +1,341 @@ +# UI-TARS 批量任务执行器使用指南 + +## 快速开始 + +### 1. 环境准备 + +确保已安装必要的依赖: +```bash +pip install -r requirements.txt +``` + +确保Android设备已连接并启用USB调试: +```bash +adb devices +``` + +可选:安装 `ADBKeyboard.apk`(用于发送文本输入) +```bash +adb install -r ADBKeyboard.apk +``` + +### 2. 脚本说明 + +| 脚本文件 | 功能说明 | 适用场景 | +|---------|----------|----------| +| `batch_task_executor.py` | 执行auto_task-3.json中的所有任务 | 生产环境批量执行 | +| `test_batch_executor.py` | 执行少量测试任务 | 测试验证 | +| `test_single_task.py` | 执行单个自定义任务 | 调试和测试 | +| `quick_start.py` | 原有的快速启动示例 | 单任务快速测试 | + +### 3. 使用步骤 + +#### 方案一:测试单个任务(推荐新手) +```bash +python3 test_single_task.py +``` +1. 输入模型服务地址 +2. 选择应用(bilibili、淘宝等) +3. 输入任务描述 +4. 确认执行 + +#### 方案二:测试少量任务(推荐调试) +```bash +python3 test_batch_executor.py +``` +- 自动执行bilibili的前2个type1任务 +- 验证批量执行流程是否正常 + +#### 方案三:执行所有任务(生产环境) +```bash +python3 batch_task_executor.py +``` +- 执行auto_task.json中的所有210个任务 +- 需要较长时间完成 + +## 配置说明 + +### 模型服务配置 +- 默认地址:http://192.168.12.152:8000/v1 +- 模型名称:UI-TARS-7B-SFT +- 可在运行时修改地址 + +### 执行参数 +- 最大步数:30步 +- 步骤延迟:2秒 +- 语言:中文 +- 任务间隔:5秒 + +### 数据保存 +执行结果保存在以下目录结构: +``` +data_example/ # 生产数据 +test_data_example/ # 测试数据 +├── bilibili/ +│ ├── type1/ +│ │ ├── 1/ +│ │ │ ├── task_data.json # 任务执行数据(原格式) +│ │ │ ├── actions.json # 操作记录(参考淘宝格式) +│ │ │ ├── 1/screenshot_1.jpg # 第1步截图 +│ │ │ ├── 1/hierarchy_1.xml # 第1步XML +│ │ │ ├── 2/screenshot_2.jpg # 第2步截图 +│ │ │ ├── 2/hierarchy_2.xml # 第2步XML +│ │ │ └── ... +│ │ ├── 2/ +│ │ └── 3/ +│ └── type2/ +└── ... +``` + +### 数据格式说明 + +#### task_data.json格式(保持原有格式) +```json +{ + "task_description": "在B站搜一下\"元神4.8版本更新\"", + "app_name": "bilibili", + "task_type": "type1", + "task_index": 1, + "package_name": "tv.danmaku.bili", + "execution_time": "2025-08-27T22:05:03.639290", + "action_count": 3, + "actions": [ + { + "reasoning": "需要点击搜索框以便输入搜索内容", + "function": { + "name": "click", + "parameters": {"x": 546, "y": 201} + } + } + ], + "success": true +} +``` + +#### actions.json格式(参考淘宝格式) +```json +{ + "app_name": "bilibili", + "task_type": "type1", + "task_description": "在B站搜一下\"元神4.8版本更新\"", + "action_count": 3, + "actions": [ + { + "type": "click", + "position_x": 546, + "position_y": 201, + "bounds": [204, 153, 773, 264], + "action_index": 1 + }, + { + "type": "input", + "text": "元神4.8版本更新", + "action_index": 2 + }, + { + "type": "done", + "action_index": 3 + } + ] +} +``` + +## 任务结构 + +### auto_task-3.json结构 +```json +[ + { + "app": "bilibili", + "type": "type1", + "tasks": [ + "在B站搜一下"元神4.8版本更新"", + "在B站搜一下"决战库班之王"", + "在B站搜一下"黑神话悟空最新实机"" + ] + } +] +``` + +### 支持的应用 +- bilibili (B站) +- 淘宝 +- 携程 +- 网易云音乐 +- 小红书 +- 高德地图 +- 饿了么 + +## 监控和日志 + +### 日志文件 +- `batch_execution.log` - 批量执行日志 +- `test_batch_execution.log` - 测试执行日志 +- `automation.log` - 框架执行日志 + +### 实时监控 +执行过程中会在终端显示: +- 当前执行的任务 +- 执行步骤详情 +- 成功/失败状态 +- 执行统计 + +## 故障排除 + +### 常见问题 + +1. **设备连接失败** +```bash +adb devices +adb kill-server +adb start-server +``` + +2. **应用启动失败** +- 检查应用是否已安装 +- 检查包名是否正确 +- 尝试手动启动应用 + +3. **模型调用失败** +- 检查网络连接 +- 验证模型服务地址 +- 检查服务是否运行 + +4. **任务执行超时** +- 检查任务复杂度 +- 增加最大步数限制 +- 检查设备响应速度 + +### 调试技巧 + +1. **先测试单个任务** +```bash +python3 test_single_task.py +``` + +2. **检查保存的截图** +查看步骤目录中的screenshot文件 + +3. **查看详细日志** +```bash +tail -f batch_execution.log +``` + +4. **验证应用状态** +确保应用处于预期界面 + +## 性能优化 + +### 提高成功率 +1. 确保设备性能良好 +2. 关闭不必要的后台应用 +3. 保持网络连接稳定 +4. 适当增加步骤延迟 + +### 提高执行效率 +1. 批量执行相同应用的任务 +2. 合理设置任务间隔 +3. 优化任务描述的准确性 + +## 数据分析 + +### task_data.json 字段说明 +- `task_description`: 任务描述 +- `app_name`: 应用名称 +- `task_type`: 任务类型 +- `task_index`: 任务索引 +- `package_name`: 应用包名 +- `execution_time`: 执行时间 +- `action_count`: 操作步数 +- `actions`: 详细操作记录 +- `success`: 是否成功 + +### 统计分析 +可以通过分析task_data.json文件来: +- 计算各应用的成功率 +- 分析平均执行步数 +- 识别常见失败原因 +- 优化任务描述 + +## 扩展开发 + +### 添加新应用 +在app_packages字典中添加新的包名映射: +```python +app_packages = { + "新应用名": "com.example.package" +} +``` + +### 修改执行参数 +调整ExecutionConfig中的配置: +```python +config = ExecutionConfig( + max_steps=50, # 增加最大步数 + step_delay=3.0, # 增加延迟 + temperature=0.1 # 调整模型温度 +) +``` + +### 自定义数据保存 +重写save_task_data方法以自定义数据格式。 + +## 注意事项 + +1. **资源占用**:长时间执行会产生大量截图数据 +2. **设备稳定性**:建议定期重启设备和清理缓存 +3. **网络稳定**:确保模型服务连接稳定 +4. **权限设置**:确保应用有必要的权限 +5. **存储空间**:预留足够的存储空间保存数据 + +## 版本历史 + +- v1.0: 基础批量执行功能 +- v1.1: 添加测试脚本和数据保存优化 +- v1.2: 完善错误处理和日志记录 + +## 快速使用指南(补充) +此文档补充了 `quick_start.py` 与 `batch_task_executor.py` 的常用运行示例、配置说明和故障排查要点。 + +### 环境准备 +1. 确保已安装依赖: +```bash +pip install -r requirements.txt +``` +2. 准备 Android 设备并启用 adb: +```bash +adb devices +``` +3. 可选:安装 `ADBKeyboard.apk`(用于发送文本输入) +```bash +adb install -r ADBKeyboard.apk +``` + +### quick_start.py(手动/交互式运行) +- 目的:快速启动单任务执行,用于调试和试验。 +- 运行: +```bash +python quick_start.py --task-file +``` +- 常见参数: + - `--task-file`:指定单个任务的 JSON 配置文件(如果不传则使用内置示例)。 + - `--no-screenshot`:禁用截图保存以加速执行(若需要节省空间)。 + +### batch_task_executor.py(批量运行) +- 目的:从 JSON 列表中批量执行任务,适合收集多次样本或回放历史动作集。 +- 运行: +```bash +python batch_task_executor.py --config auto_task.json +``` +- 建议:先运行小批量(例如 2~5 个任务)进行 smoke 测试,再扩大规模。 + +### 日志与数据输出 +- 执行数据(截图、dump XML、动作日志)由 `DataManager` 保存,默认会在 `data_base_dir` 下按时间戳和任务名建立子目录。 +- 如果需要自定义保存策略,请查看并调整 `ui_tars_automation/data_manager.py` 中的 `DataManager` 类。 + +### 常见问题与排查 +- adb 无设备:确认设备已连接,且运行 `adb devices` 可以看到设备 ID。 +- 权限问题:某些设备需要开启开发者选项与 USB 调试权限。 +- 执行卡住:查看日志文件(同目录下)以定位是哪一步失败,常见为坐标匹配或超时。 + +如果你希望我把 `USAGE_GUIDE.md` 扩展为带参数文档的 CLI 参考(包含 `--help` 输出模拟),告诉我优先级,我会继续完善。 diff --git a/runner/UI-TARS-agent/auto_task.json b/runner/UI-TARS-agent/auto_task.json new file mode 100644 index 0000000..3093e52 --- /dev/null +++ b/runner/UI-TARS-agent/auto_task.json @@ -0,0 +1,468 @@ +[ + { + "app": "bilibili", + "type": "type1", + "tasks": [ + "在B站搜一下“元神4.8版本更新”", + "在B站搜一下“决战库班之王”", + "在B站搜一下“黑神话悟空最新实机”", + "在B站搜一下“AI绘画Stable Diffusion教程”", + "在B站搜一下“高考查分名场面”", + "在B站搜一下“巴黎奥运会开幕式”", + "在B站搜一下“三伏天避暑指南”", + "在B站搜一下“华为Mate70爆料”", + "在B站搜一下“科目三舞蹈原版”", + "在B站搜一下“深夜泡面番推荐”" + ] + }, + { + "app": "bilibili", + "type": "type2", + "tasks": [ + "在B站播放《进击的巨人》最终季", + "在B站播放LPL夏季赛总决赛", + "在B站播放周杰伦演唱会4K修复版", + "在B站播放《星空》游戏实况", + "在B站播放《流浪地球3》预告片" + ] + }, + { + "app": "bilibili", + "type": "type3", + "tasks": [ + "在B站搜一下UP主老番茄", + "在B站搜一下UP主老师好我叫何同学", + "在B站搜一下UP主罗翔说刑法", + "在B站搜一下UP主小潮院长", + "在B站搜一下UP主木鱼水心" + ] + }, + { + "app": "bilibili", + "type": "type4", + "tasks": [ + "在B站进入UP主半佛仙人的主页", + "在B站进入UP主花少北的主页", + "在B站进入UP主敖厂长的主页", + "在B站进入UP主某幻君的主页", + "在B站进入UP主大祥哥来啦的主页" + ] + }, + { + "app": "bilibili", + "type": "type5", + "tasks": [ + "B站搜索up主麻薯波比呀的《马斯克组建美国党》视频", + "B站搜索up主一岸舟的武林外传视频", + "B站搜索up主老师好我叫何同学的5G测评视频", + "B站搜索up主罗翔说刑法的视频《拐卖男友去电诈》", + "B站搜索up主小潮院长的《走到哪算哪》视频" + ] + }, + { + "app": "bilibili", + "type": "type6", + "tasks": [ + "搜索up主木鱼水心的《明朝整顿职场第一人》解析,并播放", + "搜索up主某幻君的《当我第八次尝试rap》剧情MV,并播放", + "搜索up主敖厂长的《敖厂长开高达》视频,并播放", + "在B站搜索up主花少北的《我教会了它如何做人》,并播放", + "在B站搜索up主大祥哥来啦的“犒劳犒劳自己吃点好的吧”视频,并播放" + ] + }, + { + "app": "bilibili", + "type": "type7", + "tasks": [ + "在B站关注up主逍遥散人", + "在B站关注up主麻薯波比呀", + "在B站关注up主火山哥哥", + "在B站关注up主自来卷三木", + "在B站关注up主帅农鸟哥" + ] + }, + { + "app": "bilibili", + "type": "type8", + "tasks": [ + "在B站视频《黑神话:悟空》实机演示下评论“画质炸裂,期待发售!”", + "在B站视频《史上最骚杀手(第一集)》下评论“感谢UP,已经转发给表弟”", + "在B站视频《相机大战》下评论“截图当壁纸了,太美了”", + "在B站视频《Kpop随机舞蹈》下评论“小姐姐跳得比原版还齐”", + "在B站视频《法考经验贴》下评论“收藏了,明年必过!”" + ] + }, + { + "app": "bilibili", + "type": "type9", + "tasks": [ + "在up主老番茄的视频《驯虫高手》下评论“番茄别回头,我是你粉丝!”", + "在up主老师好我叫何同学5G测评下评论“这期科普太硬核,已三连”", + "在up主罗翔说刑法的视频《我们为什么要读书》下评论“法外狂徒张三又出现了”", + "在up主噢呼w的视频《敢杀我的马》下评论“申遗成功!”", + "在up主洛温阿特金森的视频一个关于正常人的故事下评论“现在看到的是自己”" + ] + }, + { + "app": "bilibili", + "type": "type10", + "tasks": [ + "搜索up主木鱼水心的《明朝整顿职场第一人》解析,播放后点赞", + "搜索up主某幻君的《当我第八次尝试rap》剧情MV,播放后点赞", + "搜索up主敖厂长的《敖厂长开高达》视频,并播放后点赞", + "在B站搜索up主花少北的《我教会了它如何做人》,并播放后点赞", + "在B站搜索up主大祥哥来啦的“犒劳犒劳自己吃点好的吧”视频,并播放后点赞" + ] + }, + { + "app": "淘宝", + "type": "type1", + "tasks": [ + "淘宝搜索最畅销的机械键盘", + "淘宝中搜索价格最低的智能手环", + "淘宝中搜索戴森的吹风机", + "淘宝中搜索销量最高的儿童安全座椅", + "淘宝中搜索价格最高的空气炸锅" + ] + }, + { + "app": "淘宝", + "type": "type2", + "tasks": [ + "淘宝中选择销量最高的机械键盘,并选择第一个", + "淘宝中查找一款价格最低的智能手环,并选择第一个", + "淘宝中查找最便宜的品牌为戴森的吹风机,并选择第一个", + "淘宝中搜索销量最高的儿童安全座椅,并选择第一个", + "淘宝中搜索价格最高的空气炸锅,并选择第一个" + ] + }, + { + "app": "淘宝", + "type": "type3", + "tasks": [ + "淘宝中搜索电动牙刷,加入购物车", + "淘宝中搜索智能手环加入购物车", + "淘宝中将RTX4060显卡加入购物车", + "淘宝中淘宝中将一件防晒衣,加入购物车", + "淘宝中将冰丝凉席三件套加入购物车" + ] + }, + { + "app": "淘宝", + "type": "type4", + "tasks": [ + "淘宝中将销量最高的智能门锁加入购物车", + "淘宝中将价格最低的儿童安全座椅加入购物车", + "淘宝中将品牌为戴森的V15吸尘器加入购物车", + "淘宝中将销量最高的空气炸锅加入购物车", + "淘宝中将价格最高的骨传导耳机加入购物车" + ] + }, + { + "app": "淘宝", + "type": "type5", + "tasks": [ + "淘宝中将型号为M码的冰丝防晒裤加入购物车", + "淘宝中将长度为1.5米的绿联USB-C快充线加入购物车", + "淘宝中将型号为XL码的李宁速干运动男短袖加入购物车", + "淘宝中将型号为L码的户外冲锋衣加入购物车", + "淘宝中将型号为42码的阿迪达斯跑鞋加入购物车" + ] + }, + { + "app": "淘宝", + "type": "type6", + "tasks": [ + "淘宝中将销量最高型号为42码的亚瑟士 Gel-Kayano 跑鞋加入购物车", + "淘宝中将销量最高颜色为白色的破壁机加入购物车", + "淘宝中将价格最低型号为L码的始祖鸟冲锋衣加入购物车", + "淘宝中将销量最高的16GB+512GB黑色的小米14Ultra手机加入购物车" + ] + }, + { + "app": "网易云音乐", + "type": "type1", + "tasks": [ + "网易云音乐中帮我搜索周深的歌曲", + "网易云音乐中找一下张惠妹的歌曲", + "网易云音乐中帮我搜索陈奕迅的歌曲", + "网易云音乐中帮我搜索Taylor Swift的歌曲", + "网易云音乐中帮我搜一下陈楚生的歌曲" + ] + }, + { + "app": "网易云音乐", + "type": "type2", + "tasks": [ + "网易云音乐中播放林俊杰的《江南》", + "网易云音乐中播放歌曲稻香", + "网易云音乐中播放邓紫棋的歌曲", + "网易云音乐中播放陈奕迅的《十年》", + "网易云音乐中播放周深的《大鱼》" + ] + }, + { + "app": "携程", + "type": "type1", + "tasks": [ + "携程中搜索北京到上海的飞机票", + "携程中搜索广州到深圳的火车票", + "携程中搜索成都到重庆的飞机票", + "携程中搜索杭州到南京的火车票", + "携程中搜索西安到郑州的飞机票" + ] + }, + { + "app": "携程", + "type": "type2", + "tasks": [ + "携程中搜索9月25日上海到北京的火车票", + "携程中搜索9月30日深圳到广州的飞机票", + "携程中搜索10月5日重庆到成都的火车票", + "携程中搜索10月10日南京到杭州的飞机票", + "携程中搜索10月15日郑州到西安的火车票" + ] + }, + { + "app": "携程", + "type": "type3", + "tasks": [ + "携程中搜索9月26日北京到广州、出发时间08:00-12:00的航班", + "携程中搜索9月28日上海到杭州、出发时间14:00-18:00的火车票", + "携程中搜索10月2日广州到成都、出发时间06:00-10:00的航班", + "携程中搜索10月8日成都到西安、出发时间10:00-14:00的火车票", + "携程中搜索10月12日杭州到北京、出发时间16:00-20:00的航班" + ] + }, + { + "app": "携程", + "type": "type4", + "tasks": [ + "携程中搜索9月27日北京到上海、到达时间12:00-16:00的火车票", + "携程中搜索10月1日上海到广州、到达时间18:00-22:00的航班", + "携程中搜索10月4日广州到深圳、到达时间09:00-11:00的火车票", + "携程中搜索10月7日深圳到成都、到达时间13:00-17:00的航班", + "携程中搜索10月14日成都到杭州、到达时间20:00-23:59的火车票" + ] + }, + { + "app": "携程", + "type": "type6", + "tasks": [ + "在携程里帮我找一下万豪酒店", + "携程中帮我查询一下洲际酒店的相关信息", + "帮我在携程上找一找凯悦酒店", + "携程里帮我搜一下香格里拉酒店", + "在携程中帮我查寻一下喜来登酒店" + ] + }, + { + "app": "携程", + "type": "type7", + "tasks": [ + "携程中帮我查一下上海外滩附近的万豪酒店", + "在携程里找一找广州珠江新城周边的洲际酒店", + "携程上帮我查询一下深圳福田CBD附近的凯悦酒店", + "帮我在携程中查寻一下成都春熙路周边的香格里拉酒店", + "携程里帮我搜一下杭州西湖景区附近的喜来登酒店" + ] + }, + { + "app": "携程", + "type": "type8", + "tasks": [ + "携程中帮我预定一间上海外滩附近万豪酒店的双床房", + "携程中帮我找一间广州珠江新城周边洲际酒店的大床房", + "携程中帮我预定一间深圳福田CBD附近凯悦酒店的双床房", + "携程中帮我找一间成都春熙路周边香格里拉酒店的江景房", + "携程中帮我预定一间杭州西湖景区附近喜来登酒店的大床房" + ] + }, + { + "app": "携程", + "type": "type9", + "tasks": [ + "携程中帮我预定一间上海外滩附近评分最高的万豪酒店双床房", + "携程中帮我找一间广州珠江新城周边价格最低的洲际酒店大床房", + "携程中帮我预定一间深圳福田CBD距离最近的凯悦酒店双床房", + "携程中帮我找一间成都春熙路周边价格最低的香格里拉酒店大床房", + "携程中帮我预定一间杭州西湖景区附近评价最高的喜来登酒店大床房" + ] + }, + { + "app": "小红书", + "type": "type1", + "tasks": [ + "小红书关注博主阿喵", + "小红书关注博主抽象一坨", + "小红书关注博主侯绿萝", + "小红书关注博主木梓蓝", + "小红书关注博主不知名鸽子" + ] + }, + { + "app": "小红书", + "type": "type2", + "tasks": [ + "进入小红书博主混子哥边画边讲的主页", + "进入小红书博主李福贵的主页", + "进入小红书博主许二木的主页", + "进入小红书博主阿喵的主页", + "进入小红书博主抽象一坨的主页" + ] + }, + { + "app": "小红书", + "type": "type3", + "tasks": [ + "小红书搜索博主赵露思,查看ta第一个笔记", + "小红书搜索博主张曼玉Maggie,查看ta第一个内容", + "小红书搜索博主木梓蓝,查看他第一个文章", + "小红书搜索博主混子哥边画边讲,查看他第一个文章", + "小红书搜索博主李福贵,查看他的第一个笔记" + ] + }, + { + "app": "小红书", + "type": "type4", + "tasks": [ + "在小红书里搜索创意摄影", + "在小红书里搜索风景园林科普", + "在小红书里搜索手工沙发制作", + "在小红书里搜索朗读经典文学", + "在小红书里搜索中式恐怖美学摄影" + ] + }, + { + "app": "小红书", + "type": "type5", + "tasks": [ + "在小红书里搜索平替好物推荐,并点击查看第一个结果", + "在小红书里搜索家居收纳改造,并点击查看第一个结果", + "在小红书里搜索明星穿搭分享,并点击查看第一个结果", + "在小红书里搜索美食制作教程,并点击查看第一个结果", + "在小红书里搜索宠物日常趣事,并点击查看第一个结果" + ] + }, + { + "app": "小红书", + "type": "type6", + "tasks": [ + "小红书搜索博主不知名鸽子的中式恐怖美学摄影笔记", + "小红书搜索博主马俊达Mars的风景园林知识科普文章", + "小红书搜索博主大脸陈咔咔的创意沙发制作视频", + "小红书搜索博主阿好好好的朗读经典文学内容", + "小红书搜索博主平替君的老外的钱容易赚,细数中国制造笔记" + ] + }, + { + "app": "小红书", + "type": "type7", + "tasks": [ + "查看小红书博主桂非妃的挑战用一张纸换XX视频", + "查看小红书博主王冰冰的穿上正装变成大人模样笔记", + "查看小红书博主佳减乘除的在公司摆一天摊能赚多少钱内容", + "查看小红书博主豆皮一点都不皮的需要电子宠物吗视频", + "查看小红书博主飓风课堂的小红书正式官宣笔记" + ] + }, + { + "app": "高德", + "type": "type1", + "tasks": [ + "高德导航至上海迪士尼乐园", + "高德从当前位置导航至上海外滩", + "高德导航至上海科技馆", + "高德导航至上海虹桥站", + "高德导航至上海复旦大学邯郸校区" + ] + }, + { + "app": "高德", + "type": "type2", + "tasks": [ + "高德打车前往上海陆家嘴金融中心", + "高德打车前往上海南京东路步行街", + "高德打车至上海静安寺商圈", + "高德打车前往上海徐汇滨江绿地", + "高德打车至上海普陀区环球港购物中心" + ] + }, + { + "app": "高德", + "type": "type3", + "tasks": [ + "高德从上海浦东国际机场打车至上海外滩华尔道夫酒店", + "高德从上海虹桥国际机场T2航站楼打车至人民广场", + "高德从上海火车站南广场打车至上海杨浦区五角场万达广场", + "高德从上海迪士尼乐园停车场打车至上海浦东新区世纪公园", + "高德从上海闵行区莘庄地铁站打车至上海松江区欢乐谷景区" + ] + }, + { + "app": "饿了么", + "type": "type1", + "tasks": [ + "饿了么帮我搜索麦当劳", + "饿了么帮我搜索喜茶", + "饿了么帮我搜索老乡鸡", + "饿了么帮我搜索瑞幸", + "饿了么帮我搜索肯德基" + ] + }, + { + "app": "饿了么", + "type": "type2", + "tasks": [ + "饿了么中帮我点经典香辣鸡腿堡", + "饿了么中帮我点杨枝甘露多肉葡萄", + "饿了么中帮我点农家小炒肉盖饭", + "饿了么中帮我点生椰拿铁", + "饿了么中帮我点老北京鸡肉卷" + ] + }, + { + "app": "饿了么", + "type": "type3", + "tasks": [ + "在饿了么帮我点麦当劳的板烧鸡腿堡", + "在饿了么帮我点喜茶的清爽芭乐提", + "在饿了么帮我点老乡鸡的梅菜扣肉饭", + "在饿了么帮我点瑞幸的生椰拿铁", + "在饿了么帮我点肯德基的吮指原味鸡" + ] + }, + { + "app": "饿了么", + "type": "type4", + "tasks": [ + "在饿了么帮我点2份麦当劳的麦辣鸡翅", + "在饿了么帮我点1杯喜茶的芒芒甘露", + "在饿了么帮我点3份老乡鸡的西红柿炒蛋", + "在饿了么帮我点2杯瑞幸的橙C美式", + "在饿了么帮我点1份肯德基的香辣鸡腿汉堡" + ] + }, + { + "app": "饿了么", + "type": "type5", + "tasks": [ + "在饿了么点麦当劳的可乐,要中杯", + "在饿了么点喜茶的芒芒甘露,要常规杯、少糖、少冰", + "帮我点蜜雪冰城的芋圆葡萄,要少冰,5分糖,加珍珠", + "在饿了么帮我点瑞幸的小黄油拿铁,要大杯、冰、少甜", + "在饿了么帮我点肯德基的薯条,要大份" + ] + }, + { + "app": "饿了么", + "type": "type6", + "tasks": [ + "饿了么里帮我点2杯麦当劳的雪碧,要中杯、去冰、正常甜度", + "饿了么里帮我点3杯喜茶的多肉葡萄,要大杯、少糖、常温", + "饿了么里帮我点2份蜜雪冰城的葡萄冰美式,要少冰,七分糖", + "饿了么里帮我点1杯瑞幸的生椰拿铁,要大杯、热、不加糖", + "饿了么里帮我点4份肯德基的蛋挞,要原味" + ] + } +] \ No newline at end of file diff --git a/runner/UI-TARS-agent/batch_task_executor.py b/runner/UI-TARS-agent/batch_task_executor.py new file mode 100644 index 0000000..26be426 --- /dev/null +++ b/runner/UI-TARS-agent/batch_task_executor.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +批量任务执行器 - 读取auto_task-3.json中的任务并自动执行 +""" + +import json +import os +import time +import logging +import traceback +import shutil +import base64 +import re +from datetime import datetime +from openai import OpenAI +from ui_tars_automation import UITarsAutomationFramework, ExecutionConfig +from ui_tars_automation.action_parser import ActionParser + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('batch_execution.log', encoding='utf-8'), + logging.StreamHandler() + ] +) + +logger = logging.getLogger(__name__) + +class BatchTaskExecutor: + """批量任务执行器""" + + def __init__(self, model_url: str, data_base_dir: str = "data_uitars-test"): + """初始化批量执行器 + + Args: + model_url: 模型服务地址 + data_base_dir: 数据保存基目录 + """ + self.model_url = model_url + self.data_base_dir = data_base_dir + + # 初始化模型客户端用于智能应用选择(使用专门的模型) + try: + self.app_selection_client = OpenAI( + api_key="sk-rfCIGhxrzcdsMV4jC17e406bE56c47CbA5416068A62318D3", + base_url="http://ipads.chat.gpt:3006/v1" + ) + self.app_selection_model = "gemini-2.5-pro-preview-06-05" + logger.info(f"已连接到应用选择模型服务: http://ipads.chat.gpt:3006/v1") + except Exception as e: + logger.warning(f"应用选择模型客户端初始化失败,将只使用预定义包名映射: {e}") + self.app_selection_client = None + self.app_selection_model = None + + # 应用包名映射(从open_app.py参考) + self.app_packages = { + "微信": "com.tencent.mm", + "QQ": "com.tencent.mobileqq", + "新浪微博": "com.sina.weibo", + "饿了么": "me.ele", + "美团": "com.sankuai.meituan", + "bilibili": "tv.danmaku.bili", + "爱奇艺": "com.qiyi.video", + "腾讯视频": "com.tencent.qqlive", + "优酷": "com.youku.phone", + "淘宝": "com.taobao.taobao", + "京东": "com.jingdong.app.mall", + "携程": "ctrip.android.view", + "同城": "com.tongcheng.android", + "飞猪": "com.taobao.trip", + "去哪儿": "com.Qunar", + "华住会": "com.htinns", + "知乎": "com.zhihu.android", + "小红书": "com.xingin.xhs", + "QQ音乐": "com.tencent.qqmusic", + "网易云音乐": "com.netease.cloudmusic", + "酷狗音乐": "com.kugou.android", + "高德": "com.autonavi.minimap" + } + + def parse_json_response(self, response_str: str) -> dict: + """解析JSON响应 + + Args: + response_str: 模型返回的响应字符串 + + Returns: + 解析后的JSON对象 + """ + try: + # 尝试直接解析JSON + return json.loads(response_str) + except json.JSONDecodeError: + # 如果直接解析失败,尝试提取JSON部分 + try: + # 查找JSON代码块 + json_match = re.search(r'```json\s*(\{.*?\})\s*```', response_str, re.DOTALL) + if json_match: + return json.loads(json_match.group(1)) + + # 查找花括号包围的JSON + json_match = re.search(r'(\{.*?\})', response_str, re.DOTALL) + if json_match: + return json.loads(json_match.group(1)) + + raise ValueError("无法在响应中找到有效的JSON") + except Exception as e: + logger.error(f"JSON解析失败: {e}") + logger.error(f"原始响应: {response_str}") + raise ValueError(f"无法解析JSON响应: {e}") + + + def get_app_package_name_by_ai(self, task_description: str) -> str: + """使用AI根据任务描述获取应用包名 + + Args: + task_description: 任务描述 + + Returns: + 应用包名 + """ + if not self.app_selection_client: + logger.error("应用选择模型客户端未初始化") + return None + + # 从open_app.py复制的提示模板 + app_selection_prompt_template = """ +## 角色定义 +你是一个智能手机应用选择助手,需要根据用户的任务描述选择最合适的应用。 + +## 任务描述 +用户想要完成的任务是:"{task_description}" + +## 可用应用列表 +以下是可用的应用及其包名: +- 微信: com.tencent.mm +- QQ: com.tencent.mobileqq +- 新浪微博: com.sina.weibo +- 饿了么: me.ele +- 美团: com.sankuai.meituan +- bilibili: tv.danmaku.bili +- 爱奇艺: com.qiyi.video +- 腾讯视频: com.tencent.qqlive +- 优酷: com.youku.phone +- 淘宝: com.taobao.taobao +- 京东: com.jingdong.app.mall +- 携程: ctrip.android.view +- 同城: com.tongcheng.android +- 飞猪: com.taobao.trip +- 去哪儿: com.Qunar +- 华住会: com.htinns +- 知乎: com.zhihu.android +- 小红书: com.xingin.xhs +- QQ音乐: com.tencent.qqmusic +- 网易云音乐: com.netease.cloudmusic +- 酷狗音乐: com.kugou.android +- 高德: com.autonavi.minimap + +## 任务要求 +请分析任务描述,选择最合适的应用来完成该任务。 + +## 输出格式 +请严格按照以下JSON格式输出: +```json +{{ + "reasoning": "分析任务内容,说明为什么选择这个应用最合适", + "package_name": "选择的应用包名" +}} +``` + +## 重要规则 +1. 只能从上述可用应用列表中选择 +2. 必须选择最符合任务需求的应用 +3. 如果任务涉及多个可能的应用,选择最主要和最常用的那个 +4. 包名必须完全匹配列表中的包名,不能修改 +""" + + app_selection_prompt = app_selection_prompt_template.format(task_description=task_description) + + max_retries = 3 + for attempt in range(max_retries): + try: + response_str = self.app_selection_client.chat.completions.create( + model=self.app_selection_model, + messages=[ + { + "role": "user", + "content": app_selection_prompt + } + ] + ).choices[0].message.content + + logger.info(f"应用选择响应 (尝试 {attempt + 1}): \n{response_str}") + + # 解析响应 + response = self.parse_json_response(response_str) + package_name = response.get("package_name") + reasoning = response.get("reasoning") + + if package_name: + logger.info(f"AI选择应用原因: {reasoning}") + logger.info(f"AI选择的包名: {package_name}") + return package_name + else: + logger.warning(f"AI响应中没有包名信息: {response}") + + except Exception as e: + logger.error(f"AI应用选择失败 (尝试 {attempt + 1}): {e}") + if attempt == max_retries - 1: + logger.error("AI应用选择最终失败") + + return None + + + def load_tasks(self, task_file: str) -> list: + """加载任务文件 + + Args: + task_file: 任务文件路径 + + Returns: + 任务列表 + """ + try: + with open(task_file, 'r', encoding='utf-8') as f: + tasks = json.load(f) + logger.info(f"成功加载任务文件: {task_file}") + return tasks + except Exception as e: + logger.error(f"加载任务文件失败: {e}") + raise + + def get_package_name(self, app_name: str, task_description: str = None) -> str: + """根据应用名称获取包名,如果失败则使用AI分析任务描述 + + Args: + app_name: 应用名称 + task_description: 任务描述(可选,用于AI分析) + + Returns: + 应用包名 + """ + # 首先尝试从预定义映射中查找 + package_name = self.app_packages.get(app_name) + + if package_name: + logger.info(f"从预定义映射找到应用包名: {app_name} -> {package_name}") + return package_name + + # 如果预定义映射中没有找到,且提供了任务描述,则使用AI分析 + if task_description and self.app_selection_client: + logger.info(f"预定义映射中未找到应用 '{app_name}',尝试使用AI分析任务描述") + ai_package_name = self.get_app_package_name_by_ai(task_description) + + if ai_package_name: + logger.info(f"AI成功识别应用包名: {ai_package_name}") + # 将AI识别的结果添加到缓存中,避免重复查询 + self.app_packages[app_name] = ai_package_name + return ai_package_name + else: + logger.warning(f"AI也无法识别应用 '{app_name}' 的包名") + + # 所有方法都失败了 + logger.error(f"无法获取应用 '{app_name}' 的包名") + return None + + def create_task_directory(self, app_name: str, task_type: str, task_index: int) -> str: + """创建任务目录 + + Args: + app_name: 应用名称 + task_type: 任务类型 + task_index: 任务索引 + + Returns: + 任务目录路径 + """ + task_dir = os.path.join(self.data_base_dir, app_name, task_type, str(task_index)) + os.makedirs(task_dir, exist_ok=True) + return task_dir + + def save_task_data(self, task_dir: str, task_info: dict, execution_result: dict): + """保存任务数据到task_data.json和actions.json + + Args: + task_dir: 任务目录 + task_info: 任务信息 + execution_result: 执行结果 + """ + try: + # 构建task_data.json(保持原有格式) + task_data = { + "task_description": task_info["task_description"], + "app_name": task_info["app_name"], + "task_type": task_info["task_type"], + "task_index": task_info["task_index"], + "package_name": task_info["package_name"], + "execution_time": datetime.now().isoformat(), + "action_count": execution_result.get("total_steps", 0), + "actions": [], + "success": execution_result.get("success", False) + } + + # 构建actions.json(参考淘宝格式) + actions_data = { + "app_name": task_info["app_name"], + "task_type": task_info["task_type"], + "task_description": task_info["task_description"], + "action_count": execution_result.get("total_steps", 0), + "actions": [] + } + + # 转换action历史为标准格式 + for i, action in enumerate(execution_result.get("action_history", []), 1): + # task_data.json格式(保持原有格式) + action_record = { + "reasoning": action["thought"], + "function": { + "name": action["parsed_action"]["action_type"], + "parameters": action["parsed_action"]["action_params"] + } + } + task_data["actions"].append(action_record) + + # actions.json格式(参考淘宝格式) + action_type = action["parsed_action"]["action_type"] + action_params = action["parsed_action"]["action_params"] + + action_item = { + "type": action_type, + "action_index": i + } + + # 根据不同操作类型添加相应参数 + if action_type == "click": + action_item.update({ + "position_x": action_params.get("x"), + "position_y": action_params.get("y") + }) + # 如果有bounding_box信息,添加bounds + if "bounding_box" in action_params: + action_item["bounds"] = action_params["bounding_box"] + + elif action_type == "type": + action_item["text"] = action_params.get("text", "") + + elif action_type == "scroll": + action_item.update({ + "position_x": action_params.get("x"), + "position_y": action_params.get("y"), + "direction": action_params.get("direction", "down") + }) + + elif action_type == "swipe" or action_type == "drag": + action_item.update({ + "press_position_x": action_params.get("start_x"), + "press_position_y": action_params.get("start_y"), + "release_position_x": action_params.get("end_x"), + "release_position_y": action_params.get("end_y"), + "direction": action_params.get("direction", "") + }) + + elif action_type == "long_press": + action_item.update({ + "position_x": action_params.get("x"), + "position_y": action_params.get("y") + }) + + elif action_type in ["finished", "done"]: + action_item["type"] = "done" + + actions_data["actions"].append(action_item) + + # 保存task_data.json + task_data_path = os.path.join(task_dir, "task_data.json") + with open(task_data_path, "w", encoding="utf-8") as f: + json.dump(task_data, f, ensure_ascii=False, indent=4) + + # 保存actions.json(参考淘宝格式) + actions_path = os.path.join(task_dir, "actions.json") + with open(actions_path, "w", encoding="utf-8") as f: + json.dump(actions_data, f, ensure_ascii=False, indent=4) + + # 生成 react.json(每步的 reasoning 和 function),并复制每步的截图与 xml 到 task_dir,命名为 1.jpg / 1.xml ... + reacts = [] + for i, action in enumerate(execution_result.get("action_history", []), 1): + reasoning = action.get('thought', '') + parsed = action.get('parsed_action', {}) or {} + func_name = parsed.get('action_type') + func_params = parsed.get('action_params', {}) + + # 映射内部类型到采集格式(例如 type -> input, finished/done -> done) + out_func_name = func_name + if func_name == 'type': + out_func_name = 'input' + elif func_name in ['finished', 'done']: + out_func_name = 'done' + + reacts.append({ + 'reasoning': reasoning, + 'function': { + 'name': out_func_name, + 'parameters': func_params + }, + 'action_index': i + }) + + # 复制截图和 xml(如果存在)。支持两种位置: + # 1) action 中记录的绝对/相对路径; + # 2) task_dir//screenshot_{i}.jpg 和 task_dir//hierarchy_{i}.xml(框架增强执行时的保存位置)。 + screenshot_src = action.get('screenshot_path') + xml_src = action.get('xml_path') + try: + copied = False + if screenshot_src and os.path.exists(screenshot_src): + dst_img = os.path.join(task_dir, f"{i}.jpg") + shutil.copyfile(screenshot_src, dst_img) + copied = True + else: + # 备选位置: task_dir//screenshot_{i}.jpg + alt_img = os.path.join(task_dir, str(i), f"screenshot_{i}.jpg") + if os.path.exists(alt_img): + dst_img = os.path.join(task_dir, f"{i}.jpg") + shutil.copyfile(alt_img, dst_img) + copied = True + + if xml_src and os.path.exists(xml_src): + dst_xml = os.path.join(task_dir, f"{i}.xml") + shutil.copyfile(xml_src, dst_xml) + copied = True + else: + # 备选位置: task_dir//hierarchy_{i}.xml + alt_xml = os.path.join(task_dir, str(i), f"hierarchy_{i}.xml") + if os.path.exists(alt_xml): + dst_xml = os.path.join(task_dir, f"{i}.xml") + shutil.copyfile(alt_xml, dst_xml) + copied = True + + if not copied: + logger.debug(f"未找到步骤 {i} 的截图或 xml,跳过复制") + except Exception as e: + logger.warning(f"复制步骤资源失败 (step {i}): {e}") + + reacts_path = os.path.join(task_dir, "reacts.json") + with open(reacts_path, 'w', encoding='utf-8') as f: + json.dump(reacts, f, ensure_ascii=False, indent=4) + + # 如果 reacts 或 action_history 为空,仍然尝试从 step 子目录复制截图和 xml + try: + # 查找 task_dir 下的数字子目录,如 1, 2, ... + step_dirs = [] + for name in os.listdir(task_dir): + subp = os.path.join(task_dir, name) + if os.path.isdir(subp) and name.isdigit(): + step_dirs.append(int(name)) + step_dirs.sort() + + for i in step_dirs: + dst_img = os.path.join(task_dir, f"{i}.jpg") + dst_xml = os.path.join(task_dir, f"{i}.xml") + + # 仅在根目录中不存在时复制 + if not os.path.exists(dst_img): + alt_img = os.path.join(task_dir, str(i), f"screenshot_{i}.jpg") + if os.path.exists(alt_img): + try: + shutil.copyfile(alt_img, dst_img) + logger.info(f"已复制截图到根目录: {dst_img}") + except Exception as e: + logger.warning(f"复制截图失败 ({alt_img} -> {dst_img}): {e}") + + if not os.path.exists(dst_xml): + alt_xml = os.path.join(task_dir, str(i), f"hierarchy_{i}.xml") + if os.path.exists(alt_xml): + try: + shutil.copyfile(alt_xml, dst_xml) + logger.info(f"已复制 xml 到根目录: {dst_xml}") + except Exception as e: + logger.warning(f"复制 xml 失败 ({alt_xml} -> {dst_xml}): {e}") + except Exception as e: + logger.debug(f"扫描 step 子目录并复制资源时出错: {e}") + + logger.info(f"任务数据已保存到: {task_dir}") + logger.info(f" - task_data.json: {task_data_path}") + logger.info(f" - actions.json: {actions_path}") + + except Exception as e: + logger.error(f"保存任务数据失败: {e}") + + def execute_single_task(self, app_name: str, task_type: str, task_index: int, task_description: str) -> bool: + """执行单个任务 + + Args: + app_name: 应用名称 + task_type: 任务类型 (type1, type2, ...) + task_index: 任务索引 (1, 2, 3, ...) + task_description: 任务描述 + + Returns: + 是否执行成功 + """ + logger.info(f"开始执行任务: {app_name} - {task_type} - 任务{task_index}") + logger.info(f"任务描述: {task_description}") + + # 获取包名 + package_name = self.get_package_name(app_name, task_description) + if not package_name: + logger.error(f"未找到应用 {app_name} 的包名") + return False + + # 创建任务目录 - 直接使用目标格式:app名称/typex/任务序号 + task_dir = self.create_task_directory(app_name, task_type, task_index) + + try: + # 配置 - 禁用框架自动数据保存,我们手动管理 + config = ExecutionConfig( + model_base_url=self.model_url, + model_name="UI-TARS-7B-SFT", + max_steps=20, # 设置为20步 + step_delay=2.0, + language="Chinese", + save_data=False, # 禁用框架自动保存,我们手动保存到指定位置 + save_screenshots=False, + save_xml=False + ) + + # 创建框架实例 + framework = UITarsAutomationFramework(config) + + # 启动应用(参考open_app.py) + logger.info(f"启动应用: {package_name}") + framework.device.app_start(package_name, stop=True) + time.sleep(3) # 等待应用启动 + + # 手动保存数据到我们的目录结构 + original_execute = framework.execute_task + + def enhanced_execute_task(task_desc): + """增强的执行方法,手动保存每步数据""" + framework.task_description = task_desc + framework.step_count = 0 + framework.action_history = [] + + logger.info(f"开始执行任务: {task_desc}") + + try: + while framework.step_count < framework.config.max_steps: + framework.step_count += 1 + logger.info(f"\n{'='*50}") + logger.info(f"第 {framework.step_count} 步") + logger.info(f"{'='*50}") + + # 创建步骤目录 + step_dir = os.path.join(task_dir, str(framework.step_count)) + os.makedirs(step_dir, exist_ok=True) + + # 保存截图和XML + screenshot_path = os.path.join(step_dir, f"screenshot_{framework.step_count}.jpg") + xml_path = os.path.join(step_dir, f"hierarchy_{framework.step_count}.xml") + + framework.device.screenshot(screenshot_path) + hierarchy = framework.device.dump_hierarchy() + with open(xml_path, "w", encoding="utf-8") as f: + f.write(hierarchy) + + # 获取图片数据用于模型调用 + with open(screenshot_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode('utf-8') + image_data_url = f"data:image/jpeg;base64,{image_data}" + + # 构建消息并调用模型 + messages = framework._build_messages(image_data_url) + response = framework._call_model(messages) + + # 解析响应 + from PIL import Image + + image_height, image_width = None, None + try: + with Image.open(screenshot_path) as img: + image_width, image_height = img.size + except Exception as e: + logger.warning(f"无法获取图片尺寸: {e}") + + thought, raw_action, parsed_action = ActionParser.parse_response( + response, image_height, image_width + ) + + logger.info(f"思考: {thought}") + logger.info(f"操作: {raw_action}") + + # 执行操作 + result = framework._execute_action(parsed_action, screenshot_path) + result.screenshot_path = screenshot_path + result.xml_path = xml_path + + # 记录历史 + action_record = { + 'step': framework.step_count, + 'thought': thought, + 'raw_action': raw_action, + 'parsed_action': parsed_action, + 'result': result, + 'screenshot_path': screenshot_path, + 'xml_path': xml_path + } + framework.action_history.append(action_record) + + # 检查是否完成 + if not result.success: + logger.error(f"操作失败: {result.message}") + break + + if result.error == "FINISHED": + logger.info("任务执行完成!") + break + + # 等待 + time.sleep(framework.config.step_delay) + + success = len(framework.action_history) > 0 and framework.action_history[-1]['result'].error == "FINISHED" + return success + + except Exception as e: + logger.error(f"任务执行失败: {e}") + return False + + # 替换执行方法 + framework.execute_task = enhanced_execute_task + + # 执行任务 + success = framework.execute_task(task_description) + execution_result = framework.get_execution_summary() + + # 保存任务数据 + task_info = { + "task_description": task_description, + "app_name": app_name, + "task_type": task_type, + "task_index": task_index, + "package_name": package_name + } + self.save_task_data(task_dir, task_info, execution_result) + + logger.info(f"任务执行完成: {'成功' if success else '失败'}") + return success + + except Exception as e: + logger.error(f"任务执行异常: {e}") + logger.error(traceback.format_exc()) + return False + + def execute_all_tasks(self, task_file: str): + """执行所有任务 + + Args: + task_file: 任务文件路径 + """ + logger.info("开始批量任务执行") + + # 加载任务 + all_tasks = self.load_tasks(task_file) + + total_tasks = 0 + successful_tasks = 0 + failed_tasks = 0 + + # 统计总任务数 + for app_group in all_tasks: + total_tasks += len(app_group["tasks"]) + + logger.info(f"总共需要执行 {total_tasks} 个任务") + + # 逐个执行任务 + for app_group in all_tasks: + app_name = app_group["app"] + task_type = app_group["type"] + tasks = app_group["tasks"] + + logger.info(f"开始执行应用 {app_name} 的 {task_type} 类型任务") + + for task_index, task_description in enumerate(tasks, 1): + try: + success = self.execute_single_task( + app_name=app_name, + task_type=task_type, + task_index=task_index, + task_description=task_description + ) + + if success: + successful_tasks += 1 + else: + failed_tasks += 1 + + # 任务间间隔 + time.sleep(5) + + except Exception as e: + logger.error(f"任务执行出错: {e}") + failed_tasks += 1 + + # 执行总结 + logger.info("="*60) + logger.info("批量任务执行完成") + logger.info(f"总任务数: {total_tasks}") + logger.info(f"成功: {successful_tasks}") + logger.info(f"失败: {failed_tasks}") + logger.info(f"成功率: {successful_tasks/total_tasks*100:.1f}%" if total_tasks > 0 else "N/A") + logger.info("="*60) + +def main(): + """主函数""" + print("UI-TARS 批量任务执行器") + print("=" * 50) + + # 获取模型服务地址 + model_url = input("请输入模型服务地址 (默认: http://192.168.12.152:8001/v1): ").strip() + if not model_url: + model_url = "http://192.168.12.152:8001/v1" + + # 任务文件路径 + task_file = "auto_task-3-sup.json" + if not os.path.exists(task_file): + print(f"错误: 任务文件 {task_file} 不存在") + return + + # 创建批量执行器 + executor = BatchTaskExecutor(model_url) + + # 询问是否继续 + print(f"将读取 {task_file} 并批量执行任务") + confirm = input("是否继续?(y/N): ").strip().lower() + if confirm != 'y': + print("已取消执行") + return + + # 执行所有任务 + try: + executor.execute_all_tasks(task_file) + except KeyboardInterrupt: + print("\n用户中断执行") + except Exception as e: + print(f"执行过程中出现错误: {e}") + +if __name__ == "__main__": + main() diff --git a/runner/UI-TARS-agent/quick_start.py b/runner/UI-TARS-agent/quick_start.py new file mode 100644 index 0000000..13ed0ab --- /dev/null +++ b/runner/UI-TARS-agent/quick_start.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +UI-TARS 自动化框架快速启动脚本 +""" + +from ui_tars_automation import UITarsAutomationFramework, ExecutionConfig + +def quick_start(): + """快速启动示例 - 淘宝购物任务""" + print("UI-TARS 快速启动 - 淘宝购物任务") + print("=" * 50) + + # 获取模型服务地址 + model_url = input("请输入模型服务地址 (默认: http://192.168.12.152:8000/v1): ").strip() + if not model_url: + model_url = "http://192.168.12.152:8000/v1" + + # 配置 + config = ExecutionConfig( + model_base_url=model_url, + model_name="UI-TARS-7B-SFT", + max_steps=25, + step_delay=2.0, + language="Chinese" + ) + + # 创建框架实例 + try: + framework = UITarsAutomationFramework(config) + print("✅ 框架初始化成功!") + except Exception as e: + print(f"❌ 框架初始化失败: {e}") + return + + # 执行淘宝购物任务 + task_description = "饿了么中帮我点蜜雪冰城的芋圆葡萄,要少冰,5分糖,加珍珠" + + print(f"\n开始执行任务: {task_description}") + print("-" * 50) + + try: + success = framework.execute_task(task_description) + summary = framework.get_execution_summary() + + print("\n" + "="*60) + print("执行结果:") + print(f"任务: {summary['task_description']}") + print(f"总步数: {summary['total_steps']}") + print(f"成功: {'✅ 是' if summary['success'] else '❌ 否'}") + print("="*60) + + if summary['action_history']: + print("\n最近几步操作:") + for i, action in enumerate(summary['action_history'][-3:], len(summary['action_history'])-2): + print(f" 步骤{i}: {action['thought'][:50]}...") + + except Exception as e: + print(f"❌ 任务执行失败: {e}") + +if __name__ == "__main__": + quick_start() diff --git a/runner/UI-TARS-agent/test_single_task.py b/runner/UI-TARS-agent/test_single_task.py new file mode 100644 index 0000000..ba0ab35 --- /dev/null +++ b/runner/UI-TARS-agent/test_single_task.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +单任务测试器 - 用于测试单个任务执行 +""" + +import json +import os +import time +import logging +from datetime import datetime +from ui_tars_automation import UITarsAutomationFramework, ExecutionConfig + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +logger = logging.getLogger(__name__) + +def test_single_task(): + """测试单个任务""" + print("UI-TARS 单任务测试器") + print("=" * 50) + + # 应用包名映射 + app_packages = { + "bilibili": "tv.danmaku.bili", + "淘宝": "com.taobao.taobao", + "携程": "ctrip.android.view", + "网易云音乐": "com.netease.cloudmusic", + "小红书": "com.xingin.xhs", + "高德": "com.autonavi.minimap", + "饿了么": "me.ele" + } + + # 获取模型服务地址 + model_url = input("请输入模型服务地址 (默认: http://192.168.12.152:8000/v1): ").strip() + if not model_url: + model_url = "http://192.168.12.152:8000/v1" + + # 选择要测试的任务 + print("\n可用应用:") + for i, app in enumerate(app_packages.keys(), 1): + print(f"{i}. {app}") + + app_choice = input("\n请选择应用 (输入数字): ").strip() + try: + app_names = list(app_packages.keys()) + app_name = app_names[int(app_choice) - 1] + package_name = app_packages[app_name] + except (ValueError, IndexError): + print("无效选择") + return + + # 输入任务描述 + task_description = input(f"\n请输入{app_name}的任务描述: ").strip() + if not task_description: + print("任务描述不能为空") + return + + print(f"\n将执行任务: {task_description}") + print(f"目标应用: {app_name} ({package_name})") + + confirm = input("是否继续?(y/N): ").strip().lower() + if confirm != 'y': + print("已取消执行") + return + + try: + # 配置 + config = ExecutionConfig( + model_base_url=model_url, + model_name="UI-TARS-7B-SFT", + max_steps=30, + step_delay=2.0, + language="Chinese", + save_data=True, + data_base_dir="test_data" + ) + + # 创建框架实例 + framework = UITarsAutomationFramework(config) + print("✅ 框架初始化成功!") + + # 启动应用 + print(f"启动应用: {package_name}") + framework.device.app_start(package_name, stop=True) + time.sleep(3) # 等待应用启动 + + # 执行任务 + print(f"\n开始执行任务: {task_description}") + print("-" * 50) + + success = framework.execute_task(task_description) + summary = framework.get_execution_summary() + + print("\n" + "="*60) + print("执行结果:") + print(f"任务: {summary['task_description']}") + print(f"总步数: {summary['total_steps']}") + print(f"成功: {'✅ 是' if summary['success'] else '❌ 否'}") + print(f"数据保存位置: {summary.get('task_directory', 'N/A')}") + print("="*60) + + if summary['action_history']: + print("\n最近几步操作:") + for i, action in enumerate(summary['action_history'][-3:], len(summary['action_history'])-2): + print(f" 步骤{i}: {action['thought'][:50]}...") + + except Exception as e: + print(f"❌ 任务执行失败: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + test_single_task() diff --git a/runner/UI-TARS-agent/ui_tars_automation/__init__.py b/runner/UI-TARS-agent/ui_tars_automation/__init__.py new file mode 100644 index 0000000..79456cc --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/__init__.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +UI-TARS 自动化框架包 +""" + +# 首先设置日志 +from .logger import setup_logging + +# 然后导入主要模块 +from .framework import UITarsAutomationFramework +from .config import ExecutionConfig, ActionResult +from .action_parser import ActionParser +from .data_manager import DataManager +from .coordinate_processor import CoordinateProcessor + +__version__ = "1.0.0" +__all__ = [ + "UITarsAutomationFramework", + "ExecutionConfig", + "ActionResult", + "ActionParser", + "DataManager", + "CoordinateProcessor", + "setup_logging" +] diff --git a/runner/UI-TARS-agent/ui_tars_automation/action_parser.py b/runner/UI-TARS-agent/ui_tars_automation/action_parser.py new file mode 100644 index 0000000..edbae53 --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/action_parser.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +UI-TARS 动作解析器 - 修复坐标转换问题 +""" + +import re +import ast +import logging +from typing import Dict, Tuple, Optional + +# 直接使用UI-TARS官方库 +import ui_tars.action_parser as ui_tars_parser + +logger = logging.getLogger(__name__) + + +class ActionParser: + """动作解析器 - 完全使用UI-TARS官方解析功能""" + + @staticmethod + def parse_response(response: str, image_height: Optional[int] = None, + image_width: Optional[int] = None, model_type: str = "qwen25vl") -> Tuple[str, str, Dict]: + """ + 解析模型响应,使用UI-TARS官方完整流程: + 1. parse_action_to_structure_output - 解析响应为结构化数据 + 2. parsing_response_to_pyautogui_code - 转换为PyAutoGUI代码 + 3. 从PyAutoGUI代码中提取最终坐标 + + Args: + response: 模型原始响应 + image_height: 原始图片高度 + image_width: 原始图片宽度 + model_type: 模型类型 + + Returns: + (thought, raw_action, parsed_action) + """ + if not image_height or not image_width: + raise ValueError("需要提供image_height和image_width进行坐标转换") + + try: + # 步骤1: 计算smart resize尺寸 + smart_height, smart_width = ui_tars_parser.smart_resize( + image_height, image_width, + factor=ui_tars_parser.IMAGE_FACTOR + ) + + logger.debug(f"原始图像尺寸: {image_width}x{image_height}") + logger.debug(f"Smart resize尺寸: {smart_width}x{smart_height}") + logger.debug(f"模型响应: {response}") + + # 步骤2: 使用官方parse_action_to_structure_output解析 + actions = ui_tars_parser.parse_action_to_structure_output( + response, + factor=ui_tars_parser.IMAGE_FACTOR, + origin_resized_height=smart_height, + origin_resized_width=smart_width, + model_type=model_type + ) + + if not actions: + raise ValueError("UI-TARS解析未返回有效动作") + + # 取第一个动作 + action = actions[0] + + # 提取thought和raw_action + thought = action.get('thought', '') + raw_action = ActionParser._extract_raw_action(response) + + # 步骤3: 使用官方parsing_response_to_pyautogui_code生成代码 + pyautogui_code = ui_tars_parser.parsing_response_to_pyautogui_code( + action, image_height, image_width + ) + + # 调试信息 + logger.debug(f"Action结构化输出: {action}") + logger.info(f"PyAutoGUI代码: {pyautogui_code}") + + # 步骤4: 从PyAutoGUI代码中提取坐标并转换为框架格式 + parsed_action = ActionParser._convert_pyautogui_to_internal( + action, pyautogui_code + ) + + logger.info(f"UI-TARS官方解析成功: {parsed_action['action_type']}") + logger.info(f"最终解析结果: {parsed_action}") + + # 检查是否需要将相对坐标转换为绝对坐标 + if parsed_action['action_type'] in ['click', 'double_click', 'right_click', 'hover']: + x = parsed_action['action_params'].get('x', 0) + y = parsed_action['action_params'].get('y', 0) + + # 如果x,y的值在0-1区间,则判定为相对位置 + if 0 <= x <= 1 and 0 <= y <= 1: + # 转换为绝对坐标 + parsed_action['action_params']['x'] = int(x * image_width) + parsed_action['action_params']['y'] = int(y * image_height) + logger.info(f"转换相对坐标({x:.3f}, {y:.3f})为绝对坐标: {parsed_action['action_params']}") + else: + # 确保坐标是整数 + parsed_action['action_params']['x'] = int(x) + parsed_action['action_params']['y'] = int(y) + + elif parsed_action['action_type'] == 'scroll' and 'x' in parsed_action['action_params']: + x = parsed_action['action_params'].get('x', 0) + y = parsed_action['action_params'].get('y', 0) + + if 0 <= x <= 1 and 0 <= y <= 1: + parsed_action['action_params']['x'] = int(x * image_width) + parsed_action['action_params']['y'] = int(y * image_height) + logger.info(f"转换滚动相对坐标为绝对坐标: {parsed_action['action_params']}") + else: + parsed_action['action_params']['x'] = int(x) + parsed_action['action_params']['y'] = int(y) + + elif parsed_action['action_type'] == 'drag': + # 处理拖拽坐标 + start_x = parsed_action['action_params'].get('start_x', 0) + start_y = parsed_action['action_params'].get('start_y', 0) + end_x = parsed_action['action_params'].get('end_x', 0) + end_y = parsed_action['action_params'].get('end_y', 0) + + if (0 <= start_x <= 1 and 0 <= start_y <= 1 and + 0 <= end_x <= 1 and 0 <= end_y <= 1): + parsed_action['action_params']['start_x'] = int(start_x * image_width) + parsed_action['action_params']['start_y'] = int(start_y * image_height) + parsed_action['action_params']['end_x'] = int(end_x * image_width) + parsed_action['action_params']['end_y'] = int(end_y * image_height) + logger.info(f"转换拖拽相对坐标为绝对坐标: {parsed_action['action_params']}") + else: + parsed_action['action_params']['start_x'] = int(start_x) + parsed_action['action_params']['start_y'] = int(start_y) + parsed_action['action_params']['end_x'] = int(end_x) + parsed_action['action_params']['end_y'] = int(end_y) + + return thought, raw_action, parsed_action + + except Exception as e: + logger.error(f"UI-TARS官方解析失败: {e}") + # 使用简单备用解析 + return ActionParser._parse_fallback(response, image_height, image_width) + + @staticmethod + def _convert_pyautogui_to_internal(action: Dict, pyautogui_code: str) -> Dict: + """ + 将UI-TARS官方生成的PyAutoGUI代码转换为框架内部格式 + """ + action_type = action.get('action_type', '') + + # 处理DONE情况 + if pyautogui_code.strip() == "DONE": + return {'action_type': 'finished', 'action_params': {}} + + if action_type in ["click", "left_single", "left_double", "right_single", "hover"]: + # 从PyAutoGUI代码中提取点击坐标 + click_match = re.search(r'pyautogui\.click\((\d+(?:\.\d+)?), (\d+(?:\.\d+)?)', pyautogui_code) + double_click_match = re.search(r'pyautogui\.doubleClick\((\d+(?:\.\d+)?), (\d+(?:\.\d+)?)', pyautogui_code) + move_match = re.search(r'pyautogui\.moveTo\((\d+(?:\.\d+)?), (\d+(?:\.\d+)?)', pyautogui_code) + + match = click_match or double_click_match or move_match + if match: + x = float(match.group(1)) + y = float(match.group(2)) + + # 映射动作类型 + if action_type == "left_double" or "doubleClick" in pyautogui_code: + device_action_type = "double_click" + elif action_type == "right_single" or "button='right'" in pyautogui_code: + device_action_type = "right_click" + elif action_type == "hover" or "moveTo" in pyautogui_code: + device_action_type = "hover" + else: + device_action_type = "click" + + return {'action_type': device_action_type, 'action_params': {'x': x, 'y': y}} + + elif action_type == "scroll": + # 从PyAutoGUI代码中提取滚动信息 + scroll_match = re.search(r'pyautogui\.scroll\((-?\d+)', pyautogui_code) + if scroll_match: + scroll_value = int(scroll_match.group(1)) + direction = 'up' if scroll_value > 0 else 'down' + + # 检查是否有坐标 + coord_match = re.search(r'x=(\d+(?:\.\d+)?), y=(\d+(?:\.\d+)?)', pyautogui_code) + params = {'direction': direction} + if coord_match: + params['x'] = float(coord_match.group(1)) + params['y'] = float(coord_match.group(2)) + + return {'action_type': 'scroll', 'action_params': params} + + elif action_type in ["drag", "select"]: + # 从PyAutoGUI代码中提取拖拽坐标 + move_match = re.search(r'pyautogui\.moveTo\((\d+(?:\.\d+)?), (\d+(?:\.\d+)?)', pyautogui_code) + drag_match = re.search(r'pyautogui\.dragTo\((\d+(?:\.\d+)?), (\d+(?:\.\d+)?)', pyautogui_code) + + if move_match and drag_match: + start_x = float(move_match.group(1)) + start_y = float(move_match.group(2)) + end_x = float(drag_match.group(1)) + end_y = float(drag_match.group(2)) + + return { + 'action_type': 'drag', + 'action_params': { + 'start_x': start_x, 'start_y': start_y, + 'end_x': end_x, 'end_y': end_y + } + } + + elif action_type == "type": + # 从PyAutoGUI代码中提取文本 + write_match = re.search(r"pyautogui\.write\('([^']*)'", pyautogui_code) + copy_match = re.search(r"pyperclip\.copy\('([^']*)'", pyautogui_code) + + if write_match: + text = write_match.group(1) + elif copy_match: + text = copy_match.group(1) + else: + text = action.get('action_inputs', {}).get('content', '') + + return {'action_type': 'type', 'action_params': {'text': text}} + + elif action_type == "hotkey": + # 处理热键 + hotkey_match = re.search(r"pyautogui\.hotkey\(([^)]+)\)", pyautogui_code) + keydown_match = re.search(r"pyautogui\.keyDown\(([^)]+)\)", pyautogui_code) + press_match = re.search(r"pyautogui\.press\(([^)]+)\)", pyautogui_code) + + match = hotkey_match or keydown_match or press_match + if match: + keys_str = match.group(1).replace("'", "").replace('"', '') + + # 映射常见热键到我们的操作 + if 'home' in keys_str.lower(): + return {'action_type': 'press_home', 'action_params': {}} + elif 'back' in keys_str.lower() or 'escape' in keys_str.lower(): + return {'action_type': 'press_back', 'action_params': {}} + else: + return {'action_type': 'hotkey', 'action_params': {'keys': keys_str}} + + # 如果没找到匹配,使用原始输入 + hotkey_input = action.get('action_inputs', {}) + keys = hotkey_input.get('key', '') or hotkey_input.get('hotkey', '') + return {'action_type': 'hotkey', 'action_params': {'keys': keys}} + + elif action_type == "finished": + return {'action_type': 'finished', 'action_params': {}} + + # 默认返回原始信息 + logger.warning(f"未处理的动作类型: {action_type}, PyAutoGUI代码: {pyautogui_code}") + return {'action_type': action_type, 'action_params': action.get('action_inputs', {})} + + @staticmethod + def _extract_raw_action(response: str) -> str: + """从响应中提取原始动作字符串""" + try: + if "Action:" in response: + return response.split("Action:")[-1].strip().split('\n')[0] + return response.strip() + except: + return response.strip() + + @staticmethod + def _parse_fallback(response: str, image_height: int, image_width: int) -> Tuple[str, str, Dict]: + """简单备用解析方法""" + try: + # 提取Thought和Action + thought_match = re.search(r"Thought:\s*(.*?)(?=\nAction:|\Z)", response, re.DOTALL) + action_match = re.search(r"Action:\s*(.*?)(?=\n\n|\Z)", response, re.DOTALL) + + thought = thought_match.group(1).strip() if thought_match else "" + action_str = action_match.group(1).strip() if action_match else "" + + if not action_str: + raise ValueError("未找到有效的Action") + + # 简单解析 + if 'click(' in action_str: + coord_match = re.search(r"[\(,\s](\d+(?:\.\d+)?)[,\s]+(\d+(?:\.\d+)?)[\),\s]", action_str) + if coord_match: + x, y = float(coord_match.group(1)), float(coord_match.group(2)) + + # 简单的坐标标准化 + if 0 <= x <= 1 and 0 <= y <= 1: + x *= image_width + y *= image_height + + return thought, action_str, { + 'action_type': 'click', + 'action_params': {'x': int(x), 'y': int(y)} + } + + # 其他简单情况 + if 'scroll(' in action_str: + direction_match = re.search(r'direction[=:]\s*["\']?(\w+)["\']?', action_str) + direction = direction_match.group(1) if direction_match else 'down' + return thought, action_str, { + 'action_type': 'scroll', + 'action_params': {'direction': direction} + } + + raise ValueError(f"备用解析无法处理: {action_str}") + + except Exception as e: + logger.error(f"备用解析失败: {e}") + raise diff --git a/runner/UI-TARS-agent/ui_tars_automation/config.py b/runner/UI-TARS-agent/ui_tars_automation/config.py new file mode 100644 index 0000000..3ccc4ea --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/config.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +配置和数据结构定义 +""" + +from dataclasses import dataclass +from typing import Optional + +@dataclass +class ActionResult: + """操作结果""" + success: bool + message: str + screenshot_path: Optional[str] = None + xml_path: Optional[str] = None + error: Optional[str] = None + +@dataclass +class ExecutionConfig: + """执行配置""" + model_base_url: str = "http://192.168.12.152:8000/v1" + model_name: str = "UI-TARS-7B-SFT" + device_ip: Optional[str] = None # None表示USB连接 + max_steps: int = 50 + step_delay: float = 1.5 + language: str = "Chinese" + temperature: float = 0.0 + max_tokens: int = 400 + + # 数据保存配置 + save_data: bool = True + data_base_dir: str = "automation_data" + save_screenshots: bool = True + save_xml: bool = True + +# 应用包名映射 +APP_PACKAGES = { + "微信": "com.tencent.mm", + "QQ": "com.tencent.mobileqq", + "微博": "com.sina.weibo", + + "饿了么": "me.ele", + "美团": "com.sankuai.meituan", + + "bilibili": "tv.danmaku.bili", + "爱奇艺": "com.qiyi.video", + "腾讯视频": "com.tencent.qqlive", + "优酷": "com.youku.phone", + + "淘宝": "com.taobao.taobao", + "京东": "com.jingdong.app.mall", + + "携程": "ctrip.android.view", + "同城": "com.tongcheng.android", + "飞猪": "com.taobao.trip", + "去哪儿": "com.Qunar", + "华住会": "com.htinns", + + "知乎": "com.zhihu.android", + "小红书": "com.xingin.xhs", + + "QQ音乐": "com.tencent.qqmusic", + "网易云音乐": "com.netease.cloudmusic", + "酷狗音乐": "com.kugou.android" +} + +# 内置提示词模板 +MOBILE_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space + +click(point='x1 y1') +long_press(point='x1 y1') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(point='x1 y1', direction='down or up or right or left') +drag(start_point='x1 y1', end_point='x2 y2') +press_home() +press_back() +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. +- To open an app, use click() to tap on the app icon you can see in the screenshot, don't use open_app(). +- Always look for app icons, buttons, or UI elements in the current screenshot and click on them. + +## User Instruction +{instruction}""" diff --git a/runner/UI-TARS-agent/ui_tars_automation/coordinate_processor.py b/runner/UI-TARS-agent/ui_tars_automation/coordinate_processor.py new file mode 100644 index 0000000..af9f0be --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/coordinate_processor.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +坐标处理和可视化模块 +参考UI-TARS/README_coordinates.md的实现 +""" + +import math +import logging +from PIL import Image, ImageDraw +import matplotlib.pyplot as plt +import os +from typing import Tuple, Optional + +logger = logging.getLogger(__name__) + +# 常量定义(参考README_coordinates.md) +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + +def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS +) -> Tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + +class CoordinateProcessor: + """坐标处理器""" + + @staticmethod + def convert_model_coords_to_actual( + model_x: int, + model_y: int, + original_width: int, + original_height: int + ) -> Tuple[int, int]: + """ + 将模型输出的坐标转换为实际设备坐标 + 参考README_coordinates.md的实现 + """ + try: + # 计算调整后的尺寸 + new_height, new_width = smart_resize(original_height, original_width) + + # 转换坐标 + actual_x = int(model_x / new_width * original_width) + actual_y = int(model_y / new_height * original_height) + + logger.info(f"坐标转换: 模型({model_x}, {model_y}) -> 实际({actual_x}, {actual_y})") + logger.info(f"原始尺寸: {original_width}x{original_height}, 调整尺寸: {new_width}x{new_height}") + + return actual_x, actual_y + + except Exception as e: + logger.error(f"坐标转换失败: {e}") + # 如果转换失败,返回原始坐标 + return model_x, model_y + + @staticmethod + def create_visualization_image( + screenshot_path: str, + click_x: int, + click_y: int, + output_path: str, + marker_size: int = 20, + marker_color: str = 'red' + ) -> bool: + """ + 创建带有点击位置标记的可视化图像 + """ + try: + # 打开原始截图 + img = Image.open(screenshot_path) + img_copy = img.copy() + + # 创建绘图对象 + draw = ImageDraw.Draw(img_copy) + + # 绘制点击位置标记(十字架) + marker_half = marker_size // 2 + + # 绘制红色十字架 + draw.line([ + (click_x - marker_half, click_y), + (click_x + marker_half, click_y) + ], fill=marker_color, width=3) + + draw.line([ + (click_x, click_y - marker_half), + (click_x, click_y + marker_half) + ], fill=marker_color, width=3) + + # 绘制圆形标记 + draw.ellipse([ + (click_x - marker_half//2, click_y - marker_half//2), + (click_x + marker_half//2, click_y + marker_half//2) + ], outline=marker_color, width=2) + + # 保存可视化图像 + img_copy.save(output_path) + logger.info(f"可视化图像已保存: {output_path}") + + return True + + except Exception as e: + logger.error(f"创建可视化图像失败: {e}") + return False + + @staticmethod + def create_matplotlib_visualization( + screenshot_path: str, + click_x: int, + click_y: int, + output_path: str + ) -> bool: + """ + 使用matplotlib创建可视化图像 + """ + try: + # 打开图像 + img = Image.open(screenshot_path) + + # 创建matplotlib图像 + plt.figure(figsize=(12, 8)) + plt.imshow(img) + plt.scatter([click_x], [click_y], c='red', s=100, marker='x') # 用红色X标记点击位置 + plt.title(f'Click Visualization at ({click_x}, {click_y})') + plt.axis('off') # 隐藏坐标轴 + + # 保存图像 + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() # 关闭图像以释放内存 + + logger.info(f"matplotlib可视化图像已保存: {output_path}") + return True + + except Exception as e: + logger.error(f"创建matplotlib可视化图像失败: {e}") + return False diff --git a/runner/UI-TARS-agent/ui_tars_automation/data_manager.py b/runner/UI-TARS-agent/ui_tars_automation/data_manager.py new file mode 100644 index 0000000..754f0d1 --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/data_manager.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据管理模块 - 负责保存截图、XML、操作记录等数据 +""" + +import os +import json +import time +import datetime +from typing import Dict, Any, Optional +from .config import ExecutionConfig + +class DataManager: + """数据管理器""" + + def __init__(self, config: ExecutionConfig, task_description: str): + self.config = config + self.task_description = task_description + self.task_start_time = datetime.datetime.now() + self.task_dir = None + self.current_step = 0 + + if config.save_data: + self._create_task_directory() + + def _create_task_directory(self): + """创建任务目录""" + # 生成任务目录名:基于时间戳和任务描述前几个字 + timestamp = self.task_start_time.strftime("%Y%m%d_%H%M%S") + task_name = "".join(c for c in self.task_description[:20] if c.isalnum() or c in ('_', '-')) + if not task_name: + task_name = "task" + + self.task_dir = os.path.join( + self.config.data_base_dir, + f"{timestamp}_{task_name}" + ) + + # 创建目录结构 + os.makedirs(self.task_dir, exist_ok=True) + + # 保存任务基本信息 + task_info = { + "task_description": self.task_description, + "start_time": self.task_start_time.isoformat(), + "config": { + "model_base_url": self.config.model_base_url, + "model_name": self.config.model_name, + "max_steps": self.config.max_steps, + "language": self.config.language + } + } + + with open(os.path.join(self.task_dir, "task_info.json"), "w", encoding="utf-8") as f: + json.dump(task_info, f, ensure_ascii=False, indent=2) + + def start_new_step(self, step_number: int): + """开始新的步骤""" + self.current_step = step_number + if self.config.save_data: + step_dir = os.path.join(self.task_dir, str(step_number)) + os.makedirs(step_dir, exist_ok=True) + + def save_screenshot(self, device, step_number: int) -> Optional[str]: + """保存截图""" + if not self.config.save_data or not self.config.save_screenshots: + return None + + screenshot_path = os.path.join( + self.task_dir, + str(step_number), + f"screenshot_{step_number}.jpg" + ) + + try: + device.screenshot(screenshot_path) + return screenshot_path + except Exception as e: + print(f"保存截图失败: {e}") + return None + + def save_xml(self, device, step_number: int) -> Optional[str]: + """保存XML层次结构""" + if not self.config.save_data or not self.config.save_xml: + return None + + xml_path = os.path.join( + self.task_dir, + str(step_number), + f"hierarchy_{step_number}.xml" + ) + + try: + hierarchy = device.dump_hierarchy() + with open(xml_path, "w", encoding="utf-8") as f: + f.write(hierarchy) + return xml_path + except Exception as e: + print(f"保存XML失败: {e}") + return None + + def save_step_data(self, step_number: int, step_data: Dict[str, Any]): + """保存单步数据""" + if not self.config.save_data: + return + + step_file = os.path.join( + self.task_dir, + str(step_number), + f"step_{step_number}.json" + ) + + try: + with open(step_file, "w", encoding="utf-8") as f: + json.dump(step_data, f, ensure_ascii=False, indent=2, default=str) + except Exception as e: + print(f"保存步骤数据失败: {e}") + + def save_execution_summary(self, execution_data: Dict[str, Any]): + """保存执行总结""" + if not self.config.save_data: + return + + # 保存actions.json格式的数据 + actions_data = { + "task_description": self.task_description, + "start_time": self.task_start_time.isoformat(), + "end_time": datetime.datetime.now().isoformat(), + "action_count": execution_data.get("total_steps", 0), + "success": execution_data.get("success", False), + "actions": [] + } + + # 转换action历史为标准格式 + for action in execution_data.get("action_history", []): + action_record = { + "step": action["step"], + "thought": action["thought"], + "raw_action": action["raw_action"], + "action_type": action["parsed_action"]["action_type"], + "action_params": action["parsed_action"]["action_params"], + "result": { + "success": action["result"].success, + "message": action["result"].message, + "error": action["result"].error + }, + "screenshot_path": action.get("screenshot_path"), + "xml_path": action.get("xml_path") + } + actions_data["actions"].append(action_record) + + # 保存actions.json + actions_path = os.path.join(self.task_dir, "actions.json") + with open(actions_path, "w", encoding="utf-8") as f: + json.dump(actions_data, f, ensure_ascii=False, indent=2) + + # 保存react.json格式的数据(简化版) + react_data = [] + for action in execution_data.get("action_history", []): + react_record = { + "reasoning": action["thought"], + "function": { + "name": action["parsed_action"]["action_type"], + "parameters": action["parsed_action"]["action_params"] + }, + "action_index": action["step"] + } + react_data.append(react_record) + + react_path = os.path.join(self.task_dir, "react.json") + with open(react_path, "w", encoding="utf-8") as f: + json.dump(react_data, f, ensure_ascii=False, indent=2) + + print(f"执行数据已保存到: {self.task_dir}") + + def get_task_directory(self) -> Optional[str]: + """获取任务目录路径""" + return self.task_dir diff --git a/runner/UI-TARS-agent/ui_tars_automation/framework.py b/runner/UI-TARS-agent/ui_tars_automation/framework.py new file mode 100644 index 0000000..332cacc --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/framework.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +UI-TARS 自动化执行框架 - 主框架模块 +""" + +import base64 +import time +import logging +from typing import List, Dict, Any +from openai import OpenAI +import uiautomator2 as u2 +from PIL import Image + +from .config import ExecutionConfig, ActionResult, APP_PACKAGES, MOBILE_PROMPT_TEMPLATE +from .action_parser import ActionParser +from .data_manager import DataManager +from .coordinate_processor import CoordinateProcessor + +logger = logging.getLogger(__name__) + +class UITarsAutomationFramework: + """UI-TARS自动化框架""" + + def __init__(self, config: ExecutionConfig): + self.config = config + self.device = None + self.client = None + self.action_history = [] + self.step_count = 0 + self.task_description = "" + self.data_manager = None + + self._initialize_client() + self._initialize_device() + + def _initialize_client(self): + """初始化OpenAI客户端""" + try: + self.client = OpenAI( + base_url=self.config.model_base_url, + api_key="EMPTY" + ) + logger.info(f"已连接到模型服务: {self.config.model_base_url}") + except Exception as e: + logger.error(f"模型客户端初始化失败: {e}") + raise + + def _initialize_device(self): + """初始化设备连接""" + try: + if self.config.device_ip: + self.device = u2.connect(self.config.device_ip) + else: + self.device = u2.connect() + + # 获取设备信息 + device_info = self.device.info + logger.info(f"已连接设备: {device_info.get('productName', 'Unknown')} " + f"- {device_info.get('version', 'Unknown')}") + except Exception as e: + logger.error(f"设备连接失败: {e}") + raise + + def _capture_screenshot_and_data(self, step_number: int) -> str: + """截图并保存相关数据""" + try: + # 保存截图和XML + screenshot_path = self.data_manager.save_screenshot(self.device, step_number) + xml_path = self.data_manager.save_xml(self.device, step_number) + + # 读取截图并转换为base64 + if screenshot_path: + with open(screenshot_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode('utf-8') + image_data_url = f"data:image/jpeg;base64,{image_data}" + else: + # 临时截图用于模型调用 + temp_path = f"temp_screenshot_{step_number}.jpg" + self.device.screenshot(temp_path) + with open(temp_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode('utf-8') + image_data_url = f"data:image/jpeg;base64,{image_data}" + import os + os.remove(temp_path) # 删除临时文件 + + logger.info(f"步骤 {step_number} 数据已保存") + return image_data_url, screenshot_path, xml_path + + except Exception as e: + logger.error(f"截图和数据保存失败: {e}") + raise + + def _build_messages(self, image_data: str) -> List[Dict]: + """构建发送给模型的消息""" + # 构建系统提示 + system_prompt = MOBILE_PROMPT_TEMPLATE.format( + language=self.config.language, + instruction=self.task_description + ) + + messages = [ + { + "role": "user", + "content": system_prompt + } + ] + + # 添加历史操作记录 + for action in self.action_history: + if action.get('thought') and action.get('raw_action'): + messages.append({ + "role": "assistant", + "content": f"Thought: {action['thought']}\nAction: {action['raw_action']}" + }) + + # 添加当前截图 + messages.append({ + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_data} + } + ] + }) + + return messages + + def _call_model(self, messages: List[Dict]) -> str: + """调用模型获取响应""" + try: + logger.info("正在调用模型...") + chat_completion = self.client.chat.completions.create( + model=self.config.model_name, + messages=messages, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + stream=True + ) + + response = "" + for chunk in chat_completion: + if chunk.choices[0].delta.content: + response += chunk.choices[0].delta.content + + logger.info(f"模型响应: {response}") + return response + except Exception as e: + logger.error(f"模型调用失败: {e}") + raise + + def _execute_action(self, action: Dict, screenshot_path: str = None) -> ActionResult: + """执行具体的操作""" + try: + action_type = action['action_type'] + params = action['action_params'] + + logger.info(f"执行操作: {action_type} - {params}") + + if action_type == 'click': + x, y = params['x'], params['y'] + + # 创建可视化图像(如果有截图路径) + if screenshot_path: + try: + vis_path = screenshot_path.replace('.jpg', '_visualization.jpg') + CoordinateProcessor.create_visualization_image( + screenshot_path, x, y, vis_path + ) + logger.info(f"坐标可视化已保存: {vis_path}") + except Exception as e: + logger.warning(f"坐标可视化失败: {e}") + + # 执行点击操作(坐标已经是绝对坐标) + logger.info(f"设备点击坐标: ({x}, {y})") + + # 确保坐标为整数 + x, y = round(x), round(y) + + # 执行点击 + self.device.click(x, y) + + # 等待操作完成 + time.sleep(0.5) + + return ActionResult(True, f"点击 ({x}, {y})") + + elif action_type == 'long_press': + x, y = params['x'], params['y'] + + # 创建可视化图像(如果有截图路径) + if screenshot_path: + try: + vis_path = screenshot_path.replace('.jpg', '_longpress_visualization.jpg') + CoordinateProcessor.create_visualization_image( + screenshot_path, x, y, vis_path + ) + logger.info(f"长按坐标可视化已保存: {vis_path}") + except Exception as e: + logger.warning(f"坐标可视化失败: {e}") + + logger.info(f"设备长按坐标: ({x}, {y})") + x, y = round(x), round(y) + self.device.long_click(x, y) + time.sleep(0.5) + + return ActionResult(True, f"长按 ({x}, {y})") + + elif action_type == 'type': + # 使用ADB键盘进行输入 + text = params['text'] + logger.info(f"输入文本: {text}") + + try: + # 获取当前输入法 + current_ime = self.device.current_ime() + + # 切换到ADB键盘 + self.device.shell(['settings', 'put', 'secure', 'default_input_method', + 'com.android.adbkeyboard/.AdbIME']) + time.sleep(0.5) + + # 发送文本 + charsb64 = base64.b64encode(text.encode('utf-8')).decode('utf-8') + self.device.shell(['am', 'broadcast', '-a', 'ADB_INPUT_B64', '--es', 'msg', charsb64]) + time.sleep(0.5) + + # 恢复原输入法 + self.device.shell(['settings', 'put', 'secure', 'default_input_method', current_ime]) + time.sleep(0.5) + + return ActionResult(True, f"输入文本: {text}") + + except Exception as e: + logger.error(f"文本输入失败: {e}") + return ActionResult(False, f"文本输入失败: {e}") + + elif action_type == 'scroll': + direction = params['direction'].lower() + + # 获取坐标,如果没有提供则使用屏幕中心 + if 'x' in params and 'y' in params: + x, y = params['x'], params['y'] + else: + # 使用设备屏幕中心作为滚动起点 + device_info = self.device.info + x = device_info['displayWidth'] // 2 + y = device_info['displayHeight'] // 2 + + # 坐标转换(仅当有原始坐标时) + if screenshot_path and 'x' in params and 'y' in params: + try: + img = Image.open(screenshot_path) + width, height = img.size + actual_x, actual_y = CoordinateProcessor.convert_model_coords_to_actual( + params['x'], params['y'], width, height + ) + x, y = actual_x, actual_y + except Exception as e: + logger.warning(f"坐标转换失败,使用原始坐标: {e}") + + logger.info(f"滚动操作: {direction} at ({x}, {y})") + x, y = round(x), round(y) + + # 参考滚动实现,使用较短的duration + if direction == 'down': + self.device.swipe(x, y, x, y - 300, duration=0.1) + elif direction == 'up': + self.device.swipe(x, y, x, y + 300, duration=0.1) + elif direction == 'left': + self.device.swipe(x, y, x + 300, y, duration=0.1) + elif direction == 'right': + self.device.swipe(x, y, x - 300, y, duration=0.1) + + time.sleep(0.5) + return ActionResult(True, f"滚动 {direction}") + + elif action_type == 'drag': + start_x, start_y = params['start_x'], params['start_y'] + end_x, end_y = params['end_x'], params['end_y'] + + # 坐标转换 + if screenshot_path: + try: + img = Image.open(screenshot_path) + width, height = img.size + actual_start_x, actual_start_y = CoordinateProcessor.convert_model_coords_to_actual( + start_x, start_y, width, height + ) + actual_end_x, actual_end_y = CoordinateProcessor.convert_model_coords_to_actual( + end_x, end_y, width, height + ) + start_x, start_y = actual_start_x, actual_start_y + end_x, end_y = actual_end_x, actual_end_y + except Exception as e: + logger.warning(f"坐标转换失败,使用原始坐标: {e}") + + logger.info(f"拖拽操作: ({start_x}, {start_y}) -> ({end_x}, {end_y})") + start_x, start_y = round(start_x), round(start_y) + end_x, end_y = round(end_x), round(end_y) + + # 参考拖拽实现 + self.device.swipe(start_x, start_y, end_x, end_y, duration=0.1) + time.sleep(0.5) + + return ActionResult(True, f"拖拽 ({start_x}, {start_y}) → ({end_x}, {end_y})") + + elif action_type == 'press_home': + logger.info("按下Home键") + self.device.press("home") + time.sleep(0.5) + return ActionResult(True, "按下Home键") + + elif action_type == 'press_back': + logger.info("按下返回键") + self.device.press("back") + time.sleep(0.5) + return ActionResult(True, "按下返回键") + + elif action_type == 'open_app': + app_name = params.get('app_name', '') + logger.info(f"尝试打开应用: {app_name}") + + # 从映射表中获取包名 + package_name = APP_PACKAGES.get(app_name) + if package_name: + try: + # 使用device.app_start启动应用 + self.device.app_start(package_name, stop=True) + time.sleep(2.0) # 等待应用启动 + logger.info(f"成功启动应用: {app_name} ({package_name})") + return ActionResult(True, f"已打开应用: {app_name}") + except Exception as e: + logger.error(f"启动应用失败: {e}") + return ActionResult(False, f"启动应用失败: {app_name} - {e}") + else: + # 如果没有找到包名,尝试通过图标点击 + logger.warning(f"未找到应用包名: {app_name},请手动点击应用图标") + return ActionResult(False, f"未找到应用包名: {app_name},请使用click操作点击应用图标") + + elif action_type == 'finished': + return ActionResult(True, "任务完成", error="FINISHED") + + elif action_type == 'wait': + # 等待指定秒数 + seconds = params.get('seconds', 1) + try: + seconds = float(seconds) + except Exception: + seconds = 1 + logger.info(f"等待 {seconds} 秒") + time.sleep(seconds) + return ActionResult(True, f"等待 {seconds} 秒") + else: + return ActionResult(False, f"不支持的操作类型: {action_type}") + + except Exception as e: + logger.error(f"操作执行失败: {e}") + return ActionResult(False, f"操作执行失败: {e}", error=str(e)) + + def execute_task(self, task_description: str) -> bool: + """执行完整的任务""" + self.task_description = task_description + self.step_count = 0 + self.action_history = [] + + # 初始化数据管理器 + self.data_manager = DataManager(self.config, task_description) + + logger.info(f"开始执行任务: {task_description}") + + try: + while self.step_count < self.config.max_steps: + self.step_count += 1 + logger.info(f"\n{'='*50}") + logger.info(f"第 {self.step_count} 步") + logger.info(f"{'='*50}") + + # 开始新步骤 + self.data_manager.start_new_step(self.step_count) + + # 1. 截图并保存数据 + image_data, screenshot_path, xml_path = self._capture_screenshot_and_data(self.step_count) + + # 2. 构建消息 + messages = self._build_messages(image_data) + + # 3. 调用模型 + response = self._call_model(messages) + + # 4. 解析响应(传递图片尺寸信息用于坐标转换) + image_height, image_width = None, None + if screenshot_path: + try: + from PIL import Image + with Image.open(screenshot_path) as img: + image_width, image_height = img.size + except Exception as e: + logger.warning(f"无法获取图片尺寸: {e}") + + thought, raw_action, parsed_action = ActionParser.parse_response( + response, image_height, image_width + ) + + logger.info(f"思考: {thought}") + logger.info(f"操作: {raw_action}") + + # 5. 执行操作 + result = self._execute_action(parsed_action, screenshot_path) + result.screenshot_path = screenshot_path + result.xml_path = xml_path + + # 6. 记录历史 + action_record = { + 'step': self.step_count, + 'thought': thought, + 'raw_action': raw_action, + 'parsed_action': parsed_action, + 'result': result, + 'screenshot_path': screenshot_path, + 'xml_path': xml_path + } + self.action_history.append(action_record) + + # 7. 保存步骤数据 + step_data = { + 'step': self.step_count, + 'thought': thought, + 'raw_action': raw_action, + 'parsed_action': parsed_action, + 'result': { + 'success': result.success, + 'message': result.message, + 'error': result.error + }, + 'screenshot_path': screenshot_path, + 'xml_path': xml_path, + 'timestamp': time.time() + } + self.data_manager.save_step_data(self.step_count, step_data) + + # 8. 检查是否完成 + if not result.success: + logger.error(f"操作失败: {result.message}") + break + + if result.error == "FINISHED": + logger.info("任务执行完成!") + break + + # 9. 等待 + time.sleep(self.config.step_delay) + + # 保存执行总结 + execution_summary = self.get_execution_summary() + self.data_manager.save_execution_summary(execution_summary) + + success = execution_summary.get('success', False) + if not success and self.step_count >= self.config.max_steps: + logger.warning(f"达到最大步数限制 ({self.config.max_steps})") + + return success + + except Exception as e: + logger.error(f"任务执行失败: {e}") + return False + + def get_execution_summary(self) -> Dict[str, Any]: + """获取执行摘要""" + return { + 'task_description': self.task_description, + 'total_steps': self.step_count, + 'action_history': self.action_history, + 'success': len(self.action_history) > 0 and self.action_history[-1]['result'].error == "FINISHED", + 'task_directory': self.data_manager.get_task_directory() if self.data_manager else None + } diff --git a/runner/UI-TARS-agent/ui_tars_automation/logger.py b/runner/UI-TARS-agent/ui_tars_automation/logger.py new file mode 100644 index 0000000..4a2850e --- /dev/null +++ b/runner/UI-TARS-agent/ui_tars_automation/logger.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +日志配置模块 +""" + +import logging +import os + +def setup_logging(log_level=logging.INFO, log_file="automation.log"): + """设置日志配置""" + # 创建日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 配置根日志器 + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + # 清除现有处理器 + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # 控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # 文件处理器 + if log_file: + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + +# 在模块导入时自动设置日志 +setup_logging() diff --git a/runner/mobiagent/mobiagent.py b/runner/mobiagent/mobiagent.py new file mode 100644 index 0000000..82e8dc0 --- /dev/null +++ b/runner/mobiagent/mobiagent.py @@ -0,0 +1,449 @@ +from openai import OpenAI +import uiautomator2 as u2 +import base64 +from PIL import Image +import json +import io +import logging +from abc import ABC, abstractmethod +import time +import re +import os +import argparse +from PIL import Image, ImageDraw, ImageFont +import textwrap + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +MAX_STEPS = 30 + +class Device(ABC): + @abstractmethod + def start_app(self, app): + pass + + @abstractmethod + def screenshot(self, path): + pass + + @abstractmethod + def click(self, x, y): + pass + + @abstractmethod + def input(self, text): + pass + + @abstractmethod + def swipe(self, direction): + pass + + @abstractmethod + def keyevent(self, key): + pass + +class AndroidDevice(Device): + def __init__(self, adb_endpoint=None): + super().__init__() + if adb_endpoint: + self.d = u2.connect(adb_endpoint) + else: + self.d = u2.connect() + self.app_package_names = { + "携程": "ctrip.android.view", + "同城": "com.tongcheng.android", + "飞猪": "com.taobao.trip", + "去哪儿": "com.Qunar", + "华住会": "com.htinns", + "饿了么": "me.ele", + } + + def start_app(self, app): + package_name = self.app_package_names.get(app) + if not package_name: + raise ValueError(f"App '{app}' is not registered with a package name.") + self.d.app_start(package_name, stop=True) + time.sleep(1) + if not self.d.app_wait(package_name, timeout=10): + raise RuntimeError(f"Failed to start app '{app}' with package '{package_name}'") + + def app_start(self, package_name): + self.d.app_start(package_name, stop=True) + time.sleep(1) + if not self.d.app_wait(package_name, timeout=10): + raise RuntimeError(f"Failed to start package '{package_name}'") + + def screenshot(self, path): + self.d.screenshot(path) + + def click(self, x, y): + self.d.click(x, y) + + def input(self, text): + current_ime = self.d.current_ime() + self.d.shell(['settings', 'put', 'secure', 'default_input_method', 'com.android.adbkeyboard/.AdbIME']) + time.sleep(1) + charsb64 = base64.b64encode(text.encode('utf-8')).decode('utf-8') + self.d.shell(['am', 'broadcast', '-a', 'ADB_INPUT_B64', '--es', 'msg', charsb64]) + time.sleep(1) + self.d.shell(['settings', 'put', 'secure', 'default_input_method', current_ime]) + time.sleep(1) + + def swipe(self, direction, scale=0.5): + # self.d.swipe_ext(direction, scale) + self.d.swipe_ext(direction=direction, scale=scale) + + def keyevent(self, key): + self.d.keyevent(key) + self.d.set_input_ime + + def dump_hierarchy(self): + return self.d.dump_hierarchy() + +decider_client = None +grounder_client = None +planner_client = None + +def init(service_ip, decider_port, grounder_port, planner_port): + global decider_client, grounder_client, planner_client, general_client, general_model, apps + decider_client = OpenAI( + api_key = "0", + base_url = f"http://{service_ip}:{decider_port}/v1", + ) + grounder_client = OpenAI( + api_key = "0", + base_url = f"{service_ip}:{grounder_port}/v1", + ) + planner_client = OpenAI( + api_key = "0", + base_url = f"{service_ip}:{planner_port}/v1", + ) + +decider_prompt_template = """ +You are a phone-use AI agent. Now your task is "{task}". +Your action history is: +{history} +Please provide the next action based on the screenshot and your action history. You should do careful reasoning before providing the action. +Your action space includes: +- Name: click, Parameters: target_element (a high-level description of the UI element to click). +- Name: swipe, Parameters: direction (one of UP, DOWN, LEFT, RIGHT). +- Name: input, Parameters: text (the text to input). +- Name: wait, Parameters: (no parameters, will wait for 1 second). +- Name: done, Parameters: (no parameters). +Your output should be a JSON object with the following format: +{{"reasoning": "Your reasoning here", "action": "The next action (one of click, input, swipe, done)", "parameters": {{"param1": "value1", ...}}}}""" + +grounder_prompt_template_no_bbox = ''' +Based on the screenshot, user's intent and the description of the target UI element, provide the coordinates of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"coordinates": [x, y]}}''' + +grounder_prompt_template_bbox = ''' +Based on the screenshot, user's intent and the description of the target UI element, provide the bounding box of the element using **absolute coordinates**. +User's intent: {reasoning} +Target element's description: {description} +Your output should be a JSON object with the following format: +{{"bbox": [x1, y1, x2, y2]}}''' + + +screenshot_path = "screenshot.jpg" +factor = 0.5 + +prices = {} + +app_scale = { + "去哪儿": 1.0, + "飞猪": 0.7, + "华住会": 1.0, + "携程": 0.9, + "同城": 1.0, +} + +def get_screenshot(device): + device.screenshot(screenshot_path) + # resize the screenshot to reduce the size for processing + img = Image.open(screenshot_path) + img = img.resize((int(img.width * factor), int(img.height * factor)), Image.Resampling.LANCZOS) + buffered = io.BytesIO() + img.save(buffered, format="JPEG") + screenshot = base64.b64encode(buffered.getvalue()).decode("utf-8") + return screenshot + +def task_in_app(app, old_task, task, device, data_dir, bbox_flag=True): + history = [] + actions = [] + reacts = [] + while True: + if len(actions) >= MAX_STEPS: + logging.info("Reached maximum steps, stopping the task.") + break + + if len(history) == 0: + history_str = "(No history)" + else: + history_str = "\n".join(f"{idx}. {h}" for idx, h in enumerate(history, 1)) + + screenshot = get_screenshot(device) + + decider_prompt = decider_prompt_template.format( + task=task, + history=history_str + ) + # logging.info(f"Decider prompt: \n{decider_prompt}") + decider_response_str = decider_client.chat.completions.create( + model="decider", + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{screenshot}"}}, + {"type": "text", "text": decider_prompt}, + ] + } + ], + temperature=0 + ).choices[0].message.content + + logging.info(f"Decider response: \n{decider_response_str}") + + decider_response = json.loads(decider_response_str) + converted_item = { + "reasoning": decider_response["reasoning"], + "function": { + "name": decider_response["action"], + "parameters": decider_response["parameters"] + } + } + reacts.append(converted_item) + action = decider_response["action"] + + current_dir = os.getcwd() + img_path = os.path.join(current_dir, f"screenshot.jpg") + save_path = os.path.join(data_dir, f"{len(actions) + 1}.jpg") + img = Image.open(img_path) + img.save(save_path) + + hierarchy_path = os.path.join(data_dir, f"{len(actions) + 1}.xml") + hierarchy = device.dump_hierarchy() + with open(hierarchy_path, "w", encoding="utf-8") as f: + f.write(hierarchy) + + if action == "done": + print("Task completed.") + actions.append({ + "type": "done" + }) + break + if action == "click": + reasoning = decider_response["reasoning"] + target_element = decider_response["parameters"]["target_element"] + grounder_prompt = (grounder_prompt_template_bbox if bbox_flag else grounder_prompt_template_no_bbox).format(reasoning=reasoning, description=target_element) + # logging.info(f"Grounder prompt: \n{grounder_prompt}") + + grounder_response_str = grounder_client.chat.completions.create( + model="", + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{screenshot}"}}, + {"type": "text", "text": grounder_prompt}, + ] + } + ], + temperature=0 + ).choices[0].message.content + logging.info(f"Grounder response: \n{grounder_response_str}") + grounder_response = json.loads(grounder_response_str) + if(bbox_flag): + bbox = grounder_response["bbox"] + + x1, y1, x2, y2 = [int(coord / factor) for coord in bbox] + position_x = (x1 + x2) // 2 + position_y = (y1 + y2) // 2 + device.click(position_x, position_y) + actions.append({ + "type": "click", + "position_x": position_x, + "position_y": position_y, + "bounds": [ + x1, y1, x2, y2 + ] + }) + history.append(decider_response_str) + + current_dir = os.getcwd() + img_path = os.path.join(current_dir, f"screenshot.jpg") + save_path = os.path.join(data_dir, f"{len(actions)}_highlighted.jpg") + img = Image.open(img_path) + draw = ImageDraw.Draw(img) + font = ImageFont.truetype("msyh.ttf", 40) + text = f"CLICK [{position_x}, {position_y}]" + text = textwrap.fill(text, width=20) + text_width, text_height = draw.textbbox((0, 0), text, font=font)[2:] + draw.text((img.width / 2 - text_width / 2, 0), text, fill="red", font=font) + img.save(save_path) + + # 拉框 + bounds_path = os.path.join(data_dir, f"{len(actions)}_bounds.jpg") + img_bounds = Image.open(save_path) + draw_bounds = ImageDraw.Draw(img_bounds) + draw_bounds.rectangle([x1, y1, x2, y2], outline='red', width=5) + img_bounds.save(bounds_path) + + # # 画点 + # with open(save_path, 'rb') as f: + # image_data = f.read() + # nparr = np.frombuffer(image_data, np.uint8) + # cv2image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + # if action["type"] == "click": + # x = int(action['position_x']) + # y = int(action['position_y']) + # cv2.circle(cv2image, (x, y), 50, (0, 0, 255), 10) + # elif action["type"] == "swipe": + # x1 = int(action['press_position_x']) + # y1 = int(action['press_position_y']) + # x2 = int(action['release_position_x']) + # y2 = int(action['release_position_y']) + # cv2.arrowedLine(cv2image, (x1, y1), (x2, y2), (0, 0, 255), 5) + # success, encoded_img = cv2.imencode('.jpg', cv2image) + + else: + coordinates = grounder_response["coordinates"] + x, y = [int(coord / factor) for coord in coordinates] + device.click(x, y) + + + elif action == "input": + text = decider_response["parameters"]["text"] + device.input(text) + actions.append({ + "type": "input", + "text": text + }) + history.append(decider_response_str) + + elif action == "swipe": + direction = decider_response["parameters"]["direction"] + + if direction == "DOWN": + device.swipe(direction.lower(), 2) + time.sleep(2) + continue + + if direction in ["UP", "DOWN", "LEFT", "RIGHT"]: + device.swipe(direction.lower()) + actions.append({ + "type": "swipe", + "press_position_x": None, + "press_position_y": None, + "release_position_x": None, + "release_position_y": None, + "direction": direction.lower() + }) + history.append(decider_response_str) + + else: + raise ValueError(f"Unknown swipe direction: {direction}") + elif action == "wait": + print("Waiting for a while...") + actions.append({ + "type": "wait" + }) + else: + raise ValueError(f"Unknown action: {action}") + + time.sleep(1) + + data = { + "app_name": app, + "task_type": None, + "old_task_description": old_task, + "task_description": task, + "action_count": len(actions), + "actions": actions + } + + with open(os.path.join(data_dir, "actions.json"), "w", encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + with open(os.path.join(data_dir, "react.json"), "w", encoding='utf-8') as f: + json.dump(reacts, f, ensure_ascii=False, indent=4) + +from utils.load_md_prompt import load_prompt +app_selection_prompt_template = load_prompt("planner.md") + +def get_app_package_name(task_description): + """根据任务描述获取需要启动的app包名和改写后的任务描述""" + app_selection_prompt = app_selection_prompt_template.format(task_description=task_description) + while True: + response_str = planner_client.chat.completions.create( + model="planner", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": app_selection_prompt}, + ] + } + ] + ).choices[0].message.content + + logging.info(f"应用选择响应: \n{response_str}") + + pattern = re.compile(r"```json\n(.*)\n```", re.DOTALL) + match = pattern.search(response_str) + if match: + break + + response = json.loads(match.group(1)) + app_name = response.get("app_name") + package_name = response.get("package_name") + new_task_description = response.get("task_description", task_description) # 如果没有新描述,使用原描述 + + return app_name, package_name, new_task_description + +# for testing purposes +if __name__ == "__main__": + # 解析命令行参数 + parser = argparse.ArgumentParser(description="MobiMind Agent") + parser.add_argument("--service_ip", type=str, default="localhost", help="Ip for the services (default: localhost)") + parser.add_argument("--decider_port", type=int, default=8000, help="Port for decider service (default: 8000)") + parser.add_argument("--grounder_port", type=int, default=8001, help="Port for grounder service (default: 8001)") + parser.add_argument("--planner_port", type=int, default=8002, help="Port for planner service (default: 8002)") + + args = parser.parse_args() + + # 使用命令行参数初始化 + init(args.service_ip, args.decider_port, args.grounder_port, args.planner_port) + + device = AndroidDevice() + print(f"connect to device") + + data_base_dir = os.path.join(os.path.dirname(__file__), 'data') + if not os.path.exists(data_base_dir): + os.makedirs(data_base_dir) + + # 读取任务列表 + task_json_path = os.path.join(os.path.dirname(__file__), "task.json") + with open(task_json_path, "r", encoding="utf-8") as f: + task_list = json.load(f) + + # print(task_list) + + for task in task_list: + existing_dirs = [d for d in os.listdir(data_base_dir) if os.path.isdir(os.path.join(data_base_dir, d)) and d.isdigit()] + if existing_dirs: + data_index = max(int(d) for d in existing_dirs) + 1 + else: + data_index = 1 + data_dir = os.path.join(data_base_dir, str(data_index)) + os.makedirs(data_dir) + + task_description = task + app_name, package_name, new_task_description = get_app_package_name(task_description) + + device.app_start(package_name) + print(f"Starting task '{new_task_description}' in app '{app_name}'") + task_in_app(app_name, task_description, new_task_description, device, data_dir, True) \ No newline at end of file diff --git a/runner/mobiagent/task.json b/runner/mobiagent/task.json new file mode 100644 index 0000000..818947a --- /dev/null +++ b/runner/mobiagent/task.json @@ -0,0 +1,27 @@ +[ + "在淘宝上搜索电动牙刷,选最畅销的那款", + "用淘宝搜一下苹果笔记本电脑", + "用淘宝查一下哪个蓝牙耳机最便宜", + "在淘宝上搜索男士休闲夹克", + "用淘宝查一下哪款空气净化器销量第一", + "用淘宝帮我找一下耐克运动鞋,然后放进购物车", + "在淘宝上将香奈儿口红加入购物车", + "用淘宝帮我搜一下小天鹅洗衣机,再放进购物车", + "用淘宝找一下荣耀Magic6手机,并加到购物车", + "淘宝查找iPad Pro,然后把它放进购物车", + "用淘宝把最便宜的蓝牙音箱加到购物车", + "淘宝将天猫自营的佳能单反相机加入购物车", + "淘宝请将销量最高的那款跑步机放进购物车", + "淘宝请将售价最高的那款奢侈品手表添加到购物车", + "在淘宝上,将销量最高的华为Mate 70加入购物车", + "用淘宝将L码的优衣库短袖T恤添加到购物车", + "用淘宝将M码的运动鞋加入购物车", + "在淘宝上,将12号的足球加入购物车", + "淘宝选中黑色的iPhone 14 Pro,然后加到购物车里", + "在淘宝上将1米的Anker USB-C数据线加入购物车", + "在淘宝上,将价格最低的64GB蓝色iPhone 15加入购物车", + "打开淘宝,找到大号的北欧风地毯,并把销量最高的那款放进购物车", + "用淘宝找到金色32GB优盘里价格最高的那一款并加购", + "在淘宝上,将销量最高的粉色AirPods保护壳加到购物车里", + "在淘宝上,将售价最便宜的2米Type-C快充线加入购物车" +] \ No newline at end of file diff --git a/utils/box_annotator.py b/utils/box_annotator.py new file mode 100644 index 0000000..82f7116 --- /dev/null +++ b/utils/box_annotator.py @@ -0,0 +1,262 @@ +from typing import List, Optional, Union, Tuple + +import cv2 +import numpy as np + +from supervision.detection.core import Detections +from supervision.draw.color import Color, ColorPalette + + +class BoxAnnotator: + """ + A class for drawing bounding boxes on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to draw the bounding box, + can be a single color or a color palette + thickness (int): The thickness of the bounding box lines, default is 2 + text_color (Color): The color of the text on the bounding box, default is white + text_scale (float): The scale of the text on the bounding box, default is 0.5 + text_thickness (int): The thickness of the text on the bounding box, + default is 1 + text_padding (int): The padding around the text on the bounding box, + default is 5 + + """ + + def __init__( + self, + color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, + thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo + text_color: Color = Color.BLACK, + text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + text_thickness: int = 2, #1, # 2 for demo + text_padding: int = 10, + avoid_overlap: bool = True, + ): + self.color: Union[Color, ColorPalette] = color + self.thickness: int = thickness + self.text_color: Color = text_color + self.text_scale: float = text_scale + self.text_thickness: int = text_thickness + self.text_padding: int = text_padding + self.avoid_overlap: bool = avoid_overlap + + def annotate( + self, + scene: np.ndarray, + detections: Detections, + labels: Optional[List[str]] = None, + skip_label: bool = False, + image_size: Optional[Tuple[int, int]] = None, + ) -> np.ndarray: + """ + Draws bounding boxes on the frame using the detections provided. + + Args: + scene (np.ndarray): The image on which the bounding boxes will be drawn + detections (Detections): The detections for which the + bounding boxes will be drawn + labels (Optional[List[str]]): An optional list of labels + corresponding to each detection. If `labels` are not provided, + corresponding `class_id` will be used as label. + skip_label (bool): Is set to `True`, skips bounding box label annotation. + Returns: + np.ndarray: The image with the bounding boxes drawn on it + + Example: + ```python + import supervision as sv + + classes = ['person', ...] + image = ... + detections = sv.Detections(...) + + box_annotator = sv.BoxAnnotator() + labels = [ + f"{classes[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _ in detections + ] + annotated_frame = box_annotator.annotate( + scene=image.copy(), + detections=detections, + labels=labels + ) + ``` + """ + font = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(detections)): + x1, y1, x2, y2 = detections.xyxy[i].astype(int) + class_id = ( + detections.class_id[i] if detections.class_id is not None else None + ) + idx = class_id if class_id is not None else i + color = ( + self.color.by_idx(idx) + if isinstance(self.color, ColorPalette) + else self.color + ) + cv2.rectangle( + img=scene, + pt1=(x1, y1), + pt2=(x2, y2), + color=color.as_bgr(), + thickness=self.thickness, + ) + if skip_label: + continue + + text = ( + f"{class_id}" + if (labels is None or len(detections) != len(labels)) + else labels[i] + ) + + text_width, text_height = cv2.getTextSize( + text=text, + fontFace=font, + fontScale=self.text_scale, + thickness=self.text_thickness, + )[0] + + if not self.avoid_overlap: + text_x = x1 + self.text_padding + text_y = y1 - self.text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * self.text_padding - text_height + + text_background_x2 = x1 + 2 * self.text_padding + text_width + text_background_y2 = y1 + # text_x = x1 - self.text_padding - text_width + # text_y = y1 + self.text_padding + text_height + # text_background_x1 = x1 - 2 * self.text_padding - text_width + # text_background_y1 = y1 + # text_background_x2 = x1 + # text_background_y2 = y1 + 2 * self.text_padding + text_height + else: + text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size) + + cv2.rectangle( + img=scene, + pt1=(text_background_x1, text_background_y1), + pt2=(text_background_x2, text_background_y2), + color=color.as_bgr(), + thickness=cv2.FILLED, + ) + # import pdb; pdb.set_trace() + box_color = color.as_rgb() + luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] + text_color = (0,0,0) if luminance > 160 else (255,255,255) + cv2.putText( + img=scene, + text=text, + org=(text_x, text_y), + fontFace=font, + fontScale=self.text_scale, + # color=self.text_color.as_rgb(), + color=text_color, + thickness=self.text_thickness, + lineType=cv2.LINE_AA, + ) + return scene + + +def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + +def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + +def IoU(box1, box2, return_max=True): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + if return_max: + return max(intersection / union, ratio1, ratio2) + else: + return intersection / union + + +def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size): + """ check overlap of text and background detection box, and get_optimal_label_pos, + pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right + Threshold: default to 0.3 + """ + + def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size): + is_overlap = False + for i in range(len(detections)): + detection = detections.xyxy[i].astype(int) + if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3: + is_overlap = True + break + # check if the text is out of the image + if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]: + is_overlap = True + return is_overlap + + # if pos == 'top left': + text_x = x1 + text_padding + text_y = y1 - text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x1 + 2 * text_padding + text_width + text_background_y2 = y1 + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'outer left': + text_x = x1 - text_padding - text_width + text_y = y1 + text_padding + text_height + + text_background_x1 = x1 - 2 * text_padding - text_width + text_background_y1 = y1 + + text_background_x2 = x1 + text_background_y2 = y1 + 2 * text_padding + text_height + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + + # elif pos == 'outer right': + text_x = x2 + text_padding + text_y = y1 + text_padding + text_height + + text_background_x1 = x2 + text_background_y1 = y1 + + text_background_x2 = x2 + 2 * text_padding + text_width + text_background_y2 = y1 + 2 * text_padding + text_height + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'top right': + text_x = x2 - text_padding - text_width + text_y = y1 - text_padding + + text_background_x1 = x2 - 2 * text_padding - text_width + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x2 + text_background_y2 = y1 + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 diff --git a/utils/load_md_prompt.py b/utils/load_md_prompt.py new file mode 100644 index 0000000..47ffe7c --- /dev/null +++ b/utils/load_md_prompt.py @@ -0,0 +1,11 @@ +import os + +def load_prompt(md_name): + """从markdown文件加载应用选择prompt模板""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file = os.path.join(current_dir, "..", "prompts", md_name) + + with open(prompt_file, "r", encoding="utf-8") as f: + content = f.read() + content = content.replace("````markdown", "").replace("````", "") + return content.strip() \ No newline at end of file diff --git a/utils/omni_utils.py b/utils/omni_utils.py new file mode 100644 index 0000000..72eafb0 --- /dev/null +++ b/utils/omni_utils.py @@ -0,0 +1,524 @@ +import io +import base64 +import time +from PIL import Image, ImageDraw, ImageFont +import time +import base64 + +import torch +import cv2 +import numpy as np +from matplotlib import pyplot as plt +from paddleocr import PaddleOCR +from typing import Tuple, List, Union +from torchvision.ops import box_convert +from torchvision.transforms import ToPILImage +import supervision as sv +import torchvision.transforms as T + +from utils.box_annotator import BoxAnnotator + +paddle_ocr = PaddleOCR( + lang='ch', # other lang also available + use_angle_cls=False, + use_gpu=True, # enable GPU acceleration for PaddleOCR + show_log=False, + max_batch_size=1024, + use_dilation=True, # improves accuracy + det_db_score_mode='slow', # improves accuracy + rec_batch_num=1024) + +def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): + if not device: + device = "cuda" if torch.cuda.is_available() else "cpu" + if model_name == "blip2": + from transformers import Blip2Processor, Blip2ForConditionalGeneration + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + if device == 'cpu': + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float32 + ) + else: + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float16 + ).to(device) + elif model_name == "florence2": + from transformers import AutoProcessor, AutoModelForCausalLM + processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) + if device == 'cpu': + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device) + return {'model': model.to(device), 'processor': processor} + + +def get_yolo_model(model_path): + from ultralytics import YOLO + # Load the model. + model = YOLO(model_path) + return model + + +@torch.inference_mode() +def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128): + # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model + to_pil = ToPILImage() + if starting_idx: + non_ocr_boxes = filtered_boxes[starting_idx:] + else: + non_ocr_boxes = filtered_boxes + croped_pil_image = [] + for i, coord in enumerate(non_ocr_boxes): + try: + xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) + ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) + cropped_image = image_source[ymin:ymax, xmin:xmax, :] + cropped_image = cv2.resize(cropped_image, (64, 64)) + croped_pil_image.append(to_pil(cropped_image)) + except: + continue + + model, processor = caption_model_processor['model'], caption_model_processor['processor'] + if not prompt: + if 'florence' in model.config.name_or_path: + prompt = "" + else: + prompt = "The image shows" + + generated_texts = [] + device = model.device + for i in range(0, len(croped_pil_image), batch_size): + start = time.time() + batch = croped_pil_image[i:i+batch_size] + t1 = time.time() + if model.device.type == 'cuda': + inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16) + else: + inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device) + if 'florence' in model.config.name_or_path: + generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False) + else: + generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True, + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_text = [gen.strip() for gen in generated_text] + generated_texts.extend(generated_text) + + return generated_texts + + + +def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor): + to_pil = ToPILImage() + if ocr_bbox: + non_ocr_boxes = filtered_boxes[len(ocr_bbox):] + else: + non_ocr_boxes = filtered_boxes + croped_pil_image = [] + for i, coord in enumerate(non_ocr_boxes): + xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) + ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) + cropped_image = image_source[ymin:ymax, xmin:xmax, :] + croped_pil_image.append(to_pil(cropped_image)) + + model, processor = caption_model_processor['model'], caption_model_processor['processor'] + device = model.device + messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}] + prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + batch_size = 5 # Number of samples per batch + generated_texts = [] + + for i in range(0, len(croped_pil_image), batch_size): + images = croped_pil_image[i:i+batch_size] + image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images] + inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []} + texts = [prompt] * len(images) + for i, txt in enumerate(texts): + input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt") + inputs['input_ids'].append(input['input_ids']) + inputs['attention_mask'].append(input['attention_mask']) + inputs['pixel_values'].append(input['pixel_values']) + inputs['image_sizes'].append(input['image_sizes']) + max_len = max([x.shape[1] for x in inputs['input_ids']]) + for i, v in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1) + inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1) + inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()} + + generation_args = { + "max_new_tokens": 25, + "temperature": 0.01, + "do_sample": False, + } + generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) + # # remove input tokens + generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:] + response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + response = [res.strip('\n').strip() for res in response] + generated_texts.extend(response) + + return generated_texts + +def remove_overlap(boxes, iou_threshold, ocr_bbox=None): + assert ocr_bbox is None or isinstance(ocr_bbox, List) + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + + def IoU(box1, box2): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + 1e-6 + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + return max(intersection / union, ratio1, ratio2) + + def is_inside(box1, box2): + # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] + intersection = intersection_area(box1, box2) + ratio1 = intersection / box_area(box1) + return ratio1 > 0.95 + + boxes = boxes.tolist() + filtered_boxes = [] + if ocr_bbox: + filtered_boxes.extend(ocr_bbox) + # print('ocr_bbox!!!', ocr_bbox) + for i, box1 in enumerate(boxes): + # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j): + is_valid_box = True + for j, box2 in enumerate(boxes): + # keep the smaller box + if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): + is_valid_box = False + break + if is_valid_box: + # add the following 2 lines to include ocr bbox + if ocr_bbox: + # only add the box if it does not overlap with any ocr bbox + if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)): + filtered_boxes.append(box1) + else: + filtered_boxes.append(box1) + return torch.tensor(filtered_boxes) + + +def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): + ''' + ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...] + boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...] + + ''' + assert ocr_bbox is None or isinstance(ocr_bbox, List) + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + + def IoU(box1, box2): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + 1e-6 + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + return max(intersection / union, ratio1, ratio2) + + def is_inside(box1, box2): + # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] + intersection = intersection_area(box1, box2) + ratio1 = intersection / box_area(box1) + return ratio1 > 0.80 + + # boxes = boxes.tolist() + filtered_boxes = [] + if ocr_bbox: + filtered_boxes.extend(ocr_bbox) + # print('ocr_bbox!!!', ocr_bbox) + for i, box1_elem in enumerate(boxes): + box1 = box1_elem['bbox'] + is_valid_box = True + for j, box2_elem in enumerate(boxes): + # keep the smaller box + box2 = box2_elem['bbox'] + if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): + is_valid_box = False + break + if is_valid_box: + if ocr_bbox: + # keep yolo boxes + prioritize ocr label + box_added = False + ocr_labels = '' + for box3_elem in ocr_bbox: + if not box_added: + box3 = box3_elem['bbox'] + if is_inside(box3, box1): # ocr inside icon + # box_added = True + # delete the box3_elem from ocr_bbox + try: + # gather all ocr labels + ocr_labels += box3_elem['content'] + ' ' + filtered_boxes.remove(box3_elem) + except: + continue + # break + elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box + box_added = True + break + else: + continue + if not box_added: + if ocr_labels: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'}) + else: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'}) + else: + filtered_boxes.append(box1) + return filtered_boxes # torch.tensor(filtered_boxes) + + +def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image_source = Image.open(image_path).convert("RGB") + image = np.asarray(image_source) + image_transformed, _ = transform(image_source, None) + return image, image_transformed + + +def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float, + text_padding=5, text_thickness=2, thickness=3) -> np.ndarray: + """ + This function annotates an image with bounding boxes and labels. + + Parameters: + image_source (np.ndarray): The source image to be annotated. + boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale + logits (torch.Tensor): A tensor containing confidence scores for each bounding box. + phrases (List[str]): A list of labels for each bounding box. + text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + + Returns: + np.ndarray: The annotated image. + """ + h, w, _ = image_source.shape + boxes = boxes * torch.Tensor([w, h, w, h]) + xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy() + detections = sv.Detections(xyxy=xyxy) + + labels = [f"{phrase}" for phrase in range(boxes.shape[0])] + + box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + annotated_frame = image_source.copy() + annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h)) + + label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)} + return annotated_frame, label_coordinates + + +def predict(model, image, caption, box_threshold, text_threshold): + """ Use huggingface model to replace the original model + """ + model, processor = model['model'], model['processor'] + device = model.device + + inputs = processor(images=image, text=caption, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + + results = processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=box_threshold, # 0.4, + text_threshold=text_threshold, # 0.3, + target_sizes=[image.size[::-1]] + )[0] + boxes, logits, phrases = results["boxes"], results["scores"], results["labels"] + return boxes, logits, phrases + + +def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7): + """ Use huggingface model to replace the original model + """ + # model = model['model'] + if scale_img: + result = model.predict( + source=image, + conf=box_threshold, + imgsz=imgsz, + iou=iou_threshold, # default 0.7 + ) + else: + result = model.predict( + source=image, + conf=box_threshold, + iou=iou_threshold, # default 0.7 + ) + boxes = result[0].boxes.xyxy#.tolist() # in pixel space + conf = result[0].boxes.conf + phrases = [str(i) for i in range(len(boxes))] + + return boxes, conf, phrases + +def int_box_area(box, w, h): + x1, y1, x2, y2 = box + int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)] + area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1]) + return area + +def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128): + """Process either an image path or Image object + + Args: + image_source: Either a file path (str) or PIL Image object + ... + """ + if isinstance(image_source, str): + image_source = Image.open(image_source) + image_source = image_source.convert("RGB") # for CLIP + w, h = image_source.size + if not imgsz: + imgsz = (h, w) + # print('image size:', w, h) + xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) + xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device) + image_source = np.asarray(image_source) + phrases = [str(i) for i in range(len(phrases))] + + # annotate the image with labels + if ocr_bbox: + ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h]) + ocr_bbox=ocr_bbox.tolist() + else: + print('no ocr bbox!!!') + ocr_bbox = None + + ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0] + xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0] + filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem) + + # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None + filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None) + # get the index of the first 'content': None + starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1) + filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem]) + print('len(filtered_boxes):', len(filtered_boxes), starting_idx) + + # get parsed icon local semantics + time1 = time.time() + if use_local_semantics: + caption_model = caption_model_processor['model'] + if 'phi3_v' in caption_model.config.model_type: + parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor) + else: + parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size) + ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] + icon_start = len(ocr_text) + parsed_content_icon_ls = [] + # fill the filtered_boxes_elem None content with parsed_content_icon in order + for i, box in enumerate(filtered_boxes_elem): + if box['content'] is None: + box['content'] = parsed_content_icon.pop(0) + for i, txt in enumerate(parsed_content_icon): + parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}") + parsed_content_merged = ocr_text + parsed_content_icon_ls + else: + ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] + parsed_content_merged = ocr_text + print('time to get parsed content:', time.time()-time1) + + filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh") + + phrases = [i for i in range(len(filtered_boxes))] + + # draw boxes + if draw_bbox_config: + annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config) + else: + annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding) + + pil_img = Image.fromarray(annotated_frame) + buffered = io.BytesIO() + pil_img.save(buffered, format="PNG") + encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii') + if output_coord_in_ratio: + label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()} + assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0] + + return encoded_image, label_coordinates, filtered_boxes_elem + + +def get_xywh(input): + x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1] + x, y, w, h = int(x), int(y), int(w), int(h) + return x, y, w, h + +def get_xyxy(input): + x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1] + x, y, xp, yp = int(x), int(y), int(xp), int(yp) + return x, y, xp, yp + +def get_xywh_yolo(input): + x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1] + x, y, w, h = int(x), int(y), int(w), int(h) + return x, y, w, h + +def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=True): + if isinstance(image_source, str): + image_source = Image.open(image_source) + if image_source.mode == 'RGBA': + # Convert RGBA to RGB to avoid alpha channel issues + image_source = image_source.convert('RGB') + image_np = np.array(image_source) + w, h = image_source.size + if use_paddleocr: + if easyocr_args is None: + text_threshold = 0.5 + else: + text_threshold = easyocr_args['text_threshold'] + result = paddle_ocr.ocr(image_np, cls=False)[0] + coord = [item[0] for item in result if item[1][1] > text_threshold] + text = [item[1][0] for item in result if item[1][1] > text_threshold] + # else: # EasyOCR + # if easyocr_args is None: + # easyocr_args = {} + # result = reader.readtext(image_np, **easyocr_args) + # coord = [item[0] for item in result] + # text = [item[1] for item in result] + if display_img: + opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) + bb = [] + for item in coord: + x, y, a, b = get_xywh(item) + bb.append((x, y, a, b)) + cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2) + # matplotlib expects RGB + plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)) + else: + if output_bb_format == 'xywh': + bb = [get_xywh(item) for item in coord] + elif output_bb_format == 'xyxy': + bb = [get_xyxy(item) for item in coord] + return (text, bb), goal_filtering \ No newline at end of file diff --git a/utils/parse_omni.py b/utils/parse_omni.py new file mode 100644 index 0000000..d99e2ab --- /dev/null +++ b/utils/parse_omni.py @@ -0,0 +1,69 @@ +from utils.omni_utils import get_som_labeled_img, check_ocr_box, get_yolo_model +from PIL import Image +import torch + +device = "cuda" if torch.cuda.is_available() else "cpu" +detect_model_path='./weights/icon_detect/model.pt' +caption_model_path='./weights/icon_caption_florence' + +som_model = get_yolo_model(detect_model_path) +som_model.to(device) + +def extract_all_bounds(screenshot_path): + """提取截图中的所有边界框信息""" + image = Image.open(screenshot_path).convert('RGB') + + # OCR检测文本框 + (text, ocr_bbox), _ = check_ocr_box( + image, + display_img=False, + output_bb_format='xyxy', + easyocr_args={'text_threshold': 0.9}, + use_paddleocr=True, + ) + + # YOLO检测UI元素 + _, _, parsed_content_list = get_som_labeled_img( + image, + som_model, + BOX_TRESHOLD=0.1, + output_coord_in_ratio=True, + ocr_bbox=ocr_bbox, + ocr_text=text, + use_local_semantics=False, + iou_threshold=0.7, + scale_img=False + ) + + # 提取边界框并转换为绝对坐标 + image_width, image_height = image.size + bounds_list = [] + + for item in parsed_content_list: + bbox = item.get('bbox') + if bbox and len(bbox) >= 4: + x1, y1, x2, y2 = bbox[:4] + # 转换为绝对坐标 + left = int(x1 * image_width) + top = int(y1 * image_height) + right = int(x2 * image_width) + bottom = int(y2 * image_height) + bounds_list.append([left, top, right, bottom]) + + return bounds_list + +def find_clicked_element(bounds_list, click_x, click_y): + """找到包含点击位置的最小边界框""" + smallest_bounds = None + smallest_area = float('inf') + + for bounds in bounds_list: + left, top, right, bottom = bounds + # 检查点击位置是否在边界框内 + if left <= click_x <= right and top <= click_y <= bottom: + area = (right - left) * (bottom - top) + if area < smallest_area: + smallest_area = area + smallest_bounds = bounds + + return smallest_bounds \ No newline at end of file diff --git a/utils/parse_xml.py b/utils/parse_xml.py new file mode 100644 index 0000000..c801fb0 --- /dev/null +++ b/utils/parse_xml.py @@ -0,0 +1,67 @@ +import xml.etree.ElementTree as ET +import re + +def parse_bounds(bounds_str): + """解析bounds字符串,返回(left, top, right, bottom)""" + if not bounds_str: + return None + + # 使用正则表达式提取坐标 + match = re.match(r'\[(\d+),(\d+)\]\[(\d+),(\d+)\]', bounds_str) + if match: + left, top, right, bottom = map(int, match.groups()) + return [left, top, right, bottom] + return None + +def is_point_in_bounds(x, y, bounds): + """检查点(x,y)是否在bounds范围内""" + if not bounds: + return False + + [left, top, right, bottom] = bounds + return left <= x <= right and top <= y <= bottom + +def extract_all_bounds(hierarchy_xml, need_clickable=False): + """从hierarchy.xml中提取所有bounds""" + try: + root = ET.fromstring(hierarchy_xml) + bounds_set = set() + + # 递归遍历所有节点 + def traverse_node(node): + clickable = node.get('clickable', 'false') + bounds_str = node.get('bounds', '') + if bounds_str and (need_clickable is False or clickable == 'true'): + bounds = parse_bounds(bounds_str) + if bounds: + # 将列表转换为元组添加到集合中,避免重复 + bounds_set.add(tuple(bounds)) + + # 递归处理子节点 + for child in node: + traverse_node(child) + + traverse_node(root) + # 将集合转换回列表形式返回 + bounds_list = [list(bounds) for bounds in bounds_set] + return bounds_list + + except Exception as e: + print(f"解析层次结构时出错: {str(e)}") + return [] + +def find_clicked_element(hierarchy_xml, click_x, click_y): + bounds_list = extract_all_bounds(hierarchy_xml, need_clickable=True) + + smallest_bounds = None + smallest_area = float('inf') + + for bounds in bounds_list: + if is_point_in_bounds(click_x, click_y, bounds): + left, top, right, bottom = bounds + area = (right - left) * (bottom - top) + if area < smallest_area: + smallest_area = area + smallest_bounds = bounds + + return smallest_bounds \ No newline at end of file