Skip to content

Commit 8efc75e

Browse files
committed
fix bug
1 parent 928dde7 commit 8efc75e

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

neurallogic/symbolic_primitives.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax._src.lax_reference as lax_reference
66
import jaxlib
77

8+
89
def convert_iterable_type(x: list, new_type):
910
if new_type == list:
1011
return x
@@ -15,51 +16,62 @@ def convert_iterable_type(x: list, new_type):
1516
elif new_type == jaxlib.xla_extension.DeviceArray:
1617
return jax.numpy.array(x, dtype=object)
1718
else:
18-
raise NotImplementedError(f"Cannot convert type {type(x)} to type {new_type}")
19+
raise NotImplementedError(
20+
f"Cannot convert type {type(x)} to type {new_type}")
21+
1922

2023
@dispatch
2124
def map_at_elements(x: list, func: typing.Callable):
2225
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
2326

27+
2428
@dispatch
2529
def map_at_elements(x: numpy.ndarray, func: typing.Callable):
2630
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
2731

32+
2833
@dispatch
2934
def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable):
3035
if x.ndim == 0:
3136
return func(x.item())
3237
return convert_iterable_type([map_at_elements(item, func) for item in x], type(x))
3338

39+
3440
@dispatch
3541
def map_at_elements(x: str, func: typing.Callable):
3642
return func(x)
3743

44+
3845
@dispatch
3946
def map_at_elements(x, func: typing.Callable):
4047
return func(x)
4148

49+
4250
@dispatch
4351
def to_boolean_value_string(x: bool):
4452
return 'True' if x else 'False'
4553

54+
4655
@dispatch
4756
def to_boolean_value_string(x: numpy.bool_):
4857
return 'True' if x else 'False'
4958

59+
5060
@dispatch
5161
def to_boolean_value_string(x: int):
5262
return 'True' if x == 1.0 else 'False'
5363

64+
5465
@dispatch
5566
def to_boolean_value_string(x: float):
5667
return 'True' if x == 1.0 else 'False'
5768

69+
5870
@dispatch
5971
def to_boolean_value_string(x: str):
60-
if x == '1' or x == '1.0' or x =='True':
72+
if x == '1' or x == '1.0' or x == 'True':
6173
return 'True'
62-
elif x == '0' or x == '0.0' or x =='False':
74+
elif x == '0' or x == '0.0' or x == 'False':
6375
return 'False'
6476
else:
6577
return x
@@ -86,7 +98,6 @@ def unary_operator(operator: str, x: list):
8698

8799
@dispatch
88100
def binary_infix_operator(operator: str, a: str, b: str, bracket: bool = False) -> str:
89-
# We need to specify bracket because Python cannot evaluate expressions with too many nested parantheses
90101
if bracket:
91102
return f"({a}) {operator} ({b})"
92103
return f"{a} {operator} {b}"
@@ -160,17 +171,17 @@ def symbolic_sum(*args, **kwargs):
160171
else:
161172
return binary_infix_operator("+", *args, **kwargs)
162173

163-
# Uses the lax reference implementation of broadcast_in_dim to
164-
# implement a symbolic version of broadcast_in_dim
165-
166174

167175
def symbolic_broadcast_in_dim(*args, **kwargs):
176+
# Uses the lax reference implementation of broadcast_in_dim to
177+
# implement a symbolic version of broadcast_in_dim
168178
return lax_reference.broadcast_in_dim(*args, **kwargs)
169179

170180

171181
def symbolic_convert_element_type_impl(x, dtype):
172182
if dtype == numpy.int32 or dtype == numpy.int64:
173183
dtype = "int"
184+
174185
def convert(x):
175186
return f"{dtype}({x})"
176187
return map_at_elements(x, convert)
@@ -187,13 +198,13 @@ def symbolic_convert_element_type(*args, **kwargs):
187198
return symbolic_convert_element_type_impl(*args, dtype=kwargs['new_dtype'])
188199

189200

190-
# This function is a hack to get around the fact that JAX doesn't
191-
# support symbolic reduction operations. It takes a symbolic reduction
192-
# operation and a symbolic initial value and returns a function that
193-
# performs the reduction operation on a numpy array.
194201

195202

196203
def make_symbolic_reducer(py_binop, init_val):
204+
# This function is a hack to get around the fact that JAX doesn't
205+
# support symbolic reduction operations. It takes a symbolic reduction
206+
# operation and a symbolic initial value and returns a function that
207+
# performs the reduction operation on a numpy array.
197208
def reducer(operand, axis):
198209
# axis=None means we are reducing over all axes of the operand.
199210
axis = range(numpy.ndim(operand)) if axis is None else axis

tests/test_sym_gen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
def nln(type, x, width):
99
x = hard_or.or_layer(type)(width)(x)
10-
# if not_layer has size "width" then this fails. why?
1110
x = hard_not.not_layer(type)(10)(x)
1211
x = primitives.nl_ravel(type)(x)
1312
x = harden_layer.harden_layer(type)(x)
@@ -26,7 +25,7 @@ def test_sym_gen():
2625
test_ds["image"], (test_ds["image"].shape[0], -1))
2726

2827
# Define width of network
29-
width = 4
28+
width = 10
3029
# Define the neural logic net
3130
soft, hard, _ = neural_logic_net.net(lambda type, x: nln(type, x, width))
3231
# Initialize a random number generator

0 commit comments

Comments
 (0)