Skip to content

Commit 4beee42

Browse files
committed
remove obsolete code
1 parent a2624d0 commit 4beee42

File tree

2 files changed

+29
-141
lines changed

2 files changed

+29
-141
lines changed

neurallogic/symbolic_generation.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ def eval_jaxpr(symbolic, jaxpr, consts, *args):
121121
out : tuple
122122
The result of evaluating the jaxpr.
123123
'''
124-
#if symbolic:
125-
# numpy.set_printoptions(threshold=sys.maxsize)
126124

127125
# Mapping from variable -> value
128126
env = {}
@@ -297,14 +295,8 @@ def symbolic_expression(jaxpr, *args):
297295

298296
@dispatch
299297
def eval_symbolic_expression(x: str):
300-
# Setting up python evaluation context
301298
# TODO: distinguish python code-gen from other possible code-gen
302-
#eval_str = 'import numpy\n'
303-
#eval_str = 'import jax._src.lax_reference as lax_reference\n'
304299
eval_str = x.replace('inf', 'numpy.inf')
305-
#print(f'evaluating\n{eval_str}')
306-
#return exec(eval_str, globals(), locals())
307-
#print(f'attempting to evaluate\n{eval_str}')
308300
return eval(eval_str)
309301

310302

neurallogic/symbolic_primitives.py

Lines changed: 29 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,6 @@
66
from plum import dispatch
77

88

9-
# TODO: remove me?
10-
def convert_element_type(x, dtype):
11-
return x
12-
if dtype == numpy.int32 or dtype == numpy.int64:
13-
#dtype = 'int'
14-
return x
15-
elif dtype == bool:
16-
#dtype = 'bool'
17-
#dtype = 'numpy.all'
18-
# Don't force to bool because sometimes x is a tensor that is used in a numpy.where clause
19-
return x
20-
elif dtype == numpy.float32:
21-
#dtype = 'float'
22-
return x
23-
else:
24-
raise NotImplementedError(
25-
f'Symbolic conversion of type {type(x)} to {dtype} not implemented'
26-
)
27-
28-
def convert(x):
29-
return f'{dtype}({x})'
30-
31-
return map_at_elements(x, convert)
32-
33-
349
# TODO: allow func callable to control the type of the numpy.array or jax.numpy.array
3510

3611
# map_at_elements should alter the elements but not the type of the container
@@ -88,49 +63,6 @@ def map_at_elements(x: tuple, func: typing.Callable):
8863
return tuple(map_at_elements(list(x), func))
8964

9065

91-
'''
92-
@dispatch
93-
def to_boolean_value_string(x: bool):
94-
return 'True' if x else 'False'
95-
96-
97-
@dispatch
98-
def to_boolean_value_string(x: numpy.bool_):
99-
return 'True' if x else 'False'
100-
101-
102-
@dispatch
103-
def to_boolean_value_string(x: int):
104-
return 'True' if x >= 1 else 'False'
105-
106-
107-
@dispatch
108-
def to_boolean_value_string(x: float):
109-
return 'True' if x >= 1.0 else 'False'
110-
111-
112-
@dispatch
113-
def to_boolean_value_string(x: str):
114-
if x == '1' or x == '1.0' or x == 'True':
115-
return 'True'
116-
elif x == '0' or x == '0.0' or x == 'False':
117-
return 'False'
118-
else:
119-
return x
120-
121-
@dispatch
122-
def to_numeric_value(x):
123-
if x == 'True' or x:
124-
return 1
125-
elif x == 'False' or not x:
126-
return 0
127-
elif isinstance(x, int) or isinstance(x, float):
128-
return x
129-
else:
130-
return 0
131-
'''
132-
133-
# TODO: add tests, and handle more cases
13466
@dispatch
13567
def symbolic_representation(x: numpy.ndarray):
13668
return repr(x).replace('array', 'numpy.array').replace('\n', '').replace('float32', 'numpy.float32').replace('\'', '')
@@ -146,47 +78,69 @@ def symbolic_operator(operator: str, x: str) -> str:
14678
return f'{operator}({x})'.replace('\'', '')
14779

14880

149-
# XXX
15081
@dispatch
15182
def symbolic_operator(operator: str, x: float, y: str):
15283
return symbolic_operator(operator, str(x), y)
84+
85+
15386
@dispatch
15487
def symbolic_operator(operator: str, x: str, y: float):
15588
return symbolic_operator(operator, x, str(y))
89+
90+
15691
@dispatch
15792
def symbolic_operator(operator: str, x: float, y: numpy.ndarray):
15893
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
94+
95+
15996
@dispatch
16097
def symbolic_operator(operator: str, x: numpy.ndarray, y: numpy.ndarray):
16198
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
99+
100+
162101
@dispatch
163102
def symbolic_operator(operator: str, x: str, y: str):
164103
return f'{operator}({x}, {y})'.replace('\'', '')
104+
105+
165106
@dispatch
166107
def symbolic_operator(operator: str, x: numpy.ndarray, y: float):
167108
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
109+
110+
168111
@dispatch
169112
def symbolic_operator(operator: str, x: list, y: float):
170113
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
114+
115+
171116
@dispatch
172117
def symbolic_operator(operator: str, x: list, y: list):
173118
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
119+
120+
174121
@dispatch
175122
def symbolic_operator(operator: str, x: bool, y: str):
176123
return symbolic_operator(operator, str(x), y)
124+
125+
177126
@dispatch
178127
def symbolic_operator(operator: str, x: str, y: numpy.ndarray):
179128
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
129+
130+
180131
@dispatch
181132
def symbolic_operator(operator: str, x: str, y: int):
182133
return symbolic_operator(operator, x, str(y))
134+
135+
183136
@dispatch
184137
def symbolic_operator(operator: str, x: list, y: numpy.ndarray):
185138
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
139+
140+
186141
@dispatch
187142
def symbolic_operator(operator: str, x: numpy.ndarray, y: jax.numpy.ndarray):
188143
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
189-
# XXX
190144

191145

192146
@dispatch
@@ -199,59 +153,6 @@ def symbolic_operator(operator: str, x: list):
199153
return symbolic_operator(operator, numpy.array(x))
200154

201155

202-
# TODO: remove infix_operator?
203-
"""
204-
@dispatch
205-
def symbolic_infix_operator(operator: str, a: str, b: str) -> str:
206-
return f'{a} {operator} {b}'.replace('\'', '')
207-
208-
209-
@dispatch
210-
def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: numpy.ndarray):
211-
return numpy.vectorize(symbolic_infix_operator, otypes=[object])(operator, a, b)
212-
213-
214-
@dispatch
215-
def symbolic_infix_operator(operator: str, a: list, b: numpy.ndarray):
216-
return symbolic_infix_operator(operator, numpy.array(a), b)
217-
218-
219-
@dispatch
220-
def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: list):
221-
return symbolic_infix_operator(operator, a, numpy.array(b))
222-
223-
224-
@dispatch
225-
def symbolic_infix_operator(operator: str, a: str, b: int):
226-
return symbolic_infix_operator(operator, a, str(b))
227-
228-
229-
@dispatch
230-
def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: float):
231-
return symbolic_infix_operator(operator, a, str(b))
232-
233-
234-
@dispatch
235-
def symbolic_infix_operator(operator: str, a: str, b: float):
236-
return symbolic_infix_operator(operator, a, str(b))
237-
238-
239-
@dispatch
240-
def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: jax.numpy.ndarray):
241-
return symbolic_infix_operator(operator, a, numpy.array(b))
242-
243-
244-
# XXXX
245-
@dispatch
246-
def symbolic_infix_operator(operator: str, a: list, b: list):
247-
return symbolic_infix_operator(operator, numpy.array(a), numpy.array(b))
248-
# XXXX
249-
250-
@dispatch
251-
def symbolic_infix_operator(operator: str, a: bool, b: str):
252-
return symbolic_infix_operator(operator, str(a), b)
253-
"""
254-
255156
def all_concrete_values(data):
256157
if isinstance(data, str):
257158
return False
@@ -308,7 +209,6 @@ def symbolic_gt(*args, **kwargs):
308209
return symbolic_operator('lax_reference.gt', *args, **kwargs)
309210

310211

311-
312212
def symbolic_abs(*args, **kwargs):
313213
if all_concrete_values([*args]):
314214
return lax_reference.abs(*args, **kwargs)
@@ -359,7 +259,6 @@ def symbolic_min(*args, **kwargs):
359259
return symbolic_operator('numpy.minimum', *args, **kwargs)
360260

361261

362-
363262
def symbolic_select_n(*args, **kwargs):
364263
'''
365264
Important comment from lax.py
@@ -375,14 +274,9 @@ def symbolic_select_n(*args, **kwargs):
375274
else:
376275
# swap order of on_true and on_false
377276
# TODO: need a more general solution to unquoting symbolic strings
378-
#return f'numpy.where({repr(pred)}, {repr(on_false)}, {repr(on_true)})'.replace('\'', '')
379-
#return f'numpy.where({repr(pred)}, {repr(on_false)}, {repr(on_true)})'
380277
evaluable_pred = symbolic_representation(pred)
381-
#print(f'evaluable_pred: {evaluable_pred}')
382278
evaluable_on_true = symbolic_representation(on_true)
383-
#print(f'evaluable_on_true: {evaluable_on_true}')
384279
evaluable_on_false = symbolic_representation(on_false)
385-
#print(f'evaluable_on_false: {evaluable_on_false}')
386280
return f'lax_reference.select({evaluable_pred}, {evaluable_on_false}, {evaluable_on_true})'
387281

