# Owner(s): ["module: intel"] import contextlib import functools import inspect import itertools import math import random from functools import partial from itertools import product import numpy as np import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, precisionOverride, ) from torch.testing._internal.common_utils import ( iter_indices, parametrize, run_tests, TestCase, ) @contextlib.contextmanager def tf32_off(): enabled = torch.backends.mkldnn.enabled deterministic = torch.backends.mkldnn.deterministic with torch.backends.mkldnn.flags( enabled=enabled, deterministic=deterministic, allow_tf32=False ): yield @contextlib.contextmanager def tf32_on(self, tf32_precision=1e-5): enabled = torch.backends.mkldnn.enabled deterministic = torch.backends.mkldnn.deterministic old_precision = self.precision try: self.precision = tf32_precision with torch.backends.mkldnn.flags( enabled=enabled, deterministic=deterministic, allow_tf32=True ): yield finally: self.precision = old_precision # This is a wrapper that wraps a test to run this test twice, one with # allow_tf32=True, another with allow_tf32=False. When running with # allow_tf32=True, it will use reduced precision as specified by the # argument. For example: # @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) # @tf32_on_and_off(0.005) # def test_matmul(self, device, dtype): # a = ...; b = ...; # c = torch.matmul(a, b) # self.assertEqual(c, expected) # In the above example, when testing torch.float32 , the matmul will be running at # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced # precision to check values. # # This decorator can be used for function with or without device/dtype, such as # @tf32_on_and_off(0.005) # def test_my_op(self) # @tf32_on_and_off(0.005) # def test_my_op(self, device) # @tf32_on_and_off(0.005) # def test_my_op(self, device, dtype) # @tf32_on_and_off(0.005) # def test_my_op(self, dtype) def tf32_on_and_off(tf32_precision=1e-5): def with_tf32_disabled(self, function_call): with tf32_off(): function_call() def with_tf32_enabled(self, function_call): with tf32_on(self, tf32_precision): function_call() def wrapper(f): params = inspect.signature(f).parameters arg_names = tuple(params.keys()) @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) cond = True if "device" in kwargs: cond = cond and (torch.device(kwargs["device"]).type == "xpu") if "dtype" in kwargs: cond = cond and ( kwargs["dtype"] in {torch.float32} ) # TODO: add complex64 if cond: with_tf32_disabled(kwargs["self"], lambda: f(**kwargs)) with_tf32_enabled(kwargs["self"], lambda: f(**kwargs)) else: f(**kwargs) return wrapped return wrapper # This is a wrapper that wraps a test to run it with TF32 turned off. # This wrapper is designed to be used when a test uses matmul or convolutions # but the purpose of that test is not testing matmul or convolutions. # Disabling TF32 will enforce torch.float tensors to be always computed # at full precision. def with_tf32_off(f): @functools.wraps(f) def wrapped(*args, **kwargs): with tf32_off(): return f(*args, **kwargs) return wrapped class TestBasicGEMM(TestCase): def _test_addmm_addmv( self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None ): dtype = t.dtype numpy_dtype = dtype if dtype in {torch.bfloat16, torch.half}: numpy_dtype = torch.float if dtype.is_complex: alpha = 0.9 + 0.3j if alpha is None else alpha beta = 0.5 + 0.6j if beta is None else beta else: alpha = 1.2 if alpha is None else alpha beta = 0.8 if beta is None else beta if activation == "gelu": res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) else: res1 = f(t, m, v, alpha=alpha, beta=beta) res2 = torch.full_like(res1, math.nan) if transpose_out: res2 = res2.t().clone(memory_format=torch.contiguous_format).t() if activation == "gelu": f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) else: f(t, m, v, alpha=alpha, beta=beta, out=res2) m.to(numpy_dtype).cpu().numpy() v.to(numpy_dtype).cpu().numpy() res3 = alpha * ( m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy() ) if beta != 0: res3 += (beta * t).to(numpy_dtype).cpu().numpy() if activation == "relu": res3 = res3 * (res3 > 0) elif activation == "gelu": res3_t = torch.from_numpy(res3).to(dtype) approximate = "tanh" if t.is_cuda else "none" res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) res3 = res3_t.to(numpy_dtype).cpu().numpy() else: assert activation is None, f"unsupported activation {activation}" res3 = torch.from_numpy(res3).to(dtype) self.assertEqual(res1, res2) self.assertEqual(res1, res3) def _test_addmm_impl(self, func, activation, device, dtype): M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) self._test_addmm_addmv(func, M, m1, m2, activation=activation) # vector-shaped bias and beta=1 result in epilogue fusion in CUDA V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device) self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) # Test 0-strided M = ( torch.randn(10, 1, device="cpu", dtype=torch.float32) .to(dtype) .expand(10, 25) .to(device) ) m1 = ( torch.randn(10, 1, device="cpu", dtype=torch.float32) .to(dtype) .expand(10, 50) .to(device) ) m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) self._test_addmm_addmv(func, M, m1, m2, activation=activation) # Test beta=0, M=nan M = ( torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32) .to(dtype) .to(device) ) m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) # Test transpose for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): def maybe_transpose(cond, m): if not cond: return m return m.t().clone(memory_format=torch.contiguous_format).t() M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) self._test_addmm_addmv( func, M, m1, m2, transpose_out=t4, activation=activation ) if t1: # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) self._test_addmm_addmv( func, V, m1, m2, beta=1, transpose_out=t4, activation=activation, ) @precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1}) @dtypes(torch.float32, torch.half, torch.double) @tf32_on_and_off(0.05) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) @precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1}) @dtypes(torch.float, torch.half, torch.double) def test_addmm_badmm_scalar_tnesor_input(self, device, dtype): input = torch.tensor(1).to(device=device, dtype=dtype) # test addmm mat1 = torch.randn(10, 25, device=device).to(dtype) mat2 = torch.randn(25, 10, device=device).to(dtype) result = torch.addmm(input, mat1, mat2) ref = mat1.cpu().numpy() @ mat2.cpu().numpy() + 1 self.assertEqual(result, ref) # test baddbmm mat1 = torch.randn(3, 10, 25, device=device).to(dtype) mat2 = torch.randn(3, 25, 10, device=device).to(dtype) result = torch.baddbmm(input, mat1, mat2) ref = mat1.cpu().numpy() @ mat2.cpu().numpy() + 1 self.assertEqual(result, ref) @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4}) @dtypes(torch.bfloat16, torch.half, torch.float, torch.double) @tf32_on_and_off(0.005) def test_addmv(self, device, dtype): # have to use torch.randn(...).to(bfloat16) instead of # torch.randn(..., dtype=bfloat16). randn does not support # bfloat16 yet. # "*0.2" to reduce errors for low precision ts = [ 0.2 * torch.randn(50, device=device).to(dtype), 0.2 * torch.randn(1, device=device).to(dtype).expand(50), ] vs = [ 0.2 * torch.randn(100, device=device).to(dtype), 0.2 * torch.ones(1, device=device) .to(dtype) .expand(100), # to reduce errors for low precision ] ms = [ # 0d 0.2 * torch.ones((), device=device) .to(dtype) .expand(50, 100), # to reduce errors for low precision # 1d 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), # this initialization reduces errors for low precision for broadcasted matrices # by making sure that intermediate and result values are exactly representable # in low precision type 0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device) .to(dtype) .expand(50, 100), # 2d 0.2 * torch.randn((50, 100), device=device).to(dtype), 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), ] for m, v, t in itertools.product(ms, vs, ts): self._test_addmm_addmv(torch.addmv, t, m, v) # Test beta=0, t=nan t = torch.full((50,), math.nan, device=device).to(dtype) for m, v in itertools.product(ms, vs): self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) @dtypes( torch.half, torch.float32, torch.float64, ) @tf32_on_and_off(0.05) def test_mm(self, device, dtype): def _test_mm(n, m, p, dtype, genf): # helper function def matrixmultiply(mat1, mat2): n = mat1.size(0) m = mat1.size(1) p = mat2.size(1) dtype_ = torch.float if dtype == torch.half else dtype if dtype == torch.half: mat1 = mat1.float() mat2 = mat2.float() res = torch.zeros(n, p, dtype=dtype_, device=device) for i, j in iter_indices(res): res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) return res.half() if dtype == torch.half else res # contiguous case mat1 = genf(n, m) mat2 = genf(m, p) res = torch.mm(mat1, mat2) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # non contiguous case 1 mat1 = genf(n, m) mat2 = genf(p, m).t() res = torch.mm(mat1, mat2) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # non contiguous case 2 mat1 = genf(m, n).t() mat2 = genf(m, p) res = torch.mm(mat1, mat2) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # non contiguous case 3 mat1 = genf(m, n).t() mat2 = genf(p, m).t() res = torch.mm(mat1, mat2) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # test with zero stride mat1 = genf(n, m) mat2 = genf(m, 1).expand(m, p) res = torch.mm(mat1, mat2) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # explicitly exercise the _out variant in torch.mm(). # contiguous case mat1 = genf(n, m) mat2 = genf(m, p) res = genf(n, p) torch.mm(mat1, mat2, out=res) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) # explicitly exercise the _out variant in torch.mm(). # non contiguous case 3 mat1 = genf(m, n).t() mat2 = genf(p, m).t() res = genf(n, p) torch.mm(mat1, mat2, out=res) res2 = matrixmultiply(mat1, mat2) self.assertEqual(res, res2) def genf_int(x, y): return torch.randint(0, 100, (x, y), dtype=dtype, device=device) def genf_bfloat(x, y): return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 def genf_float(x, y): return torch.randn(x, y, dtype=dtype, device=device) def genf_Half(x, y): return torch.randn(x, y, dtype=dtype, device=device) for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: if (dtype == torch.int32) or (dtype == torch.int64): genf = genf_int elif dtype == torch.bfloat16: genf = genf_bfloat elif dtype == torch.half: genf = genf_Half else: genf = genf_float _test_mm(n, m, p, dtype, genf) @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) @dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64) @tf32_on_and_off(0.05) def test_bmm(self, device, dtype): batch_sizes = [1, 10] M, N, O = 23, 15, 12 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 def invert_perm(p): d = {x: i for i, x in enumerate(p)} return (d[0], d[1], d[2]) def generate_inputs(num_batches): # transposed tensors for perm1, perm2 in itertools.product( itertools.permutations((0, 1, 2)), repeat=2 ): b1 = make_tensor( (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1 ) b2 = make_tensor( (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1 ) b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) yield b1, b2 # broadcasting tensors for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) b1 = make_tensor( shape1, dtype=dtype, device=device, low=-0.1, high=0.1 ).expand(num_batches, M, N) b2 = make_tensor( shape2, dtype=dtype, device=device, low=-0.1, high=0.1 ).expand(num_batches, N, O) yield b1, b2 # zero-sized tensors for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) b1 = torch.randn(shape1, dtype=dtype, device=device) b2 = torch.randn(shape2, dtype=dtype, device=device) yield b1, b2 for num_batches in batch_sizes: for (b1, b2), perm3 in itertools.product( generate_inputs(num_batches), itertools.permutations((0, 1, 2)) ): res1 = torch.bmm(b1, b2) res2 = ( torch.full( (num_batches, M, O), math.nan, dtype=dtype, device=device ) .permute(perm3) .contiguous() .permute(invert_perm(perm3)) ) torch.bmm(b1, b2, out=res2) expect = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype) self.assertEqual(expect, res1) self.assertEqual(expect, res2) if self.device_type == "cuda": # check that mixed arguments are rejected self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) self.assertRaises( RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()) ) def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): getattr(out_tensor, func + "_")(b1, b2) self.assertEqual(out_tensor, ref) res3 = out_tensor.clone() with self.assertWarnsOnceRegex( UserWarning, f"This overload of {func}_ is deprecated" ): getattr(out_tensor, func + "_")(1, b1, b2) self.assertEqual(out_tensor, ref * 2) getattr(res3, func + "_")(b1, b2, beta=1) self.assertEqual(out_tensor, res3) with self.assertWarnsOnceRegex( UserWarning, f"This overload of {func}_ is deprecated" ): getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2) self.assertEqual(out_tensor, ref * 2.5) getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5) self.assertEqual(out_tensor, res3) with self.assertWarnsOnceRegex( UserWarning, f"This overload of {func} is deprecated" ): self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5) self.assertEqual(res4, ref * 3) nan = torch.full_like(out_tensor, math.nan) res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) self.assertEqual(res5, ref) if b1.is_complex(): res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j) self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref) else: res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5) self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref) res7 = torch.full_like(out_tensor, math.nan) getattr(torch, func)(nan, b1, b2, beta=0, out=res7) self.assertEqual(res7, ref) @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half) @tf32_on_and_off(0.005) def test_addbmm(self, device, dtype): num_batches = 2 M, N, O = 16, 17, 18 def invert_perm(p): d = {x: i for i, x in enumerate(p)} return (d[0], d[1], d[2]) def generate_tensor(): numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 # transposed tensors for perm1, perm2 in itertools.product( itertools.permutations((0, 1, 2)), repeat=2 ): for perm3 in itertools.permutations((0, 1)): b1 = ( make_tensor( (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1, ) * 0.1 ) b2 = ( make_tensor( (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1, ) * 0.1 ) b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) ref = ( torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ) .to(device=device, dtype=dtype) .sum(0) ) out_tensor = ( torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) ) yield b1, b2, ref, out_tensor # broadcasting tensors for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) b1 = ( make_tensor( shape1, dtype=dtype, device=device, low=-1, high=1 ).expand(num_batches, M, N) * 0.1 ) b2 = ( make_tensor( shape2, dtype=dtype, device=device, low=-1, high=1 ).expand(num_batches, N, O) * 0.1 ) ref = ( torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ) .to(device=device, dtype=dtype) .sum(0) ) out_tensor = torch.zeros_like(ref) yield b1, b2, ref, out_tensor # zero-sized tensors for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) b1 = ( make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1 ) b2 = ( make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1 ) ref = ( torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ) .to(device=device, dtype=dtype) .sum(0) ) out_tensor = torch.zeros_like(ref) yield b1, b2, ref, out_tensor for b1, b2, ref, out_tensor in generate_tensor(): self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6}) @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half) @tf32_on_and_off(0.01) def test_baddbmm(self, device, dtype): num_batches = 10 M, N, O = 12, 8, 50 def invert_perm(p): d = {x: i for i, x in enumerate(p)} return (d[0], d[1], d[2]) def generate_tensor(): numpy_dtype = ( dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 ) # transposed tensors for perm1, perm2, perm3 in itertools.product( itertools.permutations((0, 1, 2)), repeat=3 ): b1 = make_tensor( (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 ) b2 = make_tensor( (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 ) b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype) out_tensor = torch.zeros_like(ref) out_tensor = ( out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) ) yield b1, b2, ref, out_tensor # broadcasting tensors for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) b1 = make_tensor( shape1, dtype=dtype, device=device, low=-1, high=1 ).expand(num_batches, M, N) b2 = make_tensor( shape2, dtype=dtype, device=device, low=-1, high=1 ).expand(num_batches, N, O) ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype) out_tensor = torch.zeros_like(ref) yield b1, b2, ref, out_tensor # zero-sized tensors for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype) out_tensor = torch.zeros_like(ref) yield b1, b2, ref, out_tensor for b1, b2, ref, out_tensor in generate_tensor(): self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) @tf32_on_and_off(0.05) def test_tensordot(self, device): a = torch.arange(60.0, device=device).reshape(3, 4, 5) b = torch.arange(24.0, device=device).reshape(4, 3, 2) c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() cn = torch.from_numpy( np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1])) ) self.assertEqual(c, cn) cout = torch.zeros((5, 2), device=device) torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() self.assertEqual(c, cout) a = torch.randn(2, 3, 4, 5, device=device) b = torch.randn(4, 5, 6, 7, device=device) c = torch.tensordot(a, b, dims=2).cpu() cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): torch.tensordot(a, b, dims=-1) self.assertEqual(c, cn) c = torch.tensordot(a, b).cpu() cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) self.assertEqual(c, cn) a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0) an = torch.from_numpy( np.tensordot( np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0 ) ) self.assertEqual(a, an) @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 1e-4}) @tf32_on_and_off(0.005) def test_1_sized_with_0_strided(self, device, dtype): a = make_tensor((8, 1, 64), dtype=dtype, device=device) a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) b = make_tensor((8, 64, 512), dtype=dtype, device=device) b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) res = torch.bmm(a_strided, b_strided) expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to( device=device, dtype=dtype ) self.assertEqual(expect, res) def _select_broadcastable_dims(self, dims_full=None): # select full dimensionality if dims_full is None: dims_full = [] ndims = random.randint(1, 4) dims_full = [random.randint(1, 8) for _ in range(ndims)] else: ndims = len(dims_full) # select actual dimensions for ops: # larger: full ndims, individual sizes may be reduced # smaller: possibly reduced ndims, sizes may be reduced smaller_ndims = random.randint(1, ndims) dims_small = [] dims_large = [] for i in range(ndims - 1, -1, -1): j = random.randint(1, 3) if j == 1: # no reduced singleton dimension ds = dims_full[i] dl = dims_full[i] elif j == 2: # larger may have reduced singleton dimension ds = dims_full[i] dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] elif j == 3: # smaller may have reduced singleton dimension ds = 1 dl = dims_full[i] dims_large = [dl] + dims_large if len(dims_small) < smaller_ndims: dims_small = [ds] + dims_small return (dims_small, dims_large, dims_full) @tf32_on_and_off(0.005) def test_broadcast_fused_matmul(self, device): fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] for fn in fns: batch_dim = random.randint(1, 8) n_dim = random.randint(1, 8) m_dim = random.randint(1, 8) p_dim = random.randint(1, 8) def dims_full_for_fn(): if fn == "baddbmm": return ( [batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim], ) elif fn == "addbmm": return ( [n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim], ) elif fn == "addmm": return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) elif fn == "addmv": return ([n_dim], [n_dim, m_dim], [m_dim]) elif fn == "addr": return ([n_dim, m_dim], [n_dim], [m_dim]) else: raise AssertionError("unknown function") (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) t0_small = torch.randn(*t0_dims_small, device=device).float() t1 = torch.randn(*t1_dims, device=device).float() t2 = torch.randn(*t2_dims, device=device).float() t0_full = t0_small.expand(*t0_dims_full).to(device) fntorch = getattr(torch, fn) r0 = fntorch(t0_small, t1, t2) r1 = fntorch(t0_full, t1, t2) self.assertEqual(r0, r1) @dtypes(torch.float32, torch.float64) @tf32_on_and_off(0.005) def test_strided_mm_bmm(self, device, dtype): # Tests strided view case with stride smaller than corresponding dimension size x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device) new_shape = [2, 2, 2] new_stride = [3, 1, 1] sx = torch.as_strided(x, size=new_shape, stride=new_stride) torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 np_fn = lambda x: np.matmul(x, x) # noqa: E731 self.compare_with_numpy(torch_fn, np_fn, sx) torch_fn = lambda x: torch.mm(x, x) # noqa: E731 self.compare_with_numpy(torch_fn, np_fn, sx[0]) @tf32_on_and_off(0.005) def test_mm_empty_inputs_mixed_dtype_errors(self, device): a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) b = torch.randn(10, 20, dtype=torch.float32, device=device) with self.assertRaisesRegex( RuntimeError, "expected .* and .* to have the same dtype, but got:" ): torch.mm(a, b) @tf32_on_and_off(0.005) def test_matmul_45724(self, device): # https://github.com/pytorch/pytorch/issues/45724 a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half() torch.matmul(a, b, out=c) self.assertEqual(c, cpu_result) @dtypes( torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, ) @tf32_on_and_off(0.005) def test_baddbmm_input_dtypes_compatibility(self, device, dtype): batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) if dtype != torch.float32: with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) else: out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) y_ref = torch.bmm(batch1, batch2) torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) self.assertEqual(out, y_ref) @dtypes(torch.float) @tf32_on_and_off(0.005) def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): for shape in [[3, 2, 2], [2, 20, 20]]: mat1, mat2 = ( torch.randn(shape, dtype=dtype, device=device) for _ in range(2) ) inputs = [ torch.randn(shape, dtype=dtype, device=device), torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), ] outs = [ None, torch.randn(shape, dtype=dtype, device=device), torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), ] options = itertools.product(inputs, outs) for input, out in options: y_ref = torch.bmm(mat1, mat2) y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) self.assertEqual(y_ref, y) @precisionOverride({torch.double: 1e-6}) @dtypes(torch.float, torch.double) @tf32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: for n in [0, 1, 10]: for k in [0, 1, 8]: M = torch.randn(n, m, device=device).to(dtype) m1 = torch.randn(n, k, device=device).to(dtype) m2 = torch.randn(k, m, device=device).to(dtype) self._test_addmm_addmv(torch.addmm, M, m1, m2) m1 = torch.randn(n, k + 1, device=device).to(dtype) m2 = torch.randn(k, m, device=device).to(dtype) self.assertRaisesRegex( RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2), ) self.assertRaisesRegex( RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2) ) @precisionOverride( { torch.double: 1e-6, torch.float: 1e-4, torch.bfloat16: 5e-2, torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8, } ) @dtypes(torch.double, torch.float32, torch.bfloat16, torch.half) @tf32_on_and_off(0.05) def test_addmm_gelu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) @precisionOverride( { torch.double: 1e-6, torch.float: 1e-4, torch.bfloat16: 5e-2, torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8, } ) @dtypes(torch.double, torch.float32, torch.bfloat16, torch.half) @tf32_on_and_off(0.05) def test_addmm_relu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) @dtypes(torch.float, torch.bfloat16, torch.half) @tf32_on_and_off(0.005) def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): # tests (o, s)*(s). o is output size, s is summed size. o = 5 s = 3 a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) y_data = torch.ones(o, device=device, dtype=dtype) def _test(row_major, incx, incy, lda_tail): if row_major: a_storage = torch.full( (o, s + lda_tail), float("nan"), device=device, dtype=dtype ) else: a_storage = torch.full( (s, o + lda_tail), float("nan"), device=device, dtype=dtype ).permute(1, 0) a = a_storage[:o, :s].copy_(a_data) x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype) x = x_storage[:, 0].copy_(x_data) y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype) y = y_storage[:, 0].copy_(y_data) self._test_addmm_addmv(torch.addmv, y, a, x) for row_major, incx, incy, lda_tail in itertools.product( (False, True), (1, 2), (1, 2), (0, 1) ): _test(row_major, incx, incy, lda_tail) @precisionOverride( { torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8, } ) @dtypes(torch.double, torch.bfloat16, torch.half, torch.float32) @tf32_on_and_off(0.005) def test_corner_cases_of_cublasltmatmul(self, device, dtype): # common case M = torch.randn(128, device=device).to(dtype) m1 = torch.randn(2048, 2400, device=device).to(dtype) m2 = torch.randn(128, 2400, device=device).to(dtype) torch.nn.functional.linear(m1, m2, M) # Ntrans_B has ld >> rows m1 = torch.rand([128, 2400]).to(dtype).to(device).t() m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] M = torch.rand([128]).to(dtype).to(device) torch.addmm(M, m2.t(), m1) # trans_A has ld >> rows m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() m2 = torch.randn(2048, 2400, device=device).to(dtype) M = torch.rand([128]).to(dtype).to(device) torch.addmm(M, m2, m1) # large tensor dim > 65535 M = torch.randn(16, device=device).to(dtype) m1 = torch.randn(32, 131071, device=device).to(dtype) m2 = torch.randn(16, 131071, device=device).to(dtype) torch.nn.functional.linear(m1, m2, M) def test_blas_empty(self, device): def fn(torchfn, *args, test_out=False, **kwargs): def call_torch_fn(*args, **kwargs): return torchfn( *tuple( torch.randn(shape, device=device) if isinstance(shape, tuple) else shape for shape in args ), **kwargs, ) result = call_torch_fn(*args, **kwargs) if not test_out: return result else: out = torch.full_like(result, math.nan) out1 = call_torch_fn(*args, **kwargs, out=out) # noqa: F841 # FIXME(rec): should this return out1? return out # mm, addmm self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) self.assertEqual( torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)) ) self.assertEqual( torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True), ) self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape) t = torch.randn((5, 6), device=device) self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) # mv, addmv self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) self.assertEqual( torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True) ) self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) t = torch.randn((3,), device=device) self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) # bmm, baddbmm self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) self.assertEqual( torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)) ) self.assertEqual( torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True), ) self.assertEqual( (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape ) self.assertEqual( (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape ) self.assertEqual( (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape ) self.assertEqual( (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape ) c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) self.assertEqual( -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2) ) # Issue #33467 self.assertEqual( -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True) ) # Issue #33467 # addbmm self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) t = torch.randn((5, 6), device=device) self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) # matmul self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,))) self.assertEqual( torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,), test_out=True), ) self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) self.assertEqual( torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)), ) self.assertEqual( torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True), ) # dot self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,))) self.assertEqual( torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True) ) @tf32_on_and_off(0.005) def test_large_bmm_backward(self, device): A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT B = torch.randn([1, 1024, 65536], device=device, requires_grad=True) G = torch.randn([1024, 2, 65536], device=device) # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM (A @ B).backward(G) @tf32_on_and_off(0.005) def test_large_bmm_mm_backward(self, device): A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT B = torch.randn([1024, 65536], device=device, requires_grad=True) G = torch.randn([1024, 2, 65536], device=device) # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM (A @ B).backward(G) def check_single_matmul(self, x, y): def assertEqual(answer, expected): if x.dtype.is_floating_point or x.dtype.is_complex: k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix self.assertEqual( answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}", atol=k * 5e-5, rtol=1e-4, ) else: self.assertEqual( answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}" ) # test x @ y expected = np.matmul(x.cpu(), y.cpu()) ans = torch.matmul(x, y) self.assertTrue(ans.is_contiguous()) assertEqual(ans, expected) # test out out = torch.empty_like(ans) ans = torch.matmul(x, y, out=out) self.assertIs(ans, out) self.assertTrue(ans.is_contiguous()) assertEqual(ans, expected) def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): """ Generates sequences of tuples (x, y) of with size(x) = x_dim and size(y) <= y_dim that are compatible wrt. matmul """ assert x_dim >= 1 assert y_dim >= 2 x = x_dim for y in range(1, y_dim + 1): for batch, mn in product( product(range(batch_size), repeat=max(x - 2, y - 2, 0)), product(range(matrix_size), repeat=min(y, 2)), ): if x == 1: size_x = mn[:1] size_y = batch + mn yield size_x, size_y else: for k in range(matrix_size): size_x = (k,) + mn[:1] if x > 2: size_x = batch[-(x - 2) :] + size_x size_y = mn if y > 2: size_y = batch[-(y - 2) :] + size_y yield size_x, size_y @dtypes(torch.float) def test_matmul_small_brute_force_1d_Nd(self, device, dtype): make_arg = partial(make_tensor, device=device, dtype=dtype) for (size_x, size_y), nctg_x, nctg_y in product( self.gen_sizes_matmul(1), (True, False), (True, False) ): x = make_arg(size_x, noncontiguous=nctg_x) y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) @dtypes(torch.float) def test_matmul_small_brute_force_2d_Nd(self, device, dtype): make_arg = partial(make_tensor, device=device, dtype=dtype) for (size_x, size_y), nctg_x, nctg_y in product( self.gen_sizes_matmul(2), (True, False), (True, False) ): x = make_arg(size_x, noncontiguous=nctg_x) y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) @dtypes(torch.float) def test_matmul_small_brute_force_3d_Nd(self, device, dtype): make_arg = partial(make_tensor, device=device, dtype=dtype) for (size_x, size_y), nctg_x, nctg_y in product( self.gen_sizes_matmul(3), (True, False), (True, False) ): x = make_arg(size_x, noncontiguous=nctg_x) y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) @dtypes(torch.float) @tf32_on_and_off(0.005) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty( (256, 512), device=device, dtype=dtype, requires_grad=True ).unsqueeze(0) b = torch.empty( (4, 128, 512), device=device, dtype=dtype, requires_grad=True ).transpose(-1, -2) c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) torch.matmul(a.detach(), b.detach(), out=c) with self.assertRaisesRegex( RuntimeError, "functions with out=... arguments don't support automatic differentiation", ): torch.matmul(a, b, out=c) with torch.no_grad(): torch.matmul(a, b, out=c) def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16): # w [k, n] = [32, 48] assert w.dim() == 2 # w [n, k] = [48, 32] w = w.transpose(0, 1).contiguous() assert q_group_size > 1 assert w.shape[-1] % q_group_size == 0 # to_quant: [n * k / group_size, group_size] to_quant = w.reshape(-1, q_group_size) assert torch.isnan(to_quant).sum() == 0 max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 zeros = min_int - min_val.div(scales).round() zeros = torch.clamp(zeros, min_int, max_int) zeros = zeros.to(torch.int8) assert torch.isnan(zeros).sum() == 0 out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int) assert torch.isnan(out).sum() == 0 # [n, k] out = out.to(dtype=torch.int32).reshape(w.shape) if out.device != torch.device("cpu"): out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8) # Scales and zeros for the same q-group should be contiguous, so we can # load as a 32-bit word scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous() zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous() return out, scales, zeros @parametrize("m", [128]) @parametrize("k", [512, 1024]) @parametrize("n", [512, 1024]) def test__int4_mm(self, device, m, k, n): q_group = 32 inner_k_tiles = 2 torch.manual_seed(1) a_bf16 = torch.rand((m, k), dtype=torch.float32, device=device) b_bf16 = torch.rand((k, n), dtype=torch.float32, device=device) def convert_weight_to_int4pack(b): # b_uint8 [n, k //2] b_uint8, scales, zeros = self._group_quantize_tensor( b, n_bit=4, q_group_size=q_group ) # b_int4pack [k//8, n] b_int4pack = torch._convert_weight_to_int4pack(b_uint8, inner_k_tiles) return b_int4pack, scales, zeros def weight_int4pack_mm(a, b_int4pack, qscale, qzeros): return torch._weight_int4pack_mm_with_scales_and_zeros( a, b_int4pack, q_group, qscale, qzeros ) b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b_bf16) for dtype in [torch.bfloat16, torch.float16]: a = a_bf16.to(dtype=dtype) b = b_bf16.to(dtype=dtype) b_scales = b_scales.to(dtype=dtype) ref = torch.mm(a, b) res = weight_int4pack_mm(a, b_int4pack, b_scales, zeros_int8) mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) def test_mm_with_offset(self, device): from torch._dynamo.testing import rand_strided offset = 997 a = rand_strided( (2, 4, 128, 64), (65536, 16384, 64, 1), dtype=torch.float16, device=device, extra_size=offset, ) a = a.as_strided((2, 4, 128, 64), (65536, 16384, 64, 1), storage_offset=offset) b = rand_strided( (2, 4, 64, 256), (65536, 16384, 1, 64), dtype=torch.float16, device=device ) gpu_out = torch.matmul(a, b) cpu_out = torch.matmul(a.cpu(), b.cpu()) self.assertEqual(gpu_out.cpu(), cpu_out) instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) if __name__ == "__main__": run_tests()