diff --git a/.bazelrc b/.bazelrc index 2584bd43..4640259f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,7 +2,9 @@ build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm build --action_env=BAZEL_LINKOPTS=-static-libgcc build --action_env=CUDA_DIR=/usr/local/cuda build --action_env=LD_LIBRARY_PATH=/usr/local/lib:/usr/lib/nvidia -build --incompatible_strict_action_env --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --client_env=BAZEL_CXXOPTS=-std=c++17 +build --incompatible_strict_action_env --cxxopt=-std=c++20 --host_cxxopt=-std=c++20 --client_env=BAZEL_CXXOPTS=-std=c++20 +build --action_env=CC=clang +build --action_env=CXX=clang++ build:debug --cxxopt=-DENVPOOL_TEST --compilation_mode=dbg -s build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx build:release --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx diff --git a/.circleci/config.yml b/.circleci/config.yml index 2fbcb01f..a2855ec2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ parameters: docker_img_version: # Docker image version for running tests. type: string - default: "8d8cf1a-envpool-ci" + default: "614e97a-envpool-devbox" workflows: test-jobs: @@ -26,7 +26,7 @@ workflows: jobs: lint: docker: - - image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >> + - image: ghcr.io/alignmentresearch/train-learned-planner:<< pipeline.parameters.docker_img_version >> auth: username: "$GHCR_DOCKER_USER" password: "$GHCR_DOCKER_TOKEN" @@ -51,10 +51,10 @@ jobs: name: clang-format command: | make clang-format - - run: - name: clang-tidy - command: | - make clang-tidy + # - run: + # name: clang-tidy + # command: | + # make clang-tidy - run: name: buildifier command: | @@ -67,7 +67,7 @@ jobs: tests: docker: - - image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >> + - image: ghcr.io/alignmentresearch/train-learned-planner:<< pipeline.parameters.docker_img_version >> auth: username: "$GHCR_DOCKER_USER" password: "$GHCR_DOCKER_TOKEN" diff --git a/.clang-tidy b/.clang-tidy index fed4549b..771a0733 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -54,4 +54,3 @@ CheckOptions: - { key: readability-identifier-naming.VariableCase, value: lower_case } WarningsAsErrors: '*' HeaderFilterRegex: '/envpool/' -AnalyzeTemporaryDtors: true diff --git a/.gitignore b/.gitignore index d784028f..ede4e92e 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ log _vizdoom* MUJOCO_LOG.TXT .vscode/ + + +MODULE.bazel.lock diff --git a/BUILD b/BUILD index 8050ac27..865d8fce 100644 --- a/BUILD +++ b/BUILD @@ -1,4 +1,5 @@ load("@pip_requirements//:requirements.bzl", "requirement") +load("@rules_python//python:packaging.bzl", "py_package", "py_wheel", "py_wheel_dist") filegroup( name = "clang_tidy_config", @@ -22,3 +23,39 @@ py_binary( requirement("wheel"), ], ) + +# Collect transitive dependencies of envpool +py_package( + name = "pkg", + packages = [], + deps = ["//envpool"], +) + +py_wheel( + name = "wheel", + abi = "cp312", + distribution = "far_envpool", + platform = "manylinux2014_x86_64", + python_tag = "cp312", + requires=[ + "numpy>=2.2.0", + "dm-env>=1.6", + "gym>=0.26", + "gymnasium>=0.26,!=0.27.0", + "optree>=0.6.0", + "jax>=0.5.0", + "pytest", + ], + python_requires = ">=3.10,<3.13", + twine = None, + version = "0.9.0", + deps = [ + ":pkg", + ], +) + +py_wheel_dist( + name = "wheel_dist", + out = "dist", + wheel = "wheel", +) diff --git a/MODULE.bazel b/MODULE.bazel new file mode 100644 index 00000000..ce30cf05 --- /dev/null +++ b/MODULE.bazel @@ -0,0 +1,447 @@ +module( + name = "envpool", + repo_name = "envpool", + # Adjust the version if you wish. + version = "0.0.1", +) + +# Pull in rules_python with Bzlmod +bazel_dep(name = "rules_python", version = "1.1.0") + +# +# 1) Declare and configure the Python toolchain for Python 3.12. +# +python = use_extension("@rules_python//python/extensions:python.bzl", "python") +python.toolchain( + python_version = "3.12", + is_default = True, +) + +# Actually create the repos for the python_3_12 toolchain +use_repo(python, "python_3_12") + +# Load CUDA module extension and use it +cuda_configure = use_repo_rule("//third_party/cuda:cuda.bzl", "cuda_configure") +cuda_configure(name="cuda") + +# +# 2) Parse and install your pip packages. This uses the new pip extension +# to directly list dependencies in the MODULE.bazel file. +# +pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") +pip.parse( + # "hub_name" is an identifier for the generated repos (like "pip__wheel__numpy__...") + hub_name = "pip_requirements", + python_version = "3.12", + # We list dependencies for "linux_x86_64" by referencing our file + requirements_by_platform = { + "//third_party/pip_requirements:requirements-dev-locked.txt": "linux_x86_64", + }, +) + +# Actually create the repos for these pip dependencies +use_repo(pip, "pip_requirements") + +# Other dependencies +bazel_dep(name = "platforms", version = "0.0.11") +bazel_dep(name = "pybind11_bazel", version = "2.13.6") + +# Add bazel_skylib dependency +bazel_dep(name = "bazel_skylib", version = "1.7.1") + +# rules_foreign_cc dependency with toolchain registration +bazel_dep(name = "rules_foreign_cc", version = "0.14.0") +# Register all the toolchains for rules_foreign_cc +register_toolchains("@rules_foreign_cc//toolchains:all") + +bazel_dep(name = "glog") +archive_override( + module_name = "glog", + urls = ["https://github.com/google/glog/archive/4f007d96212d3dfd11dfaaf9ed7758fd1ea37a25.tar.gz"], + strip_prefix = "glog-4f007d96212d3dfd11dfaaf9ed7758fd1ea37a25", + integrity = "sha256-9hkaq2gE/U4mGZinLTBbLpi1wXgkThpSwBzEBCSF/Hw=", +) + +bazel_dep(name = "rules_boost", repo_name = "com_github_nelhage_rules_boost") +archive_override( + module_name = "rules_boost", + urls = ["https://github.com/nelhage/rules_boost/archive/refs/heads/master.tar.gz"], + strip_prefix = "rules_boost-master", + integrity = "sha256-MKo5D1Ifiiwk2OUaZCwOfZQTmgUfASeve9ug5CeUC8g=", +) + +non_module_boost_repositories = use_extension("@com_github_nelhage_rules_boost//:boost/repositories.bzl", "non_module_dependencies") +use_repo( + non_module_boost_repositories, + "boost", +) + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# NOTE: The best way to configure Qt with Bazel MODULE.bazel would be: +# 1. Use qt_configure to create a local_config_qt repository +# 2. Use local_qt_path to get the path to the Qt installation +# 3. Create a new_local_repository called "qt" with the build file from com_justbuchanan_rules_qt +# 4. Register the Qt toolchains +# +# However, this approach has issues with the current Bazel module system. +# For a WORKSPACE-based project, the configuration would look like: +# +# load("@com_justbuchanan_rules_qt//:qt_configure.bzl", "qt_configure") +# qt_configure() +# load("@local_config_qt//:local_qt.bzl", "local_qt_path") +# new_local_repository( +# name = "qt", +# build_file = "@com_justbuchanan_rules_qt//:qt.BUILD", +# path = local_qt_path(), +# ) +# load("@com_justbuchanan_rules_qt//tools:qt_toolchain.bzl", "register_qt_toolchains") +# register_qt_toolchains() + +# Commenting out Qt configuration to move on without it +# qt_configure = use_repo_rule("@com_justbuchanan_rules_qt//:qt_configure.bzl", "qt_configure") +# qt_configure(name = "local_config_qt") +# +# local_qt_path = use_repo_rule("@local_config_qt//:local_qt.bzl", "local_qt_path") +# new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:new_local_repository.bzl", "new_local_repository") +# new_local_repository( +# name = "qt", +# build_file = "@com_justbuchanan_rules_qt//:qt.BUILD", +# path = local_qt_path(), +# ) + + +############################################## +# 2) Declare each http_archive exactly once. +# The "name" attribute is your final repo name. +############################################## + +# com_google_absl +http_archive( + name = "com_google_absl", + sha256 = "497ebdc3a4885d9209b9bd416e8c3f71e7a1fb8af249f6c2a80b7cbeefcd7e21", + strip_prefix = "abseil-cpp-20230802.1", + urls = [ + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.1.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/abseil/abseil-cpp/20230802.1.zip", + ], +) + +# com_github_gflags_gflags +http_archive( + name = "com_github_gflags_gflags", + sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", + strip_prefix = "gflags-2.2.2", + urls = [ + "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/gflags/gflags/v2.2.2.tar.gz", + ], +) + +# com_github_google_glog +http_archive( + name = "com_github_google_glog", + sha256 = "122fb6b712808ef43fbf80f75c52a21c9760683dae470154f02bddfc61135022", + strip_prefix = "glog-0.6.0", + urls = [ + "https://github.com/google/glog/archive/v0.6.0.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/google/glog/v0.6.0.zip", + ], +) + +# com_google_googletest +http_archive( + name = "com_google_googletest", + sha256 = "8ad598c73ad796e0d8280b082cebd82a630d73e73cd3c70057938a6501bba5d7", + strip_prefix = "googletest-1.14.0", + urls = [ + "https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/google/googletest/v1.14.0.tar.gz", + ], +) + +# com_justbuchanan_rules_qt +http_archive( + name = "com_justbuchanan_rules_qt", + sha256 = "6b42a58f062b3eea10ada5340cd8f63b47feb986d16794b0f8e0fde750838348", + strip_prefix = "bazel_rules_qt-3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9", + urls = [ + "https://github.com/justbuchanan/bazel_rules_qt/archive/3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/justbuchanan/bazel_rules_qt/3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9.tar.gz", + ], +) + +# glibc_version_header +http_archive( + name = "glibc_version_header", + sha256 = "57db74f933b7a9ea5c653498640431ce0e52aaef190d6bb586711ec4f8aa2b9e", + strip_prefix = "glibc_version_header-0.1/version_headers/", + urls = [ + "https://github.com/wheybags/glibc_version_header/archive/refs/tags/0.1.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/wheybags/glibc_version_header/0.1.tar.gz", + ], + build_file = "//third_party/glibc_version_header:glibc_version_header.BUILD", +) + +# concurrentqueue +http_archive( + name = "concurrentqueue", + sha256 = "87fbc9884d60d0d4bf3462c18f4c0ee0a9311d0519341cac7cbd361c885e5281", + strip_prefix = "concurrentqueue-1.0.4", + urls = [ + "https://github.com/cameron314/concurrentqueue/archive/refs/tags/v1.0.4.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/cameron314/concurrentqueue/v1.0.4.tar.gz", + ], + build_file = "//third_party/concurrentqueue:concurrentqueue.BUILD", +) + +# threadpool +http_archive( + name = "threadpool", + sha256 = "18854bb7ecc1fc9d7dda9c798a1ef0c81c2dd331d730c76c75f648189fa0c20f", + strip_prefix = "ThreadPool-9a42ec1329f259a5f4881a291db1dcb8f2ad9040", + urls = [ + "https://github.com/progschj/ThreadPool/archive/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/progschj/ThreadPool/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip", + ], + build_file = "//third_party/threadpool:threadpool.BUILD", +) + +# zlib +http_archive( + name = "zlib", + sha256 = "ff0ba4c292013dbc27530b3a81e1f9a813cd39de01ca5e0f8bf355702efa593e", + strip_prefix = "zlib-1.3", + urls = [ + "https://github.com/madler/zlib/releases/download/v1.3/zlib-1.3.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/madler/zlib/zlib-1.3.tar.gz", + ], + build_file = "//third_party/zlib:zlib.BUILD", +) + +# opencv +http_archive( + name = "opencv", + sha256 = "62f650467a60a38794d681ae7e66e3e8cfba38f445e0bf87867e2f2cdc8be9d5", + strip_prefix = "opencv-4.8.1", + urls = [ + "https://github.com/opencv/opencv/archive/refs/tags/4.8.1.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/opencv/opencv/4.8.1.tar.gz", + ], + build_file = "//third_party/opencv:opencv.BUILD", +) + +# pugixml +http_archive( + name = "pugixml", + sha256 = "610f98375424b5614754a6f34a491adbddaaec074e9044577d965160ec103d2e", + strip_prefix = "pugixml-1.14/src", + urls = [ + "https://github.com/zeux/pugixml/archive/refs/tags/v1.14.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/zeux/pugixml/v1.14.tar.gz", + ], + build_file = "//third_party/pugixml:pugixml.BUILD", +) + +# ale +http_archive( + name = "ale", + sha256 = "28960616cd89c18925ced7bbdeec01ab0b2ebd2d8ce5b7c88930e97381b4c3b5", + strip_prefix = "Arcade-Learning-Environment-0.8.1", + urls = [ + "https://github.com/mgbellemare/Arcade-Learning-Environment/archive/refs/tags/v0.8.1.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/mgbellemare/Arcade-Learning-Environment/v0.8.1.tar.gz", + ], + build_file = "//third_party/ale:ale.BUILD", +) + +# atari_roms +http_archive( + name = "atari_roms", + sha256 = "e39e9fc379fe3f336911d928ce0a52e6ff6861258906efc5e849390867ff35f5", + urls = [ + "https://roms8.s3.us-east-2.amazonaws.com/Roms.tar.gz", + "https://cdn.sail.sea.com/sail/Roms.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/atari/Roms.tar.gz", + ], + build_file = "//third_party/atari_roms:atari_roms.BUILD", +) + +# libjpeg_turbo +http_archive( + name = "libjpeg_turbo", + sha256 = "b3090cd37b5a8b3e4dbd30a1311b3989a894e5d3c668f14cbc6739d77c9402b7", + strip_prefix = "libjpeg-turbo-2.0.5", + urls = [ + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.5.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/libjpeg-turbo/libjpeg-turbo/2.0.5.tar.gz", + ], + build_file = "//third_party/jpeg:jpeg.BUILD", +) + +# nasm +http_archive( + name = "nasm", + sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011", + strip_prefix = "nasm-2.13.03", + urls = [ + "https://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/nasm/nasm-2.13.03.tar.bz2", + ], + build_file = "//third_party/nasm:nasm.BUILD", +) + +# sdl2 +http_archive( + name = "sdl2", + sha256 = "888b8c39f36ae2035d023d1b14ab0191eb1d26403c3cf4d4d5ede30e66a4942c", + strip_prefix = "SDL2-2.28.4", + urls = [ + "https://www.libsdl.org/release/SDL2-2.28.4.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/libsdl/SDL2-2.28.4.tar.gz", + ], + build_file = "//third_party/sdl2:sdl2.BUILD", +) + +# freedoom +http_archive( + name = "freedoom", + sha256 = "f42c6810fc89b0282de1466c2c9c7c9818031a8d556256a6db1b69f6a77b5806", + strip_prefix = "freedoom-0.12.1/", + urls = [ + "https://github.com/freedoom/freedoom/releases/download/v0.12.1/freedoom-0.12.1.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/freedoom/freedoom/freedoom-0.12.1.zip", + ], + build_file = "//third_party/freedoom:freedoom.BUILD", +) + +# vizdoom +http_archive( + name = "vizdoom", + sha256 = "e379a242ada7e1028b7a635da672b0936d99da3702781b76a4400b83602d78c4", + strip_prefix = "ViZDoom-1.1.13/src/vizdoom/", + urls = [ + "https://github.com/Farama-Foundation/ViZDoom/archive/refs/tags/1.1.13.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Farama-Foundation/ViZDoom/1.1.13.tar.gz", + ], + build_file = "//third_party/vizdoom:vizdoom.BUILD", + patches = ["//third_party/vizdoom:sdl_thread.patch"], +) + +# vizdoom_lib +http_archive( + name = "vizdoom_lib", + sha256 = "e379a242ada7e1028b7a635da672b0936d99da3702781b76a4400b83602d78c4", + strip_prefix = "ViZDoom-1.1.13/", + urls = [ + "https://github.com/Farama-Foundation/ViZDoom/archive/refs/tags/1.1.13.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Farama-Foundation/ViZDoom/1.1.13.tar.gz", + ], + build_file = "//third_party/vizdoom_lib:vizdoom_lib.BUILD", +) + +# vizdoom_extra_maps +http_archive( + name = "vizdoom_extra_maps", + sha256 = "325440fe566ff478f35947c824ea5562e2735366845d36c5a0e40867b59f7d69", + strip_prefix = "DirectFuturePrediction-b4757769f167f1bd7fb1ece5fdc6d874409c68a9/", + urls = [ + "https://github.com/isl-org/DirectFuturePrediction/archive/b4757769f167f1bd7fb1ece5fdc6d874409c68a9.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/isl-org/DirectFuturePrediction/b4757769f167f1bd7fb1ece5fdc6d874409c68a9.zip", + ], + build_file = "//third_party/vizdoom_extra_maps:vizdoom_extra_maps.BUILD", +) + +# mujoco +http_archive( + name = "mujoco", + sha256 = "d1cb3a720546240d894cd315b7fd358a2b96013a1f59b6d718036eca6f6edac2", + strip_prefix = "mujoco-2.2.1", + urls = [ + "https://github.com/deepmind/mujoco/releases/download/2.2.1/mujoco-2.2.1-linux-x86_64.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/deepmind/mujoco/mujoco-2.2.1-linux-x86_64.tar.gz", + ], + build_file = "//third_party/mujoco:mujoco.BUILD", +) + +# mujoco_gym_xml +http_archive( + name = "mujoco_gym_xml", + sha256 = "96a5fc8345bd92b73a15fc25112d53a294f86fcace1c5e4ef7f0e052b5e1bdf4", + strip_prefix = "gym-0.26.2/gym/envs/mujoco", + urls = [ + "https://github.com/openai/gym/archive/refs/tags/0.26.2.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/openai/gym/0.26.2.tar.gz", + ], + build_file = "//third_party/mujoco_gym_xml:mujoco_gym_xml.BUILD", +) + +# mujoco_dmc_xml +http_archive( + name = "mujoco_dmc_xml", + sha256 = "fb8d57cbeb92bebe56a992dab8401bc00b3bff61b62526eb563854adf3dfb595", + strip_prefix = "dm_control-1.0.9/dm_control", + urls = [ + "https://github.com/deepmind/dm_control/archive/refs/tags/1.0.9.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/deepmind/dm_control/1.0.9.tar.gz", + ], + build_file = "//third_party/mujoco_dmc_xml:mujoco_dmc_xml.BUILD", +) + +# box2d +http_archive( + name = "box2d", + sha256 = "d6b4650ff897ee1ead27cf77a5933ea197cbeef6705638dd181adc2e816b23c2", + strip_prefix = "box2d-2.4.1", + urls = [ + "https://github.com/erincatto/box2d/archive/refs/tags/v2.4.1.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/erincatto/box2d/v2.4.1.tar.gz", + ], + build_file = "//third_party/box2d:box2d.BUILD", +) + +# pretrain_weight +http_archive( + name = "pretrain_weight", + sha256 = "b1b64e0db84cf7317c2a96b27f549147dfcb4074ed2d799334c23a067075ac1c", + urls = [ + "https://cdn.sail.sea.com/sail/pretrain.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/pretrain.tar.gz", + ], + build_file = "//third_party/pretrain_weight:pretrain_weight.BUILD", +) + +# procgen +http_archive( + name = "procgen", + sha256 = "d5620394418b885f9028f98759189a5f78bc4ba71fb6605f910ae22fca870c8e", + strip_prefix = "procgen-0.10.8/procgen", + urls = [ + "https://github.com/Trinkle23897/procgen/archive/refs/tags/0.10.8.tar.gz", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Trinkle23897/procgen/0.10.8.tar.gz", + ], + build_file = "//third_party/procgen:procgen.BUILD", +) + +# gym3_libenv +http_archive( + name = "gym3_libenv", + sha256 = "9a764d79d4215609c2612b2c84fec8bcea6609941bdcb7051f3335ed4576b8ef", + strip_prefix = "gym3-4c3824680eaf9dd04dce224ee3d4856429878226/gym3", + urls = [ + "https://github.com/openai/gym3/archive/4c3824680eaf9dd04dce224ee3d4856429878226.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/openai/gym3/4c3824680eaf9dd04dce224ee3d4856429878226.zip", + ], + build_file = "//third_party/gym3_libenv:gym3_libenv.BUILD", +) + +# bazel_clang_tidy +http_archive( + name = "bazel_clang_tidy", + sha256 = "ec8c5bf0c02503b928c2e42edbd15f75e306a05b2cae1f34a7bc84724070b98b", + strip_prefix = "bazel_clang_tidy-783aa523aafb4a6798a538c61e700b6ed27975a7", + urls = [ + "https://github.com/erenon/bazel_clang_tidy/archive/783aa523aafb4a6798a538c61e700b6ed27975a7.zip", + "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/erenon/bazel_clang_tidy/783aa523aafb4a6798a538c61e700b6ed27975a7.zip", + ], +) diff --git a/Makefile b/Makefile index bcce0fad..7a61f9bb 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ clang-tidy-install: go-install: # requires go >= 1.16 - command -v go || (sudo apt-get install -y golang-1.18 && sudo ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go) + command -v go || (sudo apt-get install -y golang-1.21 && sudo ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go) bazel-install: go-install command -v bazel || (go install github.com/bazelbuild/bazelisk@latest && ln -sf $(HOME)/go/bin/bazelisk $(HOME)/go/bin/bazel) @@ -102,19 +102,13 @@ clang-tidy: clang-tidy-install bazel-pip-requirement-dev bazel build $(BAZELOPT) //envpool/core/... //envpool/sokoban/... --config=clang-tidy --config=test bazel-debug: bazel-install bazel-pip-requirement-dev - bazel run $(BAZELOPT) //:setup --config=debug -- bdist_wheel - mkdir -p dist - cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist + bazel build $(BAZELOPT) //:wheel --config=debug bazel-build: bazel-install bazel-pip-requirement-dev - bazel run $(BAZELOPT) //:setup --config=test -- bdist_wheel - mkdir -p dist - cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist + bazel build $(BAZELOPT) //:wheel --config=test bazel-release: bazel-install bazel-pip-requirement-release - bazel run $(BAZELOPT) //:setup --config=release -- bdist_wheel - mkdir -p dist - cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist + bazel build $(BAZELOPT) //:wheel_dist bazel-test: bazel-install bazel-pip-requirement-dev bazel test --test_output=all $(BAZELOPT) //envpool/core/... //envpool/sokoban/... --config=test --spawn_strategy=local --color=yes diff --git a/WORKSPACE b/WORKSPACE deleted file mode 100644 index bb0ab8e3..00000000 --- a/WORKSPACE +++ /dev/null @@ -1,27 +0,0 @@ -workspace(name = "envpool") - -load("//envpool:workspace0.bzl", workspace0 = "workspace") - -workspace0() - -load("//envpool:workspace1.bzl", workspace1 = "workspace") - -workspace1() - -# QT special, cannot move to workspace2.bzl, not sure why - -load("@local_config_qt//:local_qt.bzl", "local_qt_path") - -new_local_repository( - name = "qt", - build_file = "@com_justbuchanan_rules_qt//:qt.BUILD", - path = local_qt_path(), -) - -load("@com_justbuchanan_rules_qt//tools:qt_toolchain.bzl", "register_qt_toolchains") - -register_qt_toolchains() - -load("//envpool:pip.bzl", pip_workspace = "workspace") - -pip_workspace() diff --git a/envpool/BUILD b/envpool/BUILD index b1a7a0dc..3eff1bf3 100644 --- a/envpool/BUILD +++ b/envpool/BUILD @@ -16,11 +16,6 @@ load("@pip_requirements//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) -exports_files([ - "workspace0.bzl", - "workspace1.bzl", -]) - py_library( name = "registration", srcs = ["registration.py"], @@ -30,6 +25,14 @@ py_library( name = "entry", srcs = ["entry.py"], deps = [ + "//envpool/atari:atari_registration", + "//envpool/box2d:box2d_registration", + "//envpool/classic_control:classic_control_registration", + "//envpool/mujoco:mujoco_dmc_registration", + "//envpool/mujoco:mujoco_gym_registration", + # "//envpool/procgen:procgen_registration", # Disabled, we have not installed qt5 in envpool dockerfile + "//envpool/toy_text:toy_text_registration", + "//envpool/vizdoom:vizdoom_registration", "//envpool/sokoban:registration", ], ) @@ -40,6 +43,15 @@ py_library( deps = [ ":entry", ":registration", + "//envpool/python", + "//envpool/atari", + "//envpool/box2d", + "//envpool/classic_control", + "//envpool/mujoco:mujoco_dmc", + "//envpool/mujoco:mujoco_gym", + # "//envpool/procgen", # Disabled, we have not installed qt5 in envpool dockerfile + "//envpool/toy_text", + "//envpool/vizdoom", "//envpool/sokoban", ], ) diff --git a/envpool/__init__.py b/envpool/__init__.py index 1c46bb34..f7b150f2 100644 --- a/envpool/__init__.py +++ b/envpool/__init__.py @@ -24,7 +24,7 @@ register, ) -__version__ = "0.8.4" +__version__ = "0.9.0" __all__ = [ "register", "make", diff --git a/envpool/atari/BUILD b/envpool/atari/BUILD index c11e6bc0..876dd93b 100644 --- a/envpool/atari/BUILD +++ b/envpool/atari/BUILD @@ -14,16 +14,10 @@ load("@pip_requirements//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@envpool//third_party:common.bzl", "copy_directory") package(default_visibility = ["//visibility:public"]) -genrule( - name = "gen_atari_roms", - srcs = ["@atari_roms//:roms"], - outs = ["roms"], - cmd = "mkdir -p $(OUTS) && cp $(SRCS) $(OUTS)", -) - genrule( name = "gen_pretrain_weight", srcs = [ @@ -43,11 +37,19 @@ py_library( deps = ["//envpool/python:api"], ) + +copy_directory( + name = "roms", + src = "@atari_roms//:roms_sources", + out = "", + visibility = ["//visibility:public"], +) + cc_library( name = "atari_env", hdrs = ["atari_env.h"], data = [ - ":gen_atari_roms", + ":roms", ], deps = [ "//envpool/core:async_envpool", diff --git a/envpool/core/BUILD b/envpool/core/BUILD index 79cb9bdd..acafea58 100644 --- a/envpool/core/BUILD +++ b/envpool/core/BUILD @@ -30,7 +30,7 @@ cc_library( name = "spec", hdrs = ["spec.h"], deps = [ - "@com_github_google_glog//:glog", + "@glog", ], ) @@ -39,7 +39,7 @@ cc_library( hdrs = ["array.h"], deps = [ ":spec", - "@com_github_google_glog//:glog", + "@glog", ], ) @@ -51,7 +51,7 @@ cc_library( ":spec", ":tuple_utils", ":type_utils", - "@com_github_google_glog//:glog", + "@glog", ], ) @@ -99,8 +99,8 @@ cc_test( srcs = ["circular_buffer_test.cc"], deps = [ ":circular_buffer", - "@com_github_google_glog//:glog", "@com_google_googletest//:gtest_main", + "@glog", ], ) @@ -185,33 +185,10 @@ cc_library( ], ) -cc_library( - name = "xla_template", - hdrs = ["xla_template.h"], - linkopts = [ - "-ldl", - "-lrt", - ], - deps = [ - "@cuda//:cudart_static", - ], -) - -cc_library( - name = "xla", - hdrs = ["xla.h"], - deps = [ - ":array", - ":xla_template", - "@cuda//:cudart_static", - ], -) - pybind_library( name = "py_envpool", hdrs = ["py_envpool.h"], deps = [ ":envpool", - ":xla", ], ) diff --git a/envpool/core/action_buffer_queue_test.cc b/envpool/core/action_buffer_queue_test.cc index ed4a12e9..f832974c 100644 --- a/envpool/core/action_buffer_queue_test.cc +++ b/envpool/core/action_buffer_queue_test.cc @@ -49,7 +49,7 @@ TEST(ActionBufferQueueTest, Concurrent) { std::thread send([&] { for (std::size_t m = 0; m < mul; ++m) { - while (flag[m] == 1) { + while (flag[m] == 1) { // NOLINT[bugprone-infinite-loop] } actions.clear(); for (std::size_t i = 0; i < env_num[m]; ++i) { diff --git a/envpool/core/py_envpool.h b/envpool/core/py_envpool.h index 35a33e45..2527e2e7 100644 --- a/envpool/core/py_envpool.h +++ b/envpool/core/py_envpool.h @@ -30,7 +30,6 @@ #include #include "envpool/core/envpool.h" -#include "envpool/core/xla.h" namespace py = pybind11; @@ -214,29 +213,6 @@ class PyEnvPool : public EnvPool { explicit PyEnvPool(const PySpec& py_spec) : EnvPool(py_spec), py_spec(py_spec) {} - /** - * get xla functions - */ - auto Xla() { - if (HasContainerType(EnvPool::spec.state_spec)) { - throw std::runtime_error( - "State of this env has dynamic shaped container, xla is disabled"); - } - if (HasDynamicDim(EnvPool::spec.state_spec)) { - throw std::runtime_error( - "State of this env has dynamic (-1) shape, xla is disabled"); - } - if (EnvPool::spec.config["max_num_players"_] != 1) { - throw std::runtime_error( - "Xla is not available for multiplayer environment."); - } - return std::make_tuple( - std::make_tuple("recv", - CustomCall>::Xla(this)), - std::make_tuple("send", - CustomCall>::Xla(this))); - } - /** * py api */ @@ -307,7 +283,6 @@ py::object abc_meta = py::module::import("abc").attr("ABCMeta"); .def("_send", &ENVPOOL::PySend) \ .def("_reset", &ENVPOOL::PyReset) \ .def_readonly_static("_state_keys", &ENVPOOL::py_state_keys) \ - .def_readonly_static("_action_keys", &ENVPOOL::py_action_keys) \ - .def("_xla", &ENVPOOL::Xla); + .def_readonly_static("_action_keys", &ENVPOOL::py_action_keys); #endif // ENVPOOL_CORE_PY_ENVPOOL_H_ diff --git a/envpool/core/state_buffer_queue.h b/envpool/core/state_buffer_queue.h index 56f946c3..dde4e99d 100644 --- a/envpool/core/state_buffer_queue.h +++ b/envpool/core/state_buffer_queue.h @@ -58,7 +58,7 @@ class StateBufferQueue { s.shape[0] == -1); })), specs_(Transform(specs, - [=](ShapeSpec s) { + [this](ShapeSpec s) { if (!s.shape.empty() && s.shape[0] == -1) { // If first dim is num_players s.shape[0] = batch_ * max_num_players_; @@ -85,7 +85,7 @@ class StateBufferQueue { // hardcode here :( std::size_t create_buffer_thread_num = std::max(1UL, processor_count / 64); for (std::size_t i = 0; i < create_buffer_thread_num; ++i) { - create_buffer_thread_.emplace_back(std::thread([&]() { + create_buffer_thread_.emplace_back([&]() { while (true) { stock_buffer_.Put(std::make_unique( batch_, max_num_players_, specs_, is_player_state_)); @@ -93,7 +93,7 @@ class StateBufferQueue { break; } } - })); + }); } } diff --git a/envpool/core/xla.h b/envpool/core/xla.h deleted file mode 100644 index f4d9afff..00000000 --- a/envpool/core/xla.h +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright 2021 Garena Online Private Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ENVPOOL_CORE_XLA_H_ -#define ENVPOOL_CORE_XLA_H_ - -#include - -#include -#include -#include -#include -#include -#include - -#include "envpool/core/array.h" -#include "envpool/core/xla_template.h" - -template -constexpr bool is_container_v = false; // NOLINT -template -constexpr bool is_container_v> = true; // NOLINT -template -constexpr bool HasContainerType(std::tuple /*unused*/) { - return (is_container_v || ...); -} -bool HasDynamicDim(const std::vector& shape) { - return std::any_of(shape.begin() + 1, shape.end(), - [](int s) { return s == -1; }); -} -template -bool HasDynamicDim(const std::tuple& state_spec) { - bool dyn = false; - std::apply([&](auto&&... spec) { dyn = (HasDynamicDim(spec.shape) || ...); }, - state_spec); - return dyn; -} - -template -Array CpuBufferToArray(const void* buffer, ::Spec spec, int batch_size, - int max_num_players) { - if (!spec.shape.empty() && - spec.shape[0] == -1) { // If first dim is max_num_players - spec.shape[0] = max_num_players * batch_size; - } else { - spec = spec.Batch(batch_size); - } - Array ret(spec); - ret.Assign(reinterpret_cast(buffer), ret.size); - return ret; -} - -template -Array GpuBufferToArray(cudaStream_t stream, const void* buffer, - ::Spec spec, int batch_size, - int max_num_players) { - if (!spec.shape.empty() && - spec.shape[0] == -1) { // If first dim is max_num_players - spec.shape[0] = max_num_players * batch_size; - } else { - spec = spec.Batch(batch_size); - } - Array ret(spec); - cudaMemcpyAsync(ret.Data(), buffer, ret.size * ret.element_size, - cudaMemcpyDeviceToHost, stream); - return ret; -} - -template -::Spec NormalizeSpec(const ::Spec& spec, int batch_size, - int max_num_players) { - std::vector shape({0}); - if (!spec.shape.empty() && spec.shape[0] == -1) { - shape[0] = batch_size * max_num_players; - shape.insert(shape.end(), spec.shape.begin() + 1, spec.shape.end()); - } else { - shape[0] = batch_size; - shape.insert(shape.end(), spec.shape.begin(), spec.shape.end()); - } - return ::Spec(shape); -} - -/** - * If Spec is a container, the xla interface should be disabled. - */ -template -::Spec NormalizeSpec(const ::Spec>& spec, int batch_size, - int max_num_players) { - std::vector shape({0}); - if (!spec.shape.empty() && spec.shape[0] == -1) { - shape[0] = batch_size * max_num_players; - shape.insert(shape.end(), spec.shape.begin() + 1, spec.shape.end()); - } else { - shape[0] = batch_size; - shape.insert(shape.end(), spec.shape.begin(), spec.shape.end()); - } - return ::Spec(shape); -} - -template -struct XlaSend { - using In = - std::array>; - using Out = std::array; - - static decltype(auto) InSpecs(EnvPool* envpool) { - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - return std::apply( - [&](auto&&... s) { - return std::make_tuple( - NormalizeSpec(s, batch_size, max_num_players)...); - }, - envpool->spec.action_spec.AllValues()); - } - - static decltype(auto) OutSpecs(EnvPool* envpool) { return std::tuple<>(); } - - static void Cpu(EnvPool* envpool, const In& in, const Out& out) { - std::vector action; - action.reserve(std::tuple_size_v); - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - auto action_spec = envpool->spec.action_spec.AllValues(); - std::size_t index = 0; - std::apply( - [&](auto&&... spec) { - ((action.emplace_back(CpuBufferToArray(in[index++], spec, batch_size, - max_num_players))), - ...); - }, - action_spec); - envpool->Send(action); - } - - static void Gpu(EnvPool* envpool, cudaStream_t stream, const In& in, - const Out& out) { - std::vector action; - action.reserve(std::tuple_size_v); - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - auto action_spec = envpool->spec.action_spec.AllValues(); - std::size_t index = 0; - std::apply( - [&](auto&&... spec) { - ((action.emplace_back(GpuBufferToArray(stream, in[index++], spec, - batch_size, max_num_players))), - ...); - }, - action_spec); - cudaStreamSynchronize(stream); - envpool->Send(action); - } -}; - -template -struct XlaRecv { - using In = std::array; - using Out = - std::array>; - - static decltype(auto) InSpecs(EnvPool* envpool) { return std::tuple<>(); } - - static decltype(auto) OutSpecs(EnvPool* envpool) { - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - return std::apply( - [&](auto&&... s) { - return std::make_tuple( - NormalizeSpec(s, batch_size, max_num_players)...); - }, - envpool->spec.state_spec.AllValues()); - } - - static void Cpu(EnvPool* envpool, const In& in, const Out& out) { - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - std::vector recv = envpool->Recv(); - for (std::size_t i = 0; i < recv.size(); ++i) { - CHECK_LE(recv[i].Shape(0), (std::size_t)batch_size * max_num_players); - std::memcpy(out[i], recv[i].Data(), recv[i].size * recv[i].element_size); - } - } - - static void Gpu(EnvPool* envpool, cudaStream_t stream, const In& in, - const Out& out) { - int batch_size = envpool->spec.config["batch_size"_]; - int max_num_players = envpool->spec.config["max_num_players"_]; - std::vector recv = envpool->Recv(); - for (std::size_t i = 0; i < recv.size(); ++i) { - CHECK_LE(recv[i].Shape(0), (std::size_t)batch_size * max_num_players); - cudaMemcpyAsync(out[i], recv[i].Data(), - recv[i].size * recv[i].element_size, - cudaMemcpyHostToDevice, stream); - } - } -}; - -#endif // ENVPOOL_CORE_XLA_H_ diff --git a/envpool/core/xla_template.h b/envpool/core/xla_template.h deleted file mode 100644 index 2da813db..00000000 --- a/envpool/core/xla_template.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright 2022 Garena Online Private Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ENVPOOL_CORE_XLA_TEMPLATE_H_ -#define ENVPOOL_CORE_XLA_TEMPLATE_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace py = pybind11; - -template -static auto SpecToTuple(const Spec& spec) { - return std::make_tuple(py::dtype::of(), spec.shape); -} - -template -void ToArray(const void** raw, std::array* array) { - int i = 0; - std::apply([&](auto&&... a) { ((a = const_cast(raw[i++])), ...); }, - *array); -} - -template -void ToArray(void** raw, std::array* array) { - int i = 0; - std::apply([&](auto&&... a) { ((a = raw[i++]), ...); }, *array); -} - -template -struct CustomCall { - using InSpecs = - typename std::invoke_result::type; - using OutSpecs = - typename std::invoke_result::type; - using In = std::array>; - using Out = std::array>; - - static py::bytes Handle(Class* obj) { - return py::bytes( - std::string(reinterpret_cast(&obj), sizeof(Class*))); - } - - static void Cpu(void* out, const void** in) { - Class* obj = *reinterpret_cast(const_cast(in[0])); - in += 1; - In in_arr; - Out out_arr; - ToArray(in, &in_arr); - if (std::tuple_size::value == 0) { - std::memcpy(out, &obj, sizeof(Class*)); - } else { - void** outs = reinterpret_cast(out); - std::memcpy(outs[0], &obj, sizeof(Class*)); - ToArray(outs + 1, &out_arr); - } - CC::Cpu(obj, in_arr, out_arr); - } - - static void Gpu(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - Class* obj = *reinterpret_cast(const_cast(opaque)); - buffers += 1; - In in_arr; - Out out_arr; - ToArray(buffers, &in_arr); - buffers += std::tuple_size::value; - buffers += 1; - ToArray(buffers, &out_arr); - CC::Gpu(obj, stream, in_arr, out_arr); - } - - static auto Specs(Class* obj) { - auto handle_spec = - std::make_tuple(SpecToTuple(Spec({sizeof(Class*)}))); - auto in_specs = CC::InSpecs(obj); - auto in = std::apply( - [&](auto&&... a) { return std::make_tuple(SpecToTuple(a)...); }, - in_specs); - auto out_specs = CC::OutSpecs(obj); - auto out = std::apply( - [&](auto&&... a) { return std::make_tuple(SpecToTuple(a)...); }, - out_specs); - return std::make_tuple(std::tuple_cat(handle_spec, in), - std::tuple_cat(handle_spec, out)); - } - - static auto Capsules() { - return std::make_tuple( - py::capsule(reinterpret_cast(Cpu), "xla._CUSTOM_CALL_TARGET"), - py::capsule(reinterpret_cast(Gpu), "xla._CUSTOM_CALL_TARGET")); - } - - static auto Xla(Class* obj) { - return std::make_tuple(Handle(obj), Specs(obj), Capsules()); - } -}; - -#endif // ENVPOOL_CORE_XLA_TEMPLATE_H_ diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index fe92bd6c..03f0f52f 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@envpool//third_party:common.bzl", "copy_directory") load("@pip_requirements//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package(default_visibility = ["//visibility:public"]) -genrule( - name = "gen_mujoco_gym_xml", - srcs = ["@mujoco_gym_xml"], - outs = ["assets_gym"], - cmd = "mkdir -p $(OUTS) && cp -r $(SRCS) $(OUTS)", +copy_directory( + name = "assets_gym", + src = "@mujoco_gym_xml", + out = "", ) -genrule( - name = "gen_mujoco_dmc_xml", - srcs = ["@mujoco_dmc_xml"], - outs = ["assets_dmc"], - cmd = "mkdir -p $(OUTS) && cp -r $(SRCS) $(OUTS)", +copy_directory( + name = "assets_dmc", + src = "@mujoco_dmc_xml", + out = "", ) genrule( @@ -55,7 +54,7 @@ cc_library( "gym/walker2d.h", ], data = [ - ":gen_mujoco_gym_xml", + ":assets_gym", ], deps = [ "//envpool/core:async_envpool", @@ -99,7 +98,7 @@ cc_library( "dmc/utils.h", "dmc/walker.h", ], - data = [":gen_mujoco_dmc_xml"], + data = [":assets_dmc"], deps = [ "//envpool/core:async_envpool", "@mujoco//:mujoco_lib", @@ -122,7 +121,7 @@ py_library( name = "mujoco_dmc", srcs = ["dmc/__init__.py"], data = [ - ":gen_mujoco_dmc_xml", + ":assets_dmc", ":gen_mujoco_so", ":mujoco_dmc_envpool.so", ], @@ -133,7 +132,7 @@ py_library( name = "mujoco_gym", srcs = ["gym/__init__.py"], data = [ - ":gen_mujoco_gym_xml", + ":assets_gym", ":gen_mujoco_so", ":mujoco_gym_envpool.so", ], diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 23a75264..7188a65a 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -20,6 +20,7 @@ #include #include +#include #include #include #include diff --git a/envpool/pip.bzl b/envpool/pip.bzl deleted file mode 100644 index b6323108..00000000 --- a/envpool/pip.bzl +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2021 Garena Online Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""EnvPool pip requirements initialization, this is loaded in WORKSPACE.""" - -load("@rules_python//python:pip.bzl", "pip_install") - -def workspace(): - """Configure pip requirements.""" - - if "pip_requirements" not in native.existing_rules().keys(): - pip_install( - name = "pip_requirements", - python_interpreter = "python3", - # default timeout value is 600, change it if you failed. - # timeout = 3600, - quiet = False, - requirements = "@envpool//third_party/pip_requirements:requirements.txt", - # extra_pip_args = ["--extra-index-url", "https://mirrors.aliyun.com/pypi/simple"], - ) diff --git a/envpool/python/BUILD b/envpool/python/BUILD index f95b1573..849d82ad 100644 --- a/envpool/python/BUILD +++ b/envpool/python/BUILD @@ -72,27 +72,6 @@ py_library( ], ) -py_library( - name = "xla_template", - srcs = ["xla_template.py"], - deps = [ - requirement("jax"), - ], -) - -py_library( - name = "lax", - srcs = ["lax.py"], - deps = [ - requirement("jax"), - requirement("dm-env"), - requirement("numpy"), - requirement("absl-py"), - ":protocol", - ":xla_template", - ], -) - py_library( name = "dm_envpool", srcs = ["dm_envpool.py"], @@ -102,7 +81,6 @@ py_library( requirement("numpy"), ":data", ":envpool", - ":lax", ":utils", ], ) @@ -117,7 +95,6 @@ py_library( requirement("numpy"), ":data", ":envpool", - ":lax", ":utils", ], ) @@ -132,7 +109,6 @@ py_library( requirement("numpy"), ":data", ":envpool", - ":lax", ":utils", ], ) diff --git a/envpool/python/lax.py b/envpool/python/lax.py deleted file mode 100644 index bb136167..00000000 --- a/envpool/python/lax.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2022 Garena Online Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Provide xla mixin for envpool.""" - -from abc import ABC -from typing import Any, Callable, Dict, Optional, Tuple, Union - -from dm_env import TimeStep -from jax import numpy as jnp - -from .xla_template import make_xla - - -class XlaMixin(ABC): - """Mixin to provide XLA for envpool class.""" - - def xla(self: Any) -> Tuple[Any, Callable, Callable, Callable]: - """Return the XLA version of send/recv/step functions.""" - _handle, _recv, _send = make_xla(self) - - def recv(handle: jnp.ndarray) -> Union[TimeStep, Tuple]: - ret = _recv(handle) - new_handle = ret[0] - state_list = ret[1:] - return new_handle, self._to(state_list, reset=False, return_info=True) - - def send( - handle: jnp.ndarray, - action: Union[Dict[str, Any], jnp.ndarray], - env_id: Optional[jnp.ndarray] = None - ) -> Any: - action = self._from(action, env_id) - self._check_action(action) - return _send(handle, *action) - - def step( - handle: jnp.ndarray, - action: Union[Dict[str, Any], jnp.ndarray], - env_id: Optional[jnp.ndarray] = None - ) -> Any: - return recv(send(handle, action, env_id)) - - return _handle, recv, send, step diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py deleted file mode 100644 index 4238f945..00000000 --- a/envpool/python/xla_template.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2022 Garena Online Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""xla template on python side.""" - -from collections import namedtuple -from functools import partial -from typing import Any, Callable, List, Tuple, Union - -import numpy as np -from jax import core, dtypes -from jax import numpy as jnp -from jax.core import ShapedArray -from jax.interpreters import xla -from jax.lib import xla_client - - -def _shape_with_layout( - specs: Tuple[Tuple[List[int], Any], ...] -) -> Tuple[xla_client.Shape, ...]: - return tuple( - xla_client.Shape - .array_shape(dtype, shape, tuple(range(len(shape) - - 1, -1, -1))) if len(shape) > - 0 else xla_client.Shape.scalar_shape(dtype) for shape, dtype in specs - ) - - -def _normalize_specs( - specs: Tuple[Tuple[Any, List[int]], ...] -) -> Tuple[Tuple[List[int], Any], ...]: - return tuple( - (shape, dtypes.canonicalize_dtype(dtype)) for dtype, shape in specs - ) - - -def _make_xla_function( - obj: Any, handle: bytes, name: str, specs: Tuple[Tuple[Any], Tuple[Any]], - capsules: Tuple[Any, Any] -) -> Callable: - in_specs, out_specs = specs - in_specs = _normalize_specs(in_specs) - out_specs = _normalize_specs(out_specs) - cpu_capsule, gpu_capsule = capsules - xla_client.register_custom_call_target( - f"{type(obj).__name__}_{id(obj)}_{name}_cpu".encode(), - cpu_capsule, - platform="cpu" - ) - xla_client.register_custom_call_target( - f"{type(obj).__name__}_{id(obj)}_{name}_gpu".encode(), - gpu_capsule, - platform="gpu", - ) - - def abstract( - *args: List[jnp.ndarray] - ) -> Union[ShapedArray, Tuple[ShapedArray, ...]]: - if len(out_specs) > 1: - return tuple(ShapedArray(*spec) for spec in out_specs) - else: - return ShapedArray(*out_specs[0]) - - def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: - output_shape_with_layout = _shape_with_layout(out_specs) - if len(out_specs) == 1: - output_shape = output_shape_with_layout[0] - else: - output_shape = xla_client.Shape.tuple_shape(output_shape_with_layout) - return xla_client.ops.CustomCallWithLayout( - c, - f"{type(obj).__name__}_{id(obj)}_{name}_{platform}".encode(), - operands=args, - operand_shapes_with_layout=_shape_with_layout(in_specs), - shape_with_layout=output_shape, - opaque=handle, - has_side_effect=True, - ) - - prim = core.Primitive(f"{type(obj).__name__}_{id(obj)}_{name}") - prim.multiple_results = (len(out_specs) > 1) - prim.def_impl(partial(xla.apply_primitive, prim)) - prim.def_abstract_eval(abstract) - xla.backend_specific_translations["cpu"][prim] = partial( - translation, platform="cpu" - ) - xla.backend_specific_translations["gpu"][prim] = partial( - translation, platform="gpu" - ) - - def call(*args: Any) -> Any: - return prim.bind(*args) - - return call - - -def make_xla(obj: Any) -> Any: - """Return callables that can be jitted in a namedtuple. - - Args: - obj: The object that has a `_xla` function. - All instances of envpool has a `_xla` function that returns - the necessary information for creating jittable send/recv functions. - - Returns: - XlaFunctions: A namedtuple, the first element is a handle - representing `obj`. The rest of the elements are jittable functions. - """ - xla_native = obj._xla() - method_names = [] - methods = [] - for name, (handle, specs, capsules) in xla_native: - method_names.append(name) - methods.append(_make_xla_function(obj, handle, name, specs, capsules)) - XlaFunctions = namedtuple( # type: ignore - "XlaFunctions", - ["handle", *method_names] - ) - return XlaFunctions( # type: ignore - np.frombuffer(handle, dtype=np.uint8), - *methods - ) diff --git a/envpool/sokoban/BUILD b/envpool/sokoban/BUILD index b79ee42f..b0117979 100644 --- a/envpool/sokoban/BUILD +++ b/envpool/sokoban/BUILD @@ -17,69 +17,80 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package(default_visibility = ["//visibility:public"]) -py_library( - name = "sokoban", - srcs = ["__init__.py"], - data = [":sokoban_envpool.so"], - deps = ["//envpool/python:api"], +# 1. Level Loader +cc_library( + name = "level_loader_lib", + srcs = ["level_loader.cc"], + hdrs = [ + "level_loader.h", + "utils.h", + ], ) -py_library( - name = "registration", - srcs = ["registration.py"], +# 2. Sokoban Node +cc_library( + name = "sokoban_node_lib", + srcs = ["sokoban_node.cc"], + hdrs = ["sokoban_node.h"], deps = [ - "//envpool:registration", + ":level_loader_lib", + "//third_party/astar_stl:astar_stl_h", ], ) +# 4. Sokoban Env cc_library( - name = "sokoban_envpool_h", + name = "sokoban_envpool_lib", + srcs = [ + "sokoban_envpool.cc", + ], hdrs = [ - "level_loader.h", "sokoban_envpool.h", "utils.h", ], deps = [ + ":level_loader_lib", "//envpool/core:async_envpool", "//envpool/core:env", "//envpool/core:env_spec", + "//envpool/core:py_envpool", ], ) +# 5. astar_log cc_library( - name = "sokoban_node_h", - hdrs = [ - "level_loader.h", - "sokoban_node.h", - "utils.h", + name = "astar_log_lib", + srcs = [ + "astar_log.cc", + ], + deps = [ + ":level_loader_lib", + ":sokoban_node_lib", + "//third_party/astar_stl:astar_stl_h", ], - deps = ["//third_party/astar_stl:astar_stl_h"], ) +# The actual CLI tool with the main: cc_binary( name = "astar_log", srcs = [ - "astar_log.cc", - "level_loader.cc", - "sokoban_node.cc", + "astar_log_main.cc", ], deps = [ - ":sokoban_node_h", + ":astar_log_lib", ], ) cc_binary( name = "astar_log_level", - srcs = [ - "astar_log_level.cc", - "level_loader.cc", - "sokoban_node.cc", - ], + srcs = ["astar_log_level.cc"], deps = [ - ":sokoban_node_h", + ":level_loader_lib", + ":sokoban_node_lib", ], ) +# 6. Python tests py_test( name = "test", srcs = ["sokoban_py_envpool_test.py"], @@ -93,17 +104,44 @@ py_test( ], ) -pybind_extension( - name = "sokoban_envpool", +# Now your test doesn't accidentally invoke that main. +cc_test( + name = "astar_log_test", srcs = [ - "level_loader.cc", - "sokoban_envpool.cc", + "astar_log_test.cc", + ], + deps = [ + ":astar_log_lib", + "@com_google_googletest//:gtest_main", ], +) + +# 7. Python code and extension +py_library( + name = "sokoban", + srcs = ["__init__.py"], + data = [":sokoban_envpool.so"], + deps = [ + "//envpool/python:api", + ], +) + +py_library( + name = "registration", + srcs = ["registration.py"], + deps = [ + "//envpool:registration", + ], +) + +pybind_extension( + name = "sokoban_envpool", + srcs = ["sokoban_envpool.cc"], linkopts = [ "-ldl", ], deps = [ - ":sokoban_envpool_h", + ":sokoban_envpool_lib", "//envpool/core:py_envpool", ], ) diff --git a/envpool/sokoban/astar_log.cc b/envpool/sokoban/astar_log.cc index 75be9ce9..f5c0e237 100644 --- a/envpool/sokoban/astar_log.cc +++ b/envpool/sokoban/astar_log.cc @@ -20,10 +20,10 @@ namespace sokoban { void RunAStar(const std::string& level_file_name, - const std::string& log_file_name, int total_levels_to_run = 1000, - int fsa_limit = 1000000) { + const std::string& log_file_name, int total_levels_to_run, + int fsa_limit) { std::cout << "Running A* on file " << level_file_name << " and logging to " - << log_file_name << " with fsa_limit " << fsa_limit << std::endl; + << log_file_name << " with fsa_limit " << fsa_limit << '\n'; const int dim_room = 10; int level_idx = 0; LevelLoader level_loader(level_file_name, true, -1); @@ -33,7 +33,7 @@ void RunAStar(const std::string& level_file_name, std::ifstream log_file_in(log_file_name); // check if the file is empty if (log_file_in.peek() == std::ifstream::traits_type::eof()) { - log_file_out << "Level,Actions,Steps,SearchSteps" << std::endl; + log_file_out << "Level,Actions,Steps,SearchSteps" << '\n'; } else { // skip levels that have already been run std::string line; std::getline(log_file_in, line); // skip header @@ -49,7 +49,7 @@ void RunAStar(const std::string& level_file_name, while (level_idx < total_levels_to_run) { std::AStarSearch astarsearch(fsa_limit); - std::cout << "Running level " << level_idx << std::endl; + std::cout << "Running level " << level_idx << '\n'; SokobanLevel level = level_loader.GetLevel(gen).data; SokobanNode node_start(dim_room, level, false); @@ -57,7 +57,7 @@ void RunAStar(const std::string& level_file_name, astarsearch.SetStartAndGoalStates(node_start, node_end); unsigned int search_state; unsigned int search_steps = 0; - std::cout << "Starting search" << std::endl; + std::cout << "Starting search" << '\n'; do { search_state = astarsearch.SearchStep(); search_steps++; @@ -92,10 +92,9 @@ void RunAStar(const std::string& level_file_name, prev_y = curr_y; } if (!correct_solution) { - loglinestream << ",INCORRECT_SOLUTION_FOUND," << search_steps - << std::endl; + loglinestream << ",INCORRECT_SOLUTION_FOUND," << search_steps << '\n'; } else { - loglinestream << "," << steps << "," << search_steps << std::endl; + loglinestream << "," << steps << "," << search_steps << '\n'; } log_file_out << loglinestream.str(); astarsearch.FreeSolutionNodes(); @@ -103,55 +102,30 @@ void RunAStar(const std::string& level_file_name, } else if (search_state == std::AStarSearch::SEARCH_STATE_FAILED) { log_file_out << level_idx << "," - << "SEARCH_STATE_FAILED,-1," << search_steps << std::endl; + << "SEARCH_STATE_FAILED,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_NOT_INITIALISED) { log_file_out << level_idx << "," << "SEARCH_STATE_NOT_INITIALISED,-1," << search_steps - << std::endl; + << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_SEARCHING) { log_file_out << level_idx << "," - << "SEARCH_STATE_SEARCHING,-1," << search_steps << std::endl; + << "SEARCH_STATE_SEARCHING,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_OUT_OF_MEMORY) { log_file_out << level_idx << "," - << "SEARCH_STATE_OUT_OF_MEMORY,-1," << search_steps - << std::endl; + << "SEARCH_STATE_OUT_OF_MEMORY,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_INVALID) { log_file_out << level_idx << "," - << "SEARCH_STATE_INVALID,-1," << search_steps << std::endl; + << "SEARCH_STATE_INVALID,-1," << search_steps << '\n'; } else { log_file_out << level_idx << "," - << "UNKNOWN,-1," << search_steps << std::endl; + << "UNKNOWN,-1," << search_steps << '\n'; } log_file_out.flush(); level_idx++; } } } // namespace sokoban - -int main(int argc, char** argv) { - int total_levels_to_run = 1000; - int fsa_limit = 1000000; - if (argc < 3) { - std::cout - << "Usage: " << argv[0] - << " level_file_name log_file_name [total_levels_to_run] [fsa_limit]" - << std::endl; - return 1; - } - std::string level_file_name = argv[1]; - std::string log_file_name = argv[2]; - if (argc > 3) { - total_levels_to_run = std::stoi(argv[3]); - } - if (argc > 4) { - fsa_limit = std::stoi(argv[4]); - } - - sokoban::RunAStar(level_file_name, log_file_name, total_levels_to_run, - fsa_limit); - return 0; -} diff --git a/envpool/sokoban/astar_log_level.cc b/envpool/sokoban/astar_log_level.cc index 91b59620..fd6f1834 100644 --- a/envpool/sokoban/astar_log_level.cc +++ b/envpool/sokoban/astar_log_level.cc @@ -24,7 +24,7 @@ void RunAStar(const std::string& level_file_name, int fsa_limit = 1000000) { std::cout << "Running A* on file " << level_file_name << " and logging to " << log_file_name << " with fsa_limit " << fsa_limit << "on level " - << level_to_run << std::endl; + << level_to_run << '\n'; const int dim_room = 10; int level_idx = 0; LevelLoader level_loader(level_file_name, true, -1); @@ -40,7 +40,7 @@ void RunAStar(const std::string& level_file_name, level_idx++; } std::AStarSearch astarsearch(fsa_limit); - std::cout << "Running level " << level_idx << std::endl; + std::cout << "Running level " << level_idx << '\n'; SokobanLevel level = level_loader.GetLevel(gen).data; SokobanNode node_start(dim_room, level, false); @@ -48,7 +48,7 @@ void RunAStar(const std::string& level_file_name, astarsearch.SetStartAndGoalStates(node_start, node_end); unsigned int search_state; unsigned int search_steps = 0; - std::cout << "Starting search" << std::endl; + std::cout << "Starting search" << '\n'; do { search_state = astarsearch.SearchStep(); search_steps++; @@ -83,10 +83,9 @@ void RunAStar(const std::string& level_file_name, prev_y = curr_y; } if (!correct_solution) { - loglinestream << ",INCORRECT_SOLUTION_FOUND," << search_steps - << std::endl; + loglinestream << ",INCORRECT_SOLUTION_FOUND," << search_steps << '\n'; } else { - loglinestream << "," << steps << "," << search_steps << std::endl; + loglinestream << "," << steps << "," << search_steps << '\n'; } log_file_out << loglinestream.str(); astarsearch.FreeSolutionNodes(); @@ -94,48 +93,54 @@ void RunAStar(const std::string& level_file_name, } else if (search_state == std::AStarSearch::SEARCH_STATE_FAILED) { log_file_out << level_idx << "," - << "SEARCH_STATE_FAILED,-1," << search_steps << std::endl; + << "SEARCH_STATE_FAILED,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_NOT_INITIALISED) { log_file_out << level_idx << "," - << "SEARCH_STATE_NOT_INITIALISED,-1," << search_steps - << std::endl; + << "SEARCH_STATE_NOT_INITIALISED,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_SEARCHING) { log_file_out << level_idx << "," - << "SEARCH_STATE_SEARCHING,-1," << search_steps << std::endl; + << "SEARCH_STATE_SEARCHING,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_OUT_OF_MEMORY) { log_file_out << level_idx << "," - << "SEARCH_STATE_OUT_OF_MEMORY,-1," << search_steps - << std::endl; + << "SEARCH_STATE_OUT_OF_MEMORY,-1," << search_steps << '\n'; } else if (search_state == std::AStarSearch::SEARCH_STATE_INVALID) { log_file_out << level_idx << "," - << "SEARCH_STATE_INVALID,-1," << search_steps << std::endl; + << "SEARCH_STATE_INVALID,-1," << search_steps << '\n'; } else { log_file_out << level_idx << "," - << "UNKNOWN,-1," << search_steps << std::endl; + << "UNKNOWN,-1," << search_steps << '\n'; } log_file_out.flush(); } } // namespace sokoban int main(int argc, char** argv) { - int fsa_limit = 1000000; - if (argc < 4) { - std::cout << "Usage: " << argv[0] - << " level_file_name log_file_name level_to_run [fsa_limit]" - << std::endl; + try { + int fsa_limit = 1000000; + if (argc < 4) { + std::cout << "Usage: " << argv[0] + << " level_file_name log_file_name level_to_run [fsa_limit]" + << '\n'; + return 1; + } + std::string level_file_name = argv[1]; + std::string log_file_name = argv[2]; + int level_to_run = std::stoi(argv[3]); + if (argc > 4) { + fsa_limit = std::stoi(argv[4]); + } + + sokoban::RunAStar(level_file_name, log_file_name, level_to_run, fsa_limit); + return 0; + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << '\n'; + return 1; + } catch (...) { + std::cerr << "Unknown error occurred\n"; return 1; } - std::string level_file_name = argv[1]; - std::string log_file_name = argv[2]; - int level_to_run = std::stoi(argv[3]); - if (argc > 4) { - fsa_limit = std::stoi(argv[4]); - } - - sokoban::RunAStar(level_file_name, log_file_name, level_to_run, fsa_limit); - return 0; } diff --git a/envpool/sokoban/astar_log_main.cc b/envpool/sokoban/astar_log_main.cc new file mode 100644 index 00000000..d1128ca3 --- /dev/null +++ b/envpool/sokoban/astar_log_main.cc @@ -0,0 +1,47 @@ +// Copyright 2023-2024 FAR AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "envpool/sokoban/sokoban_node.h" + +namespace sokoban { +// forward-declare RunAStar +void RunAStar(const std::string& level_file_name, + const std::string& log_file_name, int total_levels_to_run = 1000, + int fsa_limit = 1000000); +} // namespace sokoban + +int main(int argc, char** argv) { + int total_levels_to_run = 1000; + int fsa_limit = 1000000; + if (argc < 3) { + std::cout + << "Usage: " << argv[0] + << " level_file_name log_file_name [total_levels_to_run] [fsa_limit]\n"; + return 1; + } + std::string level_file_name = argv[1]; + std::string log_file_name = argv[2]; + if (argc > 3) { + total_levels_to_run = std::stoi(argv[3]); + } + if (argc > 4) { + fsa_limit = std::stoi(argv[4]); + } + sokoban::RunAStar(level_file_name, log_file_name, total_levels_to_run, + fsa_limit); + return 0; +} diff --git a/envpool/sokoban/astar_log_test.cc b/envpool/sokoban/astar_log_test.cc new file mode 100644 index 00000000..3948cdf0 --- /dev/null +++ b/envpool/sokoban/astar_log_test.cc @@ -0,0 +1,50 @@ +// Copyright 2023-2024 FAR AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "envpool/sokoban/sokoban_node.h" + +namespace sokoban { + +// Declare the RunAStar function from astar_log.cc +void RunAStar(const std::string& level_file_name, + const std::string& log_file_name, int total_levels_to_run = 1000, + int fsa_limit = 1000000); + +TEST(AStarLogTest, ValidateSolution) { + // Create a temporary file for the log + std::string level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"; + std::string log_file_name = testing::TempDir() + "/test_log_file.csv"; + + // Run A* on the first level only + RunAStar(level_file_name, log_file_name, 1); + + // Read the log file and check for the expected solution + std::ifstream log_file(log_file_name); + std::string line; + + // Skip header + ASSERT_TRUE(std::getline(log_file, line)); + // Read the solution line + ASSERT_TRUE(std::getline(log_file, line)); + + const std::string expected = "0,222200001112330322210,21,1380"; + EXPECT_EQ(line, expected); +} + +} // namespace sokoban diff --git a/envpool/sokoban/level_loader.cc b/envpool/sokoban/level_loader.cc index e2803663..18e9df17 100644 --- a/envpool/sokoban/level_loader.cc +++ b/envpool/sokoban/level_loader.cc @@ -15,6 +15,7 @@ #include "level_loader.h" #include +#include #include #include #include @@ -67,7 +68,7 @@ void AddLine(SokobanLevel& level, const std::string& line) { if ((start != '#') || (end != '#')) { std::stringstream msg; msg << "Line '" << line << "' does not start (" << start << ") and end (" - << end << ") with '#', as it should." << std::endl; + << end << ") with '#', as it should." << '\n'; throw std::runtime_error(msg.str()); } for (const char& r : line) { @@ -90,7 +91,7 @@ void AddLine(SokobanLevel& level, const std::string& line) { default: std::stringstream msg; msg << "Line '" << line << "'has character '" << r - << "' which is not in the valid set '#@$. '." << std::endl; + << "' which is not in the valid set '#@$. '." << '\n'; throw std::runtime_error(msg.str()); break; } @@ -107,7 +108,7 @@ void PrintLevel(std::ostream& os, const SokobanLevel& vec) { for (size_t i = 0; i < vec.size(); i++) { os << kPrintLevelKey.at(vec.at(i)); if ((i + 1) % dim_room == 0) { - os << std::endl; + os << '\n'; } } } @@ -153,7 +154,7 @@ void LevelLoader::LoadFile(std::mt19937& gen) { if (line.length() != dim_room) { std::stringstream msg; msg << "Irregular line '" << line - << "' does not match dim_room=" << dim_room << std::endl; + << "' does not match dim_room=" << dim_room << '\n'; throw std::runtime_error(msg.str()); } AddLine(cur_level, line); @@ -162,11 +163,10 @@ void LevelLoader::LoadFile(std::mt19937& gen) { if (cur_level.size() != dim_room * dim_room) { std::stringstream msg; msg << "Room is not square: " << cur_level.size() << " != " << dim_room - << "x" << dim_room << std::endl; + << "x" << dim_room << '\n'; throw std::runtime_error(msg.str()); } - levels_.emplace_back( - std::make_pair(cur_level_idx++, std::move(cur_level))); + levels_.emplace_back(cur_level_idx++, std::move(cur_level)); } } if (!load_sequentially_) { @@ -174,18 +174,18 @@ void LevelLoader::LoadFile(std::mt19937& gen) { } if (levels_.empty()) { std::stringstream msg; - msg << "No levels loaded from file '" << file_path << std::endl; + msg << "No levels loaded from file '" << file_path << '\n'; throw std::runtime_error(msg.str()); } if (verbose >= 1) { std::cout << "***Loaded " << levels_.size() << " levels from " << file_path - << std::endl; + << '\n'; if (verbose >= 2) { PrintLevel(std::cout, levels_.at(0).second); - std::cout << std::endl; + std::cout << '\n'; PrintLevel(std::cout, levels_.at(1).second); - std::cout << std::endl; + std::cout << '\n'; } } } @@ -193,7 +193,7 @@ void LevelLoader::LoadFile(std::mt19937& gen) { TaggedSokobanLevel LevelLoader::GetLevel(std::mt19937& gen) { if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) { // std::cerr << "Warning: All levels loaded. Looping around now." << - // std::endl; + // '\n'; levels_loaded_ = 0; cur_file_ = level_file_paths_.begin(); cur_level_file_ = -1; @@ -203,8 +203,8 @@ TaggedSokobanLevel LevelLoader::GetLevel(std::mt19937& gen) { } // Load new files until the current level index is within the loaded levels // this is required when new files have lesser levels than the number of envs - while (cur_level_ >= levels_.size()) { - cur_level_ -= levels_.size(); + while (cur_level_ >= std::size(levels_)) { + cur_level_ -= std::size(levels_); LoadFile(gen); } // no need for bound checks since it is checked in the while loop above diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index 7705f958..ee7891a1 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -41,7 +41,7 @@ void SokobanEnv::ResetWithoutWrite() { if (world_.size() != dim_room_ * dim_room_) { std::stringstream msg; msg << "Loaded level is not dim_room x dim_room. world_.size()=" - << world_.size() << ", dim_room_=" << dim_room_ << std::endl; + << world_.size() << ", dim_room_=" << dim_room_ << '\n'; throw std::runtime_error(msg.str()); } unmatched_boxes_ = 0; @@ -193,7 +193,7 @@ void SokobanEnv::WriteState(float reward) { std::stringstream msg; msg << "Obs size and level size are different: obs_size=" << obs.size << "/3, level_size=" << world_.size() << ", dim_room=" << dim_room_ - << std::endl; + << '\n'; throw std::runtime_error(msg.str()); } diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index 6658057e..f1b0b2e3 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -83,14 +83,14 @@ class SokobanEnv : public Env { if (max_num_players_ != spec_.config["max_num_players"_]) { std::stringstream msg; msg << "max_num_players_ != spec_['max_num_players'] " << max_num_players_ - << " != " << spec_.config["max_num_players"_] << std::endl; + << " != " << spec_.config["max_num_players"_] << '\n'; throw std::runtime_error(msg.str()); } if (max_num_players_ != spec.config["max_num_players"_]) { std::stringstream msg; msg << "max_num_players_ != spec['max_num_players'] " << max_num_players_ - << " != " << spec.config["max_num_players"_] << std::endl; + << " != " << spec.config["max_num_players"_] << '\n'; throw std::runtime_error(msg.str()); } } diff --git a/envpool/sokoban/sokoban_node.cc b/envpool/sokoban/sokoban_node.cc index 7be85918..b6ffd68b 100644 --- a/envpool/sokoban/sokoban_node.cc +++ b/envpool/sokoban/sokoban_node.cc @@ -27,7 +27,7 @@ bool SokobanNode::IsSameState(SokobanNode& rhs) const { } void SokobanNode::PrintNodeInfo(std::vector>* goals) { - std::cout << "Action: " << action_from_parent << std::endl; + std::cout << "Action: " << action_from_parent << '\n'; for (int y = 0; y < dim_room; y++) { for (int x = 0; x < dim_room; x++) { bool is_wall = walls->at(x + y * dim_room); @@ -68,7 +68,7 @@ void SokobanNode::PrintNodeInfo(std::vector>* goals) { std::cout << " "; } } - std::cout << std::endl; + std::cout << '\n'; } } diff --git a/envpool/sokoban/sokoban_node.h b/envpool/sokoban/sokoban_node.h index ef789ed2..14c5d45d 100644 --- a/envpool/sokoban/sokoban_node.h +++ b/envpool/sokoban/sokoban_node.h @@ -15,7 +15,10 @@ #ifndef ENVPOOL_SOKOBAN_SOKOBAN_NODE_H_ #define ENVPOOL_SOKOBAN_SOKOBAN_NODE_H_ +#include +#include #include +#include #include #include @@ -54,27 +57,31 @@ class SokobanNode { case kBox: if (!is_goal_node) { total_boxes++; - boxes.emplace_back(std::make_pair(x, y)); + boxes.emplace_back(x, y); } break; case kTarget: if (is_goal_node) { total_boxes++; - boxes.emplace_back(std::make_pair(x, y)); + boxes.emplace_back(x, y); } break; case kBoxOnTarget: total_boxes++; - boxes.emplace_back(std::make_pair(x, y)); + boxes.emplace_back(x, y); break; case kPlayerOnTarget: player_x = x; player_y = y; break; - } - - if (world.at(x + y * dim_room) == kWall) { - walls->at(x + y * dim_room) = true; + case kWall: + walls->at(x + y * dim_room) = true; + case kEmpty: + break; + default: + std::stringstream msg; + msg << "Invalid character in Sokoban level: " << static_cast(world.at(x + y * dim_room)) << '\n'; + throw std::runtime_error(msg.str()); } } } diff --git a/envpool/sokoban/sokoban_py_envpool_test.py b/envpool/sokoban/sokoban_py_envpool_test.py index 4c15779d..cf67d575 100644 --- a/envpool/sokoban/sokoban_py_envpool_test.py +++ b/envpool/sokoban/sokoban_py_envpool_test.py @@ -15,7 +15,6 @@ import glob import re -import subprocess import sys import time from pathlib import Path @@ -160,22 +159,6 @@ def test_envpool_load_sequentially(capfd) -> None: assert lev2 == levels_by_files[i][1][1] -def test_xla() -> None: - num_envs = 10 - env = envpool.make( - "Sokoban-v0", - env_type="dm", - num_envs=num_envs, - batch_size=num_envs, - seed=2346890, - max_episode_steps=60, - reward_step=-0.1, - dim_room=10, - levels_dir="/app/envpool/sokoban/sample_levels", - ) - handle, recv, send, step = env.xla() - - SOLVE_LEVEL_ZERO: str = "222200001112330322210" TINY_COLORS: list[tuple[tuple[int, int, int], str]] = [ ((0, 0, 0), "#"), @@ -335,27 +318,6 @@ def test_load_sequentially_with_multiple_envs() -> None: assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly." -def test_astar_log(tmp_path) -> None: - level_file_name = "/app/envpool/sokoban/sample_levels/small.txt" - log_file_name = tmp_path / "log_file.csv" - subprocess.run( - [ - "/root/go/bin/bazel", f"--output_base={str(tmp_path)}", "run", - "//envpool/sokoban:astar_log", "--", level_file_name, - str(log_file_name), "1" - ], - check=True, - cwd="/app/envpool", - env={ - "HOME": "/root", - "PATH": "/opt/conda/bin:/usr/bin", - "USE_BAZEL_VERSION": "6.4.0", - }, - ) - log = log_file_name.read_text() - assert f"0,{SOLVE_LEVEL_ZERO},21,1380" == log.split("\n")[1] - - def test_sneaky_noop(): """ Even though an action < 0 is not part of the environment, we overload it to @@ -432,5 +394,5 @@ def test_noop_action(): if __name__ == "__main__": - retcode = pytest.main(["-v", __file__]) + retcode = pytest.main(["-v", *sys.argv[1:]]) sys.exit(retcode) diff --git a/envpool/vizdoom/BUILD b/envpool/vizdoom/BUILD index 76ccffa7..c06c5fc3 100644 --- a/envpool/vizdoom/BUILD +++ b/envpool/vizdoom/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@envpool//third_party:common.bzl", "copy_directory") load("@pip_requirements//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") @@ -29,14 +30,18 @@ genrule( cmd = "cp $(SRCS) $(@D)", ) -genrule( - name = "gen_vizdoom_maps", +filegroup( + name = "vizdoom_maps_sources", srcs = [ - "@vizdoom_lib//:vizdoom_maps", "@vizdoom_extra_maps//:vizdoom_maps", + "@vizdoom_lib//:vizdoom_maps", ], - outs = ["maps"], - cmd = "mkdir -p $(OUTS) && cp $(SRCS) $(OUTS)", +) + +copy_directory( + name = "maps", + src = ":vizdoom_maps_sources", + out = "", ) cc_library( @@ -71,7 +76,7 @@ py_library( name = "vizdoom", srcs = ["__init__.py"], data = [ - ":gen_vizdoom_maps", + ":maps", ":vizdoom_envpool.so", "//envpool/vizdoom/bin:freedoom", "//envpool/vizdoom/bin:vizdoom_bin", diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl deleted file mode 100644 index 98c44cde..00000000 --- a/envpool/workspace0.bzl +++ /dev/null @@ -1,433 +0,0 @@ -# Copyright 2021 Garena Online Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""EnvPool workspace initialization, this is loaded in WORKSPACE.""" - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -load("//third_party/cuda:cuda.bzl", "cuda_configure") - -def workspace(): - """Load requested packages.""" - - # we cannot upgrade rules_python because it requires requirements_lock.txt after 0.13.0 - maybe( - http_archive, - name = "rules_python", - sha256 = "b593d13bb43c94ce94b483c2858e53a9b811f6f10e1e0eedc61073bd90e58d9c", - strip_prefix = "rules_python-0.12.0", - urls = [ - "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.12.0.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/bazelbuild/rules_python/0.12.0.tar.gz", - ], - ) - - maybe( - http_archive, - name = "rules_foreign_cc", - sha256 = "476303bd0f1b04cc311fc258f1708a5f6ef82d3091e53fd1977fa20383425a6a", - strip_prefix = "rules_foreign_cc-0.10.1", - urls = [ - "https://github.com/bazelbuild/rules_foreign_cc/archive/0.10.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/bazelbuild/rules_foreign_cc/0.10.1.tar.gz", - ], - ) - - maybe( - http_archive, - name = "pybind11_bazel", - sha256 = "2c466c9b3cca7852b47e0785003128984fcf0d5d61a1a2e4c5aceefd935ac220", - strip_prefix = "pybind11_bazel-2.11.1", - urls = [ - "https://github.com/pybind/pybind11_bazel/archive/refs/tags/v2.11.1.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/pybind/pybind11_bazel/v2.11.1.zip", - ], - ) - - maybe( - http_archive, - name = "pybind11", - build_file = "@pybind11_bazel//:pybind11.BUILD", - sha256 = "d475978da0cdc2d43b73f30910786759d593a9d8ee05b1b6846d1eb16c6d2e0c", - strip_prefix = "pybind11-2.11.1", - urls = [ - "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/pybind/pybind11/v2.11.1.tar.gz", - ], - ) - - maybe( - http_archive, - name = "com_google_absl", - sha256 = "497ebdc3a4885d9209b9bd416e8c3f71e7a1fb8af249f6c2a80b7cbeefcd7e21", - strip_prefix = "abseil-cpp-20230802.1", - urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.1.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/abseil/abseil-cpp/20230802.1.zip", - ], - ) - - maybe( - http_archive, - name = "com_github_gflags_gflags", - sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", - strip_prefix = "gflags-2.2.2", - urls = [ - "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/gflags/gflags/v2.2.2.tar.gz", - ], - ) - - maybe( - http_archive, - name = "com_github_google_glog", - sha256 = "122fb6b712808ef43fbf80f75c52a21c9760683dae470154f02bddfc61135022", - strip_prefix = "glog-0.6.0", - urls = [ - "https://github.com/google/glog/archive/v0.6.0.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/google/glog/v0.6.0.zip", - ], - ) - - maybe( - http_archive, - name = "com_google_googletest", - sha256 = "8ad598c73ad796e0d8280b082cebd82a630d73e73cd3c70057938a6501bba5d7", - strip_prefix = "googletest-1.14.0", - urls = [ - "https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/google/googletest/v1.14.0.tar.gz", - ], - ) - - maybe( - http_archive, - name = "com_justbuchanan_rules_qt", - sha256 = "6b42a58f062b3eea10ada5340cd8f63b47feb986d16794b0f8e0fde750838348", - strip_prefix = "bazel_rules_qt-3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9", - urls = [ - "https://github.com/justbuchanan/bazel_rules_qt/archive/3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/justbuchanan/bazel_rules_qt/3196fcf2e6ee81cf3a2e2b272af3d4259b84fcf9.tar.gz", - ], - ) - - maybe( - http_archive, - name = "glibc_version_header", - sha256 = "57db74f933b7a9ea5c653498640431ce0e52aaef190d6bb586711ec4f8aa2b9e", - strip_prefix = "glibc_version_header-0.1/version_headers/", - urls = [ - "https://github.com/wheybags/glibc_version_header/archive/refs/tags/0.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/wheybags/glibc_version_header/0.1.tar.gz", - ], - build_file = "//third_party/glibc_version_header:glibc_version_header.BUILD", - ) - - maybe( - http_archive, - name = "concurrentqueue", - sha256 = "87fbc9884d60d0d4bf3462c18f4c0ee0a9311d0519341cac7cbd361c885e5281", - strip_prefix = "concurrentqueue-1.0.4", - urls = [ - "https://github.com/cameron314/concurrentqueue/archive/refs/tags/v1.0.4.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/cameron314/concurrentqueue/v1.0.4.tar.gz", - ], - build_file = "//third_party/concurrentqueue:concurrentqueue.BUILD", - ) - - maybe( - http_archive, - name = "threadpool", - sha256 = "18854bb7ecc1fc9d7dda9c798a1ef0c81c2dd331d730c76c75f648189fa0c20f", - strip_prefix = "ThreadPool-9a42ec1329f259a5f4881a291db1dcb8f2ad9040", - urls = [ - "https://github.com/progschj/ThreadPool/archive/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/progschj/ThreadPool/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip", - ], - build_file = "//third_party/threadpool:threadpool.BUILD", - ) - - maybe( - http_archive, - name = "zlib", - sha256 = "ff0ba4c292013dbc27530b3a81e1f9a813cd39de01ca5e0f8bf355702efa593e", - strip_prefix = "zlib-1.3", - urls = [ - "https://github.com/madler/zlib/releases/download/v1.3/zlib-1.3.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/madler/zlib/zlib-1.3.tar.gz", - ], - build_file = "//third_party/zlib:zlib.BUILD", - ) - - maybe( - http_archive, - name = "opencv", - sha256 = "62f650467a60a38794d681ae7e66e3e8cfba38f445e0bf87867e2f2cdc8be9d5", - strip_prefix = "opencv-4.8.1", - urls = [ - "https://github.com/opencv/opencv/archive/refs/tags/4.8.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/opencv/opencv/4.8.1.tar.gz", - ], - build_file = "//third_party/opencv:opencv.BUILD", - ) - - maybe( - http_archive, - name = "pugixml", - sha256 = "610f98375424b5614754a6f34a491adbddaaec074e9044577d965160ec103d2e", - strip_prefix = "pugixml-1.14/src", - urls = [ - "https://github.com/zeux/pugixml/archive/refs/tags/v1.14.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/zeux/pugixml/v1.14.tar.gz", - ], - build_file = "//third_party/pugixml:pugixml.BUILD", - ) - - maybe( - http_archive, - name = "ale", - sha256 = "28960616cd89c18925ced7bbdeec01ab0b2ebd2d8ce5b7c88930e97381b4c3b5", - strip_prefix = "Arcade-Learning-Environment-0.8.1", - urls = [ - "https://github.com/mgbellemare/Arcade-Learning-Environment/archive/refs/tags/v0.8.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/mgbellemare/Arcade-Learning-Environment/v0.8.1.tar.gz", - ], - build_file = "//third_party/ale:ale.BUILD", - ) - - maybe( - http_archive, - name = "atari_roms", - sha256 = "e39e9fc379fe3f336911d928ce0a52e6ff6861258906efc5e849390867ff35f5", - urls = [ - "https://roms8.s3.us-east-2.amazonaws.com/Roms.tar.gz", - "https://cdn.sail.sea.com/sail/Roms.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/atari/Roms.tar.gz", - ], - build_file = "//third_party/atari_roms:atari_roms.BUILD", - ) - - maybe( - http_archive, - name = "libjpeg_turbo", - sha256 = "b3090cd37b5a8b3e4dbd30a1311b3989a894e5d3c668f14cbc6739d77c9402b7", - strip_prefix = "libjpeg-turbo-2.0.5", - urls = [ - "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.5.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/libjpeg-turbo/libjpeg-turbo/2.0.5.tar.gz", - ], - build_file = "//third_party/jpeg:jpeg.BUILD", - ) - - maybe( - http_archive, - name = "nasm", - sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011", - strip_prefix = "nasm-2.13.03", - urls = [ - "https://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/nasm/nasm-2.13.03.tar.bz2", - ], - build_file = "//third_party/nasm:nasm.BUILD", - ) - - maybe( - http_archive, - name = "sdl2", - sha256 = "888b8c39f36ae2035d023d1b14ab0191eb1d26403c3cf4d4d5ede30e66a4942c", - strip_prefix = "SDL2-2.28.4", - urls = [ - "https://www.libsdl.org/release/SDL2-2.28.4.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/libsdl/SDL2-2.28.4.tar.gz", - ], - build_file = "//third_party/sdl2:sdl2.BUILD", - ) - - maybe( - http_archive, - name = "com_github_nelhage_rules_boost", - # sha256 = "2215e6910eb763a971b1f63f53c45c0f2b7607df38c96287666d94d954da8cdc", - strip_prefix = "rules_boost-e60cf50996da9fe769b6e7a31b88c54966ecb191", - urls = [ - "https://github.com/nelhage/rules_boost/archive/e60cf50996da9fe769b6e7a31b88c54966ecb191.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/nelhage/rules_boost/e60cf50996da9fe769b6e7a31b88c54966ecb191.tar.gz", - ], - ) - - maybe( - http_archive, - name = "boost", - build_file = "@com_github_nelhage_rules_boost//:boost.BUILD", - patch_cmds = ["rm -f doc/pdf/BUILD"], - sha256 = "6478edfe2f3305127cffe8caf73ea0176c53769f4bf1585be237eb30798c3b8e", - strip_prefix = "boost_1_83_0", - urls = [ - "https://boostorg.jfrog.io/artifactory/main/release/1.83.0/source/boost_1_83_0.tar.bz2", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/boost/boost_1_83_0.tar.bz2", - ], - ) - - maybe( - http_archive, - name = "freedoom", - sha256 = "f42c6810fc89b0282de1466c2c9c7c9818031a8d556256a6db1b69f6a77b5806", - strip_prefix = "freedoom-0.12.1/", - urls = [ - "https://github.com/freedoom/freedoom/releases/download/v0.12.1/freedoom-0.12.1.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/freedoom/freedoom/freedoom-0.12.1.zip", - ], - build_file = "//third_party/freedoom:freedoom.BUILD", - ) - - maybe( - http_archive, - name = "vizdoom", - sha256 = "e379a242ada7e1028b7a635da672b0936d99da3702781b76a4400b83602d78c4", - strip_prefix = "ViZDoom-1.1.13/src/vizdoom/", - urls = [ - "https://github.com/Farama-Foundation/ViZDoom/archive/refs/tags/1.1.13.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Farama-Foundation/ViZDoom/1.1.13.tar.gz", - ], - build_file = "//third_party/vizdoom:vizdoom.BUILD", - patches = [ - "//third_party/vizdoom:sdl_thread.patch", - ], - ) - - maybe( - http_archive, - name = "vizdoom_lib", - sha256 = "e379a242ada7e1028b7a635da672b0936d99da3702781b76a4400b83602d78c4", - strip_prefix = "ViZDoom-1.1.13/", - urls = [ - "https://github.com/Farama-Foundation/ViZDoom/archive/refs/tags/1.1.13.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Farama-Foundation/ViZDoom/1.1.13.tar.gz", - ], - build_file = "//third_party/vizdoom_lib:vizdoom_lib.BUILD", - ) - - maybe( - http_archive, - name = "vizdoom_extra_maps", - sha256 = "325440fe566ff478f35947c824ea5562e2735366845d36c5a0e40867b59f7d69", - strip_prefix = "DirectFuturePrediction-b4757769f167f1bd7fb1ece5fdc6d874409c68a9/", - urls = [ - "https://github.com/isl-org/DirectFuturePrediction/archive/b4757769f167f1bd7fb1ece5fdc6d874409c68a9.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/isl-org/DirectFuturePrediction/b4757769f167f1bd7fb1ece5fdc6d874409c68a9.zip", - ], - build_file = "//third_party/vizdoom_extra_maps:vizdoom_extra_maps.BUILD", - ) - - maybe( - http_archive, - name = "mujoco", - sha256 = "d1cb3a720546240d894cd315b7fd358a2b96013a1f59b6d718036eca6f6edac2", - strip_prefix = "mujoco-2.2.1", - urls = [ - "https://github.com/deepmind/mujoco/releases/download/2.2.1/mujoco-2.2.1-linux-x86_64.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/deepmind/mujoco/mujoco-2.2.1-linux-x86_64.tar.gz", - ], - build_file = "//third_party/mujoco:mujoco.BUILD", - ) - - maybe( - http_archive, - name = "mujoco_gym_xml", - sha256 = "96a5fc8345bd92b73a15fc25112d53a294f86fcace1c5e4ef7f0e052b5e1bdf4", - strip_prefix = "gym-0.26.2/gym/envs/mujoco", - urls = [ - "https://github.com/openai/gym/archive/refs/tags/0.26.2.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/openai/gym/0.26.2.tar.gz", - ], - build_file = "//third_party/mujoco_gym_xml:mujoco_gym_xml.BUILD", - ) - - maybe( - http_archive, - name = "mujoco_dmc_xml", - sha256 = "fb8d57cbeb92bebe56a992dab8401bc00b3bff61b62526eb563854adf3dfb595", - strip_prefix = "dm_control-1.0.9/dm_control", - urls = [ - "https://github.com/deepmind/dm_control/archive/refs/tags/1.0.9.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/deepmind/dm_control/1.0.9.tar.gz", - ], - build_file = "//third_party/mujoco_dmc_xml:mujoco_dmc_xml.BUILD", - ) - - maybe( - http_archive, - name = "box2d", - sha256 = "d6b4650ff897ee1ead27cf77a5933ea197cbeef6705638dd181adc2e816b23c2", - strip_prefix = "box2d-2.4.1", - urls = [ - "https://github.com/erincatto/box2d/archive/refs/tags/v2.4.1.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/erincatto/box2d/v2.4.1.tar.gz", - ], - build_file = "//third_party/box2d:box2d.BUILD", - ) - - # Atari/VizDoom pretrained weight for testing pipeline - - maybe( - http_archive, - name = "pretrain_weight", - sha256 = "b1b64e0db84cf7317c2a96b27f549147dfcb4074ed2d799334c23a067075ac1c", - urls = [ - "https://cdn.sail.sea.com/sail/pretrain.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/pretrain.tar.gz", - ], - build_file = "//third_party/pretrain_weight:pretrain_weight.BUILD", - ) - - maybe( - http_archive, - name = "procgen", - sha256 = "d5620394418b885f9028f98759189a5f78bc4ba71fb6605f910ae22fca870c8e", - strip_prefix = "procgen-0.10.8/procgen", - urls = [ - "https://github.com/Trinkle23897/procgen/archive/refs/tags/0.10.8.tar.gz", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/Trinkle23897/procgen/0.10.8.tar.gz", - ], - build_file = "//third_party/procgen:procgen.BUILD", - ) - - maybe( - http_archive, - name = "gym3_libenv", - sha256 = "9a764d79d4215609c2612b2c84fec8bcea6609941bdcb7051f3335ed4576b8ef", - strip_prefix = "gym3-4c3824680eaf9dd04dce224ee3d4856429878226/gym3", - urls = [ - "https://github.com/openai/gym3/archive/4c3824680eaf9dd04dce224ee3d4856429878226.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/openai/gym3/4c3824680eaf9dd04dce224ee3d4856429878226.zip", - ], - build_file = "//third_party/gym3_libenv:gym3_libenv.BUILD", - ) - - maybe( - http_archive, - name = "bazel_clang_tidy", - sha256 = "ec8c5bf0c02503b928c2e42edbd15f75e306a05b2cae1f34a7bc84724070b98b", - strip_prefix = "bazel_clang_tidy-783aa523aafb4a6798a538c61e700b6ed27975a7", - urls = [ - "https://github.com/erenon/bazel_clang_tidy/archive/783aa523aafb4a6798a538c61e700b6ed27975a7.zip", - "https://ml.cs.tsinghua.edu.cn/~jiayi/envpool/erenon/bazel_clang_tidy/783aa523aafb4a6798a538c61e700b6ed27975a7.zip", - ], - ) - - maybe( - cuda_configure, - name = "cuda", - ) - -workspace0 = workspace diff --git a/envpool/workspace1.bzl b/envpool/workspace1.bzl deleted file mode 100644 index 94100b63..00000000 --- a/envpool/workspace1.bzl +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2021 Garena Online Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""EnvPool workspace initialization, load after workspace0.""" - -load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") -load("@com_justbuchanan_rules_qt//:qt_configure.bzl", "qt_configure") -load("@pybind11_bazel//:python_configure.bzl", "python_configure") -load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") - -def workspace(): - """Configure pip requirements.""" - python_configure( - name = "local_config_python", - python_version = "3", - ) - - rules_foreign_cc_dependencies() - - boost_deps() - - qt_configure() - -workspace1 = workspace diff --git a/setup.py b/setup.py deleted file mode 100644 index 367778f9..00000000 --- a/setup.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 - -from setuptools import setup -from setuptools.command.install import install -from setuptools.dist import Distribution - - -class InstallPlatlib(install): - """Fix auditwheel error, https://github.com/google/or-tools/issues/616""" - - def finalize_options(self) -> None: - install.finalize_options(self) - if self.distribution.has_ext_modules(): - self.install_lib = self.install_platlib - - -class BinaryDistribution(Distribution): - - def is_pure(self) -> bool: - return False - - def has_ext_modules(foo) -> bool: - return True - - -if __name__ == '__main__': - setup(distclass=BinaryDistribution, cmdclass={'install': InstallPlatlib}) diff --git a/third_party/ale/ale.BUILD b/third_party/ale/ale.BUILD index ba1ddb3e..4f869638 100644 --- a/third_party/ale/ale.BUILD +++ b/third_party/ale/ale.BUILD @@ -4,10 +4,13 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "irregular_files", - hdrs = glob([ - "src/**/*.def", - "src/**/*.ins", - ]), + hdrs = glob( + [ + "src/**/*.def", + "src/**/*.ins", + ], + allow_empty = True, + ), ) template_rule( @@ -34,6 +37,7 @@ cc_library( "src/**/*.cpp", "src/**/*.cxx", ], + allow_empty = True, exclude = [ "src/python/*", ], @@ -41,6 +45,9 @@ cc_library( ":ale_version", ], hdrs = ["src/ale_interface.hpp"], + copts = [ + "-include stdint.h", + ], includes = [ "src", "src/common", diff --git a/third_party/atari_roms/atari_roms.BUILD b/third_party/atari_roms/atari_roms.BUILD index c55497e5..e7945a68 100644 --- a/third_party/atari_roms/atari_roms.BUILD +++ b/third_party/atari_roms/atari_roms.BUILD @@ -1,5 +1,7 @@ +load("@envpool//third_party:common.bzl", "copy_directory") + filegroup( - name = "roms", + name = "roms_sources", srcs = glob( ["ROM/*/*.bin"], exclude = [ @@ -9,5 +11,5 @@ filegroup( "ROM/warlords/warlords.bin", ], ), - visibility = ["//visibility:public"], + visibility=["//visibility:public"], ) diff --git a/third_party/common.bzl b/third_party/common.bzl index 95c972ea..42032d5a 100644 --- a/third_party/common.bzl +++ b/third_party/common.bzl @@ -59,3 +59,37 @@ template_rule = rule( output_to_genfiles = True, implementation = template_rule_impl, ) + +def _copy_directory_impl(ctx): + # Use the label as the output directory name, so that the output directory is unique (within a BUILD file) + output_dir = ctx.actions.declare_directory(ctx.label.name) + + # Create commands that copy all input files to the output directory + commands = [] + dest_dir = output_dir.path + "/" + ctx.attr.out + commands.append("mkdir -p %s" % dest_dir) + for src in ctx.files.src: + commands.append("cp -r %s %s/" % (src.path, dest_dir)) + + ctx.actions.run_shell( + inputs = ctx.files.src, + outputs = [output_dir], + command = " && ".join(commands), + ) + + return [DefaultInfo(files = depset([output_dir]))] + +copy_directory = rule( + implementation = _copy_directory_impl, + attrs = { + "src": attr.label( + mandatory = True, + allow_files = True, + doc = "Source directory or files to copy", + ), + "out": attr.string( + default = "", + doc = "Optional subdirectory path within the output directory to copy files to", + ), + }, +) diff --git a/third_party/pip_requirements/BUILD b/third_party/pip_requirements/BUILD index 3c506229..fe04025d 100644 --- a/third_party/pip_requirements/BUILD +++ b/third_party/pip_requirements/BUILD @@ -14,4 +14,7 @@ package(default_visibility = ["//visibility:public"]) -exports_files(["requirements.txt"]) +exports_files([ + "requirements.txt", + "requirements-release.txt", +]) diff --git a/third_party/pip_requirements/requirements-dev-locked.txt b/third_party/pip_requirements/requirements-dev-locked.txt new file mode 100644 index 00000000..26e66447 --- /dev/null +++ b/third_party/pip_requirements/requirements-dev-locked.txt @@ -0,0 +1,249 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile -o requirements-dev-locked.txt requirements-dev.txt +absl-py==2.1.0 + # via + # -r requirements-dev.txt + # dm-control + # dm-env + # dm-tree + # labmaze + # mujoco + # tensorboard +attrs==25.1.0 + # via dm-tree +box2d-py==2.3.8 + # via -r requirements-dev.txt +certifi==2025.1.31 + # via requests +cffi==1.17.1 + # via mujoco-py +charset-normalizer==3.4.1 + # via requests +cloudpickle==3.1.1 + # via + # gym + # gymnasium +cython==3.0.12 + # via mujoco-py +dm-control==1.0.7 + # via -r requirements-dev.txt +dm-env==1.6 + # via + # -r requirements-dev.txt + # dm-control +dm-tree==0.1.9 + # via + # dm-control + # dm-env +farama-notifications==0.0.4 + # via gymnasium +fasteners==0.19 + # via mujoco-py +filelock==3.17.0 + # via torch +fsspec==2025.2.0 + # via torch +glfw==2.8.0 + # via + # dm-control + # mujoco + # mujoco-py +grpcio==1.70.0 + # via tensorboard +gym==0.26.2 + # via -r requirements-dev.txt +gym-notices==0.0.8 + # via gym +gymnasium==1.1.0 + # via + # -r requirements-dev.txt + # minigrid + # pettingzoo + # tianshou +h5py==3.13.0 + # via tianshou +idna==3.10 + # via requests +imageio==2.37.0 + # via mujoco-py +iniconfig==2.0.0 + # via pytest +jax==0.5.1 + # via -r requirements-dev.txt +jaxlib==0.5.1 + # via jax +jinja2==3.1.5 + # via torch +labmaze==1.0.6 + # via dm-control +llvmlite==0.36.0 + # via numba +lxml==5.3.1 + # via dm-control +markdown==3.7 + # via tensorboard +markupsafe==3.0.2 + # via + # jinja2 + # werkzeug +minigrid==3.0.0 + # via -r requirements-dev.txt +ml-dtypes==0.5.1 + # via + # jax + # jaxlib +mpmath==1.3.0 + # via sympy +mujoco==2.2.2 + # via + # -r requirements-dev.txt + # dm-control +mujoco-py==2.1.2.14 + # via -r requirements-dev.txt +networkx==3.4.2 + # via torch +numba==0.53.1 + # via tianshou +numpy==2.2.3 + # via + # -r requirements-dev.txt + # dm-control + # dm-env + # dm-tree + # gym + # gymnasium + # h5py + # imageio + # jax + # jaxlib + # labmaze + # minigrid + # ml-dtypes + # mujoco + # mujoco-py + # numba + # opencv-python-headless + # pettingzoo + # scipy + # tensorboard + # tianshou +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.6.2 + # via torch +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +opencv-python-headless==4.11.0.86 + # via -r requirements-dev.txt +opt-einsum==3.4.0 + # via jax +optree==0.14.1 + # via -r requirements-dev.txt +packaging==24.2 + # via + # -r requirements-dev.txt + # pytest + # tensorboard + # tianshou +pettingzoo==1.24.3 + # via tianshou +pillow==11.1.0 + # via imageio +pluggy==1.5.0 + # via pytest +protobuf==3.20.3 + # via + # -r requirements-dev.txt + # dm-control + # tensorboard +pycparser==2.22 + # via cffi +pygame==2.6.1 + # via + # -r requirements-dev.txt + # minigrid +pyopengl==3.1.9 + # via + # dm-control + # mujoco +pyparsing==2.4.7 + # via dm-control +pytest==8.3.4 + # via -r requirements-dev.txt +requests==2.32.3 + # via dm-control +scipy==1.15.2 + # via + # dm-control + # jax + # jaxlib +setuptools==75.8.2 + # via + # -r requirements-dev.txt + # dm-control + # labmaze + # numba + # tensorboard + # torch +six==1.17.0 + # via tensorboard +sympy==1.13.1 + # via torch +tensorboard==2.19.0 + # via tianshou +tensorboard-data-server==0.7.2 + # via tensorboard +tianshou==0.5.1 + # via -r requirements-dev.txt +torch==2.6.0 + # via + # -r requirements-dev.txt + # tianshou +tqdm==4.67.1 + # via + # -r requirements-dev.txt + # dm-control + # tianshou +triton==3.2.0 + # via torch +typing-extensions==4.12.2 + # via + # gymnasium + # optree + # torch +urllib3==2.3.0 + # via requests +werkzeug==3.1.3 + # via tensorboard +wheel==0.45.1 + # via -r requirements-dev.txt +wrapt==1.17.2 + # via dm-tree diff --git a/third_party/pip_requirements/requirements-release.txt b/third_party/pip_requirements/requirements-release.txt index de0f9770..d9b6ab28 100644 --- a/third_party/pip_requirements/requirements-release.txt +++ b/third_party/pip_requirements/requirements-release.txt @@ -1,10 +1,9 @@ -setuptools==70.3.0 # last setuptools version that doesn't give the Lorem Ipsum.txt error -wheel -numpy==1.26.4 # test_load_sequentially_with_multiple_envs fails with latest version -dm-env +# Used for the actual requirements of the envpool wheel that we build + +numpy>=2.2.0 +dm-env>=1.6 gym>=0.26 gymnasium>=0.26,!=0.27.0 optree>=0.6.0 -jax[cpu]==0.4.27 # test_xla fails with latest version -packaging +jax>=0.5.0 pytest diff --git a/third_party/pip_requirements/requirements-sokoban.txt b/third_party/pip_requirements/requirements-sokoban.txt deleted file mode 120000 index 6829e68e..00000000 --- a/third_party/pip_requirements/requirements-sokoban.txt +++ /dev/null @@ -1 +0,0 @@ -requirements-release.txt \ No newline at end of file diff --git a/third_party/procgen/procgen.BUILD b/third_party/procgen/procgen.BUILD index b89fb9e5..7dde83b9 100644 --- a/third_party/procgen/procgen.BUILD +++ b/third_party/procgen/procgen.BUILD @@ -17,8 +17,17 @@ filegroup( cc_library( name = "procgen", - srcs = glob(["src/**/*.cpp"]) + glob(["src/*.h"]), - hdrs = glob(["src/*.h"]), + srcs = glob( + ["src/**/*.cpp"], + allow_empty = True, + ) + glob( + ["src/*.h"], + allow_empty = True, + ), + hdrs = glob( + ["src/*.h"], + allow_empty = True, + ), copts = [ "-fpic", ], diff --git a/third_party/vizdoom/vizdoom.BUILD b/third_party/vizdoom/vizdoom.BUILD index 2bbfef52..06b8760d 100644 --- a/third_party/vizdoom/vizdoom.BUILD +++ b/third_party/vizdoom/vizdoom.BUILD @@ -336,8 +336,6 @@ cc_library( copts = [ "-Dstricmp=strcasecmp", "-Dstrnicmp=strncasecmp", - "-fno-tree-dominator-opts", - "-fno-tree-fre", "-include $(execpath @glibc_version_header//:glibc_2_17)", ], includes = [ @@ -754,6 +752,9 @@ cc_binary( "-mmmx", "-include $(execpath @glibc_version_header//:glibc_2_17)", ], + cxxopts = [ + "-std=c++11", # vizdoom uses register in class variables, which is forbidden in C++17 + ], data = [ ":vizdoom_pk3", ],