Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 69 additions & 32 deletions evojax/task/slimevolley.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,18 @@ class AgentState(object):


def initAgentState(direction, x, y):
return AgentState(direction, x, y, 1.5, 0, 0, 0, 0, MAXLIVES)
return AgentState(
direction=jnp.int32(direction),
x=jnp.float32(x),
y=jnp.float32(y),
r=jnp.float32(1.5),
vx=jnp.float32(0),
vy=jnp.float32(0),
desired_vx=jnp.float32(0),
desired_vy=jnp.float32(0),
life=jnp.int32(MAXLIVES)
)



@dataclass
Expand Down Expand Up @@ -335,8 +346,10 @@ def __init__(self, p: ParticleState, c):
self.c = c

def display(self, canvas):
return circle(canvas, toX(float(self.p.x)), toY(float(self.p.y)),
toP(float(self.p.r)), color=self.c)
x = float(self.p.x.item())
y = float(self.p.y.item())
r = float(self.p.r.item())
return circle(canvas, toX(x), toY(y), toP(r), color=self.c)

def move(self):
self.p = ParticleState(self.p.x+self.p.vx*TIMESTEP,
Expand Down Expand Up @@ -558,12 +571,22 @@ def update(self):
p.life)

def updateLife(self, result):
""" updates the life based on result and internal direction """
"""Updates the life based on result and internal direction."""
p = self.p
updateAmount = p.direction*result # only update if this value is -1
new_life = jnp.where(updateAmount < 0, p.life-1, p.life)
self.p = AgentState(p.direction, p.x, p.y, p.r, p.vx, p.vy,
p.desired_vx, p.desired_vy, new_life)
updateAmount = p.direction * result # This should be a scalar
new_life = jnp.where(updateAmount < 0, p.life - 1, p.life)
self.p = AgentState(
direction=p.direction,
x=p.x,
y=p.y,
r=p.r,
vx=p.vx,
vy=p.vy,
desired_vx=p.desired_vx,
desired_vy=p.desired_vy,
life=new_life
)


def updateState(self, ball: ParticleState, opponent: AgentState):
""" normalized to side, customized for each agent's perspective"""
Expand Down Expand Up @@ -591,13 +614,13 @@ def getObservation(self):
return getObsArray(self.state)

def display(self, canvas, ball_x, ball_y):
bx = float(ball_x)
by = float(ball_y)
bx = float(ball_x.item())
by = float(ball_y.item())
p = self.p
x = float(p.x)
y = float(p.y)
r = float(p.r)
direction = int(p.direction)
x = float(p.x.item())
y = float(p.y.item())
r = float(p.r.item())
direction = int(p.direction.item())

angle = math.pi * 60 / 180
if direction == 1:
Expand All @@ -608,31 +631,45 @@ def display(self, canvas, ball_x, ball_y):
canvas = half_circle(canvas, toX(x), toY(y), toP(r), color=self.c)

# track ball with eyes (replace with observed info later):
c = math.cos(angle)
s = math.sin(angle)
ballX = bx-(x+(0.6)*r*c)
ballY = by-(y+(0.6)*r*s)

dist = math.sqrt(ballX*ballX+ballY*ballY)
eyeX = ballX/dist
eyeY = ballY/dist

canvas = circle(canvas, toX(x+(0.6)*r*c), toY(y+(0.6)*r*s),
toP(r)*0.3, color=(255, 255, 255))
canvas = circle(canvas, toX(x+(0.6)*r*c+eyeX*0.15*r),
toY(y+(0.6)*r*s+eyeY*0.15*r), toP(r)*0.1,
color=(0, 0, 0))
c_angle = math.cos(angle)
s_angle = math.sin(angle)
ballX = bx - (x + 0.6 * r * c_angle)
ballY = by - (y + 0.6 * r * s_angle)

dist = math.sqrt(ballX * ballX + ballY * ballY)
eyeX = ballX / dist
eyeY = ballY / dist

canvas = circle(
canvas,
toX(x + 0.6 * r * c_angle),
toY(y + 0.6 * r * s_angle),
toP(r) * 0.3,
color=(255, 255, 255)
)
canvas = circle(
canvas,
toX(x + 0.6 * r * c_angle + eyeX * 0.15 * r),
toY(y + 0.6 * r * s_angle + eyeY * 0.15 * r),
toP(r) * 0.1,
color=(0, 0, 0)
)

