View source on GitHub |
Multiplies matrix a
by matrix b
, producing a
* b
.
tf.compat.v1.linalg.matmul(
a,
b,
transpose_a=False,
transpose_b=False,
adjoint_a=False,
adjoint_b=False,
a_is_sparse=False,
b_is_sparse=False,
output_type=None,
grad_a=False,
grad_b=False,
name=None
)
The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 dimensions specify valid matrix multiplication dimensions, and any further outer dimensions specify matching batch size.
Both matrices must be of the same type. The supported types are:
bfloat16
, float16
, float32
, float64
, int32
, int64
,
complex64
, complex128
.
Either matrix can be transposed or adjointed (conjugated and transposed) on
the fly by setting one of the corresponding flag to True
. These are False
by default.
If one or both of the matrices contain a lot of zeros, a more efficient
multiplication algorithm can be used by setting the corresponding
a_is_sparse
or b_is_sparse
flag to True
. These are False
by default.
This optimization is only available for plain matrices (rank-2 tensors) with
datatypes bfloat16
or float32
.
A simple 2-D tensor matrix multiplication:
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
a # 2-D tensor
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6]], dtype=int32)>
b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])
b # 2-D tensor
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 7, 8],
[ 9, 10],
[11, 12]], dtype=int32)>
c = tf.matmul(a, b)
c # `a` * `b`
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[ 58, 64],
[139, 154]], dtype=int32)>
A batch matrix multiplication with batch shape [2]:
a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])
a # 3-D tensor
<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]], dtype=int32)>
b = tf.constant(np.arange(13, 25, dtype=np.int32), shape=[2, 3, 2])
b # 3-D tensor
<tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
array([[[13, 14],
[15, 16],
[17, 18]],
[[19, 20],
[21, 22],
[23, 24]]], dtype=int32)>
c = tf.matmul(a, b)
c # `a` * `b`
<tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
array([[[ 94, 100],
[229, 244]],
[[508, 532],
[697, 730]]], dtype=int32)>
Since python >= 3.5 the @ operator is supported
(see PEP 465). In TensorFlow,
it simply calls the tf.matmul()
function, so the following lines are
equivalent:
d = a @ b @ [[10], [11]]
d = tf.matmul(tf.matmul(a, b), [[10], [11]])
Args | |
---|---|
a
|
tf.Tensor of type float16 , float32 , float64 , int32 ,
complex64 , complex128 and rank > 1.
|
b
|
tf.Tensor with same type and rank as a .
|
transpose_a
|
If True , a is transposed before multiplication.
|
transpose_b
|
If True , b is transposed before multiplication.
|
adjoint_a
|
If True , a is conjugated and transposed before
multiplication.
|
adjoint_b
|
If True , b is conjugated and transposed before
multiplication.
|
a_is_sparse
|
If True , a is treated as a sparse matrix. Notice, this
does not support tf.sparse.SparseTensor , it just makes optimizations
that assume most values in a are zero. See
tf.sparse.sparse_dense_matmul for some support for
tf.sparse.SparseTensor multiplication.
|
b_is_sparse
|
If True , b is treated as a sparse matrix. Notice, this
does not support tf.sparse.SparseTensor , it just makes optimizations
that assume most values in b are zero. See
tf.sparse.sparse_dense_matmul for some support for
tf.sparse.SparseTensor multiplication.
|
output_type
|
The output datatype if needed. Defaults to None in which case the output_type is the same as input type. Currently only works when input tensors are type (u)int8 and output_type can be int32. |
grad_a
|
Set it to True to hint that Tensor a is for the backward pass.
|
grad_b
|
Set it to True to hint that Tensor b is for the backward pass.
|
name
|
Name for the operation (optional). |
Returns | |
---|---|
A tf.Tensor of the same type as a and b where each inner-most matrix
is the product of the corresponding matrices in a and b , e.g. if all
transpose or adjoint attributes are False :
|
|
Note
|
This is matrix product, not element-wise product. |