26 lines
779 B
Python
26 lines
779 B
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
|
|
|
|
|
|
def test_green_ctx():
|
|
A = torch.randn(5120, 5120).cuda()
|
|
B = torch.randn(5120, 5120).cuda()
|
|
C = torch.matmul(A, B)
|
|
sm_counts = get_sm_available(0)
|
|
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
|
|
with torch.cuda.stream(stream_group[0]):
|
|
for _ in range(100):
|
|
result_0 = torch.matmul(A, B)
|
|
with torch.cuda.stream(stream_group[1]):
|
|
for _ in range(100):
|
|
result_1 = torch.matmul(A, B)
|
|
torch.cuda.synchronize()
|
|
assert torch.allclose(result_0, C)
|
|
assert torch.allclose(result_1, C)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|