Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
9fae8b2
add simulation envs
beanie00 Dec 2, 2025
bfada88
add simulation agent and trainer
beanie00 Dec 3, 2025
ae3c490
add run scripts
beanie00 Dec 3, 2025
275d2d1
add span grouping and modify the adapter
beanie00 Dec 4, 2025
d33946c
apply step reward and message to get_train_data_batch in demon
beanie00 Dec 6, 2025
8818137
error fixed
beanie00 Dec 6, 2025
d512d8a
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 6, 2025
30109a2
clean simulation/envs
beanie00 Dec 10, 2025
76b56ff
clean task_data
beanie00 Dec 10, 2025
34b5c89
change task data path
beanie00 Dec 10, 2025
d9c069d
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 11, 2025
49bfcab
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 12, 2025
da50a7a
move examples/simulation/* to contrib folder
beanie00 Dec 12, 2025
c9312e3
rollback prev agentlightning triplet and daemon
beanie00 Dec 12, 2025
e119c68
link with PR #407
beanie00 Dec 12, 2025
966e571
clean files
beanie00 Dec 15, 2025
140ecb1
clean files
beanie00 Dec 15, 2025
cee8188
clean prompt builder and script files
beanie00 Dec 15, 2025
59e2c31
move prompt_builder to agent
beanie00 Dec 15, 2025
7340bc6
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 18, 2025
a7bde15
update readme and clean files
beanie00 Dec 19, 2025
d647fc5
update readme
beanie00 Dec 20, 2025
7e2698f
(working) add empo2
beanie00 Dec 20, 2025
591d02d
prevent empty triplet
beanie00 Dec 20, 2025
8cc512c
refactor triplet group
beanie00 Dec 22, 2025
db9f004
update naive instruction
beanie00 Dec 22, 2025
91a7b3d
update training configs
beanie00 Dec 22, 2025
c3c9725
fix intrinsic list length mismatch
beanie00 Dec 23, 2025
9d5194f
fix validation error
beanie00 Dec 23, 2025
c1fe7d9
apply verl entropy_from_logits_with_chunking
beanie00 Jan 14, 2026
dc0b30b
update instruction and relative path
beanie00 Feb 3, 2026
dcc07f4
minor update
beanie00 Feb 3, 2026
a65876b
update output dir path in yaml
beanie00 Feb 4, 2026
bac4818
update empo2 exp name
beanie00 Feb 4, 2026
2b8badf
set all tips
beanie00 Feb 4, 2026
1bb1274
update exp name
beanie00 Feb 4, 2026
d05d64f
add max total length
beanie00 Feb 5, 2026
29df046
fix tip config error in trainer
beanie00 Feb 5, 2026
0cc2c71
fix tip config error in trainer
beanie00 Feb 6, 2026
77ff0ee
minor updates
beanie00 Feb 6, 2026
a283dea
Update empo2_qwen_7b_instruct.yaml
beanie00 Feb 7, 2026
84ad50c
Update empo2_qwen_7b_instruct.yaml
beanie00 Feb 7, 2026
62cb025
update tip end pattern
beanie00 Feb 7, 2026
6971044
Update core_empo2.py
beanie00 Feb 7, 2026
50b1f82
Update trainer.py
beanie00 Feb 8, 2026
6867b96
update off-policy old log prob calculation
beanie00 Feb 8, 2026
c1dd325
change low prob masking in empo2
beanie00 Feb 8, 2026
9112988
update removing tip pattern logic
beanie00 Feb 9, 2026
090c5c0
minor update
beanie00 Feb 9, 2026
bc023cf
remove verl and change folder structure
beanie00 Feb 10, 2026
5271d84
remove verl and change folder structure
beanie00 Feb 10, 2026
f46582a
merge
beanie00 Feb 10, 2026
92db4e3
Merge branch 'main' into feature/empo2-all-off-policy2
beanie00 Feb 10, 2026
051564e
fix wrong relative path
beanie00 Feb 10, 2026
487a69a
update old log prob calculation
beanie00 Feb 10, 2026
4779761
update max train len
beanie00 Feb 10, 2026
d38d31c
update empo2 readme
beanie00 Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,7 @@ agentlightning/dashboard/**/*.svg

# Docker data
docker/data/

# AGL simulation
agl_envs/
wandb/
5 changes: 4 additions & 1 deletion contrib/agentlightning/contrib/adapter/triplet_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def message(span: Optional[Span]) -> Optional[str]:

for group in span_groups.values():
call_span = group.get("call_span")
if not token_ids(call_span, "prompt_token_ids") and not token_ids(call_span, "response_token_ids"):
if (
not token_ids(call_span, "prompt_token_ids")
and not token_ids(call_span, "response_token_ids")
):
continue

