Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
#ifndef MSCCLPP_GPU_HPP_
#define MSCCLPP_GPU_HPP_

#if defined(__HIP_PLATFORM_AMD__)
#include <mscclpp/device.hpp>

#include <hip/hip_runtime.h>
#if defined(MSCCLPP_DEVICE_HIP)

using cudaError_t = hipError_t;
using cudaEvent_t = hipEvent_t;
Expand Down Expand Up @@ -62,6 +62,7 @@ constexpr auto CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = HIP_POINTER_ATTRIBUTE_DEVIC
#define CUDA_ERROR_DEINITIALIZED hipErrorDeinitialized
#define CUDA_ERROR_CONTEXT_IS_DESTROYED hipErrorContextIsDestroyed
#define CUDA_ERROR_LAUNCH_FAILED hipErrorLaunchFailure
#define CUDA_ERROR_NOT_SUPPORTED hipErrorNotSupported
#define CUDA_ERROR_INVALID_VALUE hipErrorInvalidValue

#define cudaEventCreate(...) hipEventCreate(__VA_ARGS__)
Expand Down Expand Up @@ -122,29 +123,29 @@ constexpr auto CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = HIP_POINTER_ATTRIBUTE_DEVIC
#define cuMemGetAllocationGranularity(...) hipMemGetAllocationGranularity(__VA_ARGS__)
#define cuPointerGetAttribute(...) hipPointerGetAttribute(__VA_ARGS__)

#else
#else // !defined(MSCCLPP_DEVICE_HIP)

#include <cuda.h>
#include <cuda_runtime.h>

#endif
#endif // !defined(MSCCLPP_DEVICE_HIP)

// NVLS
#if !defined(__HIP_PLATFORM_AMD__)
#if !defined(MSCCLPP_DEVICE_HIP)
#include <linux/version.h>
#if CUDART_VERSION < 12030
#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL)
#endif
// We need CUDA 12.3 above and kernel 5.6.0 above for NVLS API
#define CUDA_NVLS_API_AVAILABLE ((CUDART_VERSION >= 12030) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0)))
#else // defined(__HIP_PLATFORM_AMD__)
#else // defined(MSCCLPP_DEVICE_HIP)
#define CUDA_NVLS_API_AVAILABLE 0
// NVLS is not supported on AMD platform, just to avoid compilation error
#define CU_MEM_HANDLE_TYPE_FABRIC (0x8ULL)
#endif // !defined(__HIP_PLATFORM_AMD__)
#define CU_MEM_HANDLE_TYPE_FABRIC ((hipMemAllocationHandleType)0x8ULL)
#endif // defined(MSCCLPP_DEVICE_HIP)

// GPU sync threads
#if defined(__HIP_PLATFORM_AMD__)
#if defined(MSCCLPP_DEVICE_HIP)
#define __syncshm() asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier");
#else
#define __syncshm() __syncthreads();
Expand Down
6 changes: 0 additions & 6 deletions include/mscclpp/switch_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,13 @@ class NvlsConnection {
NvlsConnection() = delete;
std::vector<char> serialize();

// Everyone needs to synchronize after creating a NVLS connection before adding devices
void addDevice();
void addDevice(int cudaDeviceId);

/// Bind the memory allocated via mscclpp::GpuBuffer to the multicast handle. The behavior
/// is undefined if the devicePtr is not allocated by mscclpp::GpuBuffer.
/// @param devicePtr The device pointer returned by `mscclpp::GpuBuffer::data()`.
/// @param size The bytes of the memory to bind to the multicast handle.
/// @return SwitchChannel with devicePtr, mcPtr and bufferSize
SwitchChannel bindAllocatedMemory(CUdeviceptr devicePtr, size_t size);

size_t getMultiCastMinGranularity();

private:
class Impl;
std::shared_ptr<Impl> pimpl_;
Expand Down
3 changes: 1 addition & 2 deletions python/csrc/switch_channel_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ void register_nvls(nb::module_& m) {
});

nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"))
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"));

m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
nb::arg("buffer_size"));
Expand Down
Loading
Loading