Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cuda_redist_json_tag = tag_class(attrs = {
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
),
"archs": attr.string_list(doc = "Target architectures to support. If not specified, only x86_64 will be supported."),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For attr.components, it now do "If not specified, all components will be used.", I think archs should behave similarly. This also simplify the usage of redist_json quite a bit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with this is that not all releases provide aarch64 packages. If using "all" means aarch64 and x86_64, some releases (such as 13.0.0) won't work. We could check the JSON file for available archs, so that "all" would include aarch64 only if it's available. However, then the behavior would change depending on the selected release, which could be quite surprising for users.

A further complication is that there is also sbsa, which so far I'm ignoring.

})

cuda_toolkit_tag = tag_class(attrs = {
Expand Down Expand Up @@ -100,7 +101,8 @@ def _redist_json_impl(module_ctx, attr):
mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = "@" + repo_name
component_name = spec["component_name"]
mapping[redist_json_helper.get_repo_mapping_key(component_name, spec["arch"])] = "@" + repo_name

attr = {key: value for key, value in spec.items()}
attr["name"] = repo_name
Expand Down
58 changes: 37 additions & 21 deletions cuda/private/redist_json_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,30 @@ def _collect_specs(ctx, attr, redist, the_url):
elif _is_windows(ctx):
os = "windows"

arch = "x86_64" # TODO: support cross compiling
platform = "{os}-{arch}".format(os = os, arch = arch)
components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist]

for c in components:
c_full = FULL_COMPONENT_NAME[c]

payload = redist[c_full][platform]
payload_relative_path = payload["relative_path"]
payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path
archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive"
desc_name = redist[c_full].get("name", c_full)

specs.append({
"component_name": c,
"descriptive_name": desc_name,
"urls": [payload_url],
"sha256": payload["sha256"],
"strip_prefix": archive_name,
"version": redist[c_full]["version"],
})
archs = attr.archs or ["x86_64"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is to iterate over the specified archs and add everything to the specs list.


for arch in archs:
platform = "{os}-{arch}".format(os = os, arch = arch)
components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist]

for c in components:
c_full = FULL_COMPONENT_NAME[c]

payload = redist[c_full][platform]
payload_relative_path = payload["relative_path"]
payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path
archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive"
desc_name = redist[c_full].get("name", c_full)

specs.append({
"component_name": c,
"descriptive_name": desc_name,
"urls": [payload_url],
"sha256": payload["sha256"],
"strip_prefix": archive_name,
"version": redist[c_full]["version"],
"arch": arch,
})

return specs

Expand All @@ -113,12 +116,25 @@ def _get_repo_name(ctx, spec):
version = spec.get("version", None)
if version != None:
repo_name = repo_name + "_v" + version
arch = spec.get("arch", None)
if arch != None:
repo_name = repo_name + "_" + arch

return repo_name

def _get_repo_mapping_key(spec):
"""Generate a key (string) to identify a component/arch pair (e.g., "cublas__x86_64").

Args:
spec: cuda_component attrs
"""

return "{}__{}".format(spec["component_name"], spec["arch"])

redist_json_helper = struct(
get = _get,
get_redist_version = _get_redist_version,
collect_specs = _collect_specs,
get_repo_name = _get_repo_name,
get_repo_mapping_key = _get_repo_mapping_key,
)
20 changes: 16 additions & 4 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def _is_linux(ctx):
def _is_windows(ctx):
return ctx.os.name.lower().startswith("windows")

def _is_aarch64(ctx):
return ctx.os.arch == "aarch64"

def _get_nvcc_version(repository_ctx, nvcc_root):
result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"])
if result.return_code != 0:
Expand Down Expand Up @@ -90,10 +93,16 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
# NOTE: component nvcc contains some headers that will be used.
required_components = ["cccl", "cudart", "nvcc"]
for rc in required_components:
if rc not in repository_ctx.attr.components_mapping:
fail('component "{}" is required.'.format(rc))
for arch in repository_ctx.attr.archs:
mapping_key = redist_json_helper.get_repo_mapping_key(rc, arch)
if mapping_key not in repository_ctx.attr.components_mapping:
fail('component "{}" for {} is required.'.format(rc, arch))

nvcc_repo = repository_ctx.attr.components_mapping["nvcc"]
nvcc_repo = repository_ctx.attr.components_mapping.get("nvcc", None)
if not nvcc_repo:
host_arch = "aarch64" if _is_aarch64(repository_ctx) else "x86_64"
nvcc_mapping_key = redist_json_helper.get_repo_mapping_key("nvcc", host_arch)
nvcc_repo = repository_ctx.attr.components_mapping[nvcc_mapping_key]

bin_ext = ".exe" if _is_windows(repository_ctx) else ""
nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext)
Expand Down Expand Up @@ -276,6 +285,7 @@ cuda_toolkit = repository_rule(
"nvcc_version": attr.string(
doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.",
),
"archs": attr.string_list(doc = "list of host target architectures to support (e.g., x86_64)"),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not fully plumbed yet.

},
configure = True,
local = True,
Expand Down Expand Up @@ -372,6 +382,7 @@ cuda_component = repository_rule(
"If all downloads fail, the rule will fail.",
),
"version": attr.string(doc = "A unique version number for component. Store in version.json file"),
"arch": attr.string(doc = "The supported target host architecture (e.g., x86_64, aarch64). If omitted, x86_64 is assumed."),
},
)

