1
- import numpy
1
+ import sys
2
+ import typing
3
+ from typing import Any , Mapping
4
+
5
+ import flax
2
6
import jax
7
+ import numpy
3
8
from jax import core
4
9
from jax ._src .util import safe_map
5
- import flax
6
- from neurallogic import symbolic_primitives
7
10
from plum import dispatch
8
- import typing
9
- from typing import Any , Mapping
11
+
12
+ from neurallogic import symbolic_primitives
13
+
14
+ # Imports required for evaluating symbolic expressions with eval()
15
+ import jax ._src .lax_reference as lax_reference
10
16
11
17
12
18
def symbolic_bind (prim , * args , ** params ):
13
- # print(" \nprimitive: " , prim.name)
14
- # print("args: " , args)
15
- # print("params: " , params)
19
+ # print(' \nprimitive: ' , prim.name)
20
+ # print('\targs:\n\t\t' , args)
21
+ # print('\tparams\n\t\t: ' , params)
16
22
symbolic_outvals = {
17
- " broadcast_in_dim" : symbolic_primitives .symbolic_broadcast_in_dim ,
18
- " reshape" : symbolic_primitives .symbolic_reshape ,
19
- " transpose" : symbolic_primitives .symbolic_transpose ,
20
- " convert_element_type" : symbolic_primitives .symbolic_convert_element_type ,
21
- "eq" : symbolic_primitives .symbolic_eq ,
22
- "ne" : symbolic_primitives .symbolic_ne ,
23
- "le" : symbolic_primitives .symbolic_le ,
24
- "lt" : symbolic_primitives .symbolic_lt ,
25
- "gt" : symbolic_primitives .symbolic_gt ,
26
- " abs" : symbolic_primitives .symbolic_abs ,
27
- " add" : symbolic_primitives .symbolic_add ,
28
- " sub" : symbolic_primitives .symbolic_sub ,
29
- " mul" : symbolic_primitives .symbolic_mul ,
30
- " div" : symbolic_primitives .symbolic_div ,
31
- " max" : symbolic_primitives .symbolic_max ,
32
- " min" : symbolic_primitives .symbolic_min ,
33
- " and" : symbolic_primitives .symbolic_and ,
34
- "or" : symbolic_primitives .symbolic_or ,
35
- " xor" : symbolic_primitives .symbolic_xor ,
36
- " not" : symbolic_primitives .symbolic_not ,
37
- " reduce_and" : symbolic_primitives .symbolic_reduce_and ,
38
- " reduce_or" : symbolic_primitives .symbolic_reduce_or ,
39
- " reduce_sum" : symbolic_primitives .symbolic_reduce_sum ,
40
- " select_n" : symbolic_primitives .symbolic_select_n ,
23
+ ' broadcast_in_dim' : symbolic_primitives .symbolic_broadcast_in_dim ,
24
+ ' reshape' : symbolic_primitives .symbolic_reshape ,
25
+ ' transpose' : symbolic_primitives .symbolic_transpose ,
26
+ ' convert_element_type' : symbolic_primitives .symbolic_convert_element_type ,
27
+ 'eq' : symbolic_primitives .symbolic_eq ,
28
+ 'ne' : symbolic_primitives .symbolic_ne ,
29
+ 'le' : symbolic_primitives .symbolic_le ,
30
+ 'lt' : symbolic_primitives .symbolic_lt ,
31
+ 'gt' : symbolic_primitives .symbolic_gt ,
32
+ ' abs' : symbolic_primitives .symbolic_abs ,
33
+ ' add' : symbolic_primitives .symbolic_add ,
34
+ ' sub' : symbolic_primitives .symbolic_sub ,
35
+ ' mul' : symbolic_primitives .symbolic_mul ,
36
+ ' div' : symbolic_primitives .symbolic_div ,
37
+ ' max' : symbolic_primitives .symbolic_max ,
38
+ ' min' : symbolic_primitives .symbolic_min ,
39
+ ' and' : symbolic_primitives .symbolic_and ,
40
+ 'or' : symbolic_primitives .symbolic_or ,
41
+ ' xor' : symbolic_primitives .symbolic_xor ,
42
+ ' not' : symbolic_primitives .symbolic_not ,
43
+ ' reduce_and' : symbolic_primitives .symbolic_reduce_and ,
44
+ ' reduce_or' : symbolic_primitives .symbolic_reduce_or ,
45
+ ' reduce_sum' : symbolic_primitives .symbolic_reduce_sum ,
46
+ ' select_n' : symbolic_primitives .symbolic_select_n ,
41
47
}[prim .name ](* args , ** params )
48
+ # print('\tresult:\n\t\t', symbolic_outvals)
42
49
return symbolic_outvals
43
50
44
51
@@ -64,7 +71,7 @@ def put_variable(self, col: str, name: str, value: Any):
64
71
65
72
66
73
def convert_to_numeric_params (flax_layer , param_names : str ):
67
- actual_weights = flax_layer .get_variable (" params" , param_names )
74
+ actual_weights = flax_layer .get_variable (' params' , param_names )
68
75
# Convert actual weights to dummy numeric weights (if needed)
69
76
if isinstance (actual_weights , list ) or (
70
77
isinstance (actual_weights , numpy .ndarray ) and actual_weights .dtype == object
@@ -73,13 +80,13 @@ def convert_to_numeric_params(flax_layer, param_names: str):
73
80
actual_weights , lambda x : 0
74
81
)
75
82
numeric_weights = numpy .asarray (numeric_weights , dtype = numpy .int32 )
76
- put_variable (flax_layer , " params" , param_names , numeric_weights )
83
+ put_variable (flax_layer , ' params' , param_names , numeric_weights )
77
84
return flax_layer , actual_weights
78
85
79
86
80
87
def make_symbolic_flax_jaxpr (flax_layer , x ):
81
- flax_layer , bit_weights = convert_to_numeric_params (flax_layer , " bit_weights" )
82
- flax_layer , thresholds = convert_to_numeric_params (flax_layer , " thresholds" )
88
+ flax_layer , bit_weights = convert_to_numeric_params (flax_layer , ' bit_weights' )
89
+ flax_layer , thresholds = convert_to_numeric_params (flax_layer , ' thresholds' )
83
90
# Convert input to dummy numeric input (if needed)
84
91
if isinstance (x , list ) or (isinstance (x , numpy .ndarray ) and x .dtype == object ):
85
92
x = symbolic_primitives .map_at_elements (x , lambda x : 0 )
@@ -94,7 +101,7 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
94
101
95
102
96
103
def eval_jaxpr (symbolic , jaxpr , consts , * args ):
97
- """ Evaluates a jaxpr by interpreting it as Python code.
104
+ ''' Evaluates a jaxpr by interpreting it as Python code.
98
105
99
106
Parameters
100
107
----------
@@ -113,7 +120,9 @@ def eval_jaxpr(symbolic, jaxpr, consts, *args):
113
120
-------
114
121
out : tuple
115
122
The result of evaluating the jaxpr.
116
- """
123
+ '''
124
+ if symbolic :
125
+ numpy .set_printoptions (threshold = sys .maxsize )
117
126
118
127
# Mapping from variable -> value
119
128
env = {}
@@ -155,7 +164,7 @@ def eval_jaxpr_impl(jaxpr):
155
164
symbolic_invals = safe_map (symbolic_read , eqn .invars )
156
165
prim = eqn .primitive
157
166
if type (prim ) is jax .core .CallPrimitive :
158
- call_jaxpr = eqn .params [" call_jaxpr" ]
167
+ call_jaxpr = eqn .params [' call_jaxpr' ]
159
168
if not symbolic :
160
169
safe_map (write , call_jaxpr .invars , map (read , eqn .invars ))
161
170
try :
@@ -184,7 +193,7 @@ def eval_jaxpr_impl(jaxpr):
184
193
if not symbolic :
185
194
# Check that the concrete and symbolic values are equal
186
195
# print(
187
- # f" outvals: {outvals} and symbolic_outvals: {symbolic_outvals}"
196
+ # f' outvals: {outvals} and symbolic_outvals: {symbolic_outvals}'
188
197
# )
189
198
assert numpy .allclose (
190
199
numpy .array (outvals ), symbolic_outvals , equal_nan = True
@@ -202,45 +211,46 @@ def eval_jaxpr_impl(jaxpr):
202
211
return safe_map (symbolic_read , jaxpr .outvars )[0 ]
203
212
204
213
205
- # TODO: parameterise these functions by the element conversion function
214
+ def to_string (x ):
215
+ return str (x )
206
216
207
217
# TODO: use union types to consolidate these functions
208
218
@dispatch
209
219
def make_symbolic (x : dict ):
210
220
return symbolic_primitives .map_at_elements (
211
- x , symbolic_primitives . to_boolean_value_string
221
+ x , to_string
212
222
)
213
223
214
224
215
225
@dispatch
216
226
def make_symbolic (x : list ):
217
227
return symbolic_primitives .map_at_elements (
218
- x , symbolic_primitives . to_boolean_value_string
228
+ x , to_string
219
229
)
220
230
221
231
222
232
@dispatch
223
233
def make_symbolic (x : numpy .ndarray ):
224
234
return symbolic_primitives .map_at_elements (
225
- x , symbolic_primitives . to_boolean_value_string
235
+ x , to_string
226
236
)
227
237
228
238
229
239
@dispatch
230
240
def make_symbolic (x : jax .numpy .ndarray ):
231
241
return symbolic_primitives .map_at_elements (
232
- convert_jax_to_numpy_arrays (x ), symbolic_primitives . to_boolean_value_string
242
+ convert_jax_to_numpy_arrays (x ), to_string
233
243
)
234
244
235
245
236
246
@dispatch
237
247
def make_symbolic (x : bool ):
238
- return symbolic_primitives . to_boolean_value_string (x )
248
+ return to_string (x )
239
249
240
250
241
251
@dispatch
242
252
def make_symbolic (x : str ):
243
- return symbolic_primitives . to_boolean_value_string (x )
253
+ return to_string (x )
244
254
245
255
246
256
@dispatch
@@ -270,15 +280,15 @@ def make_symbolic_jaxpr(func: typing.Callable, *args):
270
280
271
281
272
282
def eval_symbolic (symbolic_function , * args ):
273
- if hasattr (symbolic_function , " literals" ):
283
+ if hasattr (symbolic_function , ' literals' ):
274
284
return eval_jaxpr (
275
285
False , symbolic_function .jaxpr , symbolic_function .literals , * args
276
286
)
277
287
return eval_jaxpr (False , symbolic_function .jaxpr , [], * args )
278
288
279
289
280
290
def symbolic_expression (jaxpr , * args ):
281
- if hasattr (jaxpr , " literals" ):
291
+ if hasattr (jaxpr , ' literals' ):
282
292
sym_expr = eval_jaxpr (True , jaxpr .jaxpr , jaxpr .literals , * args )
283
293
else :
284
294
sym_expr = eval_jaxpr (True , jaxpr .jaxpr , [], * args )
@@ -287,14 +297,23 @@ def symbolic_expression(jaxpr, *args):
287
297
288
298
@dispatch
289
299
def eval_symbolic_expression (x : str ):
290
- return eval (x )
300
+ # Setting up python evaluation context
301
+ # TODO: distinguish python code-gen from other possible code-gen
302
+ #eval_str = 'import numpy\n'
303
+ #eval_str = 'import jax._src.lax_reference as lax_reference\n'
304
+ eval_str = x .replace ('inf' , 'numpy.inf' )
305
+ #print(f'evaluating\n{eval_str}')
306
+ #return exec(eval_str, globals(), locals())
307
+ #print(f'attempting to evaluate\n{eval_str}')
308
+ return eval (eval_str )
291
309
292
310
293
- @dispatch
294
- def eval_symbolic_expression (x : numpy .ndarray ):
295
- return numpy .vectorize (eval )(x )
311
+ # @dispatch
312
+ # def eval_symbolic_expression(x: numpy.ndarray):
313
+ # return numpy.vectorize(eval)(x)
296
314
297
315
298
- @dispatch
299
- def eval_symbolic_expression (x : list ):
300
- return numpy .vectorize (eval )(x )
316
+ #@dispatch
317
+ #def eval_symbolic_expression(x: list):
318
+ # return numpy.vectorize(eval)(x)
319
+
0 commit comments