249 lines
7.9 KiB
Python
249 lines
7.9 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
MACOS_VERSION,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
|
|
|
|
|
|
try:
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
except unittest.SkipTest:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
TestCase = test_torchinductor.TestCase
|
|
check_model = test_torchinductor.check_model
|
|
check_model_gpu = test_torchinductor.check_model_gpu
|
|
skip_if_cpp_wrapper = test_torchinductor.skip_if_cpp_wrapper
|
|
copy_tests = test_torchinductor.copy_tests
|
|
define_custom_op_for_test = test_torchinductor.define_custom_op_for_test
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class CommonTemplate:
|
|
def test_unaligned_input(self):
|
|
def fn(x):
|
|
return torch.nn.functional.relu(x)
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)[1:-15]
|
|
# TODO (malfet): Investigate failures on MacOS-14
|
|
with (
|
|
contextlib.nullcontext()
|
|
if self.device != "mps" or MACOS_VERSION >= 15.0
|
|
else self.assertRaises(AssertionError)
|
|
):
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_unaligned_input_2d(self):
|
|
def fn(x):
|
|
return torch.nn.functional.relu(x)
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15]
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_alignment_without_custom_op(self):
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = (3 * a)[1:-15]
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_no_align_for_custom_op(self):
|
|
def slice1d(x):
|
|
return (3 * x)[1:-15]
|
|
|
|
def slice1d_meta(x):
|
|
return torch.empty_like(x)[1:-15]
|
|
|
|
define_custom_op_for_test("slice1d", slice1d, slice1d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice1d(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_no_align_for_custom_op_2d(self):
|
|
def slice2d(x):
|
|
return (3 * x)[..., 1:-15]
|
|
|
|
def slice2d_meta(x):
|
|
return torch.empty_like(x)[..., 1:-15]
|
|
|
|
define_custom_op_for_test("slice2d", slice2d, slice2d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice2d(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True, alignment_asserts=True)
|
|
@skip_if_cpp_wrapper(
|
|
"Inductor does not generate alignment assertion for cpp_wrapper right now"
|
|
)
|
|
def test_incorrect_meta_for_custom_op_2d(self):
|
|
def slice2d(x):
|
|
return (3 * x)[..., 1:-15]
|
|
|
|
def slice2d_meta(x):
|
|
return torch.empty_like(x)[..., 0:-16]
|
|
|
|
define_custom_op_for_test("slice2d_incorrect_meta", slice2d, slice2d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice2d_incorrect_meta(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)
|
|
|
|
expected_error = "Expect the tensor to be 16 bytes aligned. Fail due to storage_offset=1 itemsize=4"
|
|
with self.assertRaisesRegex(AssertionError, expected_error):
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_slice(self):
|
|
def f(x):
|
|
return x[1:] + 1
|
|
|
|
x = torch.randn(1025, device=self.device)
|
|
self.common(f, (x,))
|
|
|
|
def test_view_dtype_slice(self):
|
|
def f(x):
|
|
return x.view(dtype=torch.float32)[1:] + 1
|
|
|
|
x = torch.randn(1025 * 2, dtype=torch.bfloat16, device=self.device)
|
|
self.common(f, (x,), reference_in_float=False)
|
|
|
|
@parametrize(
|
|
"size",
|
|
(
|
|
# wrapper for size = 128: https://gist.github.com/shunting314/88f1e72957b9fc5e9826aaa346a0e652
|
|
# ptx: https://gist.github.com/shunting314/eb657ee8821eef9f0685b7b91e2ad5c2
|
|
# the ptx file uses ld.global.b32 to load input buffer
|
|
128,
|
|
# wrapper for size = 1024: https://gist.github.com/shunting314/d7f64e1f52f6b1e2ec25e1a51052ce43
|
|
# ptx: https://gist.github.com/shunting314/a24ff7563bb6b04523d11b119ab0f2b2
|
|
# the ptx file uses ld.global.v2.b32 to load input buffer
|
|
1024,
|
|
# wrapper for size = 1024 * 1024: https://gist.github.com/shunting314/016b95cf0b6e9a75c25f5c9d5ed0a2ba
|
|
# ptx: https://gist.github.com/shunting314/360112a4893c759b114c12fc99958297
|
|
# the ptx file uses ld.global.v4.b32 to load input buffer
|
|
1024 * 1024,
|
|
),
|
|
)
|
|
def test_slice_view_dtype(self, size):
|
|
offset = 1
|
|
|
|
def f(x):
|
|
return x[2:].view(dtype=torch.float32) + 1
|
|
|
|
x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
|
|
self.common(f, (x,), reference_in_float=False)
|
|
|
|
def test_Q4_K_dequantization(self):
|
|
"""
|
|
Test the alignment issue for Q4_K dequantization.
|
|
"""
|
|
|
|
QK_K = 256
|
|
K_SCALE_SIZE = 12
|
|
|
|
def get_scale_min(scales):
|
|
n_blocks = scales.shape[0]
|
|
scales = scales.view(torch.uint8)
|
|
scales = scales.reshape((n_blocks, 3, 4))
|
|
|
|
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
|
|
|
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
|
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
|
|
|
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
|
|
|
def split_block_dims(blocks, *args):
|
|
n_max = blocks.shape[1]
|
|
dims = list(args) + [n_max - sum(args)]
|
|
return torch.split(blocks, dims, dim=1)
|
|
|
|
def dequantize_blocks_Q4_K(blocks, block_size, type_size):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
|
d = d.view(torch.float16)
|
|
dmin = dmin.view(torch.float16)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
|
|
|
return (d * qs - dm).reshape((n_blocks, QK_K))
|
|
|
|
data = torch.randint(
|
|
0, 16, (18432, 1728), device=self.device, dtype=torch.uint8
|
|
)
|
|
|
|
def dequantize(data):
|
|
block_size, type_size = 256, 144
|
|
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = rows.reshape((n_blocks, type_size))
|
|
blocks = dequantize_blocks_Q4_K(blocks, block_size, type_size)
|
|
return blocks.reshape(18432, 3072)
|
|
|
|
self.common(dequantize, (data,), check_lowp=False, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
if RUN_CPU:
|
|
|
|
class CpuTests(TestCase):
|
|
common = check_model
|
|
device = "cpu"
|
|
|
|
copy_tests(CommonTemplate, CpuTests, "cpu")
|
|
|
|
if RUN_GPU:
|
|
|
|
class GPUTests(TestCase):
|
|
common = check_model_gpu
|
|
device = GPU_TYPE
|
|
|
|
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if RUN_CPU or RUN_GPU:
|
|
run_tests()
|