|
1 | 1 | import numpy
|
2 | 2 | from plum import dispatch
|
| 3 | +import typing |
3 | 4 | import jax
|
4 | 5 | import jax._src.lax_reference as lax_reference
|
5 |
| -from neurallogic import primitives |
6 |
| - |
7 |
| - |
8 |
| -def to_boolean_value_string(x): |
9 |
| - if isinstance(x, bool): |
10 |
| - # x is a bool |
11 |
| - return 'True' if x else 'False' |
12 |
| - elif x == 1.0 or x == 0.0: |
13 |
| - # x is a float |
14 |
| - return 'True' if x == 1.0 else 'False' |
15 |
| - elif isinstance(x, str) and (x == '1' or x == '0'): |
16 |
| - # x is a string representing an integer |
17 |
| - return 'True' if x == '1' else 'False' |
18 |
| - elif isinstance(x, str) and (x == '1.0' or x == '0.0'): |
19 |
| - # x is a string representing a float |
20 |
| - return 'True' if x == '1.0' else 'False' |
21 |
| - elif isinstance(x, str) and (x == 'True' or x == 'False'): |
22 |
| - # x is a string representing a boolean |
| 6 | +import jaxlib |
| 7 | + |
| 8 | +def convert_iterable_type(x: list, new_type): |
| 9 | + if new_type == list: |
23 | 10 | return x
|
24 |
| - elif isinstance(x, numpy.ndarray) or isinstance(x, jax.numpy.ndarray) or isinstance(x, list) or isinstance(x, tuple): |
25 |
| - # We only operate on scalars |
26 |
| - raise ValueError( |
27 |
| - f"to_boolean_value_string only operates on scalars, but got {x}") |
| 11 | + elif new_type == numpy.ndarray: |
| 12 | + return numpy.array(x, dtype=object) |
| 13 | + elif new_type == jax.numpy.ndarray: |
| 14 | + return jax.numpy.array(x, dtype=object) |
| 15 | + elif new_type == jaxlib.xla_extension.DeviceArray: |
| 16 | + return jax.numpy.array(x, dtype=object) |
28 | 17 | else:
|
29 |
| - # x is not interpretable as a boolean |
30 |
| - return str(x) |
| 18 | + raise NotImplementedError(f"Cannot convert type {type(x)} to type {new_type}") |
| 19 | + |
| 20 | +@dispatch |
| 21 | +def map_at_elements(x: list, func: typing.Callable): |
| 22 | + return convert_iterable_type([map_at_elements(item, func) for item in x], type(x)) |
| 23 | + |
| 24 | +@dispatch |
| 25 | +def map_at_elements(x: numpy.ndarray, func: typing.Callable): |
| 26 | + return convert_iterable_type([map_at_elements(item, func) for item in x], type(x)) |
| 27 | + |
| 28 | +@dispatch |
| 29 | +def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable): |
| 30 | + if x.ndim == 0: |
| 31 | + return func(x.item()) |
| 32 | + return convert_iterable_type([map_at_elements(item, func) for item in x], type(x)) |
| 33 | + |
| 34 | +@dispatch |
| 35 | +def map_at_elements(x: str, func: typing.Callable): |
| 36 | + return func(x) |
| 37 | + |
| 38 | +@dispatch |
| 39 | +def map_at_elements(x, func: typing.Callable): |
| 40 | + return func(x) |
| 41 | + |
| 42 | +@dispatch |
| 43 | +def to_boolean_value_string(x: bool): |
| 44 | + return 'True' if x else 'False' |
31 | 45 |
|
| 46 | +@dispatch |
| 47 | +def to_boolean_value_string(x: numpy.bool_): |
| 48 | + return 'True' if x else 'False' |
32 | 49 |
|
33 |
| -def to_boolean_symbolic_values_impl(x): |
34 |
| - """Converts an arbitrary vector of arbitrary values to a list where |
35 |
| - every boolean-interpretable value gets converted to the strings "True" or "False". |
| 50 | +@dispatch |
| 51 | +def to_boolean_value_string(x: int): |
| 52 | + return 'True' if x == 1.0 else 'False' |
36 | 53 |
|
37 |
| - Args: |
38 |
| - x: The vector of values to convert (or can be a single value in the degenerate case) |
| 54 | +@dispatch |
| 55 | +def to_boolean_value_string(x: float): |
| 56 | + return 'True' if x == 1.0 else 'False' |
39 | 57 |
|
40 |
| - Returns: |
41 |
| - A list representation of the input, where boolean-interpretable |
42 |
| - values are converted to "True" or "False". |
43 |
| - """ |
44 |
| - if isinstance(x, numpy.ndarray) or isinstance(x, jax.numpy.ndarray) or isinstance(x, tuple): |
45 |
| - return to_boolean_symbolic_values_impl(x.tolist()) |
46 |
| - elif isinstance(x, list): |
47 |
| - return [to_boolean_symbolic_values_impl(y) for y in x] |
| 58 | +@dispatch |
| 59 | +def to_boolean_value_string(x: str): |
| 60 | + if x == '1' or x == '1.0' or x =='True': |
| 61 | + return 'True' |
| 62 | + elif x == '0' or x == '0.0' or x =='False': |
| 63 | + return 'False' |
48 | 64 | else:
|
49 |
| - return to_boolean_value_string(x) |
| 65 | + return x |
50 | 66 |
|
51 | 67 |
|
52 | 68 | def to_boolean_symbolic_values(x):
|
53 |
| - """Converts an arbitrary vector of arbitrary values to a numpy array where |
54 |
| - every boolean-interpretable value gets converted to the strings "True" or "False". |
55 |
| -
|
56 |
| - Args: |
57 |
| - x: The vector of values to convert (or can be a single value in the degenerate case) |
58 |
| -
|
59 |
| - Returns: |
60 |
| - A numpy array representation of the input, where boolean-interpretable |
61 |
| - values are converted to "True" or "False". |
62 |
| - """ |
63 |
| - x = to_boolean_symbolic_values_impl(x) |
64 |
| - if isinstance(x, list): |
65 |
| - x = numpy.array(x, dtype=object) |
66 |
| - else: |
67 |
| - x = numpy.array([x], dtype=object) |
68 |
| - return x |
| 69 | + return map_at_elements(x, to_boolean_value_string) |
69 | 70 |
|
70 | 71 |
|
71 | 72 | @dispatch
|
@@ -167,42 +168,12 @@ def symbolic_broadcast_in_dim(*args, **kwargs):
|
167 | 168 | return lax_reference.broadcast_in_dim(*args, **kwargs)
|
168 | 169 |
|
169 | 170 |
|
170 |
| -def is_iterable(obj): |
171 |
| - try: |
172 |
| - iter(obj) |
173 |
| - return True |
174 |
| - except TypeError: |
175 |
| - return False |
176 |
| - |
177 |
| -# TODO: unify this way of walking a nested iterable with the code above |
178 |
| -def apply_func_to_nested_impl(iterable, func): |
179 |
| - if isinstance(iterable, (numpy.ndarray, jax.numpy.ndarray)): |
180 |
| - iterable = iterable.tolist() |
181 |
| - if is_iterable(iterable): |
182 |
| - transformed = [] |
183 |
| - for item in iterable: |
184 |
| - if isinstance(item, list): |
185 |
| - transformed.append(apply_func_to_nested_impl(item, func)) |
186 |
| - else: |
187 |
| - transformed.append(func(item)) |
188 |
| - return transformed |
189 |
| - else: |
190 |
| - return func(iterable) |
191 |
| - |
192 |
| -def apply_func_to_nested(iterable, func): |
193 |
| - iterable_type = type(iterable) |
194 |
| - r = apply_func_to_nested_impl(iterable, func) |
195 |
| - if iterable_type == numpy.ndarray: |
196 |
| - r = numpy.array(r, dtype=object) |
197 |
| - assert type(r) == iterable_type |
198 |
| - return r |
199 |
| - |
200 | 171 | def symbolic_convert_element_type_impl(x, dtype):
|
201 | 172 | if dtype == numpy.int32 or dtype == numpy.int64:
|
202 | 173 | dtype = "int"
|
203 | 174 | def convert(x):
|
204 | 175 | return f"{dtype}({x})"
|
205 |
| - return apply_func_to_nested(x, convert) |
| 176 | + return map_at_elements(x, convert) |
206 | 177 |
|
207 | 178 |
|
208 | 179 | # TODO: add a test for this
|
|
0 commit comments