Skip to content

Commit 27a4d9f

Browse files
authored
Merge pull request #56 from github/multiple-backends
Document incompleteness of select_n
2 parents 4ba18f2 + 101d5cd commit 27a4d9f

8 files changed

+89
-42
lines changed

neurallogic/real_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8-
# TODO: perhaps this can be simplified with a simple multiplication?
98
# TODO: implement a soft_real_decoder that can perhaps replace the port count approach
109

10+
1111
def soft_real_encoder(t: float, x: float) -> float:
1212
eps = 0.0000001
1313
# x should be in [0, 1]
14-
t = jax.numpy.clip(t, 0.0, 1.0)
14+
t = jax.numpy.clip(t, 0, 1)
1515
return jax.numpy.where(
1616
jax.numpy.isclose(t, x),
1717
0.5,
1818
# t != x
1919
jax.numpy.where(
2020
x < t,
21-
(1.0 / (2.0 * t + eps)) * x,
21+
(x / (2 * t + eps)),
2222
# x > t
23-
(1.0 / (2.0 * (1.0 - t) + eps)) * (x + 1.0 - 2.0 * t)
23+
(x + 1 - 2 * t) / (2 * (1 - t) + eps)
2424
)
2525
)
2626

neurallogic/symbolic_generation.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ def symbolic_bind(prim, *args, **params):
2828
'lt': symbolic_primitives.symbolic_lt,
2929
'ge': symbolic_primitives.symbolic_ge,
3030
'gt': symbolic_primitives.symbolic_gt,
31-
'abs': symbolic_primitives.symbolic_abs,
3231
'add': symbolic_primitives.symbolic_add,
3332
'sub': symbolic_primitives.symbolic_sub,
3433
'mul': symbolic_primitives.symbolic_mul,
3534
'div': symbolic_primitives.symbolic_div,
35+
'tan': symbolic_primitives.symbolic_tan,
3636
'max': symbolic_primitives.symbolic_max,
3737
'min': symbolic_primitives.symbolic_min,
38+
'abs': symbolic_primitives.symbolic_abs,
39+
'round': symbolic_primitives.symbolic_round,
40+
'floor': symbolic_primitives.symbolic_floor,
41+
'ceil': symbolic_primitives.symbolic_ceil,
3842
'and': symbolic_primitives.symbolic_and,
3943
'or': symbolic_primitives.symbolic_or,
4044
'xor': symbolic_primitives.symbolic_xor,
@@ -93,10 +97,11 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
9397
x = numpy.asarray(x, dtype=numpy.int32)
9498
# Make the jaxpr that corresponds to the flax layer
9599
jaxpr = make_symbolic_jaxpr(flax_layer, x)
96-
# Make a list of bit_weights and thresholds but only include each if they are not None
97-
bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None]
98-
# Replace the dummy numeric weights with the actual weights in the jaxpr
99-
jaxpr.consts = bit_weights_and_thresholds
100+
if hasattr(jaxpr, '_consts'):
101+
# Make a list of bit_weights and thresholds but only include each if they are not None
102+
bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None]
103+
# Replace the dummy numeric weights with the actual weights in the jaxpr
104+
jaxpr.__setattr__('_consts', bit_weights_and_thresholds)
100105
return jaxpr
101106

102107

@@ -190,10 +195,9 @@ def eval_jaxpr_impl(jaxpr):
190195
outvals = [outvals]
191196
symbolic_outvals = [symbolic_outvals]
192197
if not symbolic:
193-
# Check that the concrete and symbolic values are equal
194-
# print(
195-
# f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}'
196-
# )
198+
# Always check that the symbolic binding generates the same values as the
199+
# standard jax binding in order to detect bugs early.
200+
# print(f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}')
197201
assert numpy.allclose(
198202
numpy.array(outvals), symbolic_outvals, equal_nan=True
199203
)

neurallogic/symbolic_operator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@ def symbolic_operator(operator: str, x: str) -> str:
88
return f'{operator}({x})'.replace('\'', '')
99

1010

11+
@dispatch
12+
def symbolic_operator(operator: str, x: str, y: str):
13+
return f'{operator}({x}, {y})'.replace('\'', '')
14+
15+
1116
@dispatch
1217
def symbolic_operator(operator: str, x: float, y: str):
1318
return symbolic_operator(operator, str(x), y)
1419

1520

21+
@dispatch
22+
def symbolic_operator(operator: str, x: int, y: str):
23+
return symbolic_operator(operator, str(x), y)
24+
25+
1626
@dispatch
1727
def symbolic_operator(operator: str, x: str, y: float):
1828
return symbolic_operator(operator, x, str(y))
@@ -28,11 +38,6 @@ def symbolic_operator(operator: str, x: numpy.ndarray, y: numpy.ndarray):
2838
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
2939

3040

31-
@dispatch
32-
def symbolic_operator(operator: str, x: str, y: str):
33-
return f'{operator}({x}, {y})'.replace('\'', '')
34-
35-
3641
@dispatch
3742
def symbolic_operator(operator: str, x: numpy.ndarray, y: float):
3843
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
@@ -86,4 +91,3 @@ def symbolic_operator(operator: str, x: numpy.ndarray):
8691
@dispatch
8792
def symbolic_operator(operator: str, x: list):
8893
return symbolic_operator(operator, numpy.array(x))
89-

neurallogic/symbolic_primitives.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import jax
44
import jax._src.lax_reference as lax_reference
5+
import jax._src.lax.lax as lax
56
import numpy
67

