diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-16 12:00:54 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-16 12:00:54 -0700 |
commit | de354f7ded52bfa857089769225cdf1ee694bfd6 (patch) | |
tree | f103ac674762293e4e7e0d52dbee9351ec87bbae /tests | |
parent | dede343033991c32735f01a94019e13fb4968b3c (diff) |
Added fused bias to matmullt.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_autograd.py | 49 |
1 files changed, 38 insertions, 11 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index f1a15f5..0cd17c9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,4 +1,4 @@ -from itertools import product +from itertools import product, permutations import pytest import torch @@ -241,11 +241,20 @@ decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] str_funcs = ["matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ["FF", "TF", "TT", "FT"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] dtype = [torch.float16] has_fp16_weights = [True, False] +has_bias = [True, False] values = list( product( dim1, @@ -258,6 +267,7 @@ values = list( transpose, decomp, has_fp16_weights, + has_bias ) ) str_values = list( @@ -272,18 +282,14 @@ str_values = list( str_transpose, decomp, has_fp16_weights, + has_bias ) ) -names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format( - *vals - ) - for vals in str_values -] +names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", values, ids=names, ) @@ -298,10 +304,14 @@ def test_matmullt( transpose, decomp, has_fp16_weights, + has_bias ): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False for i in range(k): @@ -322,6 +332,11 @@ def test_matmullt( requires_grad=req_grad[1], dtype=dtype, ) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -342,10 +357,13 @@ def test_matmullt( if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, state=state) + out_bnb = funcs[1](A, B2, state=state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), state=state) + out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2) + + if has_bias: + out_torch += bias n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() @@ -367,6 +385,9 @@ def test_matmullt( gradB1 = B.grad A.grad = None B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None loss_torch = torch.nn.functional.mse_loss( out_torch, target @@ -376,6 +397,9 @@ def test_matmullt( gradB2 = B.grad A.grad = None B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None if req_grad[0]: torch.testing.assert_allclose( @@ -397,3 +421,6 @@ def test_matmullt( torch.testing.assert_allclose( gradB1, gradB2, atol=0.18, rtol=0.3 ) + + if req_grad[2]: + torch.testing.assert_allclose(gradBias1, gradBias2) |