diff --git a/.gitignore b/.gitignore index c89af90d7..693fcadb1 100644 --- a/.gitignore +++ b/.gitignore @@ -218,3 +218,7 @@ agentlightning/dashboard/**/*.svg # Docker data docker/data/ + +# AGL simulation +agl_envs/ +wandb/ diff --git a/contrib/agentlightning/contrib/adapter/triplet_group.py b/contrib/agentlightning/contrib/adapter/triplet_group.py index 864c0bb39..969edce92 100644 --- a/contrib/agentlightning/contrib/adapter/triplet_group.py +++ b/contrib/agentlightning/contrib/adapter/triplet_group.py @@ -111,7 +111,10 @@ def message(span: Optional[Span]) -> Optional[str]: for group in span_groups.values(): call_span = group.get("call_span") - if not token_ids(call_span, "prompt_token_ids") and not token_ids(call_span, "response_token_ids"): + if ( + not token_ids(call_span, "prompt_token_ids") + and not token_ids(call_span, "response_token_ids") + ): continue object_span = group.get("object_span") diff --git a/contrib/agentlightning/contrib/agent/empo2_agent.py b/contrib/agentlightning/contrib/agent/empo2_agent.py new file mode 100644 index 000000000..f55b92dd1 --- /dev/null +++ b/contrib/agentlightning/contrib/agent/empo2_agent.py @@ -0,0 +1,259 @@ +import copy +import numpy as np +import requests +import logging +from typing import Any, Dict + +from add_instruction import add_chat_instruction, add_chat_tips, add_chat_all_tips +from agentlightning import ( + LLM, + NamedResources, + Rollout, + configure_logger, + emit_reward, + operation +) +from agentlightning.utils.otel import make_link_attributes + +from agl_envs import make_env_manager +from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder + +from contrib.agentlightning.contrib.agent.env_agent import EnvAgent + +configure_logger() +logger = configure_logger(name=__name__, level=logging.ERROR) + +def do_compress(text): + url = "http://127.0.0.1:8000/key_cal/" + headers = {"Content-Type": "application/json"} # 明确指定 JSON 格式 + data = {"text": text} + response = requests.post(url, json=data, headers=headers) # 使用 json 参数 + return response.json() + +url_mem = "http://127.0.0.1:8001/mem/" + +def retrieve_memory(idx, key): + response = requests.post(url_mem, json={ + "key": key, + "idx": idx + }) + count, data = response.json() + return count, data + +def reset_memory(mem_list_num): + requests.post(url_mem, json={ + "key": [], + "idx": mem_list_num, # 用于初始化多个 memory slot + "content": "Reset" + }) + +def add_memory(idx, key, content, score): + requests.post(url_mem, json={ + "key": key, + "idx": idx, + "content": content, + "score": score + }) + +def gather_chats(prompt): + chat_list = [] + for item in prompt: + role = item.type + content = item.content + if "System" in role: + continue + elif "User" in role: + role = "user" + else: + role = "assistant" + chat_list.append(f"{role}: {content}") + text = " ".join(chat_list) + return text + +class EMPO2Agent(EnvAgent): + def __init__(self, config, trained_agents: str | None = None) -> None: + super().__init__(config=config, trained_agents=trained_agents) + + def _get_tip_prompt(self, prompt, tips): + prompt_type = self.config.captioner.prompt_type + + if prompt_type == "chat": + return add_chat_tips(prompt, tips) + else: + raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')") + + def _get_all_tip_prompt(self, prompt, tip_list): + prompt_type = self.config.captioner.prompt_type + if prompt_type == "chat": + return add_chat_all_tips(prompt, tip_list) + else: + raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')") + + def _get_tip_generation_prompt(self, prompt): + return add_chat_instruction(prompt, "tip") + + async def rollout_async( + self, + task: Dict[str, Any], + resources: NamedResources, + rollout: Rollout, + ) -> float | None: + rollout_id = rollout.rollout_id + logger.info(f"[Rollout {rollout_id}] Task: {task}") + + reward_scale = float(self.config["reawrd_scale"]) + + # Setup LLM + agent + llm: LLM = resources.get("main_llm") + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) + self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4) + + if rollout.mode == "train": + train_mode = task["train_mode"] + global_steps = task["global_steps"] + else: + train_mode = "on-policy" + + if rollout.mode == "train" and (train_mode == "off-policy" or train_mode == "on-policy-with-tips"): + use_tips = True + else: + use_tips = False + + variation_idx = task["variation_idx"] + + try: + # Setup environment + prompt_builder = HistoryPromptBuilder(max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type) + + self.env = make_env_manager(self.config.env_name, task, self.config) + env_obs, infos, available_actions_hint = self.env.reset() + + prompt_builder.init(self.env) + prompt_builder.update_observation(env_obs) + # prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + episode_reward, done = 0.0, False + + pure_prompt_for_mem = [] + history_actions_for_mem = [] + tip_list = [] + + step_count = 0 + while not done: + if use_tips: + text = gather_chats(prompt) + key = np.array(do_compress(text)['key']).reshape(-1, ).tolist() + count, mem_list = retrieve_memory(variation_idx, key) + else: + count, mem_list = 0, [] + + ret_tips, intrinsic_reward = "", 0.0 + + if use_tips: + if count > 0: + ret_tips = "Here are some memories you collected in your previous exploration:\n" + for mem in mem_list: + ret_tips += mem+"\n" + + tip_list.append(ret_tips) + intrinsic_reward = 1 / (count+1) + else: + tip_list.append("") + intrinsic_reward = 1 + + try: + if count > 0: + tip_prompt = self._get_all_tip_prompt(prompt, tip_list) + instructed_prompt = self._get_instructed_prompt(tip_prompt, sep="") + else: + instructed_prompt = self._get_instructed_prompt(prompt) + + # Main agent step + with operation(step_count=step_count): + result = await self.agent._model_client.create(instructed_prompt) + output = result.content + logger.info(f"[LLM output]: {output}") + + except Exception as e: + logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True) + break + + # Environment step + pure_prompt_for_mem.append([copy.deepcopy(prompt), None]) + env_obs, executed_action,is_valid, step_reward, terminated, truncated, info, available_actions_hint = self.env.step( + output, use_reasoning=self.config.captioner.type == "cot" + ) + history_actions_for_mem.append(executed_action) + + prompt_builder.update_step_count() + prompt_builder.update_action(executed_action) + prompt_builder.update_observation(env_obs) + # prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + if rollout.mode == "train": + step_reward = reward_scale * step_reward + + emit_reward( + { + "extrinsic_reward": step_reward, + "intrinsic_reward": intrinsic_reward, + }, + primary_key="extrinsic_reward", + attributes=make_link_attributes({"step_count": str(step_count)}), + ) + + episode_reward += float(step_reward) + done = np.logical_or(terminated, truncated) + + step_count += 1 + + if ( + rollout.mode == "train" + and self.config.captioner.prompt_type == "chat" + and self.config.save_rollout + ): + filename = f"empo2_rollouts/variant_{variation_idx}/step_{global_steps}/{rollout_id}_{round(episode_reward, 1)}_use_tip_{use_tips}.json" + if use_tips: + _rollout = self._get_all_tip_obs(obs, tip_list) + else: + _rollout = obs + self._save_chat_rollout(_rollout, filename) + + if rollout.mode == "train": + prompt_builder.prompt_type = "chat" + prompt_builder.max_history = -1 + prompt = prompt_builder.get_prompt() + prompt.pop() + + tip_generation_prompt = self._get_tip_generation_prompt(prompt) + + self.agent._model_client.max_tokens = 128 + result = await self.agent._model_client.create(tip_generation_prompt) + tips = result.content + logger.info(f"Tips: {tips}") + + #! Fill the ret and tip + for i in range(len(pure_prompt_for_mem)): + max_score = 100 * reward_scale + pure_prompt_for_mem[i][1] = tips + f'; At that timestep, the specific action your took was {history_actions_for_mem[i]}; Eventually you got the score {round(episode_reward, 1)}/{int(max_score)}.' + + #! Generate the tips and save the mem + for i in range(len(pure_prompt_for_mem)): + text = gather_chats(pure_prompt_for_mem[i][0]) + key = np.array(do_compress(text)['key']).reshape(-1, ).tolist() + content = pure_prompt_for_mem[i][1] + score = episode_reward + add_memory(variation_idx, key, content, round(score, 1)) + + if self.config.use_success_rate: + return self.env.get_success_score() * reward_scale + else: + return episode_reward + + finally: + if self.env is not None: + self.env.close() \ No newline at end of file diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py b/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py new file mode 100644 index 000000000..94c9f5b3b --- /dev/null +++ b/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py @@ -0,0 +1,65 @@ +import torch +from typing import List, Any + +def is_sublist(sub, full): + n, m = len(sub), len(full) + return any(full[i:i+n] == sub for i in range(m - n + 1)) + +# Function to remove segments of a list between a start pattern and an end pattern +def remove_pattern_ranges(seq: List[Any], + start_pat: List[Any], + end_pat: List[Any]) -> List[Any]: + """Remove every [start_pat ... end_pat] slice (inclusive) from seq.""" + + out: List[Any] = [] + i = 0 + n = len(seq) + ls, le = len(start_pat), len(end_pat) + + while i < n: + # Check if the start pattern matches at the current position + if i + ls <= n and seq[i:i+ls] == start_pat: + # Look for the first occurrence of the end pattern after the start pattern + j = i + ls + found_end = -1 + while j + le <= n: + if seq[j:j+le] == end_pat: + found_end = j + break # Stop when the end pattern is found + j += 1 + + # If the end pattern is found, skip the whole segment from start to end + if found_end != -1: + i = found_end + le # Move the index past the end pattern + continue # Skip the current iteration and go to the next + else: + # If the end pattern is not found, keep the current element and move one step forward + out.append(seq[i]) + i += 1 + else: + # If the start pattern is not found, just append the current element + out.append(seq[i]) + i += 1 + + # Return the filtered list with the start-end pattern segments removed + return out + +def low_prob_token_masking(batch): + response_mask = batch.batch["response_mask"] # [N, T] + old_log_prob = batch.batch["old_log_probs"] # [N, T] + # advantages = batch.batch["advantages"] # [N, T] + + masked_old_log_prob = old_log_prob.masked_fill(response_mask == 0, 1e9) + min_values, _ = torch.min(masked_old_log_prob, dim=1) # [N] + + mask = min_values < -5 # [N] + + combined_mask = mask.unsqueeze(1) & (response_mask == 1) + + # advantages masking + response_mask = response_mask.masked_fill(combined_mask, 0) + batch.batch["response_mask"] = response_mask + + print(f"Number of tokens masked: {combined_mask.sum().item()}") + + return batch diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py b/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py index 10c947c16..92b707158 100644 --- a/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py +++ b/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py @@ -7,6 +7,7 @@ import threading import time import uuid +import copy from collections import defaultdict from collections.abc import Mapping from typing import Any, Dict, List, Literal, Optional, Tuple, cast @@ -19,12 +20,13 @@ from verl import DataProto from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy -from agentlightning.adapter.triplet import TraceToTripletBase from agentlightning.llm_proxy import LLMProxy, ModelConfig -from agentlightning.reward import find_final_reward from agentlightning.store.base import LightningStore from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task +from agentlightning.reward import find_final_reward + from contrib.agentlightning.contrib.adapter.triplet_group import TracerTraceToTripletGroup +import contrib.agentlightning.contrib.algorithm.env_verl.core_empo2 as core_empo2 __all__ = [ "AgentModeDaemon", @@ -145,7 +147,7 @@ def __init__( mode: Literal["v0", "v1"] = "v1", llm_proxy: LLMProxy | None = None, store: LightningStore | None = None, - adapter: TraceToTripletBase | None = None, + adapter: TracerTraceToTripletGroup | None = None, ): self.mode = mode self.llm_timeout_seconds = llm_timeout_seconds @@ -169,7 +171,7 @@ def __init__( ) else: # Reuse the existing LLM proxy (probably configured by user) - self.llm_proxy = llm_proxy + self.llm_proxy = llm_proxy # if adapter is None: # self.adapter = TracerTraceToTripletGroup() @@ -658,9 +660,11 @@ def get_train_data_batch( max_prompt_length: int, max_response_length: int, device: torch.device, + max_train_length: int = -1, use_final_reward_as_step_reward: bool = True, use_intrinsic_reward: bool = False, is_gigpo: bool = False, + empo2_train_mode: bool = False ): """ Processes completed rollouts to generate a training data batch. @@ -700,7 +704,7 @@ def get_train_data_batch( "response_ids": t.response.get("token_ids", []), "step_reward": t.reward, "step_intrinsic_reward": t.metadata.get("intrinsic_reward", 0.0), - "message": t.metadata.get("message", ""), + "message": t.metadata.get("message", "") } trace_list.append(trace_dict) @@ -741,13 +745,21 @@ def get_train_data_batch( for rollout_id, sample_info in finished_id_to_sample_info.items(): for turn_index, trace in enumerate(sample_info["trace_list"]): + prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + + if max_train_length > -1 and len(prompt_ids) + len(response_ids) > max_train_length: + continue final_reward_list.append(sample_info["final_reward"]) step_reward_list.append(trace["step_reward"]) step_intrinsic_reward_list.append(trace["step_intrinsic_reward"]) message_list.append(trace["message"]) - prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + if empo2_train_mode == "off-policy": + START_PATTERN = self.tokenizer.encode("") + END_PATTERN = self.tokenizer.encode("\n\n") + if core_empo2.is_sublist(START_PATTERN, prompt_ids): + prompt_ids = core_empo2.remove_pattern_ranges(prompt_ids, START_PATTERN, END_PATTERN) # Mark samples with prompts exceeding max_prompt_length to be dropped later if len(prompt_ids) > max_prompt_length: @@ -787,6 +799,7 @@ def get_train_data_batch( batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1) position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0) + is_drop_mask = torch.BoolTensor(is_drop_list).to(device) if use_final_reward_as_step_reward: scores = torch.tensor(final_reward_list, dtype=torch.float32).to(device) diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py b/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py index 3dd2458f4..23c78db73 100644 --- a/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py +++ b/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py @@ -39,6 +39,8 @@ from agentlightning.store.base import LightningStore from .daemon import EnvAgentModeDaemon + +import contrib.agentlightning.contrib.algorithm.env_verl.core_empo2 as core_empo2 __all__ = [ "EnvAgentLightningTrainer", @@ -250,6 +252,21 @@ def _train_step(self, batch_dict: dict) -> dict: # generate a batch with _timer("gen", timing_raw): self.async_rollout_manager.wake_up() + + num_problems = self.config.data.train_batch_size + gen_batch.non_tensor_batch["global_steps"] = [self.global_steps for _ in range(num_problems)] + + if hasattr(self.config, 'tips') and self.config.tips.use_tips: + touzi = random.random() + if touzi < 0.17: + self.empo2_train_mode = "off-policy" # Update with Tips and give them to the pure_chats + elif touzi < 0.25: + self.empo2_train_mode = "on-policy-with-tips" + else: + self.empo2_train_mode = "on-policy" # Normal Update, No Tips + + gen_batch.non_tensor_batch["train_mode"] = [self.empo2_train_mode for _ in range(num_problems)] + self.agent_mode_daemon.set_up_data_and_server( gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses ) @@ -257,9 +274,11 @@ def _train_step(self, batch_dict: dict) -> dict: batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.data.max_response_length, + max_train_length=getattr(self.config.data, 'max_train_length', -1), device=gen_batch.batch["fake_ids"].device, use_final_reward_as_step_reward=self.config.algorithm.use_final_reward_as_step_reward, use_intrinsic_reward=self.config.algorithm.use_intrinsic_reward, + empo2_train_mode=getattr(self, "empo2_train_mode", None), ) metrics.update(agent_metrics) self.agent_mode_daemon.clear_data_and_server() @@ -340,9 +359,7 @@ def _train_step(self, batch_dict: dict) -> dict: metrics.update(kl_metrics) else: if self.config.algorithm.use_intrinsic_reward: - batch.batch["token_level_rewards"] = ( - batch.batch["token_level_scores"] + batch.batch["token_level_intrinsic_rewards"] - ) # (bs, seq_len) + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["token_level_intrinsic_rewards"] # (bs, seq_len) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] @@ -362,6 +379,9 @@ def _train_step(self, batch_dict: dict) -> dict: config=self.config.algorithm, ) + if hasattr(self.config, 'tips') and self.config.tips.use_tips: + batch = core_empo2.low_prob_token_masking(batch) + # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing")) @@ -500,6 +520,14 @@ def fit(self): # train step metrics = self._train_step(batch_dict) + if hasattr(self.config, 'tips') and self.config.tips.use_tips: + mode_map = { + "off-policy": 0, + "on-policy-with-tips": 1, + "on-policy": 2, + } + metrics["empo2/train_mode"] = mode_map.get(self.empo2_train_mode) + # validate if ( self.val_reward_fn is not None diff --git a/contrib/recipes/envs/README.md b/contrib/recipes/envs/README.md index 105190d5e..44e9b6404 100644 --- a/contrib/recipes/envs/README.md +++ b/contrib/recipes/envs/README.md @@ -80,15 +80,28 @@ We follow the single-mode prompt for ALFWorld from [verl-agent](https://github.c   -## Run RL Training (GRPO) +## Run RL Training + +### GRPO ```bash # Run alfworld -python3 train_env_agent.py --algorithm grpo --env alfworld +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env alfworld # Run scienceworld single task task_num 0 -python3 train_env_agent.py --algorithm grpo --env scienceworld --task_num 0 +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env scienceworld --task_num 0 # Run scienceworld multi-task -python3 train_env_agent.py --algorithm grpo --env scienceworld --task_num -1 +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env scienceworld --task_num -1 ``` + +### EMPO² Integration + +We integrate **EMPO²** (*Memory-Augmented LLM Agent via Online Self-Distillation*, ICLR 2026) [[paper]](https://openreview.net/forum?id=UOzxviKVFO) into our framework. EMPO² leverages a memory-augmented mechanism combined with online self-distillation to enhance LLM agent performance. In our experiments, EMPO² consistently outperforms GRPO, demonstrating stronger learning efficiency. + +```bash +# Run scienceworld single task task_num 25 +python3 train_env_agent.py --algorithm empo2_qwen_7b_instruct --env scienceworld2 --task_num 25 +``` + +![agl_empo2_25](./assets/agl_empo2_25.png) diff --git a/contrib/recipes/envs/add_instruction.py b/contrib/recipes/envs/add_instruction.py index 627500153..c5b1c3bdb 100644 --- a/contrib/recipes/envs/add_instruction.py +++ b/contrib/recipes/envs/add_instruction.py @@ -11,22 +11,31 @@ """.strip() NAIVE_INSTRUCTION = """ +You could try to explore different actions, especially when you are not sure what the best action for your current observation. Please response with only one line with one sentence, following the possible action format shown above. No extra words are allowed. """.strip() +TIP_INSTRUCTION = """ +Thanks for your playing. +Now you have ended a trajectory and collect some meaningless or valuable information from the interactions with the environment. +Please summary the trajectory, and also summary what information you get from this trajectory, and how far this trajectory is from fully completing the task. +Please response with only one sentence with only one line, do not include any extra words. +You sentence should be less than 100 words. +""".strip() + # Mapping for instruction text types INSTRUCTION_MAP = { "cot": COT_INSTRUCTION, "naive": NAIVE_INSTRUCTION, + "tip": TIP_INSTRUCTION, } - def _get_instruction(type: str, env_name: str = None): """ Retrieve an instruction string from INSTRUCTION_MAP based on the given type. Args: - type (str): Instruction type key (e.g., "cot", "naive", "critic", "tip"). + type (str): Instruction type key (e.g., "cot", "naive", "tip"). env_name (str, optional): Currently unused. Reserved for future environment-specific instruction handling. @@ -60,11 +69,18 @@ def add_chat_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = N Returns: list: A new prompt list with the instruction appended to the last message. """ - new_prompt = copy.deepcopy(prompt) - instruction = _get_instruction(type, env_name) - new_prompt[-1].content += sep + instruction + if type == "tip": + new_prompt = copy.deepcopy(prompt) + tip_instruction = _get_instruction(type, env_name) + new_prompt.append(UserMessage(source="user", content=tip_instruction)) - return new_prompt + return new_prompt + else: + new_prompt = copy.deepcopy(prompt) + instruction = _get_instruction(type, env_name) + new_prompt[-1].content += sep + instruction + + return new_prompt def add_single_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = None): @@ -99,3 +115,22 @@ def add_single_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = return new_prompt else: raise TypeError("Prompt must be a string or a list of strings") + +def add_chat_tips(prompt, tips): + new_prompt = copy.deepcopy(prompt) + new_prompt[-1].content += f"\n\n {tips}\n\n\n" + return new_prompt + +def add_chat_all_tips(prompt, tip_list): + new_prompt = copy.deepcopy(prompt) + tips_iter = iter(tip_list) + + for item in new_prompt: + if "User" in item.type: + tip = next(tips_iter, None) + if tip is None: + break + if not tip == "": + item.content += f"\n\n {tip}\n\n\n" + + return new_prompt diff --git a/contrib/recipes/envs/assets/agl_empo2_25.png b/contrib/recipes/envs/assets/agl_empo2_25.png new file mode 100644 index 000000000..6233ac16e Binary files /dev/null and b/contrib/recipes/envs/assets/agl_empo2_25.png differ diff --git a/contrib/recipes/envs/assets/prompt_type.png b/contrib/recipes/envs/assets/prompt_type.png index da6349ddc..98c0ed0f6 100644 Binary files a/contrib/recipes/envs/assets/prompt_type.png and b/contrib/recipes/envs/assets/prompt_type.png differ diff --git a/contrib/recipes/envs/clean.sh b/contrib/recipes/envs/clean.sh new file mode 100644 index 000000000..c4721425f --- /dev/null +++ b/contrib/recipes/envs/clean.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +echo "Stopping AgentLightning and simulation_agent..." +pkill -f AgentLightning || true +pkill -f simulation_agent || true + +echo "Stopping Ray cluster..." +ray stop + +echo "Killing VLLM::EngineCore processes..." +ps aux | grep VLLM::EngineCore | grep -v grep | awk '{print $2}' | xargs --no-run-if-empty kill -9 + +echo "✅ Cleanup complete." \ No newline at end of file diff --git a/contrib/recipes/envs/config_env/scienceworld2.yaml b/contrib/recipes/envs/config_env/scienceworld2.yaml new file mode 100644 index 000000000..cded2fe6b --- /dev/null +++ b/contrib/recipes/envs/config_env/scienceworld2.yaml @@ -0,0 +1,16 @@ +env_name: scienceworld # scienceworld, babyai, alfworld +seed: 0 +format_penalty: 0.0 +binary_reward: False +save_rollout: False +log_env_obs: False # True for GiGPO +reawrd_scale: 1.0 +use_success_rate: False + +# only for scienceworld +use_action_correction: True + +captioner: + type: naive # naive or cot + prompt_type: chat # chat or single + max_history: -1 \ No newline at end of file diff --git a/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml b/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml new file mode 100644 index 000000000..3fed97e60 --- /dev/null +++ b/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml @@ -0,0 +1,86 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ALFWorld + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-alfworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/alfworld + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 140 + max_prompt_length: 2048 + max_response_length: 512 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 200 diff --git a/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml new file mode 100644 index 000000000..64d79f48e --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml @@ -0,0 +1,101 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 8 + MINI_BATCH_SIZE: 16 + PER_GPU_BATCH_SIZE: 1 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-7B-Instruct + PROJECT_NAME: EMPO2-ScienceWorld2 + TASK_NUM: ${oc.env:TASK_NUM,25} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: (all-off-policy-final-reward)empo2-${variables.TASK_NUM}-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/single_data/${variables.TASK_NUM} + OUTPUT_DIR: /mnt/jeonghyekim/empo2_checkpoint/0211/${variables.EXPERIMENT_NAME} + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 16 + val_batch_size: 80 + max_prompt_length: 16384 + max_response_length: 32 + max_train_length: 8192 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.5 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: false + kl_loss_coef: 0.00 + entropy_coeff: 0.0 + clip_ratio_high: 0.30 + clip_ratio_low: 0.20 + clip_ratio_c: 10.0 + entropy_checkpointing: true + entropy_from_logits_with_chunking: true + fsdp_config: + param_offload: true + optimizer_offload: true + forward_prefetch: true + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + entropy_checkpointing: true + entropy_from_logits_with_chunking: true + fsdp_config: + param_offload: true + forward_prefetch: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + default_local_dir: ${variables.OUTPUT_DIR}/checkpoints + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 50 + test_freq: 20 + total_epochs: 500 + +tips: + use_tips: true diff --git a/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml new file mode 100644 index 000000000..b90418b8b --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml @@ -0,0 +1,87 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ScienceWorld + TASK_NUM: ${oc.env:TASK_NUM,-1} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/multi_data + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 144 + max_prompt_length: 6000 + max_response_length: 1024 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 500 diff --git a/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml new file mode 100644 index 000000000..3dbba7dcd --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml @@ -0,0 +1,92 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 8 + MINI_BATCH_SIZE: 16 + PER_GPU_BATCH_SIZE: 1 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-7B-Instruct + PROJECT_NAME: EMPO2-ScienceWorld2 + TASK_NUM: ${oc.env:TASK_NUM,25} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: (final-reward)grpo-${variables.TASK_NUM}-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/single_data/${variables.TASK_NUM} + OUTPUT_DIR: /mnt/jeonghyekim/empo2_grpo_checkpoint/0211/${variables.EXPERIMENT_NAME} + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 16 + val_batch_size: 80 + max_prompt_length: 16384 + max_response_length: 32 + max_train_length: 8192 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.5 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: false + kl_loss_coef: 0.00 + entropy_coeff: 0.0 + clip_ratio_high: 0.30 + clip_ratio_low: 0.20 + clip_ratio_c: 10.0 + fsdp_config: + param_offload: true + optimizer_offload: true + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + default_local_dir: ${variables.OUTPUT_DIR}/checkpoints + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 50 + test_freq: 20 + total_epochs: 500 \ No newline at end of file diff --git a/contrib/recipes/envs/empo2_server/server_bert.py b/contrib/recipes/envs/empo2_server/server_bert.py new file mode 100644 index 000000000..943a9a9df --- /dev/null +++ b/contrib/recipes/envs/empo2_server/server_bert.py @@ -0,0 +1,32 @@ +from fastapi import FastAPI, Request +import time +from pydantic import BaseModel +import uvicorn + +import torch +import time +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +torch.cuda.set_per_process_memory_fraction(0.1, 0) + +num_works = 1 + +app = FastAPI() + +from sentence_transformers import SentenceTransformer +model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + +@app.post("/key_cal/") +async def compress(request: Request): + try: + data = await request.json() + text = data.get("text", "") + except: + text = (await request.body()).decode("utf-8") + + key = model.encode(text) + return {"key": key.tolist()} + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000, workers=num_works) \ No newline at end of file diff --git a/contrib/recipes/envs/empo2_server/server_mem.py b/contrib/recipes/envs/empo2_server/server_mem.py new file mode 100644 index 000000000..96ec8ebc3 --- /dev/null +++ b/contrib/recipes/envs/empo2_server/server_mem.py @@ -0,0 +1,68 @@ +from fastapi import FastAPI, Request +from pydantic import BaseModel +import uvicorn +import numpy as np +import time +import random +from collections import deque + +num_works = 1 +app = FastAPI() + +mem_list = None +content_set = None + +class MemRequest(BaseModel): + key: list + idx: int = None + content: str = None + score: float = None + +@app.post("/mem/") +async def mem_handler(mem_req: MemRequest): + global cnt, mem_list, content_set + + key = mem_req.key + idx = mem_req.idx + content = mem_req.content + score = mem_req.score + + if content=="Reset": + mem_list_num = idx + content_set = {id: set() for id in range(mem_list_num)} + mem_list = {id: [] for id in range(mem_list_num)} + cnt = {id: 0 for id in range(mem_list_num)} + print(f"Clean all the mem. The num of mem_list is {mem_list_num}") + return None + + if content is not None: + if content not in content_set[idx]: + content_set[idx].add(content) + mem_list[idx].append({ + "cnt": cnt[idx], + "key": key, + "content": content, + "score": score, + }) + cnt[idx] += 1 + if len(mem_list[idx]) > 1000: + oldest_hash = mem_list[idx][0]["content"] + content_set[idx].discard(oldest_hash) + mem_list[idx] = mem_list[idx][-1000:] + print("Add,", "id", idx, "cnt", cnt[idx], "content", content, "score", score) + else: + data = [] + for mem in mem_list[idx]: + mem_key = mem["key"] + sim = np.dot(key, mem_key) / (np.linalg.norm(key) * np.linalg.norm(mem_key)) + if sim > 0.5: + data.append(mem) + # data = random.sample(data, min(len(data), 10)) if len(data) > 0 else [] + data = sorted(data, key=lambda x: -x["score"])[:10] if len(data) > 0 else [] + data = [x["content"] for x in data] + count = len(data) + print("Load", count, data) + return count, data + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8001, workers=num_works) \ No newline at end of file diff --git a/contrib/recipes/envs/prompt_builder.py b/contrib/recipes/envs/prompt_builder.py index 905285d55..3029fc6f1 100644 --- a/contrib/recipes/envs/prompt_builder.py +++ b/contrib/recipes/envs/prompt_builder.py @@ -4,7 +4,6 @@ from autogen_core.models import AssistantMessage, UserMessage - class HistoryPromptBuilder: """ Builds prompts using a history of observations and actions. @@ -14,7 +13,11 @@ class HistoryPromptBuilder: - single: a single formatted prompt with optional history """ - def __init__(self, max_history: int = -1, prompt_type: str = "chat"): + def __init__( + self, + max_history: int = -1, + prompt_type: str = "chat" + ): """ Args: max_history (int): Maximum number of past steps to include @@ -80,7 +83,7 @@ def init(self, env): self._events.clear() if self.prompt_type == "chat": - inst_prompt = env.get_instruction_prompt(info) + inst_prompt = env.get_instruction_prompt() self.update_instruction_prompt(inst_prompt) elif self.prompt_type == "single": template_wo_his, template = env.get_single_prompt_template() @@ -179,5 +182,5 @@ def get_prompt(self): prompt = self.get_chat_prompt() elif self.prompt_type == "single": prompt = self.get_single_prompt() - + return prompt diff --git a/contrib/recipes/envs/train_env_agent.py b/contrib/recipes/envs/train_env_agent.py index 81aefe841..5dcba1e49 100644 --- a/contrib/recipes/envs/train_env_agent.py +++ b/contrib/recipes/envs/train_env_agent.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. - -import argparse import os +import re +import time +import argparse import subprocess from omegaconf import OmegaConf @@ -11,7 +12,6 @@ from contrib.agentlightning.contrib.algorithm.env_verl.daemon import EnvAgentModeDaemon from contrib.agentlightning.contrib.algorithm.env_verl.trainer import EnvAgentLightningTrainer - def run_cmd(cmd): """Execute a shell command and print its output""" print(f"👉 Running: {cmd}") @@ -51,9 +51,8 @@ def get_config(path): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--env", type=str, default="scienceworld") - parser.add_argument("--algorithm", type=str, default="grpo") - parser.add_argument("--debug", action="store_true") + parser.add_argument("--env", type=str, default="scienceworld2") + parser.add_argument("--algorithm", type=str, default="empo2_qwen_7b_instruct") parser.add_argument("--n_workers", type=int, default=64, help="Number of workers for training") parser.add_argument("--trial", type=int, default=0, help="Number of trials") parser.add_argument("--task_num", type=int, default=25, help="ScienceWorld Task number to inject as env var") @@ -68,17 +67,16 @@ def get_config(path): # set environment variable before loading configs os.environ["TRIAL"] = str(args.trial) - if args.env == "scienceworld": + if "scienceworld" in args.env: os.environ["TASK_NUM"] = str(args.task_num) # Load configs agent_config_path = f"config_env/{args.env}.yaml" - if args.debug: - trainer_config_path = f"config_verl/{args.env}/debug/{args.algorithm}.yaml" - else: - trainer_config_path = f"config_verl/{args.env}/{args.algorithm}.yaml" agent_config = get_config(agent_config_path) + env_prefix = re.sub(r"\d+$", "", args.env) + trainer_config_path = f"config_verl/{env_prefix}/{args.algorithm}.yaml" + if "gigpo" in args.algorithm: agent_config.log_env_obs = True rl_training_config = get_config(trainer_config_path) @@ -87,9 +85,31 @@ def get_config(path): train_dataset, val_dataset = train_val_dataset(rl_training_config) # Initialize agent - from contrib.agentlightning.contrib.agent.env_agent import EnvAgent + if "empo2" in args.algorithm: + from contrib.agentlightning.contrib.agent.empo2_agent import EMPO2Agent, reset_memory - agent = EnvAgent(agent_config) + kill_process_on_port(8000) + kill_process_on_port(8001) + + os.makedirs("logs", exist_ok=True) + + subprocess.Popen( + f"nohup python empo2_server/server_bert.py > logs/bert_{args.task_num}.log 2>&1 &", + shell=True + ) + subprocess.Popen( + f"nohup python empo2_server/server_mem.py > logs/mem_{args.task_num}.log 2>&1 &", + shell=True + ) + + NUM_MEMORY = 5 + time.sleep(1) + reset_memory(NUM_MEMORY) + + agent = EMPO2Agent(agent_config) + else: + from contrib.agentlightning.contrib.agent.env_agent import EnvAgent + agent = EnvAgent(agent_config) # Initialize trainer and start training trainer = Trainer(