object_span = group.get("object_span")
Expand Down
259 changes: 259 additions & 0 deletions contrib/agentlightning/contrib/agent/empo2_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import copy
import numpy as np
import requests
import logging
from typing import Any, Dict

from add_instruction import add_chat_instruction, add_chat_tips, add_chat_all_tips
from agentlightning import (
LLM,
NamedResources,
Rollout,
configure_logger,
emit_reward,
operation
)
from agentlightning.utils.otel import make_link_attributes

from agl_envs import make_env_manager
from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder

from contrib.agentlightning.contrib.agent.env_agent import EnvAgent

configure_logger()
logger = configure_logger(name=__name__, level=logging.ERROR)

def do_compress(text):
url = "http://127.0.0.1:8000/key_cal/"
headers = {"Content-Type": "application/json"} # 明确指定 JSON 格式
data = {"text": text}
response = requests.post(url, json=data, headers=headers) # 使用 json 参数
return response.json()

url_mem = "http://127.0.0.1:8001/mem/"

def retrieve_memory(idx, key):
response = requests.post(url_mem, json={
"key": key,
"idx": idx
})
count, data = response.json()
return count, data

def reset_memory(mem_list_num):
requests.post(url_mem, json={
"key": [],
"idx": mem_list_num, # 用于初始化多个 memory slot
"content": "Reset"
})

def add_memory(idx, key, content, score):
requests.post(url_mem, json={
"key": key,
"idx": idx,
"content": content,
"score": score
})

def gather_chats(prompt):
chat_list = []
for item in prompt:
role = item.type
content = item.content
if "System" in role:
continue
elif "User" in role:
role = "user"
else:
role = "assistant"
chat_list.append(f"{role}: {content}")
text = " ".join(chat_list)
return text

class EMPO2Agent(EnvAgent):
def __init__(self, config, trained_agents: str | None = None) -> None:
super().__init__(config=config, trained_agents=trained_agents)

def _get_tip_prompt(self, prompt, tips):
prompt_type = self.config.captioner.prompt_type

if prompt_type == "chat":
return add_chat_tips(prompt, tips)
else:
raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')")

def _get_all_tip_prompt(self, prompt, tip_list):
prompt_type = self.config.captioner.prompt_type
if prompt_type == "chat":
return add_chat_all_tips(prompt, tip_list)
else:
raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')")

def _get_tip_generation_prompt(self, prompt):
return add_chat_instruction(prompt, "tip")

async def rollout_async(
self,
task: Dict[str, Any],
resources: NamedResources,
rollout: Rollout,
) -> float | None:
rollout_id = rollout.rollout_id
logger.info(f"[Rollout {rollout_id}] Task: {task}")

reward_scale = float(self.config["reawrd_scale"])

# Setup LLM + agent
llm: LLM = resources.get("main_llm")
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4)

if rollout.mode == "train":
train_mode = task["train_mode"]
global_steps = task["global_steps"]
else:
train_mode = "on-policy"

if rollout.mode == "train" and (train_mode == "off-policy" or train_mode == "on-policy-with-tips"):
use_tips = True
else:
use_tips = False

variation_idx = task["variation_idx"]

try:
# Setup environment
prompt_builder = HistoryPromptBuilder(max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type)

self.env = make_env_manager(self.config.env_name, task, self.config)
env_obs, infos, available_actions_hint = self.env.reset()

prompt_builder.init(self.env)
prompt_builder.update_observation(env_obs)
# prompt_builder.update_admissible_actions(available_actions_hint)

prompt = prompt_builder.get_prompt()

episode_reward, done = 0.0, False

pure_prompt_for_mem = []
history_actions_for_mem = []
tip_list = []

step_count = 0
while not done:
if use_tips:
text = gather_chats(prompt)
key = np.array(do_compress(text)['key']).reshape(-1, ).tolist()
count, mem_list = retrieve_memory(variation_idx, key)
else:
count, mem_list = 0, []

ret_tips, intrinsic_reward = "", 0.0

if use_tips:
if count > 0:
ret_tips = "Here are some memories you collected in your previous exploration:\n"
for mem in mem_list:
ret_tips += mem+"\n"

tip_list.append(ret_tips)
intrinsic_reward = 1 / (count+1)
else:
tip_list.append("")
intrinsic_reward = 1

try:
if count > 0:
tip_prompt = self._get_all_tip_prompt(prompt, tip_list)
instructed_prompt = self._get_instructed_prompt(tip_prompt, sep="")
else:
instructed_prompt = self._get_instructed_prompt(prompt)

# Main agent step
with operation(step_count=step_count):
result = await self.agent._model_client.create(instructed_prompt)
output = result.content
logger.info(f"[LLM output]: {output}")

except Exception as e:
logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True)
break