# draw coins (lives) left
num_lives = int(p.life)
num_lives = int(p.life.item())
for i in range(1, num_lives):
canvas = circle(canvas, toX(direction*(REF_W/2+0.5-i*2.)),
WINDOW_HEIGHT-toY(1.5), toP(0.5),
color=COIN_COLOR)
canvas = circle(
canvas,
toX(direction * (REF_W / 2 + 0.5 - i * 2.0)),
WINDOW_HEIGHT - toY(1.5),
toP(0.5),
color=COIN_COLOR
)

return canvas



class Wall:
""" used for the fence, and also the ground """
def __init__(self, x, y, w, h, c):
Expand Down
31 changes: 23 additions & 8 deletions examples/train_slimevolley.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def main(config):
trainer.model_dir = log_dir
trainer.run(demo_mode=True)

# Visualize the policy.
# Visualize the policy
task_reset_fn = jax.jit(test_task.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(test_task.step)
Expand All @@ -136,19 +136,34 @@ def main(config):
task_state = task_reset_fn(key)
policy_state = policy_reset_fn(task_state)
screens = []

for _ in range(max_steps):
action, policy_state = action_fn(task_state, best_params, policy_state)
task_state, reward, done = step_fn(task_state, action)
screens.append(SlimeVolley.render(task_state))

gif_file = os.path.join(log_dir, 'slimevolley.gif')
screens[0].save(gif_file, save_all=True, append_images=screens[1:],
duration=40, loop=0)
logger.info('GIF saved to {}.'.format(gif_file))

# Extract scalar values from the batched state
state_numpy = jax.device_get(task_state)
# Assuming the first element of the batch is what we want to visualize
unbatched_state = jax.tree.map(lambda x: x[0], state_numpy)

try:
screen = SlimeVolley.render(unbatched_state)
screens.append(screen)
except Exception as e:
logger.error(f"Error during rendering: {e}")
break

if screens:
gif_file = os.path.join(log_dir, 'slimevolley.gif')
screens[0].save(gif_file, save_all=True, append_images=screens[1:],
duration=40, loop=0)
logger.info('GIF saved to {}.'.format(gif_file))
else:
logger.error('No frames were rendered successfully.')


if __name__ == '__main__':
configs = parse_args()
if configs.gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id
main(configs)
main(configs)
98 changes: 98 additions & 0 deletions log/slimevolley/SlimeVolley.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
SlimeVolley: 2024-10-25 16:13:04,640 [INFO] EvoJAX SlimeVolley
SlimeVolley: 2024-10-25 16:13:04,640 [INFO] ==============================
SlimeVolley: 2024-10-25 16:13:18,435 [INFO] use_for_loop=False
SlimeVolley: 2024-10-25 16:13:18,444 [INFO] Start to train for 500 iterations.
SlimeVolley: 2024-10-25 16:13:50,393 [INFO] Iter=10, size=128, max=-27.6875, avg=-32.0161, min=-37.3750, std=1.8466
SlimeVolley: 2024-10-25 16:14:18,662 [INFO] Iter=20, size=128, max=-22.1250, avg=-31.0859, min=-36.5625, std=2.1711
SlimeVolley: 2024-10-25 16:14:46,955 [INFO] Iter=30, size=128, max=-22.8750, avg=-29.9482, min=-36.1250, std=2.9832
SlimeVolley: 2024-10-25 16:15:15,193 [INFO] Iter=40, size=128, max=-22.8125, avg=-28.8965, min=-36.5625, std=2.8154
SlimeVolley: 2024-10-25 16:15:43,034 [INFO] Iter=50, size=128, max=-19.5000, avg=-28.2910, min=-35.4375, std=3.6427
SlimeVolley: 2024-10-25 16:15:43,958 [INFO] [TEST] Iter=50, #tests=100, max=-3.0000, avg=-4.8300, min=-5.0000, std=0.4014
SlimeVolley: 2024-10-25 16:16:11,675 [INFO] Iter=60, size=128, max=-11.4375, avg=-26.6978, min=-36.4375, std=4.4829
SlimeVolley: 2024-10-25 16:16:39,569 [INFO] Iter=70, size=128, max=-10.5625, avg=-26.7144, min=-37.2500, std=4.3725
SlimeVolley: 2024-10-25 16:17:07,747 [INFO] Iter=80, size=128, max=-14.7500, avg=-25.8159, min=-35.5625, std=4.3953
SlimeVolley: 2024-10-25 16:17:35,884 [INFO] Iter=90, size=128, max=-12.9375, avg=-26.1147, min=-36.4375, std=5.3905
SlimeVolley: 2024-10-25 16:18:03,634 [INFO] Iter=100, size=128, max=-12.5000, avg=-24.5405, min=-35.5625, std=5.1251
SlimeVolley: 2024-10-25 16:18:03,851 [INFO] [TEST] Iter=100, #tests=100, max=-3.0000, avg=-4.8500, min=-5.0000, std=0.4093
SlimeVolley: 2024-10-25 16:18:31,920 [INFO] Iter=110, size=128, max=-13.1875, avg=-22.8726, min=-34.5000, std=4.9581
SlimeVolley: 2024-10-25 16:18:59,693 [INFO] Iter=120, size=128, max=-12.3750, avg=-22.1538, min=-32.6250, std=4.5800
SlimeVolley: 2024-10-25 16:19:27,462 [INFO] Iter=130, size=128, max=-9.8750, avg=-21.3740, min=-34.2500, std=5.3786
SlimeVolley: 2024-10-25 16:19:55,616 [INFO] Iter=140, size=128, max=-10.1250, avg=-20.8486, min=-33.7500, std=5.2729
SlimeVolley: 2024-10-25 16:20:23,589 [INFO] Iter=150, size=128, max=-8.8750, avg=-20.5293, min=-34.4375, std=5.3173
SlimeVolley: 2024-10-25 16:20:23,807 [INFO] [TEST] Iter=150, #tests=100, max=0.0000, avg=-4.0300, min=-5.0000, std=1.2446
SlimeVolley: 2024-10-25 16:20:51,596 [INFO] Iter=160, size=128, max=-6.1875, avg=-20.3066, min=-33.8750, std=5.7341
SlimeVolley: 2024-10-25 16:21:19,786 [INFO] Iter=170, size=128, max=-8.5625, avg=-19.2305, min=-31.2500, std=5.5489
SlimeVolley: 2024-10-25 16:21:47,856 [INFO] Iter=180, size=128, max=-6.3125, avg=-18.9951, min=-33.1875, std=5.8729
SlimeVolley: 2024-10-25 16:22:15,849 [INFO] Iter=190, size=128, max=-6.5625, avg=-19.1255, min=-33.5625, std=5.9446
SlimeVolley: 2024-10-25 16:22:43,580 [INFO] Iter=200, size=128, max=-4.6875, avg=-19.2197, min=-34.6875, std=6.3952
SlimeVolley: 2024-10-25 16:22:43,797 [INFO] [TEST] Iter=200, #tests=100, max=1.0000, avg=-2.5800, min=-5.0000, std=1.6136
SlimeVolley: 2024-10-25 16:23:11,585 [INFO] Iter=210, size=128, max=-5.3750, avg=-17.8315, min=-30.6875, std=5.7723
SlimeVolley: 2024-10-25 16:23:39,598 [INFO] Iter=220, size=128, max=-4.5625, avg=-15.6294, min=-34.0000, std=6.3958
SlimeVolley: 2024-10-25 16:24:07,345 [INFO] Iter=230, size=128, max=-2.0625, avg=-15.5928, min=-32.0000, std=6.7743
SlimeVolley: 2024-10-25 16:24:35,088 [INFO] Iter=240, size=128, max=-1.6875, avg=-14.7544, min=-32.7500, std=6.4414
SlimeVolley: 2024-10-25 16:25:03,188 [INFO] Iter=250, size=128, max=-3.4375, avg=-15.2583, min=-32.5000, std=6.0321
SlimeVolley: 2024-10-25 16:25:03,406 [INFO] [TEST] Iter=250, #tests=100, max=2.0000, avg=-0.6900, min=-4.0000, std=1.1636
SlimeVolley: 2024-10-25 16:25:31,853 [INFO] Iter=260, size=128, max=-2.8125, avg=-14.3882, min=-32.7500, std=6.8517
SlimeVolley: 2024-10-25 16:25:59,919 [INFO] Iter=270, size=128, max=-0.8750, avg=-13.9692, min=-30.2500, std=6.7459
SlimeVolley: 2024-10-25 16:26:27,877 [INFO] Iter=280, size=128, max=-2.3750, avg=-13.5034, min=-32.8750, std=6.8843
SlimeVolley: 2024-10-25 16:26:55,908 [INFO] Iter=290, size=128, max=-3.0000, avg=-13.7354, min=-28.9375, std=6.7271
SlimeVolley: 2024-10-25 16:27:23,551 [INFO] Iter=300, size=128, max=-0.8125, avg=-13.2305, min=-29.0000, std=6.8488
SlimeVolley: 2024-10-25 16:27:23,774 [INFO] [TEST] Iter=300, #tests=100, max=4.0000, avg=0.1600, min=-2.0000, std=0.9972
SlimeVolley: 2024-10-25 16:27:51,388 [INFO] Iter=310, size=128, max=-0.3750, avg=-13.1182, min=-32.3125, std=7.1701
SlimeVolley: 2024-10-25 16:28:19,521 [INFO] Iter=320, size=128, max=-0.4375, avg=-11.9629, min=-31.1875, std=7.8670
SlimeVolley: 2024-10-25 16:28:47,465 [INFO] Iter=330, size=128, max=-0.5000, avg=-11.8164, min=-33.6875, std=7.1326
SlimeVolley: 2024-10-25 16:29:15,827 [INFO] Iter=340, size=128, max=0.3750, avg=-10.4497, min=-29.6250, std=7.0822
SlimeVolley: 2024-10-25 16:29:44,369 [INFO] Iter=350, size=128, max=0.2500, avg=-9.5391, min=-29.6250, std=6.8241
SlimeVolley: 2024-10-25 16:29:44,587 [INFO] [TEST] Iter=350, #tests=100, max=3.0000, avg=0.1900, min=-1.0000, std=0.7442
SlimeVolley: 2024-10-25 16:30:13,174 [INFO] Iter=360, size=128, max=0.4375, avg=-8.4819, min=-31.2500, std=6.3598
SlimeVolley: 2024-10-25 16:30:40,825 [INFO] Iter=370, size=128, max=0.0625, avg=-8.2344, min=-26.4375, std=5.8934
SlimeVolley: 2024-10-25 16:31:08,775 [INFO] Iter=380, size=128, max=0.1875, avg=-8.5503, min=-26.1875, std=5.9854
SlimeVolley: 2024-10-25 16:31:36,716 [INFO] Iter=390, size=128, max=0.1250, avg=-7.1060, min=-28.8750, std=5.6706
SlimeVolley: 2024-10-25 16:32:04,384 [INFO] Iter=400, size=128, max=0.3125, avg=-6.4419, min=-25.8125, std=5.9703
SlimeVolley: 2024-10-25 16:32:04,598 [INFO] [TEST] Iter=400, #tests=100, max=3.0000, avg=0.3500, min=-1.0000, std=0.7263
SlimeVolley: 2024-10-25 16:32:32,778 [INFO] Iter=410, size=128, max=0.1875, avg=-7.0156, min=-24.5625, std=5.3086
SlimeVolley: 2024-10-25 16:33:00,940 [INFO] Iter=420, size=128, max=-0.1875, avg=-6.1362, min=-19.8750, std=4.0926
SlimeVolley: 2024-10-25 16:33:28,953 [INFO] Iter=430, size=128, max=0.6875, avg=-6.3506, min=-27.6875, std=5.3467
SlimeVolley: 2024-10-25 16:33:56,587 [INFO] Iter=440, size=128, max=0.5000, avg=-5.7314, min=-19.6250, std=4.7511
SlimeVolley: 2024-10-25 16:34:24,380 [INFO] Iter=450, size=128, max=0.5000, avg=-6.0801, min=-25.5000, std=5.5417
SlimeVolley: 2024-10-25 16:34:24,591 [INFO] [TEST] Iter=450, #tests=100, max=2.0000, avg=0.3800, min=-2.0000, std=0.6600
SlimeVolley: 2024-10-25 16:34:52,561 [INFO] Iter=460, size=128, max=0.5625, avg=-6.3086, min=-19.6250, std=5.1042
SlimeVolley: 2024-10-25 16:35:20,585 [INFO] Iter=470, size=128, max=0.3125, avg=-4.7056, min=-21.8125, std=4.5107
SlimeVolley: 2024-10-25 16:35:48,447 [INFO] Iter=480, size=128, max=0.6875, avg=-4.5259, min=-22.7500, std=4.2573
SlimeVolley: 2024-10-25 16:36:16,096 [INFO] Iter=490, size=128, max=0.6875, avg=-4.6367, min=-22.5625, std=4.8341
SlimeVolley: 2024-10-25 16:36:41,075 [INFO] [TEST] Iter=500, #tests=100, max=3.0000, avg=0.4200, min=-1.0000, std=0.8388
SlimeVolley: 2024-10-25 16:36:41,075 [INFO] Training done, best_score=0.4200
SlimeVolley: 2024-10-25 16:36:41,078 [INFO] Loaded model parameters from ./log/slimevolley.
SlimeVolley: 2024-10-25 16:36:41,078 [INFO] Start to test the parameters.
SlimeVolley: 2024-10-25 16:36:41,292 [INFO] [TEST] #tests=100, max=3.0000, avg=0.5100, min=-1.0000, std=0.7810
SlimeVolley: 2024-10-25 16:43:27,997 [INFO] EvoJAX SlimeVolley
SlimeVolley: 2024-10-25 16:43:27,997 [INFO] ==============================
SlimeVolley: 2024-10-25 16:43:28,891 [INFO] use_for_loop=False
SlimeVolley: 2024-10-25 16:43:28,899 [INFO] Start to train for 500 iterations.
SlimeVolley: 2024-10-25 16:43:46,867 [INFO] EvoJAX SlimeVolley
SlimeVolley: 2024-10-25 16:43:46,867 [INFO] ==============================
SlimeVolley: 2024-10-25 16:43:47,423 [INFO] use_for_loop=False
SlimeVolley: 2024-10-25 16:43:47,432 [INFO] Start to train for 10 iterations.
SlimeVolley: 2024-10-25 16:44:16,864 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129
SlimeVolley: 2024-10-25 16:44:16,872 [INFO] Training done, best_score=-4.8900
SlimeVolley: 2024-10-25 16:44:16,874 [INFO] Loaded model parameters from ./log/slimevolley.
SlimeVolley: 2024-10-25 16:44:16,874 [INFO] Start to test the parameters.
SlimeVolley: 2024-10-25 16:44:17,084 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071
SlimeVolley: 2024-10-25 16:46:17,997 [INFO] EvoJAX SlimeVolley
SlimeVolley: 2024-10-25 16:46:17,997 [INFO] ==============================
SlimeVolley: 2024-10-25 16:46:18,695 [INFO] use_for_loop=False
SlimeVolley: 2024-10-25 16:46:18,704 [INFO] Start to train for 10 iterations.
SlimeVolley: 2024-10-25 16:46:48,159 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129
SlimeVolley: 2024-10-25 16:46:48,166 [INFO] Training done, best_score=-4.8900
SlimeVolley: 2024-10-25 16:46:48,167 [INFO] Loaded model parameters from ./log/slimevolley.
SlimeVolley: 2024-10-25 16:46:48,167 [INFO] Start to test the parameters.
SlimeVolley: 2024-10-25 16:46:48,381 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071
SlimeVolley: 2024-10-25 16:46:58,171 [INFO] EvoJAX SlimeVolley
SlimeVolley: 2024-10-25 16:46:58,171 [INFO] ==============================
SlimeVolley: 2024-10-25 16:46:58,845 [INFO] use_for_loop=False
SlimeVolley: 2024-10-25 16:46:58,854 [INFO] Start to train for 10 iterations.
SlimeVolley: 2024-10-25 16:47:28,143 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129
SlimeVolley: 2024-10-25 16:47:28,149 [INFO] Training done, best_score=-4.8900
SlimeVolley: 2024-10-25 16:47:28,150 [INFO] Loaded model parameters from ./log/slimevolley.
SlimeVolley: 2024-10-25 16:47:28,150 [INFO] Start to test the parameters.
SlimeVolley: 2024-10-25 16:47:28,364 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071