KIDS  ver-0.0.1
KIDS : Kernel Integrated Dynamics Simulator
Loading...
Searching...
No Matches
Einsum.cpp
Go to the documentation of this file.
1#include "kids/Einsum.h"
2
3#include <algorithm>
4#include <sstream>
5
6namespace PROJECT_NS {
7
8DimenHelper::DimenHelper(const std::string& esshape, std::vector<EinsumIdx>& idx_vec)
9 : esshape_rank{esshape.size()}, total_esidx{idx_vec.size()} {
10 dims.resize(esshape_rank);
11 ldims.resize(esshape_rank);
12 es_ldims.resize(total_esidx);
13 mapldims.resize(total_esidx);
14
15 // calculate the normal leading dimensions of the tensor
16 for (int k = esshape_rank - 1, lastsize = 1, lastldim = 1; k >= 0; --k) {
17 ldims[k] = lastsize * lastldim;
18 int q = -1;
19 while (idx_vec[++q].label != esshape[k]) {};
20 dims[k] = idx_vec[q].dim;
21 lastsize = dims[k];
22 lastldim = ldims[k];
23 }
24
25 // calculate the leading dimensions of the tensor represented in einsum indexes
26 for (int i = 0; i < total_esidx; ++i) {
27 char c = idx_vec[i].label;
28 es_ldims[i] = 0;
29 for (int k = esshape_rank - 1; k >= 0; --k) {
30 if (c == esshape[k]) es_ldims[i] += ldims[k];
31 }
32 }
33
34 // calculate sum of several leading dimensions as the shift step
35 for (int i = 0; i < total_esidx; ++i) { //
36 mapldims[i] = es_ldims[i];
37 for (int k = i + 1; k < total_esidx; ++k) { //
38 mapldims[i] -= (idx_vec[k].dim - 1) * es_ldims[k];
39 }
40 }
41}
42
43EinsumHelper::EinsumHelper(const std::string& einsum_expression, //
44 std::vector<std::vector<std::size_t>> shape_inputs, //
45 std::vector<std::size_t> shape_output //
46) {
47 std::stringstream ss{einsum_expression};
48 std::string esshape = "";
49 int ishape = 0;
50 bool auto_deduction = true;
51 for (char c; ss >> c;) {
52 switch (c) {
53 case ',': {
54 if (esshape.size() != shape_inputs[ishape].size()) {
55 throw std::runtime_error("mismatch einsum rule with shape!");
56 }
57 esshape_inputs.push_back(esshape);
58 esshape = "";
59 ishape++;
60 break;
61 }
62 case '[': {
63 std::string label_name = "";
64 while (ss >> c) {
65 if (c == ']') break;
66 label_name += c;
67 }
68 auto it = std::find(fixed_label_names.begin(), fixed_label_names.end(), label_name);
69 auto found = (it != fixed_label_names.end());
70 int ipos = found ? int(it - fixed_label_names.begin()) : fixed_label_names.size();
71 c = (char) ((int) '0' + ipos);
72
73 if (!found) {
74 einsum_idxs.push_back(EinsumIdx(c, 0, shape_inputs[ishape][esshape.size()], 0));
75 fixed_label_names.push_back(label_name);
76 } else {
77 auto it2 = std::find_if(einsum_idxs.begin(), einsum_idxs.end(),
78 [c](EinsumIdx idx) { return c == idx.label; });
79 if (it2->dim != shape_inputs[ishape][esshape.size()]) {
80 // std::cout << c << shape_inputs[ishape][esshape.size()] << "\n";
81 throw std::runtime_error("bad einsum shape!");
82 }
83 }
84 esshape += c;
85
86 if (fixed_label_names.size() > 10) throw std::runtime_error("too many fixed einsum idx!");
87 break;
88 }
89 case ' ':
90 case '-':
91 break;
92 case '>': {
93 auto_deduction = false; // then by user's deduction
94 esshape_output = "";
95 while (ss >> c) {
96 if ((int) c < (int) 'a' || (int) c > (int) 'z') {
97 throw std::runtime_error("only allowed [a-z] for normal einsum label");
98 }
99 auto it = std::find_if(einsum_idxs.begin(), einsum_idxs.end(), //
100 [c](EinsumIdx idx) { return c == idx.label; });
101 if (it != einsum_idxs.end()) {
102 if (shape_output.size() > 0 && it->dim != shape_output[esshape_output.size()]) {
103 throw std::runtime_error("bad einsum shape!");
104 }
105 it->cnt = 1;
106 } else {
107 throw std::runtime_error("bad einsum einsum_expression!");
108 }
109 esshape_output += c;
110 }
111 if (esshape_output.size() != shape_output.size() && shape_output.size() != 0) {
112 // if shape_output.size() == 0, we don't check
113 throw std::runtime_error("mismatch einsum rule with shape!");
114 }
115 break;
116 }
117 default: {
118 if ((int) c < (int) 'a' || (int) c > (int) 'z') {
119 throw std::runtime_error("only allowed [a-z] for normal einsum label");
120 }
121 auto it = std::find_if(einsum_idxs.begin(), einsum_idxs.end(), //
122 [c](EinsumIdx idx) { return c == idx.label; });
123 if (it != einsum_idxs.end()) {
124 if (it->dim == shape_inputs[ishape][esshape.size()]) {
125 it->cnt++; // update as inner label
126 } else {
127 // std::cout << c << shape_inputs[ishape][esshape.size()] << "\n";
128 throw std::runtime_error("bad einsum shape!");
129 }
130 } else {
131 einsum_idxs.push_back(EinsumIdx(c, 1, shape_inputs[ishape][esshape.size()], 0));
132 }
133 esshape += c;
134 break;
135 }
136 }
137 }
138 esshape_inputs.push_back(esshape);
139
140 if (auto_deduction) {
141 esshape_output = "";
142 for (auto& idx : einsum_idxs) {
143 if (idx.cnt == 1) esshape_output += idx.label;
144 }
145 } else {
146 for (auto& idx : einsum_idxs) {
147 if (idx.cnt == 1) idx.cnt = 2; // revise to inner label
148 for (auto& label : esshape_output) {
149 if (idx.label == label) idx.cnt = 1; // revise to outer label
150 }
151 }
152 }
153 if (esshape_output == "") { // allow return a scalar
154 einsum_idxs.push_back(EinsumIdx('*', 0, 1, 0));
155 esshape_output = "*";
156 }
157 std::sort(einsum_idxs.begin(), einsum_idxs.end(),
158 [](EinsumIdx idx1, EinsumIdx idx2) { return idx1.cnt < idx2.cnt; });
159
160 count1 = 0;
161 count2 = 0;
162 count3 = 0;
163 total_loop = 1;
164 for (auto& idx : einsum_idxs) {
165 if (idx.cnt <= 0) count1++;
166 if (idx.cnt <= 1) count2++;
167 if (idx.cnt > 0) total_loop *= idx.dim;
168 count3++;
169 }
170
171 // for (auto& idx : einsum_idxs) {
172 // std::cout << idx.label << ", " << idx.cnt << "," << idx.dim << ", " << idx.val << "\n";
173 // }
174 // for (auto& esshape : esshape_inputs) std::cout << esshape << "\n";
175 // std::cout << "->" << esshape_output << "\n";
176 // std::cout << "count1 : " << count1 << "\n";
177 // std::cout << "count2 : " << count2 << "\n";
178 // std::cout << "count3 : " << count3 << "\n";
179
180 for (auto& esshape : esshape_inputs) { dh_inputs.push_back(DimenHelper(esshape, einsum_idxs)); }
182
183 for (auto& idx : einsum_idxs) { einsum_dims.push_back(idx.dim); }
184
185 total_esidx = einsum_idxs.size();
187
190}
191
192}; // namespace PROJECT_NS
this file provides einsum operation
std::vector< EinsumIdx > einsum_idxs
the EinsumIdx System
Definition Einsum.h:182
EinsumHelper(const std::string &einsum_expression, std::vector< std::vector< std::size_t > > shape_inputs, std::vector< std::size_t > shape_output={})
Definition Einsum.cpp:43
std::vector< DimenHelper > dh_inputs
DimenHelper for input tensors.
Definition Einsum.h:190
std::vector< std::size_t > einsum_dims
each dimension of EinsumIdx System
Definition Einsum.h:183
DimenHelper dh_output
DimenHelper for ouput tensor.
Definition Einsum.h:191
std::vector< std::size_t > einsum_iposes
idx placeholder for EinsumIdx System
Definition Einsum.h:193
std::vector< std::string > esshape_inputs
store einsum's strings of input tensors
Definition Einsum.h:187
std::string esshape_output
store/deduct einsum's for the ouput tensor
Definition Einsum.h:188
std::size_t total_esidx
total number of EinsumIdx in EinsumIdx System
Definition Einsum.h:179
std::size_t total_tensor
total number of tensor in einsum rule
Definition Einsum.h:180
std::vector< std::string > fixed_label_names
store for fixed labels
Definition Einsum.h:185
std::vector< std::size_t > ipos_inputs
idx placeholder for input tensors
Definition Einsum.h:194
< http://warp.povusers.org/FunctionParser/fparser.html
Definition Context.h:39
DimenHelper is a struct control dimensional utils on the orginal/einsum index for a given tensor.
Definition Einsum.h:156
std::size_t total_esidx
size if the EinsumIdx System
Definition Einsum.h:159
std::vector< std::size_t > es_ldims
leading dimensions of the tensor represented in einsum indexes
Definition Einsum.h:163
std::vector< std::size_t > mapldims
utils for sum of several leading dimensions as the shift step
Definition Einsum.h:164
std::vector< std::size_t > ldims
leading dimensions of the tensor
Definition Einsum.h:162
std::vector< std::size_t > dims
leading dimensions of the tensor
Definition Einsum.h:161
std::size_t esshape_rank
the rank of the tensor
Definition Einsum.h:158
EinsumIdx is a struct store information of index used in einsum operation.
Definition Einsum.h:142