Skip to content

Commit 653afa8

Browse files
committed
add file
1 parent b2066db commit 653afa8

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

neurallogic/symbolic_generation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ def symbolic_bind(prim, *args, **params):
2828
'lt': symbolic_primitives.symbolic_lt,
2929
'ge': symbolic_primitives.symbolic_ge,
3030
'gt': symbolic_primitives.symbolic_gt,
31-
'abs': symbolic_primitives.symbolic_abs,
3231
'add': symbolic_primitives.symbolic_add,
3332
'sub': symbolic_primitives.symbolic_sub,
3433
'mul': symbolic_primitives.symbolic_mul,
3534
'div': symbolic_primitives.symbolic_div,
35+
'tan': symbolic_primitives.symbolic_tan,
3636
'max': symbolic_primitives.symbolic_max,
3737
'min': symbolic_primitives.symbolic_min,
38+
'abs': symbolic_primitives.symbolic_abs,
39+
'round': symbolic_primitives.symbolic_round,
40+
'floor': symbolic_primitives.symbolic_floor,
41+
'ceil': symbolic_primitives.symbolic_ceil,
3842
'and': symbolic_primitives.symbolic_and,
3943
'or': symbolic_primitives.symbolic_or,
4044
'xor': symbolic_primitives.symbolic_xor,
@@ -190,10 +194,9 @@ def eval_jaxpr_impl(jaxpr):
190194
outvals = [outvals]
191195
symbolic_outvals = [symbolic_outvals]
192196
if not symbolic:
193-
# Check that the concrete and symbolic values are equal
194-
# print(
195-
# f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}'
196-
# )
197+
# Always check that the symbolic binding generates the same values as the
198+
# standard jax binding in order to detect bugs early.
199+
# print(f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}')
197200
assert numpy.allclose(
198201
numpy.array(outvals), symbolic_outvals, equal_nan=True
199202
)

0 commit comments

Comments
 (0)