1
1
import numpy
2
2
import jax
3
- import jax ._src .lax_reference as lax_reference
4
3
from jax import core
5
4
from jax ._src .util import safe_map
6
5
import flax
7
6
from neurallogic import symbolic_primitives
8
7
from plum import dispatch
9
8
import typing
10
- from typing import (Any , Mapping )
11
-
12
- # TODO: rename this file to symbolic.py
9
+ from typing import Any , Mapping
13
10
14
11
15
12
def symbolic_bind (prim , * args , ** params ):
16
13
# print("\nprimitive: ", prim.name)
17
14
# print("args: ", args)
18
15
# print("params: ", params)
19
16
symbolic_outvals = {
20
- 'broadcast_in_dim' : symbolic_primitives .symbolic_broadcast_in_dim ,
21
- 'reshape' : lax_reference .reshape ,
22
- 'convert_element_type' : symbolic_primitives .symbolic_convert_element_type ,
23
- 'and' : symbolic_primitives .symbolic_and ,
24
- 'or' : symbolic_primitives .symbolic_or ,
25
- 'xor' : symbolic_primitives .symbolic_xor ,
26
- 'not' : symbolic_primitives .symbolic_not ,
27
- 'ne' : symbolic_primitives .symbolic_ne ,
28
- 'gt' : symbolic_primitives .symbolic_gt ,
29
- 'reduce_and' : symbolic_primitives .symbolic_reduce_and ,
30
- 'reduce_or' : symbolic_primitives .symbolic_reduce_or ,
31
- 'reduce_sum' : symbolic_primitives .symbolic_reduce_sum ,
17
+ "broadcast_in_dim" : symbolic_primitives .symbolic_broadcast_in_dim ,
18
+ "reshape" : symbolic_primitives .symbolic_reshape ,
19
+ "transpose" : symbolic_primitives .symbolic_transpose ,
20
+ "convert_element_type" : symbolic_primitives .symbolic_convert_element_type ,
21
+ "eq" : symbolic_primitives .symbolic_eq ,
22
+ "ne" : symbolic_primitives .symbolic_ne ,
23
+ "le" : symbolic_primitives .symbolic_le ,
24
+ "lt" : symbolic_primitives .symbolic_lt ,
25
+ "gt" : symbolic_primitives .symbolic_gt ,
26
+ "add" : symbolic_primitives .symbolic_add ,
27
+ "sub" : symbolic_primitives .symbolic_sub ,
28
+ "mul" : symbolic_primitives .symbolic_mul ,
29
+ "div" : symbolic_primitives .symbolic_div ,
30
+ "max" : symbolic_primitives .symbolic_max ,
31
+ "min" : symbolic_primitives .symbolic_min ,
32
+ "and" : symbolic_primitives .symbolic_and ,
33
+ "or" : symbolic_primitives .symbolic_or ,
34
+ "xor" : symbolic_primitives .symbolic_xor ,
35
+ "not" : symbolic_primitives .symbolic_not ,
36
+ "reduce_and" : symbolic_primitives .symbolic_reduce_and ,
37
+ "reduce_or" : symbolic_primitives .symbolic_reduce_or ,
38
+ "reduce_sum" : symbolic_primitives .symbolic_reduce_sum ,
39
+ "select_n" : symbolic_primitives .symbolic_select_n ,
32
40
}[prim .name ](* args , ** params )
33
41
return symbolic_outvals
34
42
43
+
35
44
def scope_put_variable (self , col : str , name : str , value : Any ):
36
45
variables = self ._collection (col )
37
46
38
47
def put (target , key , val ):
39
- if (key in target and isinstance (target [key ], dict ) and
40
- isinstance (val , Mapping )):
48
+ if key in target and isinstance (target [key ], dict ) and isinstance (val , Mapping ):
41
49
for k , v in val .items ():
42
50
put (target [key ], k , v )
43
51
else :
@@ -50,11 +58,16 @@ def put_variable(self, col: str, name: str, value: Any):
50
58
self .scope ._variables = self .scope .variables ().unfreeze ()
51
59
scope_put_variable (self .scope , col , name , value )
52
60
61
+
53
62
def make_symbolic_flax_jaxpr (flax_layer , x ):
54
63
actual_weights = flax_layer .get_variable ("params" , "weights" )
55
64
# Convert actual weights to dummy numeric weights (if needed)
56
- if isinstance (actual_weights , list ) or (isinstance (actual_weights , numpy .ndarray ) and actual_weights .dtype == object ):
57
- numeric_weights = symbolic_primitives .map_at_elements (actual_weights , lambda x : 0 )
65
+ if isinstance (actual_weights , list ) or (
66
+ isinstance (actual_weights , numpy .ndarray ) and actual_weights .dtype == object
67
+ ):
68
+ numeric_weights = symbolic_primitives .map_at_elements (
69
+ actual_weights , lambda x : 0
70
+ )
58
71
numeric_weights = numpy .asarray (numeric_weights , dtype = numpy .int32 )
59
72
put_variable (flax_layer , "params" , "weights" , numeric_weights )
60
73
# Convert input to dummy numeric input (if needed)
@@ -130,33 +143,40 @@ def eval_jaxpr_impl(jaxpr):
130
143
symbolic_invals = safe_map (symbolic_read , eqn .invars )
131
144
prim = eqn .primitive
132
145
if type (prim ) is jax .core .CallPrimitive :
133
- call_jaxpr = eqn .params [' call_jaxpr' ]
146
+ call_jaxpr = eqn .params [" call_jaxpr" ]
134
147
if not symbolic :
135
148
safe_map (write , call_jaxpr .invars , map (read , eqn .invars ))
136
149
try :
137
- safe_map (symbolic_write , call_jaxpr .invars ,
138
- map (symbolic_read , eqn .invars ))
150
+ safe_map (
151
+ symbolic_write ,
152
+ call_jaxpr .invars ,
153
+ map (symbolic_read , eqn .invars ),
154
+ )
139
155
except :
140
156
pass
141
157
eval_jaxpr_impl (call_jaxpr )
142
158
if not symbolic :
143
159
safe_map (write , eqn .outvars , map (read , call_jaxpr .outvars ))
144
- safe_map (symbolic_write , eqn .outvars , map (
145
- symbolic_read , call_jaxpr .outvars ))
160
+ safe_map (
161
+ symbolic_write , eqn .outvars , map (symbolic_read , call_jaxpr .outvars )
162
+ )
146
163
else :
147
164
if not symbolic :
148
165
outvals = prim .bind (* invals , ** eqn .params )
149
- symbolic_outvals = symbolic_bind (
150
- prim , * symbolic_invals , ** eqn .params )
166
+ symbolic_outvals = symbolic_bind (prim , * symbolic_invals , ** eqn .params )
151
167
# Primitives may return multiple outputs or not
152
168
if not prim .multiple_results :
153
169
if not symbolic :
154
170
outvals = [outvals ]
155
171
symbolic_outvals = [symbolic_outvals ]
156
172
if not symbolic :
157
173
# Check that the concrete and symbolic values are equal
158
- assert numpy .array_equal (
159
- numpy .array (outvals ), symbolic_outvals )
174
+ #print(
175
+ # f"outvals: {outvals} and symbolic_outvals: {symbolic_outvals}"
176
+ #)
177
+ assert numpy .allclose (
178
+ numpy .array (outvals ), symbolic_outvals , equal_nan = True
179
+ )
160
180
# Write the results of the primitive into the environment
161
181
if not symbolic :
162
182
safe_map (write , eqn .outvars , outvals )
@@ -169,24 +189,36 @@ def eval_jaxpr_impl(jaxpr):
169
189
else :
170
190
return safe_map (symbolic_read , jaxpr .outvars )[0 ]
171
191
192
+
172
193
# TODO: parameterise these functions by the element conversion function
173
194
174
195
# TODO: use union types to consolidate these functions
175
196
@dispatch
176
197
def make_symbolic (x : dict ):
177
- return symbolic_primitives .map_at_elements (x , symbolic_primitives .to_boolean_value_string )
198
+ return symbolic_primitives .map_at_elements (
199
+ x , symbolic_primitives .to_boolean_value_string
200
+ )
201
+
178
202
179
203
@dispatch
180
204
def make_symbolic (x : list ):
181
- return symbolic_primitives .map_at_elements (x , symbolic_primitives .to_boolean_value_string )
205
+ return symbolic_primitives .map_at_elements (
206
+ x , symbolic_primitives .to_boolean_value_string
207
+ )
208
+
182
209
183
210
@dispatch
184
211
def make_symbolic (x : numpy .ndarray ):
185
- return symbolic_primitives .map_at_elements (x , symbolic_primitives .to_boolean_value_string )
212
+ return symbolic_primitives .map_at_elements (
213
+ x , symbolic_primitives .to_boolean_value_string
214
+ )
215
+
186
216
187
217
@dispatch
188
218
def make_symbolic (x : jax .numpy .ndarray ):
189
- return symbolic_primitives .map_at_elements (convert_jax_to_numpy_arrays (x ), symbolic_primitives .to_boolean_value_string )
219
+ return symbolic_primitives .map_at_elements (
220
+ convert_jax_to_numpy_arrays (x ), symbolic_primitives .to_boolean_value_string
221
+ )
190
222
191
223
192
224
@dispatch
@@ -214,25 +246,28 @@ def make_symbolic(x: flax.core.FrozenDict):
214
246
x = convert_jax_to_numpy_arrays (x .unfreeze ())
215
247
return flax .core .FrozenDict (make_symbolic (x ))
216
248
249
+
217
250
@dispatch
218
251
def make_symbolic (* args ):
219
252
return tuple ([make_symbolic (arg ) for arg in args ])
220
253
254
+
221
255
@dispatch
222
256
def make_symbolic_jaxpr (func : typing .Callable , * args ):
223
257
return jax .make_jaxpr (lambda * args : func (* args ))(* args )
224
258
225
259
226
260
def eval_symbolic (symbolic_function , * args ):
227
- if hasattr (symbolic_function , 'literals' ):
228
- return eval_jaxpr (False , symbolic_function .jaxpr , symbolic_function .literals , * args )
261
+ if hasattr (symbolic_function , "literals" ):
262
+ return eval_jaxpr (
263
+ False , symbolic_function .jaxpr , symbolic_function .literals , * args
264
+ )
229
265
return eval_jaxpr (False , symbolic_function .jaxpr , [], * args )
230
266
231
267
232
268
def symbolic_expression (jaxpr , * args ):
233
- if hasattr (jaxpr , 'literals' ):
234
- sym_expr = eval_jaxpr (True , jaxpr .jaxpr ,
235
- jaxpr .literals , * args )
269
+ if hasattr (jaxpr , "literals" ):
270
+ sym_expr = eval_jaxpr (True , jaxpr .jaxpr , jaxpr .literals , * args )
236
271
else :
237
272
sym_expr = eval_jaxpr (True , jaxpr .jaxpr , [], * args )
238
273
return sym_expr
0 commit comments