9 : esshape_rank{esshape.size()}, total_esidx{idx_vec.size()} {
16 for (
int k =
esshape_rank - 1, lastsize = 1, lastldim = 1; k >= 0; --k) {
17 ldims[k] = lastsize * lastldim;
19 while (idx_vec[++q].label != esshape[k]) {};
20 dims[k] = idx_vec[q].dim;
27 char c = idx_vec[i].label;
44 std::vector<std::vector<std::size_t>> shape_inputs,
45 std::vector<std::size_t> shape_output
47 std::stringstream ss{einsum_expression};
48 std::string esshape =
"";
50 bool auto_deduction =
true;
51 for (
char c; ss >> c;) {
54 if (esshape.size() != shape_inputs[ishape].size()) {
55 throw std::runtime_error(
"mismatch einsum rule with shape!");
63 std::string label_name =
"";
71 c = (char) ((
int)
'0' + ipos);
78 [c](
EinsumIdx idx) { return c == idx.label; });
79 if (it2->dim != shape_inputs[ishape][esshape.size()]) {
81 throw std::runtime_error(
"bad einsum shape!");
86 if (
fixed_label_names.size() > 10)
throw std::runtime_error(
"too many fixed einsum idx!");
93 auto_deduction =
false;
96 if ((
int) c < (
int)
'a' || (
int) c > (
int)
'z') {
97 throw std::runtime_error(
"only allowed [a-z] for normal einsum label");
100 [c](
EinsumIdx idx) { return c == idx.label; });
102 if (shape_output.size() > 0 && it->dim != shape_output[
esshape_output.size()]) {
103 throw std::runtime_error(
"bad einsum shape!");
107 throw std::runtime_error(
"bad einsum einsum_expression!");
111 if (
esshape_output.size() != shape_output.size() && shape_output.size() != 0) {
113 throw std::runtime_error(
"mismatch einsum rule with shape!");
118 if ((
int) c < (
int)
'a' || (
int) c > (
int)
'z') {
119 throw std::runtime_error(
"only allowed [a-z] for normal einsum label");
122 [c](
EinsumIdx idx) { return c == idx.label; });
124 if (it->dim == shape_inputs[ishape][esshape.size()]) {
128 throw std::runtime_error(
"bad einsum shape!");
140 if (auto_deduction) {
147 if (idx.cnt == 1) idx.cnt = 2;
149 if (idx.label == label) idx.cnt = 1;
165 if (idx.cnt <= 0)
count1++;
166 if (idx.cnt <= 1)
count2++;
this file provides einsum operation
std::vector< EinsumIdx > einsum_idxs
the EinsumIdx System
EinsumHelper(const std::string &einsum_expression, std::vector< std::vector< std::size_t > > shape_inputs, std::vector< std::size_t > shape_output={})
std::vector< DimenHelper > dh_inputs
DimenHelper for input tensors.
std::vector< std::size_t > einsum_dims
each dimension of EinsumIdx System
DimenHelper dh_output
DimenHelper for ouput tensor.
std::vector< std::size_t > einsum_iposes
idx placeholder for EinsumIdx System
std::vector< std::string > esshape_inputs
store einsum's strings of input tensors
std::string esshape_output
store/deduct einsum's for the ouput tensor
std::size_t total_esidx
total number of EinsumIdx in EinsumIdx System
std::size_t total_tensor
total number of tensor in einsum rule
std::vector< std::string > fixed_label_names
store for fixed labels
std::vector< std::size_t > ipos_inputs
idx placeholder for input tensors
< http://warp.povusers.org/FunctionParser/fparser.html
DimenHelper is a struct control dimensional utils on the orginal/einsum index for a given tensor.
std::size_t total_esidx
size if the EinsumIdx System
std::vector< std::size_t > es_ldims
leading dimensions of the tensor represented in einsum indexes
std::vector< std::size_t > mapldims
utils for sum of several leading dimensions as the shift step
std::vector< std::size_t > ldims
leading dimensions of the tensor
std::vector< std::size_t > dims
leading dimensions of the tensor
std::size_t esshape_rank
the rank of the tensor
EinsumIdx is a struct store information of index used in einsum operation.