sglang_v0.5.2/flashinfer_0.3.1/flashinfer/green_ctx.py

262 lines
9.9 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List, Tuple
import torch
try:
import cuda.bindings.driver as driver
import cuda.bindings.runtime as runtime
from cuda.bindings.driver import CUdevice, CUdevResource
except ImportError as e:
raise ImportError(
"Could not import the 'cuda' module. "
"Please install cuda-python that matches your CUDA version."
) from e
from .cuda_utils import checkCudaErrors
from .utils import get_compute_capability, round_up
def get_sm_count_constraint(major: int, minor: int) -> Tuple[int, int]:
if major == 6:
return (1, 1)
elif major == 7:
return (2, 2)
elif major == 8:
return (4, 2)
elif major >= 9:
return (8, 8)
else:
raise ValueError(f"Unsupported CUDA capability: {major}.{minor}")
def get_cudevice(dev: torch.device) -> CUdevice:
try:
cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))
except RuntimeError:
runtime.cudaInitDevice(dev.index, 0, 0)
cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))
return cu_dev
def get_device_resource(cu_dev: CUdevice) -> CUdevResource:
return checkCudaErrors(
driver.cuDeviceGetDevResource(
cu_dev, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM
)
)
def split_resource(
resource: CUdevResource,
num_groups: int,
min_count: int,
) -> Tuple[CUdevResource, CUdevResource]:
results, _, remaining = checkCudaErrors(
driver.cuDevSmResourceSplitByCount(
num_groups,
resource,
0, # useFlags
min_count,
)
)
return results, remaining
def split_resource_by_sm_count(
cu_dev: CUdevice, resource: CUdevResource, sm_counts: List[int]
) -> Tuple[List[CUdevResource], CUdevResource]:
results = []
for sm_count in sm_counts:
result, remaining = split_resource(resource, 1, sm_count)
results.extend(result)
# Refresh the remaining resource for the next iteration
desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([remaining], 1))
green_ctx = checkCudaErrors(
driver.cuGreenCtxCreate(
desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
)
)
resource = checkCudaErrors(
driver.cuGreenCtxGetDevResource(
green_ctx, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM
)
)
return results, resource
def create_green_ctx_streams(
cu_dev: CUdevResource, resources: List[CUdevResource]
) -> List[torch.Stream]:
streams = []
for split in resources:
desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([split], 1))
green_ctx = checkCudaErrors(
driver.cuGreenCtxCreate(
desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
)
)
stream = checkCudaErrors(
driver.cuGreenCtxStreamCreate(
green_ctx,
driver.CUstream_flags.CU_STREAM_NON_BLOCKING,
0, # priority
)
)
streams.append(torch.cuda.get_stream_from_external(stream))
return streams
def split_device_green_ctx(
dev: torch.device, num_groups: int, min_count: int
) -> Tuple[List[torch.Stream], List[CUdevResource]]:
r"""
Split the device into multiple `green contexts <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html>`_,
return the corresponding streams and `CUdevResource` for each group and the remaining SMs.
Green contexts allow concurrent execution of multiple kernels on different SM partitions.
Args:
dev: The device to split.
num_groups: The number of groups to split the device into.
min_count: Minimum number of SMs required for each group, it will be adjusted to meet the
alignment and granularity requirements.
Returns:
streams: The list of torch.Streams objects corresponding to the green contexts.
resources: The list of CUdevResource objects corresponding to the green contexts.
Example:
>>> from flashinfer.green_ctx import split_device_green_ctx
>>> import torch
>>> dev = torch.device("cuda:0")
>>> streams, resources = split_device_green_ctx(dev, 2, 16)
>>> print([r.sm.smCount for r in resources])
[16, 16, 100]
>>> with torch.cuda.stream(streams[0]):
... x = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
... y = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
... z = x @ y
... print(z.shape)
...
torch.Size([8192, 8192])
Note:
The length of the returned streams and resources is ``num_groups + 1``,
where the last one is the remaining SMs.
The following examples show how the SM count is rounded up to meet the alignment and granularity requirements:
- Requested 7 SMs → Allocated 8 SMs (rounded up to minimum)
- Requested 10 SMs → Allocated 16 SMs (rounded up to multiple of 8)
- Requested 16 SMs → Allocated 16 SMs (no rounding needed)
- Requested 17 SMs → Allocated 24 SMs (rounded up to multiple of 8)
Raises:
RuntimeError: when requested SM allocation exceeds device capacity:
``num_groups * rounded_min_count > total_device_sms``
"""
cu_dev = get_cudevice(dev)
resource = get_device_resource(cu_dev)
results, remaining = split_resource(resource, num_groups, min_count)
resources = results + [remaining]
streams = create_green_ctx_streams(cu_dev, resources)
return streams, resources
def split_device_green_ctx_by_sm_count(
dev: torch.device, sm_counts: List[int]
) -> Tuple[List[torch.Stream], List[CUdevResource]]:
r"""
Split the device into multiple green contexts, each with a fixed number of SMs,
return the corresponding streams and `CUdevResource` for each group and the remaining SMs.
Green contexts allow concurrent execution of multiple kernels on different SM partitions.
Args:
dev: The device to split.
sm_counts: List of SM counts for each partition. Each count will be rounded up
to meet the minimum and alignment requirements.
Returns:
streams: The list of torch.Streams objects corresponding to the green contexts.
resources: The list of CUdevResource objects corresponding to the green contexts.
Raises:
RuntimeError: If the requested SM allocation exceeds device capacity:
- When sum(rounded_sm_counts) > total_device_sms
- When CUDA operations fail due to invalid resource types
- When the device is not properly initialized
ValueError: If sm_counts is empty or contains invalid values (e.g., negative values).
Example:
>>> from flashinfer.green_ctx import split_device_green_ctx_by_sm_count
>>> import torch
>>> dev = torch.device("cuda:0")
>>>
>>> # Create three partitions with specific SM counts
>>> streams, resources = split_device_green_ctx_by_sm_count(dev, [8, 16, 24])
>>> print([r.sm.smCount for r in resources])
[8, 16, 24, 84] # Last value is remaining SMs
>>>
>>> # Execute kernels on different partitions
>>> with torch.cuda.stream(streams[0]):
... x = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16)
... y = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16)
... z = x @ y
... print(f"Partition 0 result: {z.shape}")
...
>>> with torch.cuda.stream(streams[1]):
... # Different computation on partition 1
... a = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16)
... b = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16)
... c = a @ b
... print(f"Partition 1 result: {c.shape}")
Note:
The length of the returned streams and resources is ``len(sm_counts) + 1``,
where the last one contains the remaining SMs that were not allocated.
SM count alignment examples for Compute Capability 9.0+:
- Requested 7 SMs → Allocated 8 SMs (rounded up to minimum)
- Requested 10 SMs → Allocated 16 SMs (rounded up to multiple of 8)
- Requested 16 SMs → Allocated 16 SMs (no rounding needed)
- Requested 17 SMs → Allocated 24 SMs (rounded up to multiple of 8)
The actual SM count can be obtained from the ``.sm.smCount`` field of the returned resources.
See `CUDA Green Contexts <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html>`_
for more details.
"""
cu_dev = get_cudevice(dev)
resource = get_device_resource(cu_dev)
# Round sm counts to meet the alignment and granularity requirements
rounded_sm_counts = []
for sm_count in sm_counts:
min_sm_count, sm_alignment = get_sm_count_constraint(
*get_compute_capability(dev)
)
if sm_count <= 0:
raise ValueError(f"SM count must be positive, got {sm_count}")
rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment))
# Split the device into multiple green contexts
results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts)
resources = results + [remaining]
streams = create_green_ctx_streams(cu_dev, resources)
return streams, resources