388282

@@ -391,7 +285,6 @@ def symbolic_and(*args, **kwargs):
391285
return numpy.logical_and(*args, **kwargs)
392286
else:
393287
return symbolic_operator('numpy.logical_and', *args, **kwargs)
394-
395288

396289

397290
def symbolic_or(*args, **kwargs):
@@ -417,7 +310,8 @@ def symbolic_sum(*args, **kwargs):
417310

418311
def symbolic_broadcast_in_dim(*args, **kwargs):
419312
# broadcast_in_dim requires numpy arrays not lists
420-
args = tuple([numpy.array(arg) if isinstance(arg, list) else arg for arg in args])
313+
args = tuple([numpy.array(arg) if isinstance(
314+
arg, list) else arg for arg in args])
421315
return lax_reference.broadcast_in_dim(*args, **kwargs)
422316

423317

@@ -435,7 +329,9 @@ def symbolic_convert_element_type(*args, **kwargs):
435329
# If so, we can use the lax reference implementation
436330
return lax_reference.convert_element_type(*args, dtype=kwargs['new_dtype'])
437331
else:
438-
# Otherwise, we use the symbolic implementation
332+
# Otherwise, we nop
333+
def convert_element_type(x, dtype):
334+
return x
439335
return convert_element_type(*args, dtype=kwargs['new_dtype'])
440336

441337

0 commit comments

Comments
 (0)