matmul

fun DTensor.matmul(right: DTensor): DTensor

Matrix multiply of two tensors. See https://pytorch.org/docs/stable/generated/torch.matmul.html for the specification.


fun DTensor.matmul(right: DTensor, a: Shape, b: Shape, c: Shape, d: Shape): DTensor

Generalized matrix multiply of two tensors.

Given lists of integers A,B,C, and D,

Takes an input (left) of Shape(A,B,C) and an input (right) of Shape(A,C,D) and returns an output of Shape(A,B,D). -- note C is eliminated

A is known as the list of batch dimensions.