Expand Down Expand Up @@ -425,7 +436,7 @@ def rules_cuda_dependencies():
],
)

def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False):
def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False, archs = None):
"""Populate the @cuda repo.

Args:
Expand All @@ -442,6 +453,7 @@ def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, versio
components_mapping = components_mapping,
version = version,
nvcc_version = nvcc_version,
archs = archs,
)

if register_toolchains:
Expand Down
19 changes: 18 additions & 1 deletion cuda/private/rules/cuda_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _cuda_library_impl(ctx):
),
]

cuda_library = rule(
_cuda_library = rule(
doc = """This rule compiles and creates static library for CUDA kernel code. The resulting targets can then be consumed by
[C/C++ Rules](https://bazel.build/reference/be/c-cpp#rules).""",
implementation = _cuda_library_impl,
Expand Down Expand Up @@ -179,3 +179,20 @@ cuda_library = rule(
toolchains = use_cpp_toolchain() + use_cuda_toolchain(),
provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo],
)

def cuda_library(name, **kwargs):
copts = kwargs.get("copts", []) + select({
"@platforms//cpu:x86_64": [
"-Xcompiler",
"--target=x86_64-unknown-linux-gnu",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This information is already available in the CC toolchain. We need to pass it to the host compiler so that it looks in the right places in the sysroot. But this is probably not the best solution ...

],
"@platforms//cpu:aarch64": [
"-Xcompiler",
"--target=aarch64-unknown-linux-gnu",
],
})
_cuda_library(
name = name,
copts = copts,
**kwargs
)
32 changes: 26 additions & 6 deletions cuda/private/template_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ def _to_forward_slash(s):
def _is_linux(ctx):
return ctx.os.name.startswith("linux")

def _is_windows(ctx):
return ctx.os.name.lower().startswith("windows")
def _is_x86_64(ctx):
return ctx.os.arch == "amd64"

def _is_aarch64(ctx):
return ctx.os.arch == "aarch64"

def _expand_template(repository_ctx, tpl_label, substitutions):
template_content = "# Generated from fragment " + str(tpl_label) + "\n"
Expand Down Expand Up @@ -80,11 +83,28 @@ def _generate_build_impl(repository_ctx, libpath, components, is_cuda_repo, is_d
template_content.append(repository_ctx.read(frag))

if is_cuda_repo and is_deliverable: # generate `@cuda//BUILD` for CTK with deliverables
for comp in components:
component_names = set([name.split("__")[0] for name in components])
archs = set([name.split("__")[1] for name in components])

for comp in component_names:
repo_x86_64 = components[comp + "__x86_64"]
repo_aarch64 = components[comp + "__aarch64"]
for target in REGISTRY[comp]:
repo = components[comp]
line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo)
template_content.append(line)
if comp == "nvcc":
repo_host_architecture = repo_aarch64 if _is_aarch64(repository_ctx) else repo_x86_64
alias_line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo_host_architecture)
else:
alias_line = """
alias(
name = "{target}",
actual = select({{
"@platforms//cpu:x86_64": "{repo_x86_64}//:{target}",
"@platforms//cpu:aarch64": "{repo_aarch64}//:{target}",
}})
)
""".format(target = target, repo_x86_64 = repo_x86_64, repo_aarch64 = repo_aarch64)

