diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index 4eb5211d..40c42905 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -57,6 +57,7 @@ cuda_redist_json_tag = tag_class(attrs = { doc = "Generate a URL by using the specified version." + "This URL will be tried after all URLs specified in the `urls` attribute.", ), + "archs": attr.string_list(doc = "Target architectures to support. If not specified, only x86_64 will be supported."), }) cuda_toolkit_tag = tag_class(attrs = { @@ -100,7 +101,8 @@ def _redist_json_impl(module_ctx, attr): mapping = {} for spec in component_specs: repo_name = redist_json_helper.get_repo_name(module_ctx, spec) - mapping[spec["component_name"]] = "@" + repo_name + component_name = spec["component_name"] + mapping[redist_json_helper.get_repo_mapping_key(component_name, spec["arch"])] = "@" + repo_name attr = {key: value for key, value in spec.items()} attr["name"] = repo_name diff --git a/cuda/private/redist_json_helper.bzl b/cuda/private/redist_json_helper.bzl index 29544bb8..119e3acc 100644 --- a/cuda/private/redist_json_helper.bzl +++ b/cuda/private/redist_json_helper.bzl @@ -77,27 +77,30 @@ def _collect_specs(ctx, attr, redist, the_url): elif _is_windows(ctx): os = "windows" - arch = "x86_64" # TODO: support cross compiling - platform = "{os}-{arch}".format(os = os, arch = arch) - components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist] - - for c in components: - c_full = FULL_COMPONENT_NAME[c] - - payload = redist[c_full][platform] - payload_relative_path = payload["relative_path"] - payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path - archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive" - desc_name = redist[c_full].get("name", c_full) - - specs.append({ - "component_name": c, - "descriptive_name": desc_name, - "urls": [payload_url], - "sha256": payload["sha256"], - "strip_prefix": archive_name, - "version": redist[c_full]["version"], - }) + archs = attr.archs or ["x86_64"] + + for arch in archs: + platform = "{os}-{arch}".format(os = os, arch = arch) + components = attr.components if attr.components else [k for k, v in FULL_COMPONENT_NAME.items() if v in redist] + + for c in components: + c_full = FULL_COMPONENT_NAME[c] + + payload = redist[c_full][platform] + payload_relative_path = payload["relative_path"] + payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path + archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive" + desc_name = redist[c_full].get("name", c_full) + + specs.append({ + "component_name": c, + "descriptive_name": desc_name, + "urls": [payload_url], + "sha256": payload["sha256"], + "strip_prefix": archive_name, + "version": redist[c_full]["version"], + "arch": arch, + }) return specs @@ -113,12 +116,25 @@ def _get_repo_name(ctx, spec): version = spec.get("version", None) if version != None: repo_name = repo_name + "_v" + version + arch = spec.get("arch", None) + if arch != None: + repo_name = repo_name + "_" + arch return repo_name +def _get_repo_mapping_key(spec): + """Generate a key (string) to identify a component/arch pair (e.g., "cublas__x86_64"). + + Args: + spec: cuda_component attrs + """ + + return "{}__{}".format(spec["component_name"], spec["arch"]) + redist_json_helper = struct( get = _get, get_redist_version = _get_redist_version, collect_specs = _collect_specs, get_repo_name = _get_repo_name, + get_repo_mapping_key = _get_repo_mapping_key, ) diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 935291cc..846472e0 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -13,6 +13,9 @@ def _is_linux(ctx): def _is_windows(ctx): return ctx.os.name.lower().startswith("windows") +def _is_aarch64(ctx): + return ctx.os.arch == "aarch64" + def _get_nvcc_version(repository_ctx, nvcc_root): result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"]) if result.return_code != 0: @@ -90,10 +93,16 @@ def _detect_deliverable_cuda_toolkit(repository_ctx): # NOTE: component nvcc contains some headers that will be used. required_components = ["cccl", "cudart", "nvcc"] for rc in required_components: - if rc not in repository_ctx.attr.components_mapping: - fail('component "{}" is required.'.format(rc)) + for arch in repository_ctx.attr.archs: + mapping_key = redist_json_helper.get_repo_mapping_key(rc, arch) + if mapping_key not in repository_ctx.attr.components_mapping: + fail('component "{}" for {} is required.'.format(rc, arch)) - nvcc_repo = repository_ctx.attr.components_mapping["nvcc"] + nvcc_repo = repository_ctx.attr.components_mapping.get("nvcc", None) + if not nvcc_repo: + host_arch = "aarch64" if _is_aarch64(repository_ctx) else "x86_64" + nvcc_mapping_key = redist_json_helper.get_repo_mapping_key("nvcc", host_arch) + nvcc_repo = repository_ctx.attr.components_mapping[nvcc_mapping_key] bin_ext = ".exe" if _is_windows(repository_ctx) else "" nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext) @@ -276,6 +285,7 @@ cuda_toolkit = repository_rule( "nvcc_version": attr.string( doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.", ), + "archs": attr.string_list(doc = "list of host target architectures to support (e.g., x86_64)"), }, configure = True, local = True, @@ -372,6 +382,7 @@ cuda_component = repository_rule( "If all downloads fail, the rule will fail.", ), "version": attr.string(doc = "A unique version number for component. Store in version.json file"), + "arch": attr.string(doc = "The supported target host architecture (e.g., x86_64, aarch64). If omitted, x86_64 is assumed."), }, ) @@ -425,7 +436,7 @@ def rules_cuda_dependencies(): ], ) -def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False): +def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False, archs = None): """Populate the @cuda repo. Args: @@ -442,6 +453,7 @@ def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, versio components_mapping = components_mapping, version = version, nvcc_version = nvcc_version, + archs = archs, ) if register_toolchains: diff --git a/cuda/private/rules/cuda_library.bzl b/cuda/private/rules/cuda_library.bzl index b0f4044d..27d61c9d 100644 --- a/cuda/private/rules/cuda_library.bzl +++ b/cuda/private/rules/cuda_library.bzl @@ -144,7 +144,7 @@ def _cuda_library_impl(ctx): ), ] -cuda_library = rule( +_cuda_library = rule( doc = """This rule compiles and creates static library for CUDA kernel code. The resulting targets can then be consumed by [C/C++ Rules](https://bazel.build/reference/be/c-cpp#rules).""", implementation = _cuda_library_impl, @@ -179,3 +179,20 @@ cuda_library = rule( toolchains = use_cpp_toolchain() + use_cuda_toolchain(), provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo], ) + +def cuda_library(name, **kwargs): + copts = kwargs.get("copts", []) + select({ + "@platforms//cpu:x86_64": [ + "-Xcompiler", + "--target=x86_64-unknown-linux-gnu", + ], + "@platforms//cpu:aarch64": [ + "-Xcompiler", + "--target=aarch64-unknown-linux-gnu", + ], + }) + _cuda_library( + name = name, + copts = copts, + **kwargs + ) diff --git a/cuda/private/template_helper.bzl b/cuda/private/template_helper.bzl index ee024f24..392e0ae4 100644 --- a/cuda/private/template_helper.bzl +++ b/cuda/private/template_helper.bzl @@ -7,8 +7,11 @@ def _to_forward_slash(s): def _is_linux(ctx): return ctx.os.name.startswith("linux") -def _is_windows(ctx): - return ctx.os.name.lower().startswith("windows") +def _is_x86_64(ctx): + return ctx.os.arch == "amd64" + +def _is_aarch64(ctx): + return ctx.os.arch == "aarch64" def _expand_template(repository_ctx, tpl_label, substitutions): template_content = "# Generated from fragment " + str(tpl_label) + "\n" @@ -80,11 +83,28 @@ def _generate_build_impl(repository_ctx, libpath, components, is_cuda_repo, is_d template_content.append(repository_ctx.read(frag)) if is_cuda_repo and is_deliverable: # generate `@cuda//BUILD` for CTK with deliverables - for comp in components: + component_names = set([name.split("__")[0] for name in components]) + archs = set([name.split("__")[1] for name in components]) + + for comp in component_names: + repo_x86_64 = components[comp + "__x86_64"] + repo_aarch64 = components[comp + "__aarch64"] for target in REGISTRY[comp]: - repo = components[comp] - line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo) - template_content.append(line) + if comp == "nvcc": + repo_host_architecture = repo_aarch64 if _is_aarch64(repository_ctx) else repo_x86_64 + alias_line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo_host_architecture) + else: + alias_line = """ +alias( + name = "{target}", + actual = select({{ + "@platforms//cpu:x86_64": "{repo_x86_64}//:{target}", + "@platforms//cpu:aarch64": "{repo_aarch64}//:{target}", + }}) +) +""".format(target = target, repo_x86_64 = repo_x86_64, repo_aarch64 = repo_aarch64) + + template_content.append(alias_line) # add an empty line to separate aliased targets from different components template_content.append("") diff --git a/examples/.bazelrc b/examples/.bazelrc index 7ebc8e4a..d775a642 100644 --- a/examples/.bazelrc +++ b/examples/.bazelrc @@ -8,6 +8,9 @@ build --flag_alias=cuda_copts=@rules_cuda//cuda:copts build --flag_alias=cuda_runtime=@rules_cuda//cuda:runtime build --enable_cuda=True +build --cuda_archs=sm_86,sm_89,sm_120 + +build:aarch64 --cuda_archs=sm_87 # Use --config=clang to build with clang instead of gcc and nvcc. build:clang --repo_env=CC=clang @@ -15,3 +18,13 @@ build:clang --@rules_cuda//cuda:compiler=clang # https://github.com/bazel-contrib/rules_cuda/issues/1 # build --ui_event_filters=-INFO + +common --enable_bzlmod +common --noenable_workspace + +build --platforms=//:linux_x86_64 + +# add a config for aarch64 cross compilation +build:aarch64 --platforms=//:linux_aarch64 +query:aarch64 --platforms=//:linux_aarch64 + diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel new file mode 100644 index 00000000..47ea8334 --- /dev/null +++ b/examples/BUILD.bazel @@ -0,0 +1,15 @@ +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index 2a32ebb8..b8dd6ce4 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -4,6 +4,8 @@ module( compatibility_level = 1, ) +bazel_dep(name = "platforms", version = "0.0.11") + bazel_dep(name = "rules_cuda", version = "0.2.3") local_path_override( module_name = "rules_cuda", @@ -11,9 +13,22 @@ local_path_override( ) cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") +cuda.redist_json( + name = "rules_cuda_redist_json", + components = [ + "cccl", + "cudart", + "nvcc", + "cublas", + ], + archs = [ + "x86_64", + "aarch64", + ], + version = "12.8.1", +) cuda.toolkit( name = "cuda", - toolkit_path = "", ) use_repo(cuda, "cuda") @@ -21,4 +36,66 @@ use_repo(cuda, "cuda") # Dependencies for nccl example # ################################# # See WORKSPACE.bzlmod for the remaining parts -bazel_dep(name = "bazel_skylib", version = "1.4.2") +bazel_dep(name = "bazel_skylib", version = "1.7.1") + +################################# +# LLVM toolchain +################################# + +bazel_dep(name = "toolchains_llvm", version = "1.5.0") + +# Configure and register the toolchain. +llvm = use_extension("@toolchains_llvm//toolchain/extensions:llvm.bzl", "llvm") +llvm.toolchain( + name = "llvm_toolchain", + cxx_standard = {"": "c++20"}, + llvm_version = "17.0.6", + stdlib = { + "": "dynamic-stdc++", + }, +) + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "org_chromium_sysroot_linux_x86_64", + build_file_content = """ +filegroup( + name = "sysroot", + srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]), + visibility = ["//visibility:public"], +) +""", + sha256 = "36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50", + urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/36a164623d03f525e3dfb783a5e9b8a00e98e1ddd2b5cff4e449bd016dd27e50"], + type = "tar.xz", +) + +llvm.sysroot( + name = "llvm_toolchain", + label = "@org_chromium_sysroot_linux_x86_64//:sysroot", + targets = ["linux-x86_64"], +) + +http_archive( + name = "org_chromium_sysroot_linux_aarch64", + build_file_content = """ +filegroup( + name = "sysroot", + srcs = glob(["*/**"], exclude=["lib/systemd/*/**"]), + visibility = ["//visibility:public"], +) +""", + sha256 = "2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18", + urls = ["https://commondatastorage.googleapis.com/chrome-linux-sysroot/2f915d821eec27515c0c6d21b69898e23762908d8d7ccc1aa2a8f5f25e8b7e18"], + type = "tar.xz", +) + +llvm.sysroot( + name = "llvm_toolchain", + label = "@org_chromium_sysroot_linux_aarch64//:sysroot", + targets = ["linux-aarch64"], +) +use_repo(llvm, "llvm_toolchain") + +register_toolchains("@llvm_toolchain//:all") diff --git a/examples/cublas/BUILD.bazel b/examples/cublas/BUILD.bazel index b120d273..9232f74e 100644 --- a/examples/cublas/BUILD.bazel +++ b/examples/cublas/BUILD.bazel @@ -2,5 +2,8 @@ cc_binary( name = "main", srcs = ["cublas.cpp"], - deps = ["@cuda//:cublas"], + deps = [ + "@cuda//:cublas", + "@cuda//:cuda_runtime", + ], ) diff --git a/examples/platforms.bzl b/examples/platforms.bzl new file mode 100644 index 00000000..df878b04 --- /dev/null +++ b/examples/platforms.bzl @@ -0,0 +1,9 @@ +linux_x86_64 = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", +] + +linux_aarch64 = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", +]