Skip to content

Commit 07a7d12

Browse files
committed
wip
1 parent 016afd0 commit 07a7d12

File tree

4 files changed

+300
-141
lines changed

4 files changed

+300
-141
lines changed

neurallogic/symbolic_generation.py

Lines changed: 75 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,51 @@
1-
import numpy
1+
import sys
2+
import typing
3+
from typing import Any, Mapping
4+
5+
import flax
26
import jax
7+
import numpy
38
from jax import core
49
from jax._src.util import safe_map
5-
import flax
6-
from neurallogic import symbolic_primitives
710
from plum import dispatch
8-
import typing
9-
from typing import Any, Mapping
11+
12+
from neurallogic import symbolic_primitives
13+
14+
# Imports required for evaluating symbolic expressions with eval()
15+
import jax._src.lax_reference as lax_reference
1016

1117

1218
def symbolic_bind(prim, *args, **params):
13-
# print("\nprimitive: ", prim.name)
14-
# print("args: ", args)
15-
# print("params: ", params)
19+
# print('\nprimitive: ', prim.name)
20+
# print('\targs:\n\t\t', args)
21+
# print('\tparams\n\t\t: ', params)
1622
symbolic_outvals = {
17-
"broadcast_in_dim": symbolic_primitives.symbolic_broadcast_in_dim,
18-
"reshape": symbolic_primitives.symbolic_reshape,
19-
"transpose": symbolic_primitives.symbolic_transpose,
20-
"convert_element_type": symbolic_primitives.symbolic_convert_element_type,
21-
"eq": symbolic_primitives.symbolic_eq,
22-
"ne": symbolic_primitives.symbolic_ne,
23-
"le": symbolic_primitives.symbolic_le,
24-
"lt": symbolic_primitives.symbolic_lt,
25-
"gt": symbolic_primitives.symbolic_gt,
26-
"abs": symbolic_primitives.symbolic_abs,
27-
"add": symbolic_primitives.symbolic_add,
28-
"sub": symbolic_primitives.symbolic_sub,
29-
"mul": symbolic_primitives.symbolic_mul,
30-
"div": symbolic_primitives.symbolic_div,
31-
"max": symbolic_primitives.symbolic_max,
32-
"min": symbolic_primitives.symbolic_min,
33-
"and": symbolic_primitives.symbolic_and,
34-
"or": symbolic_primitives.symbolic_or,
35-
"xor": symbolic_primitives.symbolic_xor,
36-
"not": symbolic_primitives.symbolic_not,
37-
"reduce_and": symbolic_primitives.symbolic_reduce_and,
38-
"reduce_or": symbolic_primitives.symbolic_reduce_or,
39-
"reduce_sum": symbolic_primitives.symbolic_reduce_sum,
40-
"select_n": symbolic_primitives.symbolic_select_n,
23+
'broadcast_in_dim': symbolic_primitives.symbolic_broadcast_in_dim,
24+
'reshape': symbolic_primitives.symbolic_reshape,
25+
'transpose': symbolic_primitives.symbolic_transpose,
26+
'convert_element_type': symbolic_primitives.symbolic_convert_element_type,
27+
'eq': symbolic_primitives.symbolic_eq,
28+
'ne': symbolic_primitives.symbolic_ne,
29+
'le': symbolic_primitives.symbolic_le,
30+
'lt': symbolic_primitives.symbolic_lt,
31+
'gt': symbolic_primitives.symbolic_gt,
32+
'abs': symbolic_primitives.symbolic_abs,
33+
'add': symbolic_primitives.symbolic_add,
34+
'sub': symbolic_primitives.symbolic_sub,
35+
'mul': symbolic_primitives.symbolic_mul,
36+
'div': symbolic_primitives.symbolic_div,
37+
'max': symbolic_primitives.symbolic_max,
38+
'min': symbolic_primitives.symbolic_min,
39+
'and': symbolic_primitives.symbolic_and,
40+
'or': symbolic_primitives.symbolic_or,
41+
'xor': symbolic_primitives.symbolic_xor,
42+
'not': symbolic_primitives.symbolic_not,
43+
'reduce_and': symbolic_primitives.symbolic_reduce_and,
44+
'reduce_or': symbolic_primitives.symbolic_reduce_or,
45+
'reduce_sum': symbolic_primitives.symbolic_reduce_sum,
46+
'select_n': symbolic_primitives.symbolic_select_n,
4147
}[prim.name](*args, **params)
48+
# print('\tresult:\n\t\t', symbolic_outvals)
4249
return symbolic_outvals
4350

