Skip to content

Commit a2624d0

Browse files
committed
all tests passing locally
1 parent a76529f commit a2624d0

File tree

5 files changed

+78
-73
lines changed

5 files changed

+78
-73
lines changed

neurallogic/symbolic_generation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def eval_jaxpr(symbolic, jaxpr, consts, *args):
121121
out : tuple
122122
The result of evaluating the jaxpr.
123123
'''
124-
if symbolic:
125-
numpy.set_printoptions(threshold=sys.maxsize)
124+
#if symbolic:
125+
# numpy.set_printoptions(threshold=sys.maxsize)
126126

127127
# Mapping from variable -> value
128128
env = {}
@@ -313,7 +313,7 @@ def eval_symbolic_expression(x: numpy.ndarray):
313313
return numpy.vectorize(eval_symbolic_expression)(x)
314314

315315

316-
#@dispatch
317-
#def eval_symbolic_expression(x: list):
318-
# return numpy.vectorize(eval)(x)
316+
@dispatch
317+
def eval_symbolic_expression(x: list):
318+
return numpy.vectorize(eval)(x)
319319

neurallogic/symbolic_primitives.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# TODO: remove me?
1010
def convert_element_type(x, dtype):
11+
return x
1112
if dtype == numpy.int32 or dtype == numpy.int64:
1213
#dtype = 'int'
1314
return x
@@ -199,6 +200,7 @@ def symbolic_operator(operator: str, x: list):
199200

200201

201202
# TODO: remove infix_operator?
203+
"""
202204
@dispatch
203205
def symbolic_infix_operator(operator: str, a: str, b: str) -> str:
204206
return f'{a} {operator} {b}'.replace('\'', '')
@@ -248,7 +250,7 @@ def symbolic_infix_operator(operator: str, a: list, b: list):
248250
@dispatch
249251
def symbolic_infix_operator(operator: str, a: bool, b: str):
250252
return symbolic_infix_operator(operator, str(a), b)
251-
253+
"""
252254

253255
def all_concrete_values(data):
254256
if isinstance(data, str):
@@ -275,39 +277,34 @@ def symbolic_eq(*args, **kwargs):
275277
if all_concrete_values([*args]):
276278
return lax_reference.eq(*args, **kwargs)
277279
else:
278-
#return '(' + symbolic_infix_operator('==', *args, **kwargs) + ')'
279280
return symbolic_operator('lax_reference.eq', *args, **kwargs)
280281

281282

282283
def symbolic_ne(*args, **kwargs):
283284
if all_concrete_values([*args]):
284285
return lax_reference.ne(*args, **kwargs)
285286
else:
286-
#return '(' + symbolic_infix_operator('!=', *args, **kwargs) + ')'
287287
return symbolic_operator('lax_reference.ne', *args, **kwargs)
288288

289289

290290
def symbolic_le(*args, **kwargs):
291291
if all_concrete_values([*args]):
292292
return lax_reference.le(*args, **kwargs)
293293
else:
294-
#return '(' + symbolic_infix_operator('<=', *args, **kwargs) + ')'
295294
return symbolic_operator('lax_reference.le', *args, **kwargs)
296295

297296

298297
def symbolic_lt(*args, **kwargs):
299298
if all_concrete_values([*args]):
300299
return lax_reference.lt(*args, **kwargs)
301300
else:
302-
#return '(' + symbolic_infix_operator('<', *args, **kwargs) + ')'
303301
return symbolic_operator('lax_reference.lt', *args, **kwargs)
304302

305303

306304
def symbolic_gt(*args, **kwargs):
307305
if all_concrete_values([*args]):
308306
return lax_reference.gt(*args, **kwargs)
309307
else:
310-
#return symbolic_infix_operator('>', *args, **kwargs)
311308
return symbolic_operator('lax_reference.gt', *args, **kwargs)
312309

313310

@@ -393,7 +390,6 @@ def symbolic_and(*args, **kwargs):
393390
if all_concrete_values([*args]):
394391
return numpy.logical_and(*args, **kwargs)
395392
else:
396-
#return symbolic_infix_operator('and', *args, **kwargs)
397393
return symbolic_operator('numpy.logical_and', *args, **kwargs)
398394

399395

@@ -402,33 +398,27 @@ def symbolic_or(*args, **kwargs):
402398
if all_concrete_values([*args]):
403399
return numpy.logical_or(*args, **kwargs)
404400
else:
405-
#return '(' + symbolic_infix_operator('or', *args, **kwargs) + ')'
406401
return symbolic_operator('numpy.logical_or', *args, **kwargs)
407402

408403

409404
def symbolic_xor(*args, **kwargs):
410405
if all_concrete_values([*args]):
411406
return numpy.logical_xor(*args, **kwargs)
412407
else:
413-
#return symbolic_infix_operator('^', *args, **kwargs)
414408
return symbolic_operator('numpy.logical_xor', *args, **kwargs)
415409

416410

417411
def symbolic_sum(*args, **kwargs):
418412
if all_concrete_values([*args]):
419413
return lax_reference.sum(*args, **kwargs)
420414
else:
421-
#return symbolic_infix_operator('+', *args, **kwargs)
422415
return symbolic_operator('lax_reference.sum', *args, **kwargs)
423416

424417

425418
def symbolic_broadcast_in_dim(*args, **kwargs):
426-
assert len(args) == 1
427-
arg = args[0]
428-
if isinstance(args, (list, tuple)):
429-
# reference implementation demands a numpy array
430-
arg = numpy.array(arg)
431-
return lax_reference.broadcast_in_dim(arg, **kwargs)
419+
# broadcast_in_dim requires numpy arrays not lists
420+
args = tuple([numpy.array(arg) if isinstance(arg, list) else arg for arg in args])
421+
return lax_reference.broadcast_in_dim(*args, **kwargs)
432422

433423

434424
def symbolic_reshape(*args, **kwargs):

0 commit comments

Comments
 (0)