template_content.append(alias_line)

# add an empty line to separate aliased targets from different components
template_content.append("")
Expand Down
13 changes: 13 additions & 0 deletions examples/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@ build --flag_alias=cuda_copts=@rules_cuda//cuda:copts
build --flag_alias=cuda_runtime=@rules_cuda//cuda:runtime

build --enable_cuda=True
build --cuda_archs=sm_86,sm_89,sm_120

build:aarch64 --cuda_archs=sm_87

# Use --config=clang to build with clang instead of gcc and nvcc.
build:clang --repo_env=CC=clang
build:clang --@rules_cuda//cuda:compiler=clang

# https://github.com/bazel-contrib/rules_cuda/issues/1
# build --ui_event_filters=-INFO

common --enable_bzlmod
common --noenable_workspace

build --platforms=//:linux_x86_64

# add a config for aarch64 cross compilation
build:aarch64 --platforms=//:linux_aarch64
query:aarch64 --platforms=//:linux_aarch64

15 changes: 15 additions & 0 deletions examples/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
platform(
name = "linux_x86_64",
constraint_values = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
],
)

platform(
name = "linux_aarch64",
constraint_values = [
"@platforms//os:linux",
"@platforms//cpu:aarch64",
],
)
81 changes: 79 additions & 2 deletions examples/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,98 @@ module(
compatibility_level = 1,
)

bazel_dep(name = "platforms", version = "0.0.11")

bazel_dep(name = "rules_cuda", version = "0.2.3")
local_path_override(
module_name = "rules_cuda",
path = "..",
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.redist_json(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extended the examples to test this. We will probably need something in tests/integration?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need test, something at least cross compile from x64 to aarch64

name = "rules_cuda_redist_json",
components = [
"cccl",
"cudart",
"nvcc",
"cublas",
],
archs = [
"x86_64",
"aarch64",
],
version = "12.8.1",
)
cuda.toolkit(
name = "cuda",
toolkit_path = "",
)
use_repo(cuda, "cuda")

#################################
# Dependencies for nccl example #
#################################
# See WORKSPACE.bzlmod for the remaining parts
bazel_dep(name = "bazel_skylib", version = "1.4.2")
bazel_dep(name = "bazel_skylib", version = "1.7.1")

#################################
# LLVM toolchain
#################################

bazel_dep(name = "toolchains_llvm", version = "1.5.0")

# Configure and register the toolchain.
llvm = use_extension("@toolchains_llvm//toolchain/extensions:llvm.bzl", "llvm")
llvm.toolchain(
name = "llvm_toolchain",
cxx_standard = {"": "c++20"},
llvm_version = "17.0.6",
stdlib = {
"": "dynamic-stdc++",
},
)

http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "org_chromium_sysroot_linux_x86_64",
build_file_content = """
filegroup(
name = "sysroot",
srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]),
visibility = ["//visibility:public"],
)
""",
sha256 = "36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50",
urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50"],
type = "tar.xz",
)

llvm.sysroot(
name = "llvm_toolchain",
label = "@org_chromium_sysroot_linux_x86_64//:sysroot",
targets = ["linux-x86_64"],
)

http_archive(
name = "org_chromium_sysroot_linux_aarch64",
build_file_content = """
filegroup(
name = "sysroot",
srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]),
visibility = ["//visibility:public"],
)
""",
sha256 = "2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18",
urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18"],
type = "tar.xz",
)

llvm.sysroot(
name = "llvm_toolchain",
label = "@org_chromium_sysroot_linux_aarch64//:sysroot",
targets = ["linux-aarch64"],
)
use_repo(llvm, "llvm_toolchain")

register_toolchains("@llvm_toolchain//:all")
5 changes: 4 additions & 1 deletion examples/cublas/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
cc_binary(
name = "main",
srcs = ["cublas.cpp"],
deps = ["@cuda//:cublas"],
deps = [
"@cuda//:cublas",
"@cuda//:cuda_runtime",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't compile without depending on cuda_runtime. This is a separate issue, though, it has nothing to do with cross compilation.

],
)
9 changes: 9 additions & 0 deletions examples/platforms.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
linux_x86_64 = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
]

linux_aarch64 = [
"@platforms//os:linux",
"@platforms//cpu:aarch64",
]