KIDS  ver-0.0.1
KIDS : Kernel Integrated Dynamics Simulator
Loading...
Searching...
No Matches
generate_pybind11.py
Go to the documentation of this file.
1#!/bin/python
2
3import json
4import glob
5import re
6import sys
7import os
8
9root=sys.argv[1]
10config_file = sys.argv[2]
11
12if os.path.exists(root+'/config.json'):
13 config_file = root + '/config.json'
14
15
16with open(config_file, 'r', encoding='utf-8') as load_f:
17 global data
18 data = json.load(load_f)
19
20decro_list = ['const', 'static', 'inline', 'virtual', 'constexpr']
21type_list = ['void', 'bool', 'int', 'double', 'kids_real', 'kids_complex',
22 'std::string', 'std::map', 'enum', 'class', 'namespace']
23
25 decro1, decro2 = '', ''
26
27 # more compact
28 res = re.search('<(.*?)>', line, re.S)
29 if res:
30 i1, i2 = res.span()
31 old = line[i1:i2]
32 line = re.sub(old, '<__>', line)
33
34 tmpsx = re.split('\‍(|;|\{', line)
35 tmps0 = tmpsx[0].split(' ')
36 while '' in tmps0: tmps0.remove('')
37
38 if '(' not in line:
39 se_type = '_var'
40 else:
41 if len(tmps0) == 1:
42 se_type = '_init'
43 else:
44 se_type = '_fun'
45 xxx = ''
46 for t in tmps0:
47 if t in decro_list:
48 if xxx == '':
49 decro1 += t + ' '
50 else:
51 decro2 += t + ' '
52 else:
53 xxx += t + ' '
54 xxx2 = re.split('=', xxx)[0]
55 tmps0 = re.split(' |,|;', xxx2)
56 while '' in tmps0: tmps0.remove('')
57
58 typeinfo = ''
59 if len(tmps0) > 1:
60 typeinfo = tmps0[0]
61 flags = tmps0[1:]
62 else:
63 flags = tmps0[:]
64
65 if res:
66 typeinfo = re.sub('<__>', old, typeinfo)
67
68 if se_type == '_fun' and typeinfo == '': #bugs: assume that ~final all with virtual!
69 se_type = '_final'
70
71 return decro1, typeinfo, decro2, flags, se_type
72
74 '''
75 this function remove c-style comment of in the text
76 '''
77 pn1 = re.compile('//.*')
78 pn2 = re.compile('\/\*(?:[^\*]|\*+[^\/\*])*\*+\/')
79 txt = re.sub(pn1, '', txt)
80 txt = re.sub(pn2, '', txt)
81 return txt
82
83def safe_split_comma(line_in):
84 '''
85 this function split comma, such as
86 <double, std::string>, <std::map<int, double>>, <<int>, 2>
87 '''
88 line = line_in
89 _dd = []
90 res = re.search('(<[^<]*?>)', line, re.S)
91 while res:
92 i1, i2 = res.span()
93 _dd += [ line[i1:i2] ]
94
95 line = line[:i1] + "I__%d__I"%len(_dd) + line[i2:]
96 res = re.search('(<[^<]*?>)', line, re.S)
97
98 line = re.sub(',', '@', line)
99
100 while _dd:
101 line = line.replace("I__%d__I"%len(_dd), _dd[-1])
102 _dd = _dd[:-1]
103
104 return line.split('@')
105
106def unique_type(typi):
107 typi = re.sub(' +?', ' ', typi)
108 typi = re.sub('(\w) (\w)', r'\1I__0__I\2', typi)
109 typi = re.sub(' ', '', typi)
110 typi = re.sub('I__0__I', ' ', typi)
111 return typi
112
114 '''
115 this function split (unique) typeinfo & variable name.
116 passed for:
117 double* abc=0;
118 double* abc_d;
119 double* abc_d[];
120 double* abc_d[2];
121 double*&abc_d[];
122 const std::sting &abc_d[];
123 '''
124
125 # remove by symbol '=' (assignment) & ';' (end sentence)
126 _tmps = re.split(';|=', term)[0].strip()
127
128 # split array literal, such as 'A[2]' to 'A' and '[2]'
129 liter = ''
130 res = re.search('\[.*?$', _tmps, re.S)
131 if res:
132 i1, i2 = res.span()
133 _tmps, liter = _tmps[:i1], _tmps[i1:i2]
134
135 res = re.search('[\w]*?$', _tmps, re.S)
136 i1, i2 = res.span()
137 name = _tmps[i1:i2]
138 typi = _tmps[:i1] + liter
139
140 # make type to be unique
141 typi = unique_type(typi)
142
143 return typi, name
144
146 '''
147 this function split multiple typeinfo & variable name.
148 passed for:
149 "const int *a, b, **c;"
150 '''
151
152 _tmps = safe_split_comma(re.split(';', term)[0]);
153
154 type0, name0 = parse_type_and_name(_tmps[0])
155 arg_type = [type0]
156 arg_name = [name0]
157
158 # find base type for 0-th variable
159 res = re.search('(\W)*?$', type0, re.S)
160 if res:
161 i1, i2 = res.span()
162 type0 = type0[:i1]
163
164 # generate derived type for other variable
165 for i in range(1, len(_tmps)):
166 res = re.search('^(\W)*?\w', _tmps[i], re.S)
167 typex = type0
168 namex = _tmps[i]
169 if res:
170 i1, i2 = res.span()
171 typex += _tmps[i][i1:i2-1]
172 namex = _tmps[i][i2-1:]
173 typex = unique_type(typex)
174 arg_type += [typex]
175 arg_name += [namex]
176
177 return arg_type, arg_name
178
179
180def dict_append(d, keys, c):
181 ks = keys.split('.')
182 L = len(ks)
183 ds = [{}]*L
184 ds[0] = d
185 for i in range(L-1):
186 if ks[i] in ds[i]:
187 ds[i+1] = ds[i][ks[i]]
188 else:
189 ds[i][ks[i]] = {}
190 ds[i+1] = ds[i][ks[i]]
191
192 if ks[-1] not in ds[-1]:
193 ds[-1][ks[-1]] = c
194 else:
195 if(isinstance(c, list)):
196 ds[-1][ks[-1]] += c
197 if(isinstance(c, set)):
198 ds[-1][ks[-1]] |= c
199 if(isinstance(c, str)):
200 ds[-1][ks[-1]] += '\n' + c
201 if(isinstance(c, dict)):
202 ds[-1][ks[-1]] = dict( ds[-1][ks[-1]].items() + c.items())
203
204def sweep_scope(lines, istart):
205 cnt_1 = 0 # count for ( and )
206 cnt_2 = 0 # count for { and }
207 max_1 = 0
208 max_2 = 0
209 for i in range(istart, len(lines)):
210 cnt_1 += lines[i].count('(')
211 cnt_2 += lines[i].count('{')
212 max_1 = max(max_1, cnt_1)
213 max_2 = max(max_2, cnt_2)
214 cnt_1 -= lines[i].count(')')
215 cnt_2 -= lines[i].count('}')
216 if cnt_1 == 0 and cnt_2 == 0: # and ';' in lines[i]: # may not
217 return i, max_1, max_2
218 return istart, 0, 0
219
220def parse_argument(content):
221 tokens = safe_split_comma( (content.strip().split(')')[0]).split('(')[-1] )
222 arg_type = []
223 arg_name = []
224 for term in tokens:
225 _1, _2 = parse_type_and_name(term)
226 arg_type += [_1]
227 arg_name += [_2]
228 return arg_type, arg_name
229
230def parse_scope(info, kstring, lines, init_attr):
231 # print('\n'.join(lines))
232 # exit(-1)
233
234 il = 0
235 scope_attr = init_attr
236 for _ in lines:
237 if il >= len(lines): break
238
239 line = lines[il]
240 if line.strip() == '' or line[0] == '#':
241 il += 1
242 continue
243
244 # check attribute
245 if 'public:' in line:
246 scope_attr = 'public'
247 il += 1
248 continue
249 if 'protected:' in line:
250 scope_attr = 'protected'
251 il += 1
252 continue
253 if 'private:' in line:
254 scope_attr = 'private'
255 il += 1
256 continue
257
258 if 'DEFINE_POINTER' in line: # bind for container (at first token off| not confused with fun!)
259 tmps = re.split(' |,|\‍(|\‍)', line)
260 while '' in tmps: tmps.remove('')
261
262 flag = tmps[2]
263 dict_append(info, kstring+'._bind', [
264 {flag: { '_type':'EigMX<%s>'%tmps[1], '_raw': line.strip(), '_attr': scope_attr}}
265 ])
266 il += 1
267 continue
268
269 # sweep a scope and judge it to be _var, _fun or _init
270 il2, m1, m2 = sweep_scope(lines, il)
271 content = '\n'.join(lines[il:il2+1])
272
273 de1, typi, de2, flags, se_type = parse_decro_type_flag(line)
274
275 if m1 == 0 and se_type != '_var':
276 print('Error'*10)
277 else:
278 arg_type, arg_name = [], []
279 if se_type == '_fun' or se_type == '_init':
280 arg_type, arg_name = parse_argument(content)
281 # print(arg_type, arg_name)
282 for flag in flags:
283 dict_append(info, kstring+'.'+ se_type, [{
284 flag : {'_type': typi, '_raw': content.strip(), '_attr': scope_attr
285 ,'_arg_type': ', '.join(arg_type)
286 ,'_arg_name': ', '.join(arg_name)
287 }}
288 ])
289 il = il2 + 1
290
292
293 f = open(fn, 'r',encoding='utf-8')
294 txt = ''.join(f.readlines())
295 txt = remove_comment(txt)
296
297 info = {
298 "name": fn, # fn.split('/')[-1],
299 "deps": re.findall('#include "(.*?)"', txt, re.S)
300 }
301 kstring =''
302 scope_attr = ''
303
304 lines = txt.split('\n'); il = 0
305 for _ in lines:
306 if il >= len(lines): break
307
308 line = lines[il]
309 if line.strip() == '' or line[0] == '#':
310 il += 1
311 continue
312
313 terms = line.split(' ')
314 if terms[0] == 'using':
315 while ';' not in lines[il]:
316 il += 1
317 il += 1
318 continue
319
320 if '{' in terms:
321 if terms[0] == 'namespace':
322 flag = terms[1]
323 kstring = kstring + '.' + flag
324 dict_append(info, kstring, {'_type': 'namespace', '_var': []})
325 il2, _, _ = sweep_scope(lines, il)
326 parse_scope(info, kstring, lines[il+1:il2], 'public')
327 kstring = '.'.join(kstring.split('.')[:-1])
328 il = il2 + 1
329 continue
330
331 if terms[0] == 'class' or terms[0] == 'struct' :
332 flag = terms[1]
333 superclass = ''
334 if terms[2] == ':':
335 superclass = terms[4]
336 kstring = kstring + '.' + flag
337 dict_append(info, kstring, {'_type': terms[0], '_superclass': superclass,
338 '_var':[], '_fun':[], '_init':[], '_bind':[],
339 })
340 il2, _, _ = sweep_scope(lines, il)
341 parse_scope(info, kstring, lines[il+1:il2], 'private')
342 kstring = '.'.join(kstring.split('.')[:-1])
343 il = il2 + 1
344 continue
345
346 # sweep a scope and judge it to be _var, _fun or _init
347 il2, m1, m2 = sweep_scope(lines, il)
348 content = '\n'.join(lines[il:il2+1])
349
350 de1, typi, de2, flags, se_type = parse_decro_type_flag(line)
351
352 if m1 == 0 and se_type != '_var':
353 print('Error'*10)
354 else:
355 for flag in flags:
356 dict_append(info, kstring+'.'+ se_type, [{
357 flag : {'_type': typi, '_raw': content.strip(), '_attr': scope_attr}}
358 ])
359 il = il2 + 1
360 return info
361
362# info = file_parse('/home/public/hexin/share/github/opendf/solvers/solvers_md/traj.h')
363# print(json.dumps(info, indent=4))
364# exit(-1)
365
366
367objs = []
368for i in data['models']:
369 list1 = glob.glob(os.path.abspath(root+'/'+i))
370 for j in list1:
371 objs += [ '%s.h'%j[:-4] ]
372
373for i in data['solvers']:
374 list1 = glob.glob(os.path.abspath(root+'/'+i))
375 for j in list1:
376 objs += [ '%s.h'%j[:-4] ]
377
378# print(objs)
379# objs = [
380# '/home/public/hexin/share/github/opendf/models/nad_forcefield/systembath.h',
381# '/home/public/hexin/share/github/opendf/models/forcefieldbase.h',
382# '/home/public/hexin/share/github/opendf/solvers/solvers_md/traj.h'
383# '/home/public/hexin/share/github/opendf/solvers/solvers_nad/solver_mmd.h'
384# ]
385# exit(-1)
386
387complete=False
388infos = {'incl': [], 'deps':[], '': {}}
389while not complete:
390 infos1 = {'incl': [], 'deps':[], '': {}}
391 for i in objs:
392 info = file_parse(i)
393 infos1['incl'] += [ info['name'] ]
394 for j in info['deps']:
395 infos1['deps'] += [ os.path.abspath(
396 os.path.dirname(info['name']) + '/' + j
397 )]
398 infos1[''] = {**infos1[''], **info['']}
399
400 infos['incl'] += infos1['incl']
401 infos['deps'] += infos1['deps']
402 infos[''] = {**infos[''], **infos1['']}
403
404 s1 = set(infos['deps'])
405 s2 = set(infos['incl'])
406 s3 = set()
407 for i in s1-s2:
408 # print('checking deps: ', i)
409 if 'opendf/models' in i or 'opendf/solvers' in i:
410 s3 = s3 | {i}
411 if s3:
412 objs = list(s3)
413 else:
414 complete=True
415 # print("****")
416
417# print(json.dumps(infos, sort_keys=False, indent=4))
418# print(infos[''])
419# exit(-1)
420
421
422
423
424
425def get_fathers(dd, child):
426 if '_superclass' in dd[child] and dd[child]['_superclass'] != '':
427 return child + ', ' + get_fathers(dd, dd[child]['_superclass'])
428 return child
429
430def creat_trampoline_fun(ns, ns0, dd, coll):
431 for f in dd[ns]['_fun']:
432 for i in f:
433 if f[i]['_type'] not in type_list:
434 continue
435 if 'virtual' in f[i]['_raw']:
436 virtual_in_name = 'PYBIND11_OVERRIDE'
437 if ') = 0;' in f[i]['_raw']:
438 virtual_in_name = 'PYBIND11_OVERRIDE_PURE'
439
440 tag = '%s(%s)'%(i, f[i]['_arg_type'])
441 if tag in coll:
442 continue
443 coll += [tag]
444
445 token = f[i]['_raw'].split(')')[0] + ')'
446 token = token.replace('virtual ', '')
447 print('''
448 %s override {
449 %s(
450 %s, // return type
451 %s, // parent class
452 %s, // func name
453 %s
454 );
455 }'''%(token, virtual_in_name,
456 f[i]['_type'],
457 ns,
458 i,
459 f[i]['_arg_name']
460 )
461 )
462 return coll
463
465 print('''
466 class PyTrampoline_%s : public %s {
467 public:
468 using %s::%s;'''%(ns, ns, ns, ns)
469 )
470 coll = []
471 coll = creat_trampoline_fun(ns, ns, dd, coll)
472 ns1 = ns
473 while dd[ns1]['_superclass'] != '':
474 ns1 = dd[ns1]['_superclass']
475 coll = creat_trampoline_fun(ns1, ns, dd, coll)
476 print(' };\n')
477
478# def creat_public_members(ns, ns0, dd, coll):
479# # for f in dd[ns]['_var']:
480# # for i in f:
481# # if f[i]['_attr'] == 'protected':
482# # print(' using %s::%s;'%(ns, i))
483# for f in dd[ns]['_bind']:
484# for i in f:
485# if f[i]['_attr'] == 'protected':
486# print(' using %s::%s;'%(ns, i))
487# print(' using %s::%s_eigen_container;'%(ns, i))
488# def creat_publicist(ns, dd):
489# print('''
490# class PyPublicist_%s : public %s {
491# public:'''%(ns, ns)
492# )
493# coll = []
494# coll = creat_public_members(ns, ns, dd, coll)
495# # ns1 = ns
496# # while dd[ns1]['_superclass'] != '':
497# # ns1 = dd[ns1]['_superclass']
498# # coll = creat_trampoline_fun(ns1, ns, dd, coll)
499# print(' };\n')
500
502 for f in dd[ns]['_init']:
503 for i in f:
504 print('\n .def(py::init<%s>())'%f[i]['_arg_type'], end='')
505
506def creat_class_var(ns, dd):
507 for f in dd[ns]['_var']:
508 for i in f:
509 if f[i]['_attr'] == 'public':
510 if f[i]['_type'] in ['int', 'double', 'kids_real', 'kids_complex',
511 'std::string', 'bool'] and i[0]!='*':
512 print('\n .def_readwrite("%s", &%s::%s)'%(
513 i, ns, i
514 ), end='')
515
517 for f in dd[ns]['_bind']:
518 for i in f:
519 print('\n .def("ref_%s", &%s::ref_%s, py::return_value_policy::reference_internal)'%(
520 i, ns, i
521 ), end='')
522
523def creat_class_fun(ns, ns0, dd):
524 for f in dd[ns]['_fun']:
525 for i in f:
526 if f[i]['_type'] not in type_list or 'kids_complex*' in f[i]['_raw']:
527 continue
528
529 if i == 'name' and ns != ns0:
530 continue
531 if 'static' in f[i]['_raw']:
532 print('\n .def_static("%s", &%s::%s)'%
533 (
534 i, ns, i
535 ), end='')
536 elif '*' not in f[i]['_arg_type']:
537 print('\n .def("%s", static_cast<%s (%s::*)(%s)>(&%s::%s))'%
538 (
539 i, f[i]['_type'], ns, f[i]['_arg_type'], ns, i
540 ), end='')
541 else:
542 arg_type, arg_name = parse_argument(f[i]['_raw'])
543 arg_name1 = []
544 arg_name2 = []
545 pairs = []
546 for k in range(len(arg_type)):
547 if arg_type[k][-1] == '*':
548 arg_type[k] = 'py::array_t<%s, py::array::c_style | py::array::forcecast>'%arg_type[k][:-1]
549 arg_name1 += [arg_name[k]+'_arr']
550 arg_name2 += [arg_name[k]+'_arr.mutable_data()']
551 else:
552 arg_name1 += [arg_name[k]]
553 arg_name2 += [arg_name[k]]
554 pairs += [arg_type[k] + ' ' + arg_name1[k]]
555 print('\n .def("%s", [](%s& self, %s) {\n return self.%s(%s); \n }\n )'%
556 (
557 i, ns, ', \n '.join(pairs), i, ', '.join(arg_name2)
558 ), end='')
559
560def creat_class_(ns, dd, field):
561 tplname = ns
562 if dd[ns]['_superclass'] != '':
563 tplname = '%s, %s'%(ns, dd[ns]['_superclass'])
564 tplname += ', PyTrampoline_%s'%(ns)
565 # clsname = ns.lower()
566
567 print(' py::class_<%s>(%s, "%s", py::dynamic_attr())'
568 %(tplname, field, ns), end='')
569 creat_class_init(ns, dd)
570 creat_class_var(ns, dd)
571 creat_class_bind(ns, dd)
572 creat_class_fun(ns, ns, dd)
573
574 ns1 = ns
575 while dd[ns1]['_superclass'] != '':
576 ns1 = dd[ns1]['_superclass']
577 creat_class_fun(ns1, ns, dd)
578
579 print(';', end='\n')
580
581print('''
582#include <pybind11/embed.h>
583#include <pybind11/numpy.h>
584#include <pybind11/eigen.h>
585#include <pybind11/complex.h>
586#include <pybind11/pybind11.h>
587#include <pybind11/stl.h>
588''')
589for i in infos['incl']:
590 print('#include "%s"'%os.path.relpath(i, root+'/python'))
591# exit(-1)
592
593print('''
594
595#include "../utils/definitions.h"
596
597namespace py = pybind11;
598
599// clang-format off
600PYBIND11_MODULE(libopendf, m) {
601
602 #include "opendf_phys.bind"
603
604 py::module models_m = m.def_submodule("models");
605 py::module solvers_m = m.def_submodule("solvers");
606
607''')
608
609dd = infos['']
610created_class = []
611
613 global created_class
614 if (ns not in created_class
615
616 and 'SCF' not in ns
617 and 'Atomic_BasisSet' not in ns
618 and 'PPIMD' not in ns
619 ):
620 if dd[ns]['_superclass'] != '':
621 try_creat_class(dd[ns]['_superclass']) # father must previously created!!
622 # get_fathers(dd, ns)
623 creat_trampoline(ns, dd)
624 if 'Solver' in ns:
625 creat_class_(ns, dd, 'solvers_m')
626 else:
627 creat_class_(ns, dd, 'models_m')
628
629 created_class += [ns]
630
631for ns in dd.keys():
632 if '_type' in dd[ns] and dd[ns]['_type'] == 'class':
634
635# for i in infos['deps']:
636# if i not in infos['incl']:
637# print(i)
638print('''}
639// clang-format off
640''')
641
creat_trampoline_fun(ns, ns0, dd, coll)
sweep_scope(lines, istart)
parse_scope(info, kstring, lines, init_attr)
creat_class_(ns, dd, field)
creat_class_fun(ns, ns0, dd)