Skip to content

Commit 5123507

Browse files
committed
fix bug in xor; clean-up symbolic primitives
1 parent 049f973 commit 5123507

File tree

2 files changed

+50
-94
lines changed

2 files changed

+50
-94
lines changed

neurallogic/hard_xor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def soft_xor_neuron(w, x):
2828

2929
def xor(x, y):
3030
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
31-
x = jax.lax.reduce(x, jax.numpy.float16(0.0), xor, (0,))
31+
x = jax.lax.reduce(x, jax.numpy.array(0.0), xor, (0,))
3232
return x
3333

3434

neurallogic/symbolic_primitives.py

Lines changed: 49 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable
2+
13
import jax
24
import jax._src.lax_reference as lax_reference
35
import numpy
@@ -19,152 +21,83 @@ def all_concrete_values(data):
1921
return True
2022

2123

22-
def symbolic_not(*args, **kwargs):
24+
def symbolic_f(concrete_f: Callable, symbolic_f_name: str, *args, **kwargs):
2325
if all_concrete_values([*args]):
24-
return numpy.logical_not(*args, **kwargs)
26+
return concrete_f(*args, **kwargs)
2527
else:
26-
return symbolic_operator.symbolic_operator('numpy.logical_not', *args, **kwargs)
28+
return symbolic_operator.symbolic_operator(symbolic_f_name, *args, **kwargs)
29+
30+
31+
def symbolic_not(*args, **kwargs):
32+
return symbolic_f(numpy.logical_not, 'numpy.logical_not', *args, **kwargs)
2733

2834

2935
def symbolic_eq(*args, **kwargs):
30-
if all_concrete_values([*args]):
31-
return lax_reference.eq(*args, **kwargs)
32-
else:
33-
return symbolic_operator.symbolic_operator('lax_reference.eq', *args, **kwargs)
36+
return symbolic_f(lax_reference.eq, 'lax_reference.eq', *args, **kwargs)
3437

3538

3639
def symbolic_ne(*args, **kwargs):
37-
if all_concrete_values([*args]):
38-
return lax_reference.ne(*args, **kwargs)
39-
else:
40-
return symbolic_operator.symbolic_operator('lax_reference.ne', *args, **kwargs)
40+
return symbolic_f(lax_reference.ne, 'lax_reference.ne', *args, **kwargs)
4141

4242

4343
def symbolic_le(*args, **kwargs):
44-
if all_concrete_values([*args]):
45-
return lax_reference.le(*args, **kwargs)
46-
else:
47-
return symbolic_operator.symbolic_operator('lax_reference.le', *args, **kwargs)
44+
return symbolic_f(lax_reference.le, 'lax_reference.le', *args, **kwargs)
4845

4946

5047
def symbolic_lt(*args, **kwargs):
51-
if all_concrete_values([*args]):
52-
return lax_reference.lt(*args, **kwargs)
53-
else:
54-
return symbolic_operator.symbolic_operator('lax_reference.lt', *args, **kwargs)
48+
return symbolic_f(lax_reference.lt, 'lax_reference.lt', *args, **kwargs)
5549

5650

5751
def symbolic_ge(*args, **kwargs):
58-
if all_concrete_values([*args]):
59-
return lax_reference.ge(*args, **kwargs)
60-
else:
61-
return symbolic_operator.symbolic_operator('lax_reference.ge', *args, **kwargs)
52+
return symbolic_f(lax_reference.ge, 'lax_reference.ge', *args, **kwargs)
6253

6354

6455
def symbolic_gt(*args, **kwargs):
65-
if all_concrete_values([*args]):
66-
return lax_reference.gt(*args, **kwargs)
67-
else:
68-
return symbolic_operator.symbolic_operator('lax_reference.gt', *args, **kwargs)
56+
return symbolic_f(lax_reference.gt, 'lax_reference.gt', *args, **kwargs)
6957

7058

7159
def symbolic_abs(*args, **kwargs):
72-
if all_concrete_values([*args]):
73-
return lax_reference.abs(*args, **kwargs)
74-
else:
75-
return symbolic_operator.symbolic_operator('numpy.absolute', *args, **kwargs)
60+
return symbolic_f(lax_reference.abs, 'numpy.absolute', *args, **kwargs)
7661

7762

7863
def symbolic_add(*args, **kwargs):
79-
if all_concrete_values([*args]):
80-
return lax_reference.add(*args, **kwargs)
81-
else:
82-
return symbolic_operator.symbolic_operator('numpy.add', *args, **kwargs)
64+
return symbolic_f(lax_reference.add, 'numpy.add', *args, **kwargs)
8365

8466

8567
def symbolic_sub(*args, **kwargs):
86-
if all_concrete_values([*args]):
87-
return lax_reference.sub(*args, **kwargs)
88-
else:
89-
return symbolic_operator.symbolic_operator('numpy.subtract', *args, **kwargs)
68+
return symbolic_f(lax_reference.sub, 'numpy.subtract', *args, **kwargs)
9069

9170

9271
def symbolic_mul(*args, **kwargs):
93-
if all_concrete_values([*args]):
94-
return lax_reference.mul(*args, **kwargs)
95-
else:
96-
return symbolic_operator.symbolic_operator('numpy.multiply', *args, **kwargs)
72+
return symbolic_f(lax_reference.mul, 'numpy.multiply', *args, **kwargs)
9773

