import pytest import torch import torch.nn.functional as F from sgl_kernel import create_greenctx_stream_by_value, get_sm_available def test_green_ctx(): A = torch.randn(5120, 5120).cuda() B = torch.randn(5120, 5120).cuda() C = torch.matmul(A, B) sm_counts = get_sm_available(0) stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0) with torch.cuda.stream(stream_group[0]): for _ in range(100): result_0 = torch.matmul(A, B) with torch.cuda.stream(stream_group[1]): for _ in range(100): result_1 = torch.matmul(A, B) torch.cuda.synchronize() assert torch.allclose(result_0, C) assert torch.allclose(result_1, C) if __name__ == "__main__": pytest.main([__file__])