# Environment step
pure_prompt_for_mem.append([copy.deepcopy(prompt), None])
env_obs, executed_action,is_valid, step_reward, terminated, truncated, info, available_actions_hint = self.env.step(
output, use_reasoning=self.config.captioner.type == "cot"
)
history_actions_for_mem.append(executed_action)

prompt_builder.update_step_count()
prompt_builder.update_action(executed_action)
prompt_builder.update_observation(env_obs)
# prompt_builder.update_admissible_actions(available_actions_hint)

prompt = prompt_builder.get_prompt()

if rollout.mode == "train":
step_reward = reward_scale * step_reward

emit_reward(
{
"extrinsic_reward": step_reward,
"intrinsic_reward": intrinsic_reward,
},
primary_key="extrinsic_reward",
attributes=make_link_attributes({"step_count": str(step_count)}),
)

episode_reward += float(step_reward)
done = np.logical_or(terminated, truncated)

step_count += 1

if (
rollout.mode == "train"
and self.config.captioner.prompt_type == "chat"
and self.config.save_rollout
):
filename = f"empo2_rollouts/variant_{variation_idx}/step_{global_steps}/{rollout_id}_{round(episode_reward, 1)}_use_tip_{use_tips}.json"
if use_tips:
_rollout = self._get_all_tip_obs(obs, tip_list)
else:
_rollout = obs
self._save_chat_rollout(_rollout, filename)

if rollout.mode == "train":
prompt_builder.prompt_type = "chat"
prompt_builder.max_history = -1
prompt = prompt_builder.get_prompt()
prompt.pop()

tip_generation_prompt = self._get_tip_generation_prompt(prompt)

self.agent._model_client.max_tokens = 128
result = await self.agent._model_client.create(tip_generation_prompt)
tips = result.content
logger.info(f"Tips: {tips}")

#! Fill the ret and tip
for i in range(len(pure_prompt_for_mem)):
max_score = 100 * reward_scale
pure_prompt_for_mem[i][1] = tips + f'; At that timestep, the specific action your took was {history_actions_for_mem[i]}; Eventually you got the score {round(episode_reward, 1)}/{int(max_score)}.'

#! Generate the tips and save the mem
for i in range(len(pure_prompt_for_mem)):
text = gather_chats(pure_prompt_for_mem[i][0])
key = np.array(do_compress(text)['key']).reshape(-1, ).tolist()
content = pure_prompt_for_mem[i][1]
score = episode_reward
add_memory(variation_idx, key, content, round(score, 1))

if self.config.use_success_rate:
return self.env.get_success_score() * reward_scale
else:
return episode_reward

finally:
if self.env is not None:
self.env.close()
65 changes: 65 additions & 0 deletions contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from typing import List, Any

def is_sublist(sub, full):
n, m = len(sub), len(full)
return any(full[i:i+n] == sub for i in range(m - n + 1))

# Function to remove segments of a list between a start pattern and an end pattern
def remove_pattern_ranges(seq: List[Any],
start_pat: List[Any],
end_pat: List[Any]) -> List[Any]:
"""Remove every [start_pat ... end_pat] slice (inclusive) from seq."""

out: List[Any] = []
i = 0
n = len(seq)
ls, le = len(start_pat), len(end_pat)

while i < n:
# Check if the start pattern matches at the current position
if i + ls <= n and seq[i:i+ls] == start_pat:
# Look for the first occurrence of the end pattern after the start pattern
j = i + ls
found_end = -1
while j + le <= n:
if seq[j:j+le] == end_pat:
found_end = j
break # Stop when the end pattern is found
j += 1

# If the end pattern is found, skip the whole segment from start to end
if found_end != -1:
i = found_end + le # Move the index past the end pattern
continue # Skip the current iteration and go to the next
else:
# If the end pattern is not found, keep the current element and move one step forward
out.append(seq[i])
i += 1
else:
# If the start pattern is not found, just append the current element
out.append(seq[i])
i += 1

# Return the filtered list with the start-end pattern segments removed
return out

def low_prob_token_masking(batch):
response_mask = batch.batch["response_mask"] # [N, T]
old_log_prob = batch.batch["old_log_probs"] # [N, T]
# advantages = batch.batch["advantages"] # [N, T]

masked_old_log_prob = old_log_prob.masked_fill(response_mask == 0, 1e9)
min_values, _ = torch.min(masked_old_log_prob, dim=1) # [N]

mask = min_values < -5 # [N]

combined_mask = mask.unsqueeze(1) & (response_mask == 1)

# advantages masking
response_mask = response_mask.masked_fill(combined_mask, 0)
batch.batch["response_mask"] = response_mask

print(f"Number of tokens masked: {combined_mask.sum().item()}")

return batch
Loading