Thanks for the mention!
Regarding the naming of axes: einx.vmap doesn't know anything about my_op, other than that it has the signature "m, n, m n -> "
in the first case and "a, b, c d -> "
in the second case. Both are valid if you pass the right inputs shapes. You get different behavior for incorrect input shapes though: In the first case, einx will raise an exception before calling my_op due to failing the shape resolution (e.g. due to multiple different values for m). In the second case, einx will assume the shapes to be correct (and it can't know they aren't correct before calling my_op), so the error will be raised somewhere in my_op.
The decorator for einx.vmap is a good point. I did only realize when typing the above comment that wrapping is a nice way of writing the operation in the first place. :D
I did at some point consider adding a new symbol (like
":"
) that would act as a new axis with a unique name, but have been hesitant so far, since it adds to the complexity of the notation. There are a bunch of ideas for improving quality-of-life in einx, but so far I've tried erring on the side of less complexity (and there's probably some cases where I should've adhered to this more); to keep a low barrier of entry and also not end up with a mess of many different rules that classical tensor notation is in (you made good points about that here...). There are indeed cases where the operation depends on names in the brackets, e.g.einx.dot("a [b], [b c], [c] d", ...)
, so the":"
would be an additional variant rather than a simplification.What I like better about using actual names is also that they convey semantics
and indicate corresponding axes in consecutive operations (although this is not enforced strictly).
Defining the shapes of a vmapped operation in the decorator sounds like a good idea. It would probably require a kind of pattern matching to align the inner and outer expression (e.g. to also allow for n-dimensional inputs or variadic arguments to the custom, vmapped function).