From 0d5476331e75eaf1b635914a20d9471b44e2854d Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 16:48:02 -0400 Subject: [PATCH 1/6] attempt: fix slime --- evojax/task/slimevolley.py | 16 +++--- examples/train_slimevolley.py | 2 +- log/slimevolley/SlimeVolley.txt | 98 +++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 log/slimevolley/SlimeVolley.txt diff --git a/evojax/task/slimevolley.py b/evojax/task/slimevolley.py index 3e12c988..0d284af3 100644 --- a/evojax/task/slimevolley.py +++ b/evojax/task/slimevolley.py @@ -38,6 +38,8 @@ import math import numpy as np +import jax.numpy as jnp + from typing import Tuple @@ -591,13 +593,13 @@ def getObservation(self): return getObsArray(self.state) def display(self, canvas, ball_x, ball_y): - bx = float(ball_x) - by = float(ball_y) + bx = float(jnp.squeeze(ball_x)) + by = float(jnp.squeeze(ball_y)) p = self.p - x = float(p.x) - y = float(p.y) - r = float(p.r) - direction = int(p.direction) + x = float(jnp.squeeze(p.x)) + y = float(jnp.squeeze(p.y)) + r = float(jnp.squeeze(p.r)) + direction = int(jnp.squeeze(p.direction)) angle = math.pi * 60 / 180 if direction == 1: @@ -624,7 +626,7 @@ def display(self, canvas, ball_x, ball_y): color=(0, 0, 0)) # draw coins (lives) left - num_lives = int(p.life) + num_lives = int(jnp.squeeze(p.life)) for i in range(1, num_lives): canvas = circle(canvas, toX(direction*(REF_W/2+0.5-i*2.)), WINDOW_HEIGHT-toY(1.5), toP(0.5), diff --git a/examples/train_slimevolley.py b/examples/train_slimevolley.py index 6f694831..72a532f1 100644 --- a/examples/train_slimevolley.py +++ b/examples/train_slimevolley.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( '--n-repeats', type=int, default=16, help='Training repetitions.') parser.add_argument( - '--max-iter', type=int, default=500, help='Max training iterations.') + '--max-iter', type=int, default=10, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=50, help='Test interval.') parser.add_argument( diff --git a/log/slimevolley/SlimeVolley.txt b/log/slimevolley/SlimeVolley.txt new file mode 100644 index 00000000..7d249157 --- /dev/null +++ b/log/slimevolley/SlimeVolley.txt @@ -0,0 +1,98 @@ +SlimeVolley: 2024-10-25 16:13:04,640 [INFO] EvoJAX SlimeVolley +SlimeVolley: 2024-10-25 16:13:04,640 [INFO] ============================== +SlimeVolley: 2024-10-25 16:13:18,435 [INFO] use_for_loop=False +SlimeVolley: 2024-10-25 16:13:18,444 [INFO] Start to train for 500 iterations. +SlimeVolley: 2024-10-25 16:13:50,393 [INFO] Iter=10, size=128, max=-27.6875, avg=-32.0161, min=-37.3750, std=1.8466 +SlimeVolley: 2024-10-25 16:14:18,662 [INFO] Iter=20, size=128, max=-22.1250, avg=-31.0859, min=-36.5625, std=2.1711 +SlimeVolley: 2024-10-25 16:14:46,955 [INFO] Iter=30, size=128, max=-22.8750, avg=-29.9482, min=-36.1250, std=2.9832 +SlimeVolley: 2024-10-25 16:15:15,193 [INFO] Iter=40, size=128, max=-22.8125, avg=-28.8965, min=-36.5625, std=2.8154 +SlimeVolley: 2024-10-25 16:15:43,034 [INFO] Iter=50, size=128, max=-19.5000, avg=-28.2910, min=-35.4375, std=3.6427 +SlimeVolley: 2024-10-25 16:15:43,958 [INFO] [TEST] Iter=50, #tests=100, max=-3.0000, avg=-4.8300, min=-5.0000, std=0.4014 +SlimeVolley: 2024-10-25 16:16:11,675 [INFO] Iter=60, size=128, max=-11.4375, avg=-26.6978, min=-36.4375, std=4.4829 +SlimeVolley: 2024-10-25 16:16:39,569 [INFO] Iter=70, size=128, max=-10.5625, avg=-26.7144, min=-37.2500, std=4.3725 +SlimeVolley: 2024-10-25 16:17:07,747 [INFO] Iter=80, size=128, max=-14.7500, avg=-25.8159, min=-35.5625, std=4.3953 +SlimeVolley: 2024-10-25 16:17:35,884 [INFO] Iter=90, size=128, max=-12.9375, avg=-26.1147, min=-36.4375, std=5.3905 +SlimeVolley: 2024-10-25 16:18:03,634 [INFO] Iter=100, size=128, max=-12.5000, avg=-24.5405, min=-35.5625, std=5.1251 +SlimeVolley: 2024-10-25 16:18:03,851 [INFO] [TEST] Iter=100, #tests=100, max=-3.0000, avg=-4.8500, min=-5.0000, std=0.4093 +SlimeVolley: 2024-10-25 16:18:31,920 [INFO] Iter=110, size=128, max=-13.1875, avg=-22.8726, min=-34.5000, std=4.9581 +SlimeVolley: 2024-10-25 16:18:59,693 [INFO] Iter=120, size=128, max=-12.3750, avg=-22.1538, min=-32.6250, std=4.5800 +SlimeVolley: 2024-10-25 16:19:27,462 [INFO] Iter=130, size=128, max=-9.8750, avg=-21.3740, min=-34.2500, std=5.3786 +SlimeVolley: 2024-10-25 16:19:55,616 [INFO] Iter=140, size=128, max=-10.1250, avg=-20.8486, min=-33.7500, std=5.2729 +SlimeVolley: 2024-10-25 16:20:23,589 [INFO] Iter=150, size=128, max=-8.8750, avg=-20.5293, min=-34.4375, std=5.3173 +SlimeVolley: 2024-10-25 16:20:23,807 [INFO] [TEST] Iter=150, #tests=100, max=0.0000, avg=-4.0300, min=-5.0000, std=1.2446 +SlimeVolley: 2024-10-25 16:20:51,596 [INFO] Iter=160, size=128, max=-6.1875, avg=-20.3066, min=-33.8750, std=5.7341 +SlimeVolley: 2024-10-25 16:21:19,786 [INFO] Iter=170, size=128, max=-8.5625, avg=-19.2305, min=-31.2500, std=5.5489 +SlimeVolley: 2024-10-25 16:21:47,856 [INFO] Iter=180, size=128, max=-6.3125, avg=-18.9951, min=-33.1875, std=5.8729 +SlimeVolley: 2024-10-25 16:22:15,849 [INFO] Iter=190, size=128, max=-6.5625, avg=-19.1255, min=-33.5625, std=5.9446 +SlimeVolley: 2024-10-25 16:22:43,580 [INFO] Iter=200, size=128, max=-4.6875, avg=-19.2197, min=-34.6875, std=6.3952 +SlimeVolley: 2024-10-25 16:22:43,797 [INFO] [TEST] Iter=200, #tests=100, max=1.0000, avg=-2.5800, min=-5.0000, std=1.6136 +SlimeVolley: 2024-10-25 16:23:11,585 [INFO] Iter=210, size=128, max=-5.3750, avg=-17.8315, min=-30.6875, std=5.7723 +SlimeVolley: 2024-10-25 16:23:39,598 [INFO] Iter=220, size=128, max=-4.5625, avg=-15.6294, min=-34.0000, std=6.3958 +SlimeVolley: 2024-10-25 16:24:07,345 [INFO] Iter=230, size=128, max=-2.0625, avg=-15.5928, min=-32.0000, std=6.7743 +SlimeVolley: 2024-10-25 16:24:35,088 [INFO] Iter=240, size=128, max=-1.6875, avg=-14.7544, min=-32.7500, std=6.4414 +SlimeVolley: 2024-10-25 16:25:03,188 [INFO] Iter=250, size=128, max=-3.4375, avg=-15.2583, min=-32.5000, std=6.0321 +SlimeVolley: 2024-10-25 16:25:03,406 [INFO] [TEST] Iter=250, #tests=100, max=2.0000, avg=-0.6900, min=-4.0000, std=1.1636 +SlimeVolley: 2024-10-25 16:25:31,853 [INFO] Iter=260, size=128, max=-2.8125, avg=-14.3882, min=-32.7500, std=6.8517 +SlimeVolley: 2024-10-25 16:25:59,919 [INFO] Iter=270, size=128, max=-0.8750, avg=-13.9692, min=-30.2500, std=6.7459 +SlimeVolley: 2024-10-25 16:26:27,877 [INFO] Iter=280, size=128, max=-2.3750, avg=-13.5034, min=-32.8750, std=6.8843 +SlimeVolley: 2024-10-25 16:26:55,908 [INFO] Iter=290, size=128, max=-3.0000, avg=-13.7354, min=-28.9375, std=6.7271 +SlimeVolley: 2024-10-25 16:27:23,551 [INFO] Iter=300, size=128, max=-0.8125, avg=-13.2305, min=-29.0000, std=6.8488 +SlimeVolley: 2024-10-25 16:27:23,774 [INFO] [TEST] Iter=300, #tests=100, max=4.0000, avg=0.1600, min=-2.0000, std=0.9972 +SlimeVolley: 2024-10-25 16:27:51,388 [INFO] Iter=310, size=128, max=-0.3750, avg=-13.1182, min=-32.3125, std=7.1701 +SlimeVolley: 2024-10-25 16:28:19,521 [INFO] Iter=320, size=128, max=-0.4375, avg=-11.9629, min=-31.1875, std=7.8670 +SlimeVolley: 2024-10-25 16:28:47,465 [INFO] Iter=330, size=128, max=-0.5000, avg=-11.8164, min=-33.6875, std=7.1326 +SlimeVolley: 2024-10-25 16:29:15,827 [INFO] Iter=340, size=128, max=0.3750, avg=-10.4497, min=-29.6250, std=7.0822 +SlimeVolley: 2024-10-25 16:29:44,369 [INFO] Iter=350, size=128, max=0.2500, avg=-9.5391, min=-29.6250, std=6.8241 +SlimeVolley: 2024-10-25 16:29:44,587 [INFO] [TEST] Iter=350, #tests=100, max=3.0000, avg=0.1900, min=-1.0000, std=0.7442 +SlimeVolley: 2024-10-25 16:30:13,174 [INFO] Iter=360, size=128, max=0.4375, avg=-8.4819, min=-31.2500, std=6.3598 +SlimeVolley: 2024-10-25 16:30:40,825 [INFO] Iter=370, size=128, max=0.0625, avg=-8.2344, min=-26.4375, std=5.8934 +SlimeVolley: 2024-10-25 16:31:08,775 [INFO] Iter=380, size=128, max=0.1875, avg=-8.5503, min=-26.1875, std=5.9854 +SlimeVolley: 2024-10-25 16:31:36,716 [INFO] Iter=390, size=128, max=0.1250, avg=-7.1060, min=-28.8750, std=5.6706 +SlimeVolley: 2024-10-25 16:32:04,384 [INFO] Iter=400, size=128, max=0.3125, avg=-6.4419, min=-25.8125, std=5.9703 +SlimeVolley: 2024-10-25 16:32:04,598 [INFO] [TEST] Iter=400, #tests=100, max=3.0000, avg=0.3500, min=-1.0000, std=0.7263 +SlimeVolley: 2024-10-25 16:32:32,778 [INFO] Iter=410, size=128, max=0.1875, avg=-7.0156, min=-24.5625, std=5.3086 +SlimeVolley: 2024-10-25 16:33:00,940 [INFO] Iter=420, size=128, max=-0.1875, avg=-6.1362, min=-19.8750, std=4.0926 +SlimeVolley: 2024-10-25 16:33:28,953 [INFO] Iter=430, size=128, max=0.6875, avg=-6.3506, min=-27.6875, std=5.3467 +SlimeVolley: 2024-10-25 16:33:56,587 [INFO] Iter=440, size=128, max=0.5000, avg=-5.7314, min=-19.6250, std=4.7511 +SlimeVolley: 2024-10-25 16:34:24,380 [INFO] Iter=450, size=128, max=0.5000, avg=-6.0801, min=-25.5000, std=5.5417 +SlimeVolley: 2024-10-25 16:34:24,591 [INFO] [TEST] Iter=450, #tests=100, max=2.0000, avg=0.3800, min=-2.0000, std=0.6600 +SlimeVolley: 2024-10-25 16:34:52,561 [INFO] Iter=460, size=128, max=0.5625, avg=-6.3086, min=-19.6250, std=5.1042 +SlimeVolley: 2024-10-25 16:35:20,585 [INFO] Iter=470, size=128, max=0.3125, avg=-4.7056, min=-21.8125, std=4.5107 +SlimeVolley: 2024-10-25 16:35:48,447 [INFO] Iter=480, size=128, max=0.6875, avg=-4.5259, min=-22.7500, std=4.2573 +SlimeVolley: 2024-10-25 16:36:16,096 [INFO] Iter=490, size=128, max=0.6875, avg=-4.6367, min=-22.5625, std=4.8341 +SlimeVolley: 2024-10-25 16:36:41,075 [INFO] [TEST] Iter=500, #tests=100, max=3.0000, avg=0.4200, min=-1.0000, std=0.8388 +SlimeVolley: 2024-10-25 16:36:41,075 [INFO] Training done, best_score=0.4200 +SlimeVolley: 2024-10-25 16:36:41,078 [INFO] Loaded model parameters from ./log/slimevolley. +SlimeVolley: 2024-10-25 16:36:41,078 [INFO] Start to test the parameters. +SlimeVolley: 2024-10-25 16:36:41,292 [INFO] [TEST] #tests=100, max=3.0000, avg=0.5100, min=-1.0000, std=0.7810 +SlimeVolley: 2024-10-25 16:43:27,997 [INFO] EvoJAX SlimeVolley +SlimeVolley: 2024-10-25 16:43:27,997 [INFO] ============================== +SlimeVolley: 2024-10-25 16:43:28,891 [INFO] use_for_loop=False +SlimeVolley: 2024-10-25 16:43:28,899 [INFO] Start to train for 500 iterations. +SlimeVolley: 2024-10-25 16:43:46,867 [INFO] EvoJAX SlimeVolley +SlimeVolley: 2024-10-25 16:43:46,867 [INFO] ============================== +SlimeVolley: 2024-10-25 16:43:47,423 [INFO] use_for_loop=False +SlimeVolley: 2024-10-25 16:43:47,432 [INFO] Start to train for 10 iterations. +SlimeVolley: 2024-10-25 16:44:16,864 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129 +SlimeVolley: 2024-10-25 16:44:16,872 [INFO] Training done, best_score=-4.8900 +SlimeVolley: 2024-10-25 16:44:16,874 [INFO] Loaded model parameters from ./log/slimevolley. +SlimeVolley: 2024-10-25 16:44:16,874 [INFO] Start to test the parameters. +SlimeVolley: 2024-10-25 16:44:17,084 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071 +SlimeVolley: 2024-10-25 16:46:17,997 [INFO] EvoJAX SlimeVolley +SlimeVolley: 2024-10-25 16:46:17,997 [INFO] ============================== +SlimeVolley: 2024-10-25 16:46:18,695 [INFO] use_for_loop=False +SlimeVolley: 2024-10-25 16:46:18,704 [INFO] Start to train for 10 iterations. +SlimeVolley: 2024-10-25 16:46:48,159 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129 +SlimeVolley: 2024-10-25 16:46:48,166 [INFO] Training done, best_score=-4.8900 +SlimeVolley: 2024-10-25 16:46:48,167 [INFO] Loaded model parameters from ./log/slimevolley. +SlimeVolley: 2024-10-25 16:46:48,167 [INFO] Start to test the parameters. +SlimeVolley: 2024-10-25 16:46:48,381 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071 +SlimeVolley: 2024-10-25 16:46:58,171 [INFO] EvoJAX SlimeVolley +SlimeVolley: 2024-10-25 16:46:58,171 [INFO] ============================== +SlimeVolley: 2024-10-25 16:46:58,845 [INFO] use_for_loop=False +SlimeVolley: 2024-10-25 16:46:58,854 [INFO] Start to train for 10 iterations. +SlimeVolley: 2024-10-25 16:47:28,143 [INFO] [TEST] Iter=10, #tests=100, max=-4.0000, avg=-4.8900, min=-5.0000, std=0.3129 +SlimeVolley: 2024-10-25 16:47:28,149 [INFO] Training done, best_score=-4.8900 +SlimeVolley: 2024-10-25 16:47:28,150 [INFO] Loaded model parameters from ./log/slimevolley. +SlimeVolley: 2024-10-25 16:47:28,150 [INFO] Start to test the parameters. +SlimeVolley: 2024-10-25 16:47:28,364 [INFO] [TEST] #tests=100, max=-3.0000, avg=-4.7300, min=-5.0000, std=0.5071 From e35e5352f730b75da0d0a21e7679971a52427562 Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 17:11:54 -0400 Subject: [PATCH 2/6] fix: error in slimevolley --- evojax/task/slimevolley.py | 103 +++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/evojax/task/slimevolley.py b/evojax/task/slimevolley.py index 0d284af3..0d3d2dc9 100644 --- a/evojax/task/slimevolley.py +++ b/evojax/task/slimevolley.py @@ -38,8 +38,6 @@ import math import numpy as np -import jax.numpy as jnp - from typing import Tuple @@ -271,7 +269,18 @@ class AgentState(object): def initAgentState(direction, x, y): - return AgentState(direction, x, y, 1.5, 0, 0, 0, 0, MAXLIVES) + return AgentState( + direction=jnp.int32(direction), + x=jnp.float32(x), + y=jnp.float32(y), + r=jnp.float32(1.5), + vx=jnp.float32(0), + vy=jnp.float32(0), + desired_vx=jnp.float32(0), + desired_vy=jnp.float32(0), + life=jnp.int32(MAXLIVES) + ) + @dataclass @@ -337,8 +346,10 @@ def __init__(self, p: ParticleState, c): self.c = c def display(self, canvas): - return circle(canvas, toX(float(self.p.x)), toY(float(self.p.y)), - toP(float(self.p.r)), color=self.c) + x = float(self.p.x.item()) + y = float(self.p.y.item()) + r = float(self.p.r.item()) + return circle(canvas, toX(x), toY(y), toP(r), color=self.c) def move(self): self.p = ParticleState(self.p.x+self.p.vx*TIMESTEP, @@ -560,12 +571,22 @@ def update(self): p.life) def updateLife(self, result): - """ updates the life based on result and internal direction """ + """Updates the life based on result and internal direction.""" p = self.p - updateAmount = p.direction*result # only update if this value is -1 - new_life = jnp.where(updateAmount < 0, p.life-1, p.life) - self.p = AgentState(p.direction, p.x, p.y, p.r, p.vx, p.vy, - p.desired_vx, p.desired_vy, new_life) + updateAmount = p.direction * result # This should be a scalar + new_life = jnp.where(updateAmount < 0, p.life - 1, p.life) + self.p = AgentState( + direction=p.direction, + x=p.x, + y=p.y, + r=p.r, + vx=p.vx, + vy=p.vy, + desired_vx=p.desired_vx, + desired_vy=p.desired_vy, + life=new_life + ) + def updateState(self, ball: ParticleState, opponent: AgentState): """ normalized to side, customized for each agent's perspective""" @@ -593,13 +614,13 @@ def getObservation(self): return getObsArray(self.state) def display(self, canvas, ball_x, ball_y): - bx = float(jnp.squeeze(ball_x)) - by = float(jnp.squeeze(ball_y)) + bx = float(ball_x.item()) + by = float(ball_y.item()) p = self.p - x = float(jnp.squeeze(p.x)) - y = float(jnp.squeeze(p.y)) - r = float(jnp.squeeze(p.r)) - direction = int(jnp.squeeze(p.direction)) + x = float(p.x.item()) + y = float(p.y.item()) + r = float(p.r.item()) + direction = int(p.direction.item()) angle = math.pi * 60 / 180 if direction == 1: @@ -610,31 +631,45 @@ def display(self, canvas, ball_x, ball_y): canvas = half_circle(canvas, toX(x), toY(y), toP(r), color=self.c) # track ball with eyes (replace with observed info later): - c = math.cos(angle) - s = math.sin(angle) - ballX = bx-(x+(0.6)*r*c) - ballY = by-(y+(0.6)*r*s) - - dist = math.sqrt(ballX*ballX+ballY*ballY) - eyeX = ballX/dist - eyeY = ballY/dist - - canvas = circle(canvas, toX(x+(0.6)*r*c), toY(y+(0.6)*r*s), - toP(r)*0.3, color=(255, 255, 255)) - canvas = circle(canvas, toX(x+(0.6)*r*c+eyeX*0.15*r), - toY(y+(0.6)*r*s+eyeY*0.15*r), toP(r)*0.1, - color=(0, 0, 0)) + c_angle = math.cos(angle) + s_angle = math.sin(angle) + ballX = bx - (x + 0.6 * r * c_angle) + ballY = by - (y + 0.6 * r * s_angle) + + dist = math.sqrt(ballX * ballX + ballY * ballY) + eyeX = ballX / dist + eyeY = ballY / dist + + canvas = circle( + canvas, + toX(x + 0.6 * r * c_angle), + toY(y + 0.6 * r * s_angle), + toP(r) * 0.3, + color=(255, 255, 255) + ) + canvas = circle( + canvas, + toX(x + 0.6 * r * c_angle + eyeX * 0.15 * r), + toY(y + 0.6 * r * s_angle + eyeY * 0.15 * r), + toP(r) * 0.1, + color=(0, 0, 0) + ) # draw coins (lives) left - num_lives = int(jnp.squeeze(p.life)) + num_lives = int(p.life.item()) for i in range(1, num_lives): - canvas = circle(canvas, toX(direction*(REF_W/2+0.5-i*2.)), - WINDOW_HEIGHT-toY(1.5), toP(0.5), - color=COIN_COLOR) + canvas = circle( + canvas, + toX(direction * (REF_W / 2 + 0.5 - i * 2.0)), + WINDOW_HEIGHT - toY(1.5), + toP(0.5), + color=COIN_COLOR + ) return canvas + class Wall: """ used for the fence, and also the ground """ def __init__(self, x, y, w, h, c): From 2a7963120e4fa478c13a8a9bfda5624ee12cd8d5 Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 17:15:44 -0400 Subject: [PATCH 3/6] change back default iterations to 500 --- examples/train_slimevolley.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/train_slimevolley.py b/examples/train_slimevolley.py index 72a532f1..6f694831 100644 --- a/examples/train_slimevolley.py +++ b/examples/train_slimevolley.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( '--n-repeats', type=int, default=16, help='Training repetitions.') parser.add_argument( - '--max-iter', type=int, default=10, help='Max training iterations.') + '--max-iter', type=int, default=500, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=50, help='Test interval.') parser.add_argument( From e747f9fdef00b766a694d916a36aef052ec85c88 Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 17:26:16 -0400 Subject: [PATCH 4/6] recheck cla --- evojax/task/slimevolley.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evojax/task/slimevolley.py b/evojax/task/slimevolley.py index 0d3d2dc9..af080376 100644 --- a/evojax/task/slimevolley.py +++ b/evojax/task/slimevolley.py @@ -349,7 +349,7 @@ def display(self, canvas): x = float(self.p.x.item()) y = float(self.p.y.item()) r = float(self.p.r.item()) - return circle(canvas, toX(x), toY(y), toP(r), color=self.c) + return circle(canvas, toX(x), toY(y), toP(r), color=self.c) def move(self): self.p = ParticleState(self.p.x+self.p.vx*TIMESTEP, From 1d8099f876b2d608af176b888c9244f052633691 Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 18:09:02 -0400 Subject: [PATCH 5/6] fix: JAX dimension error --- examples/train_slimevolley.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/train_slimevolley.py b/examples/train_slimevolley.py index 6f694831..1f03eccd 100644 --- a/examples/train_slimevolley.py +++ b/examples/train_slimevolley.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( '--n-repeats', type=int, default=16, help='Training repetitions.') parser.add_argument( - '--max-iter', type=int, default=500, help='Max training iterations.') + '--max-iter', type=int, default=10, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=50, help='Test interval.') parser.add_argument( @@ -125,7 +125,7 @@ def main(config): trainer.model_dir = log_dir trainer.run(demo_mode=True) - # Visualize the policy. + # Visualize the policy task_reset_fn = jax.jit(test_task.reset) policy_reset_fn = jax.jit(policy.reset) step_fn = jax.jit(test_task.step) @@ -136,19 +136,34 @@ def main(config): task_state = task_reset_fn(key) policy_state = policy_reset_fn(task_state) screens = [] + for _ in range(max_steps): action, policy_state = action_fn(task_state, best_params, policy_state) task_state, reward, done = step_fn(task_state, action) - screens.append(SlimeVolley.render(task_state)) - - gif_file = os.path.join(log_dir, 'slimevolley.gif') - screens[0].save(gif_file, save_all=True, append_images=screens[1:], - duration=40, loop=0) - logger.info('GIF saved to {}.'.format(gif_file)) + + # Extract scalar values from the batched state + state_numpy = jax.device_get(task_state) + # Assuming the first element of the batch is what we want to visualize + unbatched_state = jax.tree_map(lambda x: x[0], state_numpy) + + try: + screen = SlimeVolley.render(unbatched_state) + screens.append(screen) + except Exception as e: + logger.error(f"Error during rendering: {e}") + break + + if screens: + gif_file = os.path.join(log_dir, 'slimevolley.gif') + screens[0].save(gif_file, save_all=True, append_images=screens[1:], + duration=40, loop=0) + logger.info('GIF saved to {}.'.format(gif_file)) + else: + logger.error('No frames were rendered successfully.') if __name__ == '__main__': configs = parse_args() if configs.gpu_id is not None: os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id - main(configs) + main(configs) \ No newline at end of file From 6530e3f633b336593676bcd298a1c0183591b036 Mon Sep 17 00:00:00 2001 From: DorsaRoh Date: Fri, 25 Oct 2024 18:15:11 -0400 Subject: [PATCH 6/6] fix: deprecated jax tree.map --- examples/train_slimevolley.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/train_slimevolley.py b/examples/train_slimevolley.py index 1f03eccd..8b4205be 100644 --- a/examples/train_slimevolley.py +++ b/examples/train_slimevolley.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( '--n-repeats', type=int, default=16, help='Training repetitions.') parser.add_argument( - '--max-iter', type=int, default=10, help='Max training iterations.') + '--max-iter', type=int, default=500, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=50, help='Test interval.') parser.add_argument( @@ -144,7 +144,7 @@ def main(config): # Extract scalar values from the batched state state_numpy = jax.device_get(task_state) # Assuming the first element of the batch is what we want to visualize - unbatched_state = jax.tree_map(lambda x: x[0], state_numpy) + unbatched_state = jax.tree.map(lambda x: x[0], state_numpy) try: screen = SlimeVolley.render(unbatched_state)