@@ -28,13 +28,17 @@ def symbolic_bind(prim, *args, **params):
28
28
'lt' : symbolic_primitives .symbolic_lt ,
29
29
'ge' : symbolic_primitives .symbolic_ge ,
30
30
'gt' : symbolic_primitives .symbolic_gt ,
31
- 'abs' : symbolic_primitives .symbolic_abs ,
32
31
'add' : symbolic_primitives .symbolic_add ,
33
32
'sub' : symbolic_primitives .symbolic_sub ,
34
33
'mul' : symbolic_primitives .symbolic_mul ,
35
34
'div' : symbolic_primitives .symbolic_div ,
35
+ 'tan' : symbolic_primitives .symbolic_tan ,
36
36
'max' : symbolic_primitives .symbolic_max ,
37
37
'min' : symbolic_primitives .symbolic_min ,
38
+ 'abs' : symbolic_primitives .symbolic_abs ,
39
+ 'round' : symbolic_primitives .symbolic_round ,
40
+ 'floor' : symbolic_primitives .symbolic_floor ,
41
+ 'ceil' : symbolic_primitives .symbolic_ceil ,
38
42
'and' : symbolic_primitives .symbolic_and ,
39
43
'or' : symbolic_primitives .symbolic_or ,
40
44
'xor' : symbolic_primitives .symbolic_xor ,
@@ -190,10 +194,9 @@ def eval_jaxpr_impl(jaxpr):
190
194
outvals = [outvals ]
191
195
symbolic_outvals = [symbolic_outvals ]
192
196
if not symbolic :
193
- # Check that the concrete and symbolic values are equal
194
- # print(
195
- # f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}'
196
- # )
197
+ # Always check that the symbolic binding generates the same values as the
198
+ # standard jax binding in order to detect bugs early.
199
+ # print(f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}')
197
200
assert numpy .allclose (
198
201
numpy .array (outvals ), symbolic_outvals , equal_nan = True
199
202
)
0 commit comments