-
-
Notifications
You must be signed in to change notification settings - Fork 63
Implement initial cross-compilation support #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,27 +77,30 @@ def _collect_specs(ctx, attr, redist, the_url): | |
| elif _is_windows(ctx): | ||
| os = "windows" | ||
|
|
||
| arch = "x86_64" # TODO: support cross compiling | ||
| platform = "{os}-{arch}".format(os = os, arch = arch) | ||
| components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist] | ||
|
|
||
| for c in components: | ||
| c_full = FULL_COMPONENT_NAME[c] | ||
|
|
||
| payload = redist[c_full][platform] | ||
| payload_relative_path = payload["relative_path"] | ||
| payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path | ||
| archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive" | ||
| desc_name = redist[c_full].get("name", c_full) | ||
|
|
||
| specs.append({ | ||
| "component_name": c, | ||
| "descriptive_name": desc_name, | ||
| "urls": [payload_url], | ||
| "sha256": payload["sha256"], | ||
| "strip_prefix": archive_name, | ||
| "version": redist[c_full]["version"], | ||
| }) | ||
| archs = attr.archs or ["x86_64"] | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea here is to iterate over the specified archs and add everything to the specs list. |
||
|
|
||
| for arch in archs: | ||
| platform = "{os}-{arch}".format(os = os, arch = arch) | ||
| components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist] | ||
|
|
||
| for c in components: | ||
| c_full = FULL_COMPONENT_NAME[c] | ||
|
|
||
| payload = redist[c_full][platform] | ||
| payload_relative_path = payload["relative_path"] | ||
| payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path | ||
| archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive" | ||
| desc_name = redist[c_full].get("name", c_full) | ||
|
|
||
| specs.append({ | ||
| "component_name": c, | ||
| "descriptive_name": desc_name, | ||
| "urls": [payload_url], | ||
| "sha256": payload["sha256"], | ||
| "strip_prefix": archive_name, | ||
| "version": redist[c_full]["version"], | ||
| "arch": arch, | ||
| }) | ||
|
|
||
| return specs | ||
|
|
||
|
|
@@ -113,12 +116,25 @@ def _get_repo_name(ctx, spec): | |
| version = spec.get("version", None) | ||
| if version != None: | ||
| repo_name = repo_name + "_v" + version | ||
| arch = spec.get("arch", None) | ||
| if arch != None: | ||
| repo_name = repo_name + "_" + arch | ||
|
|
||
| return repo_name | ||
|
|
||
| def _get_repo_mapping_key(spec): | ||
| """Generate a key (string) to identify a component/arch pair (e.g., "cublas__x86_64"). | ||
|
|
||
| Args: | ||
| spec: cuda_component attrs | ||
| """ | ||
|
|
||
| return "{}__{}".format(spec["component_name"], spec["arch"]) | ||
|
|
||
| redist_json_helper = struct( | ||
| get = _get, | ||
| get_redist_version = _get_redist_version, | ||
| collect_specs = _collect_specs, | ||
| get_repo_name = _get_repo_name, | ||
| get_repo_mapping_key = _get_repo_mapping_key, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,9 @@ def _is_linux(ctx): | |
| def _is_windows(ctx): | ||
| return ctx.os.name.lower().startswith("windows") | ||
|
|
||
| def _is_aarch64(ctx): | ||
| return ctx.os.arch == "aarch64" | ||
|
|
||
| def _get_nvcc_version(repository_ctx, nvcc_root): | ||
| result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"]) | ||
| if result.return_code != 0: | ||
|
|
@@ -90,10 +93,16 @@ def _detect_deliverable_cuda_toolkit(repository_ctx): | |
| # NOTE: component nvcc contains some headers that will be used. | ||
| required_components = ["cccl", "cudart", "nvcc"] | ||
| for rc in required_components: | ||
| if rc not in repository_ctx.attr.components_mapping: | ||
| fail('component "{}" is required.'.format(rc)) | ||
| for arch in repository_ctx.attr.archs: | ||
| mapping_key = redist_json_helper.get_repo_mapping_key(rc, arch) | ||
| if mapping_key not in repository_ctx.attr.components_mapping: | ||
| fail('component "{}" for {} is required.'.format(rc, arch)) | ||
|
|
||
| nvcc_repo = repository_ctx.attr.components_mapping["nvcc"] | ||
| nvcc_repo = repository_ctx.attr.components_mapping.get("nvcc", None) | ||
| if not nvcc_repo: | ||
| host_arch = "aarch64" if _is_aarch64(repository_ctx) else "x86_64" | ||
| nvcc_mapping_key = redist_json_helper.get_repo_mapping_key("nvcc", host_arch) | ||
| nvcc_repo = repository_ctx.attr.components_mapping[nvcc_mapping_key] | ||
|
|
||
| bin_ext = ".exe" if _is_windows(repository_ctx) else "" | ||
| nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext) | ||
|
|
@@ -276,6 +285,7 @@ cuda_toolkit = repository_rule( | |
| "nvcc_version": attr.string( | ||
| doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.", | ||
| ), | ||
| "archs": attr.string_list(doc = "list of host target architectures to support (e.g., x86_64)"), | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not fully plumbed yet. |
||
| }, | ||
| configure = True, | ||
| local = True, | ||
|
|
@@ -372,6 +382,7 @@ cuda_component = repository_rule( | |
| "If all downloads fail, the rule will fail.", | ||
| ), | ||
| "version": attr.string(doc = "A unique version number for component. Store in version.json file"), | ||
| "arch": attr.string(doc = "The supported target host architecture (e.g., x86_64, aarch64). If omitted, x86_64 is assumed."), | ||
| }, | ||
| ) | ||
|
|
||
|
|
@@ -425,7 +436,7 @@ def rules_cuda_dependencies(): | |
| ], | ||
| ) | ||
|
|
||
| def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False): | ||
| def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False, archs = None): | ||
| """Populate the @cuda repo. | ||
|
|
||
| Args: | ||
|
|
@@ -442,6 +453,7 @@ def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, versio | |
| components_mapping = components_mapping, | ||
| version = version, | ||
| nvcc_version = nvcc_version, | ||
| archs = archs, | ||
| ) | ||
|
|
||
| if register_toolchains: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,7 +144,7 @@ def _cuda_library_impl(ctx): | |
| ), | ||
| ] | ||
|
|
||
| cuda_library = rule( | ||
| _cuda_library = rule( | ||
| doc = """This rule compiles and creates static library for CUDA kernel code. The resulting targets can then be consumed by | ||
| [C/C++ Rules](https://bazel.build/reference/be/c-cpp#rules).""", | ||
| implementation = _cuda_library_impl, | ||
|
|
@@ -179,3 +179,20 @@ cuda_library = rule( | |
| toolchains = use_cpp_toolchain() + use_cuda_toolchain(), | ||
| provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo], | ||
| ) | ||
|
|
||
| def cuda_library(name, **kwargs): | ||
| copts = kwargs.get("copts", []) + select({ | ||
| "@platforms//cpu:x86_64": [ | ||
| "-Xcompiler", | ||
| "--target=x86_64-unknown-linux-gnu", | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This information is already available in the CC toolchain. We need to pass it to the host compiler so that it looks in the right places in the sysroot. But this is probably not the best solution ... |
||
| ], | ||
| "@platforms//cpu:aarch64": [ | ||
| "-Xcompiler", | ||
| "--target=aarch64-unknown-linux-gnu", | ||
| ], | ||
| }) | ||
| _cuda_library( | ||
| name = name, | ||
| copts = copts, | ||
| **kwargs | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| platform( | ||
| name = "linux_x86_64", | ||
| constraint_values = [ | ||
| "@platforms//os:linux", | ||
| "@platforms//cpu:x86_64", | ||
| ], | ||
| ) | ||
|
|
||
| platform( | ||
| name = "linux_aarch64", | ||
| constraint_values = [ | ||
| "@platforms//os:linux", | ||
| "@platforms//cpu:aarch64", | ||
| ], | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,21 +4,98 @@ module( | |
| compatibility_level = 1, | ||
| ) | ||
|
|
||
| bazel_dep(name = "platforms", version = "0.0.11") | ||
|
|
||
| bazel_dep(name = "rules_cuda", version = "0.2.3") | ||
| local_path_override( | ||
| module_name = "rules_cuda", | ||
| path = "..", | ||
| ) | ||
|
|
||
| cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") | ||
| cuda.redist_json( | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I extended the examples to test this. We will probably need something in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we need test, something at least cross compile from x64 to aarch64 |
||
| name = "rules_cuda_redist_json", | ||
| components = [ | ||
| "cccl", | ||
| "cudart", | ||
| "nvcc", | ||
| "cublas", | ||
| ], | ||
| archs = [ | ||
| "x86_64", | ||
| "aarch64", | ||
| ], | ||
| version = "12.8.1", | ||
| ) | ||
| cuda.toolkit( | ||
| name = "cuda", | ||
| toolkit_path = "", | ||
| ) | ||
| use_repo(cuda, "cuda") | ||
|
|
||
| ################################# | ||
| # Dependencies for nccl example # | ||
| ################################# | ||
| # See WORKSPACE.bzlmod for the remaining parts | ||
| bazel_dep(name = "bazel_skylib", version = "1.4.2") | ||
| bazel_dep(name = "bazel_skylib", version = "1.7.1") | ||
|
|
||
| ################################# | ||
| # LLVM toolchain | ||
| ################################# | ||
|
|
||
| bazel_dep(name = "toolchains_llvm", version = "1.5.0") | ||
|
|
||
| # Configure and register the toolchain. | ||
| llvm = use_extension("@toolchains_llvm//toolchain/extensions:llvm.bzl", "llvm") | ||
| llvm.toolchain( | ||
| name = "llvm_toolchain", | ||
| cxx_standard = {"": "c++20"}, | ||
| llvm_version = "17.0.6", | ||
| stdlib = { | ||
| "": "dynamic-stdc++", | ||
| }, | ||
| ) | ||
|
|
||
| http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") | ||
|
|
||
| http_archive( | ||
| name = "org_chromium_sysroot_linux_x86_64", | ||
| build_file_content = """ | ||
| filegroup( | ||
| name = "sysroot", | ||
| srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]), | ||
| visibility = ["//visibility:public"], | ||
| ) | ||
| """, | ||
| sha256 = "36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50", | ||
| urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50"], | ||
| type = "tar.xz", | ||
| ) | ||
|
|
||
| llvm.sysroot( | ||
| name = "llvm_toolchain", | ||
| label = "@org_chromium_sysroot_linux_x86_64//:sysroot", | ||
| targets = ["linux-x86_64"], | ||
| ) | ||
|
|
||
| http_archive( | ||
| name = "org_chromium_sysroot_linux_aarch64", | ||
| build_file_content = """ | ||
| filegroup( | ||
| name = "sysroot", | ||
| srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]), | ||
| visibility = ["//visibility:public"], | ||
| ) | ||
| """, | ||
| sha256 = "2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18", | ||
| urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18"], | ||
| type = "tar.xz", | ||
| ) | ||
|
|
||
| llvm.sysroot( | ||
| name = "llvm_toolchain", | ||
| label = "@org_chromium_sysroot_linux_aarch64//:sysroot", | ||
| targets = ["linux-aarch64"], | ||
| ) | ||
| use_repo(llvm, "llvm_toolchain") | ||
|
|
||
| register_toolchains("@llvm_toolchain//:all") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,5 +2,8 @@ | |
| cc_binary( | ||
| name = "main", | ||
| srcs = ["cublas.cpp"], | ||
| deps = ["@cuda//:cublas"], | ||
| deps = [ | ||
| "@cuda//:cublas", | ||
| "@cuda//:cuda_runtime", | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't compile without depending on |
||
| ], | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| linux_x86_64 = [ | ||
| "@platforms//os:linux", | ||
| "@platforms//cpu:x86_64", | ||
| ] | ||
|
|
||
| linux_aarch64 = [ | ||
| "@platforms//os:linux", | ||
| "@platforms//cpu:aarch64", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For attr.components, it now do "If not specified, all components will be used.", I think archs should behave similarly. This also simplify the usage of
redist_jsonquite a bitThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with this is that not all releases provide
aarch64packages. If using "all" means aarch64 and x86_64, some releases (such as 13.0.0) won't work. We could check the JSON file for available archs, so that "all" would include aarch64 only if it's available. However, then the behavior would change depending on the selected release, which could be quite surprising for users.A further complication is that there is also
sbsa, which so far I'm ignoring.