9874

9975
def symbolic_div(*args, **kwargs):
100-
if all_concrete_values([*args]):
101-
return lax_reference.div(*args, **kwargs)
102-
else:
103-
return symbolic_operator.symbolic_operator('lax_reference.div', *args, **kwargs)
76+
return symbolic_f(lax_reference.div, 'lax_reference.div', *args, **kwargs)
10477

10578

10679
def symbolic_max(*args, **kwargs):
107-
if all_concrete_values([*args]):
108-
return lax_reference.max(*args, **kwargs)
109-
else:
110-
r = symbolic_operator.symbolic_operator('numpy.maximum', *args, **kwargs)
111-
return r
80+
return symbolic_f(lax_reference.max, 'numpy.maximum', *args, **kwargs)
11281

11382

11483
def symbolic_min(*args, **kwargs):
115-
if all_concrete_values([*args]):
116-
return lax_reference.min(*args, **kwargs)
117-
else:
118-
return symbolic_operator.symbolic_operator('numpy.minimum', *args, **kwargs)
119-
120-
121-
def symbolic_select_n(*args, **kwargs):
122-
'''
123-
Important comment from lax.py
124-
# Caution! The select_n_p primitive has the *opposite* order of arguments to
125-
# select(). This is because it implements `select_n`.
126-
'''
127-
pred = args[0]
128-
on_true = args[1]
129-
on_false = args[2]
130-
if all_concrete_values([*args]):
131-
# swap order of on_true and on_false
132-
return lax_reference.select(pred, on_false, on_true)
133-
else:
134-
# swap order of on_true and on_false
135-
# TODO: need a more general solution to unquoting symbolic strings
136-
evaluable_pred = symbolic_representation.symbolic_representation(pred)
137-
evaluable_on_true = symbolic_representation.symbolic_representation(on_true)
138-
evaluable_on_false = symbolic_representation.symbolic_representation(on_false)
139-
return f'lax_reference.select({evaluable_pred}, {evaluable_on_false}, {evaluable_on_true})'
84+
return symbolic_f(lax_reference.min, 'numpy.minimum', *args, **kwargs)
14085

14186

14287
def symbolic_and(*args, **kwargs):
143-
if all_concrete_values([*args]):
144-
return numpy.logical_and(*args, **kwargs)
145-
else:
146-
return symbolic_operator.symbolic_operator('numpy.logical_and', *args, **kwargs)
88+
return symbolic_f(numpy.logical_and, 'numpy.logical_and', *args, **kwargs)
14789

14890

14991
def symbolic_or(*args, **kwargs):
150-
if all_concrete_values([*args]):
151-
return numpy.logical_or(*args, **kwargs)
152-
else:
153-
return symbolic_operator.symbolic_operator('numpy.logical_or', *args, **kwargs)
92+
return symbolic_f(numpy.logical_or, 'numpy.logical_or', *args, **kwargs)
15493

15594

15695
def symbolic_xor(*args, **kwargs):
157-
if all_concrete_values([*args]):
158-
return numpy.logical_xor(*args, **kwargs)
159-
else:
160-
return symbolic_operator.symbolic_operator('numpy.logical_xor', *args, **kwargs)
96+
return symbolic_f(numpy.logical_xor, 'numpy.logical_xor', *args, **kwargs)
16197

16298

16399
def symbolic_sum(*args, **kwargs):
164-
if all_concrete_values([*args]):
165-
return lax_reference.sum(*args, **kwargs)
166-
else:
167-
return symbolic_operator.symbolic_operator('lax_reference.sum', *args, **kwargs)
100+
return symbolic_f(lax_reference.sum, 'lax_reference.sum', *args, **kwargs)
168101

169102

170103
def symbolic_broadcast_in_dim(*args, **kwargs):
@@ -194,6 +127,29 @@ def convert_element_type(x, dtype):
194127
return convert_element_type(*args, dtype=kwargs['new_dtype'])
195128

196129

130+
def symbolic_select_n(*args, **kwargs):
131+
'''
132+
Important comment from lax.py
133+
# Caution! The select_n_p primitive has the *opposite* order of arguments to
134+
# select(). This is because it implements `select_n`.
135+
'''
136+
pred = args[0]
137+
on_true = args[1]
138+
on_false = args[2]
139+
if all_concrete_values([*args]):
140+
# swap order of on_true and on_false
141+
return lax_reference.select(pred, on_false, on_true)
142+
else:
143+
# swap order of on_true and on_false
144+
# TODO: need a more general solution to unquoting symbolic strings
145+
evaluable_pred = symbolic_representation.symbolic_representation(pred)
146+
evaluable_on_true = symbolic_representation.symbolic_representation(
147+
on_true)
148+
evaluable_on_false = symbolic_representation.symbolic_representation(
149+
on_false)
150+
return f'lax_reference.select({evaluable_pred}, {evaluable_on_false}, {evaluable_on_true})'
151+
152+
197153
def make_symbolic_reducer(py_binop, init_val):
198154
# This function is a hack to get around the fact that JAX doesn't
199155
# support symbolic reduction operations. It takes a symbolic reduction

0 commit comments

Comments
 (0)