Skip to content

Commit 928dde7

Browse files
committed
clean-up handling of type conversions, and mapping of elements in arbitrary nested data structures
1 parent 49eaeda commit 928dde7

File tree

4 files changed

+72
-102
lines changed

4 files changed

+72
-102
lines changed

neurallogic/sym_gen.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ def eval_jaxpr_concrete(jaxpr, *args):
129129

130130

131131
def eval_jaxpr_symbolic(jaxpr, *args):
132-
# Convert the literals to symbolic literals
133-
symbolic_jaxpr_literals = symbolic_primitives.to_boolean_symbolic_values(
134-
jaxpr.literals)
132+
symbolic_jaxpr_literals = safe_map(lambda x: numpy.array(x, dtype=object), jaxpr.literals)
133+
symbolic_jaxpr_literals = symbolic_primitives.to_boolean_symbolic_values(symbolic_jaxpr_literals)
135134
return eval_jaxpr(True, jaxpr.jaxpr, symbolic_jaxpr_literals, *args)
136135

neurallogic/symbolic_primitives.py

Lines changed: 56 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,72 @@
11
import numpy
22
from plum import dispatch
3+
import typing
34
import jax
45
import jax._src.lax_reference as lax_reference
5-
from neurallogic import primitives
6-
7-
8-
def to_boolean_value_string(x):
9-
if isinstance(x, bool):
10-
# x is a bool
11-
return 'True' if x else 'False'
12-
elif x == 1.0 or x == 0.0:
13-
# x is a float
14-
return 'True' if x == 1.0 else 'False'
15-
elif isinstance(x, str) and (x == '1' or x == '0'):
16-
# x is a string representing an integer
17-
return 'True' if x == '1' else 'False'
18-
elif isinstance(x, str) and (x == '1.0' or x == '0.0'):
19-
# x is a string representing a float
20-
return 'True' if x == '1.0' else 'False'
21-
elif isinstance(x, str) and (x == 'True' or x == 'False'):
22-
# x is a string representing a boolean
6+
import jaxlib
7+
8+
def convert_iterable_type(x: list, new_type):
9+
if new_type == list:
2310
return x
24-
elif isinstance(x, numpy.ndarray) or isinstance(x, jax.numpy.ndarray) or isinstance(x, list) or isinstance(x, tuple):
25-
# We only operate on scalars
26-
raise ValueError(
27-
f"to_boolean_value_string only operates on scalars, but got {x}")
11+
elif new_type == numpy.ndarray:
12+
return numpy.array(x, dtype=object)
13+
elif new_type == jax.numpy.ndarray:
14+
return jax.numpy.array(x, dtype=object)
15+
elif new_type == jaxlib.xla_extension.DeviceArray:
16+
return jax.numpy.array(x, dtype=object)
2817
else:
29-
# x is not interpretable as a boolean
30-
return str(x)
18+
raise NotImplementedError(f"Cannot convert type {type(x)} to type {new_type}")
19+
20+
@dispatch
21+
def map_at_elements(x: list, func: typing.Callable):
22+
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
23+
24+
@dispatch
25+
def map_at_elements(x: numpy.ndarray, func: typing.Callable):
26+
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
27+
28+
@dispatch
29+
def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable):
30+
if x.ndim == 0:
31+
return func(x.item())
32+
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
33+
34+
@dispatch
35+
def map_at_elements(x: str, func: typing.Callable):
36+
return func(x)
37+
38+
@dispatch
39+
def map_at_elements(x, func: typing.Callable):
40+
return func(x)
41+
42+
@dispatch
43+
def to_boolean_value_string(x: bool):
44+
return 'True' if x else 'False'
3145

46+
@dispatch
47+
def to_boolean_value_string(x: numpy.bool_):
48+
return 'True' if x else 'False'
3249

33-
def to_boolean_symbolic_values_impl(x):
34-
"""Converts an arbitrary vector of arbitrary values to a list where
35-
every boolean-interpretable value gets converted to the strings "True" or "False".
50+
@dispatch
51+
def to_boolean_value_string(x: int):
52+
return 'True' if x == 1.0 else 'False'
3653

37-
Args:
38-
x: The vector of values to convert (or can be a single value in the degenerate case)
54+
@dispatch
55+
def to_boolean_value_string(x: float):
56+
return 'True' if x == 1.0 else 'False'
3957

40-
Returns:
41-
A list representation of the input, where boolean-interpretable
42-
values are converted to "True" or "False".
43-
"""
44-
if isinstance(x, numpy.ndarray) or isinstance(x, jax.numpy.ndarray) or isinstance(x, tuple):
45-
return to_boolean_symbolic_values_impl(x.tolist())
46-
elif isinstance(x, list):
47-
return [to_boolean_symbolic_values_impl(y) for y in x]
58+
@dispatch
59+
def to_boolean_value_string(x: str):
60+
if x == '1' or x == '1.0' or x =='True':
61+
return 'True'
62+
elif x == '0' or x == '0.0' or x =='False':
63+
return 'False'
4864
else:
49-
return to_boolean_value_string(x)
65+
return x
5066

5167

5268
def to_boolean_symbolic_values(x):
53-
"""Converts an arbitrary vector of arbitrary values to a numpy array where
54-
every boolean-interpretable value gets converted to the strings "True" or "False".
55-
56-
Args:
57-
x: The vector of values to convert (or can be a single value in the degenerate case)
58-
59-
Returns:
60-
A numpy array representation of the input, where boolean-interpretable
61-
values are converted to "True" or "False".
62-
"""
63-
x = to_boolean_symbolic_values_impl(x)
64-
if isinstance(x, list):
65-
x = numpy.array(x, dtype=object)
66-
else:
67-
x = numpy.array([x], dtype=object)
68-
return x
69+
return map_at_elements(x, to_boolean_value_string)
6970