78
from neurallogic import symbolic_operator, symbolic_representation
@@ -59,31 +60,54 @@ def symbolic_gt(*args, **kwargs):
5960

6061

6162
def symbolic_abs(*args, **kwargs):
62-
return symbolic(lax_reference.abs, 'numpy.absolute', *args, **kwargs)
63+
return symbolic(lax_reference.abs, 'lax_reference.abs', *args, **kwargs)
64+
65+
66+
def symbolic_floor(*args, **kwargs):
67+
return symbolic(lax_reference.floor, 'lax_reference.floor', *args, **kwargs)
68+
69+
70+
def symbolic_ceil(*args, **kwargs):
71+
return symbolic(lax_reference.ceil, 'lax_reference.ceil', *args, **kwargs)
72+
73+
74+
def symbolic_round(*args, **kwargs):
75+
# The reference implementation only supports away from zero
76+
if kwargs['rounding_method'] == lax.RoundingMethod.AWAY_FROM_ZERO:
77+
return symbolic(lax_reference.round, 'lax_reference.round', *args)
78+
elif kwargs['rounding_method'] == lax.RoundingMethod.TO_NEAREST_EVEN:
79+
return symbolic(numpy.around, 'numpy.around', *args)
80+
else:
81+
raise NotImplementedError(
82+
f'rounding_method {str(kwargs["rounding_method"])} not implemented')
6383

6484

6585
def symbolic_add(*args, **kwargs):
66-
return symbolic(lax_reference.add, 'numpy.add', *args, **kwargs)
86+
return symbolic(lax_reference.add, 'lax_reference.add', *args, **kwargs)
6787

6888

6989
def symbolic_sub(*args, **kwargs):
70-
return symbolic(lax_reference.sub, 'numpy.subtract', *args, **kwargs)
90+
return symbolic(lax_reference.sub, 'lax_reference.sub', *args, **kwargs)
7191

7292

7393
def symbolic_mul(*args, **kwargs):
74-
return symbolic(lax_reference.mul, 'numpy.multiply', *args, **kwargs)
94+
return symbolic(lax_reference.mul, 'lax_reference.mul', *args, **kwargs)
7595

7696

7797
def symbolic_div(*args, **kwargs):
7898
return symbolic(lax_reference.div, 'lax_reference.div', *args, **kwargs)
7999

80100

101+
def symbolic_tan(*args, **kwargs):
102+
return symbolic(lax_reference.tan, 'lax_reference.tan', *args, **kwargs)
103+
104+
81105
def symbolic_max(*args, **kwargs):
82-
return symbolic(lax_reference.max, 'numpy.maximum', *args, **kwargs)
106+
return symbolic(lax_reference.max, 'lax_reference.max', *args, **kwargs)
83107

84108

85109
def symbolic_min(*args, **kwargs):
86-
return symbolic(lax_reference.min, 'numpy.minimum', *args, **kwargs)
110+
return symbolic(lax_reference.min, 'lax_reference.min', *args, **kwargs)
87111

88112

89113
def symbolic_and(*args, **kwargs):
@@ -143,8 +167,11 @@ def symbolic_select_n(*args, **kwargs):
143167
# swap order of on_true and on_false
144168
return lax_reference.select(pred, on_false, on_true)
145169
else:
170+
# TODO: to retain tensor structure we need to push down the select to the
171+
# lowest level of the symbolic expression tree. This is not currently
172+
# implemented.
173+
print('WARNING: symbolic_select_n is not fully implemented. This may not work as expected.')
146174
# swap order of on_true and on_false
147-
# TODO: need a more general solution to unquoting symbolic strings
148175
evaluable_pred = symbolic_representation.symbolic_representation(pred)
149176
evaluable_on_true = symbolic_representation.symbolic_representation(
150177
on_true)

neurallogic/symbolic_representation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy
22
from plum import dispatch
33

4+
# TODO: need a more general solution to unquoting symbolic strings
45

56
@dispatch
67
def symbolic_representation(x: numpy.ndarray):

tests/test_mnist.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def nln(type, x, width):
4747
def nln(type, x):
4848
num_classes = 10
4949

50-
x = hard_or.or_layer(type)(1000, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
51-
x = hard_not.not_layer(type)(num_classes)(x)
50+
x = hard_or.or_layer(type)(1800, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
51+
x = hard_not.not_layer(type)(1, dtype=jax.numpy.float16)(x)
5252
x = x.ravel()
5353
x = harden_layer.harden_layer(type)(x)
5454
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
@@ -128,13 +128,14 @@ def get_datasets():
128128
ds_builder.download_and_prepare()
129129
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
130130
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
131-
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
132-
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
131+
# XXXX
132+
train_ds["image"] = (jnp.float32(train_ds["image"]) / 255.0)
133+
test_ds["image"] = (jnp.float32(test_ds["image"]) / 255.0)
133134
# TODO: we don't need to do this even when we don't use the real encoder
134135
# Use grayscale information
135136
# Convert the floating point values in [0,1] to binary values in {0,1}
136-
train_ds["image"] = jnp.round(train_ds["image"])
137-
test_ds["image"] = jnp.round(test_ds["image"])
137+
#train_ds["image"] = jnp.round(train_ds["image"])
138+
#test_ds["image"] = jnp.round(test_ds["image"])
138139
return train_ds, test_ds
139140

140141

0 commit comments

Comments
 (0)