4451

@@ -64,7 +71,7 @@ def put_variable(self, col: str, name: str, value: Any):
6471

6572

6673
def convert_to_numeric_params(flax_layer, param_names: str):
67-
actual_weights = flax_layer.get_variable("params", param_names)
74+
actual_weights = flax_layer.get_variable('params', param_names)
6875
# Convert actual weights to dummy numeric weights (if needed)
6976
if isinstance(actual_weights, list) or (
7077
isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object
@@ -73,13 +80,13 @@ def convert_to_numeric_params(flax_layer, param_names: str):
7380
actual_weights, lambda x: 0
7481
)
7582
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32)
76-
put_variable(flax_layer, "params", param_names, numeric_weights)
83+
put_variable(flax_layer, 'params', param_names, numeric_weights)
7784
return flax_layer, actual_weights
7885

7986

8087
def make_symbolic_flax_jaxpr(flax_layer, x):
81-
flax_layer, bit_weights = convert_to_numeric_params(flax_layer, "bit_weights")
82-
flax_layer, thresholds = convert_to_numeric_params(flax_layer, "thresholds")
88+
flax_layer, bit_weights = convert_to_numeric_params(flax_layer, 'bit_weights')
89+
flax_layer, thresholds = convert_to_numeric_params(flax_layer, 'thresholds')
8390
# Convert input to dummy numeric input (if needed)
8491
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
8592
x = symbolic_primitives.map_at_elements(x, lambda x: 0)
@@ -94,7 +101,7 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
94101

95102