7071

7172
@dispatch
@@ -167,42 +168,12 @@ def symbolic_broadcast_in_dim(*args, **kwargs):
167168
return lax_reference.broadcast_in_dim(*args, **kwargs)
168169

169170

170-
def is_iterable(obj):
171-
try:
172-
iter(obj)
173-
return True
174-
except TypeError:
175-
return False
176-
177-
# TODO: unify this way of walking a nested iterable with the code above
178-
def apply_func_to_nested_impl(iterable, func):
179-
if isinstance(iterable, (numpy.ndarray, jax.numpy.ndarray)):
180-
iterable = iterable.tolist()
181-
if is_iterable(iterable):
182-
transformed = []
183-
for item in iterable:
184-
if isinstance(item, list):
185-
transformed.append(apply_func_to_nested_impl(item, func))
186-
else:
187-
transformed.append(func(item))
188-
return transformed
189-
else:
190-
return func(iterable)
191-
192-
def apply_func_to_nested(iterable, func):
193-
iterable_type = type(iterable)
194-
r = apply_func_to_nested_impl(iterable, func)
195-
if iterable_type == numpy.ndarray:
196-
r = numpy.array(r, dtype=object)
197-
assert type(r) == iterable_type
198-
return r
199-
200171
def symbolic_convert_element_type_impl(x, dtype):
201172
if dtype == numpy.int32 or dtype == numpy.int64:
202173
dtype = "int"
203174
def convert(x):
204175
return f"{dtype}({x})"
205-
return apply_func_to_nested(x, convert)
176+
return map_at_elements(x, convert)
206177

207178

208179
# TODO: add a test for this

tests/test_sym_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_sym_gen():
5353
# -- TEST 2: Compare the standard evaluation of the network with the non-standard symbolic evaluation of the jaxpr
5454
# Convert the hard input to a symbolic input
5555
symbolic_mock_input = symbolic_primitives.to_boolean_symbolic_values(
56-
hard_mock_input)
56+
numpy.array(hard_mock_input, dtype=object))
5757
# Symbolically evaluate the jaxpr with the symbolic input
5858
eval_symbolic_output = sym_gen.eval_jaxpr_symbolic(
5959
jaxpr, symbolic_mock_input)

tests/test_symbolic_primitives.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,42 +141,42 @@ def test_to_boolean_value_string():
141141

142142
def test_to_boolean_symbolic_values():
143143
output = symbolic_primitives.to_boolean_symbolic_values([1, 1])
144-
expected = numpy.array(["True", "True"])
144+
expected = ["True", "True"]
145145
assert numpy.array_equal(output, expected)
146146
output = symbolic_primitives.to_boolean_symbolic_values([0, 0])
147-
expected = numpy.array(["False", "False"])
147+
expected = ["False", "False"]
148148
assert numpy.array_equal(output, expected)
149149
output = symbolic_primitives.to_boolean_symbolic_values([True, False])
150-
expected = numpy.array(["True", "False"])
150+
expected = ["True", "False"]
151151
assert numpy.array_equal(output, expected)
152152
output = symbolic_primitives.to_boolean_symbolic_values([False, True])
153-
expected = numpy.array(["False", "True"])
153+
expected = ["False", "True"]
154154
assert numpy.array_equal(output, expected)
155155
output = symbolic_primitives.to_boolean_symbolic_values([1.0, 1.0])
156-
expected = numpy.array(["True", "True"])
156+
expected = ["True", "True"]
157157
assert numpy.array_equal(output, expected)
158158
output = symbolic_primitives.to_boolean_symbolic_values([0.0, 0.0])
159-
expected = numpy.array(["False", "False"])
159+
expected = ["False", "False"]
160160
assert numpy.array_equal(output, expected)
161161
output = symbolic_primitives.to_boolean_symbolic_values([[1, 1], [1, 1]])
162-
expected = numpy.array([["True", "True"], ["True", "True"]])
162+
expected = [["True", "True"], ["True", "True"]]
163163
assert numpy.array_equal(output, expected)
164164
output = symbolic_primitives.to_boolean_symbolic_values([[0, 0], [0, 0]])
165-
expected = numpy.array([["False", "False"], ["False", "False"]])
165+
expected = [["False", "False"], ["False", "False"]]
166166
assert numpy.array_equal(output, expected)
167167
output = symbolic_primitives.to_boolean_symbolic_values(
168168
[[True, False], [False, True]])
169-
expected = numpy.array([["True", "False"], ["False", "True"]])
169+
expected = [["True", "False"], ["False", "True"]]
170170
assert numpy.array_equal(output, expected)
171171
output = symbolic_primitives.to_boolean_symbolic_values(
172172
[[[1, 0, 1], [1, 0, 1]], [[1, 0, 0], [1, 0, 0]]])
173-
expected = numpy.array([[["True", "False", "True"], ["True", "False", "True"]], [
174-
["True", "False", "False"], ["True", "False", "False"]]])
173+
expected = [[["True", "False", "True"], ["True", "False", "True"]], [
174+
["True", "False", "False"], ["True", "False", "False"]]]
175175
assert numpy.array_equal(output, expected)
176176
output = symbolic_primitives.to_boolean_symbolic_values(
177177
[[[1, "f", 1], [1, "g", 1]], [[1, "h", 0], [1, "f", 0]]])
178-
expected = numpy.array([[["True", "f", "True"], ["True", "g", "True"]], [
179-
["True", "h", "False"], ["True", "f", "False"]]])
178+
expected = [[["True", "f", "True"], ["True", "g", "True"]], [
179+
["True", "h", "False"], ["True", "f", "False"]]]
180180
assert numpy.array_equal(output, expected)
181181

182182

0 commit comments

Comments
 (0)