96103
def eval_jaxpr(symbolic, jaxpr, consts, *args):
97-
"""Evaluates a jaxpr by interpreting it as Python code.
104+
'''Evaluates a jaxpr by interpreting it as Python code.
98105
99106
Parameters
100107
----------
@@ -113,7 +120,9 @@ def eval_jaxpr(symbolic, jaxpr, consts, *args):
113120
-------
114121
out : tuple
115122
The result of evaluating the jaxpr.
116-
"""
123+
'''
124+
if symbolic:
125+
numpy.set_printoptions(threshold=sys.maxsize)
117126

118127
# Mapping from variable -> value
119128
env = {}
@@ -155,7 +164,7 @@ def eval_jaxpr_impl(jaxpr):
155164
symbolic_invals = safe_map(symbolic_read, eqn.invars)
156165
prim = eqn.primitive
157166
if type(prim) is jax.core.CallPrimitive:
158-
call_jaxpr = eqn.params["call_jaxpr"]
167+
call_jaxpr = eqn.params['call_jaxpr']
159168
if not symbolic:
160169
safe_map(write, call_jaxpr.invars, map(read, eqn.invars))
161170
try:
@@ -184,7 +193,7 @@ def eval_jaxpr_impl(jaxpr):
184193
if not symbolic:
185194
# Check that the concrete and symbolic values are equal
186195
# print(
187-
# f"outvals: {outvals} and symbolic_outvals: {symbolic_outvals}"
196+
# f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}'
188197
# )
189198
assert numpy.allclose(
190199
numpy.array(outvals), symbolic_outvals, equal_nan=True
@@ -202,45 +211,46 @@ def eval_jaxpr_impl(jaxpr):
202211
return safe_map(symbolic_read, jaxpr.outvars)[0]
203212

204213

205-
# TODO: parameterise these functions by the element conversion function
214+
def to_string(x):
215+
return str(x)
206216

207217
# TODO: use union types to consolidate these functions
208218
@dispatch
209219
def make_symbolic(x: dict):
210220
return symbolic_primitives.map_at_elements(
211-
x, symbolic_primitives.to_boolean_value_string
221+
x, to_string
212222
)
213223

214224

215225
@dispatch
216226
def make_symbolic(x: list):
217227
return symbolic_primitives.map_at_elements(
218-
x, symbolic_primitives.to_boolean_value_string
228+
x, to_string
219229
)
220230

221231

222232
@dispatch
223233
def make_symbolic(x: numpy.ndarray):
224234
return symbolic_primitives.map_at_elements(
225-
x, symbolic_primitives.to_boolean_value_string
235+
x, to_string
226236
)
227237

228238

229239
@dispatch
230240
def make_symbolic(x: jax.numpy.ndarray):
231241
return symbolic_primitives.map_at_elements(
232-
convert_jax_to_numpy_arrays(x), symbolic_primitives.to_boolean_value_string
242+
convert_jax_to_numpy_arrays(x), to_string
233243
)
234244

235245

236246
@dispatch
237247
def make_symbolic(x: bool):
238-
return symbolic_primitives.to_boolean_value_string(x)
248+
return to_string(x)
239249

240250

241251
@dispatch
242252
def make_symbolic(x: str):
243-
return symbolic_primitives.to_boolean_value_string(x)
253+
return to_string(x)
244254

245255

246256
@dispatch
@@ -270,15 +280,15 @@ def make_symbolic_jaxpr(func: typing.Callable, *args):
270280

271281

272282
def eval_symbolic(symbolic_function, *args):
273-
if hasattr(symbolic_function, "literals"):
283+
if hasattr(symbolic_function, 'literals'):
274284
return eval_jaxpr(
275285
False, symbolic_function.jaxpr, symbolic_function.literals, *args
276286
)
277287
return eval_jaxpr(False, symbolic_function.jaxpr, [], *args)
278288

279289

280290
def symbolic_expression(jaxpr, *args):
281-
if hasattr(jaxpr, "literals"):
291+
if hasattr(jaxpr, 'literals'):
282292
sym_expr = eval_jaxpr(True, jaxpr.jaxpr, jaxpr.literals, *args)
283293
else:
284294
sym_expr = eval_jaxpr(True, jaxpr.jaxpr, [], *args)
@@ -287,14 +297,23 @@ def symbolic_expression(jaxpr, *args):
287297

288298
@dispatch
289299
def eval_symbolic_expression(x: str):
290-
return eval(x)
300+
# Setting up python evaluation context
301+
# TODO: distinguish python code-gen from other possible code-gen
302+
#eval_str = 'import numpy\n'
303+
#eval_str = 'import jax._src.lax_reference as lax_reference\n'
304+
eval_str = x.replace('inf', 'numpy.inf')
305+
#print(f'evaluating\n{eval_str}')
306+
#return exec(eval_str, globals(), locals())
307+
#print(f'attempting to evaluate\n{eval_str}')
308+
return eval(eval_str)
291309

292310

293-
@dispatch
294-
def eval_symbolic_expression(x: numpy.ndarray):
295-
return numpy.vectorize(eval)(x)
311+
#@dispatch
312+
#def eval_symbolic_expression(x: numpy.ndarray):
313+
# return numpy.vectorize(eval)(x)
296314

297315

298-
@dispatch
299-
def eval_symbolic_expression(x: list):
300-
return numpy.vectorize(eval)(x)
316+
#@dispatch
317+
#def eval_symbolic_expression(x: list):
318+
# return numpy.vectorize(eval)(x)
319+

0 commit comments

Comments
 (0)