Skip to content

Commit b2066db

Browse files
committed
document problem with select_n
1 parent 4ba18f2 commit b2066db

File tree

7 files changed

+76
-33
lines changed

7 files changed

+76
-33
lines changed

neurallogic/real_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8-
# TODO: perhaps this can be simplified with a simple multiplication?
98
# TODO: implement a soft_real_decoder that can perhaps replace the port count approach
109

10+
1111
def soft_real_encoder(t: float, x: float) -> float:
1212
eps = 0.0000001
1313
# x should be in [0, 1]
14-
t = jax.numpy.clip(t, 0.0, 1.0)
14+
t = jax.numpy.clip(t, 0, 1)
1515
return jax.numpy.where(
1616
jax.numpy.isclose(t, x),
1717
0.5,
1818
# t != x
1919
jax.numpy.where(
2020
x < t,
21-
(1.0 / (2.0 * t + eps)) * x,
21+
(x / (2 * t + eps)),
2222
# x > t
23-
(1.0 / (2.0 * (1.0 - t) + eps)) * (x + 1.0 - 2.0 * t)
23+
(x + 1 - 2 * t) / (2 * (1 - t) + eps)
2424
)
2525
)
2626

neurallogic/symbolic_operator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@ def symbolic_operator(operator: str, x: str) -> str:
88
return f'{operator}({x})'.replace('\'', '')
99

1010

11+
@dispatch
12+
def symbolic_operator(operator: str, x: str, y: str):
13+
return f'{operator}({x}, {y})'.replace('\'', '')
14+
15+
1116
@dispatch
1217
def symbolic_operator(operator: str, x: float, y: str):
1318
return symbolic_operator(operator, str(x), y)
1419

1520

21+
@dispatch
22+
def symbolic_operator(operator: str, x: int, y: str):
23+
return symbolic_operator(operator, str(x), y)
24+
25+
1626
@dispatch
1727
def symbolic_operator(operator: str, x: str, y: float):
1828
return symbolic_operator(operator, x, str(y))
@@ -28,11 +38,6 @@ def symbolic_operator(operator: str, x: numpy.ndarray, y: numpy.ndarray):
2838
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
2939

3040

31-
@dispatch
32-
def symbolic_operator(operator: str, x: str, y: str):
33-
return f'{operator}({x}, {y})'.replace('\'', '')
34-
35-
3641
@dispatch
3742
def symbolic_operator(operator: str, x: numpy.ndarray, y: float):
3843
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
@@ -86,4 +91,3 @@ def symbolic_operator(operator: str, x: numpy.ndarray):
8691
@dispatch
8792
def symbolic_operator(operator: str, x: list):
8893
return symbolic_operator(operator, numpy.array(x))
89-

neurallogic/symbolic_primitives.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import jax
44
import jax._src.lax_reference as lax_reference
5+
import jax._src.lax.lax as lax
56
import numpy
67

78
from neurallogic import symbolic_operator, symbolic_representation
@@ -59,31 +60,54 @@ def symbolic_gt(*args, **kwargs):
5960

6061

6162
def symbolic_abs(*args, **kwargs):
62-
return symbolic(lax_reference.abs, 'numpy.absolute', *args, **kwargs)
63+
return symbolic(lax_reference.abs, 'lax_reference.abs', *args, **kwargs)
64+
65+
66+
def symbolic_floor(*args, **kwargs):
67+
return symbolic(lax_reference.floor, 'lax_reference.floor', *args, **kwargs)
68+
69+
70+
def symbolic_ceil(*args, **kwargs):
71+
return symbolic(lax_reference.ceil, 'lax_reference.ceil', *args, **kwargs)
72+
73+
74+
def symbolic_round(*args, **kwargs):
75+
# The reference implementation only supports away from zero
76+
if kwargs['rounding_method'] == lax.RoundingMethod.AWAY_FROM_ZERO:
77+
return symbolic(lax_reference.round, 'lax_reference.round', *args)
78+
elif kwargs['rounding_method'] == lax.RoundingMethod.TO_NEAREST_EVEN:
79+
return symbolic(numpy.around, 'numpy.around', *args)
80+
else:
81+
raise NotImplementedError(
82+
f'rounding_method {str(kwargs["rounding_method"])} not implemented')
6383

6484

6585
def symbolic_add(*args, **kwargs):
66-
return symbolic(lax_reference.add, 'numpy.add', *args, **kwargs)
86+
return symbolic(lax_reference.add, 'lax_reference.add', *args, **kwargs)
6787

6888

6989
def symbolic_sub(*args, **kwargs):
70-
return symbolic(lax_reference.sub, 'numpy.subtract', *args, **kwargs)
90+
return symbolic(lax_reference.sub, 'lax_reference.sub', *args, **kwargs)
7191

7292

7393
def symbolic_mul(*args, **kwargs):
74-
return symbolic(lax_reference.mul, 'numpy.multiply', *args, **kwargs)
94+
return symbolic(lax_reference.mul, 'lax_reference.mul', *args, **kwargs)
7595

7696

7797
def symbolic_div(*args, **kwargs):
7898
return symbolic(lax_reference.div, 'lax_reference.div', *args, **kwargs)
7999

80100

101+
def symbolic_tan(*args, **kwargs):
102+
return symbolic(lax_reference.tan, 'lax_reference.tan', *args, **kwargs)
103+
104+
81105
def symbolic_max(*args, **kwargs):
82-
return symbolic(lax_reference.max, 'numpy.maximum', *args, **kwargs)
106+
return symbolic(lax_reference.max, 'lax_reference.max', *args, **kwargs)
83107

84108

85109
def symbolic_min(*args, **kwargs):
86-
return symbolic(lax_reference.min, 'numpy.minimum', *args, **kwargs)
110+
return symbolic(lax_reference.min, 'lax_reference.min', *args, **kwargs)
87111

88112

89113
def symbolic_and(*args, **kwargs):
@@ -143,8 +167,11 @@ def symbolic_select_n(*args, **kwargs):
143167
# swap order of on_true and on_false
144168
return lax_reference.select(pred, on_false, on_true)
145169
else:
170+
# TODO: to retain tensor structure we need to push down the select to the
171+
# lowest level of the symbolic expression tree. This is not currently
172+
# implemented.
173+
print('WARNING: symbolic_select_n is not fully implemented. This may not work as expected.')
146174
# swap order of on_true and on_false
147-
# TODO: need a more general solution to unquoting symbolic strings
148175
evaluable_pred = symbolic_representation.symbolic_representation(pred)
149176
evaluable_on_true = symbolic_representation.symbolic_representation(
150177
on_true)

neurallogic/symbolic_representation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy
22
from plum import dispatch
33

4+
# TODO: need a more general solution to unquoting symbolic strings
45

56
@dispatch
67
def symbolic_representation(x: numpy.ndarray):

tests/test_mnist.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def nln(type, x, width):
4747
def nln(type, x):
4848
num_classes = 10
4949

50-
x = hard_or.or_layer(type)(1000, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
51-
x = hard_not.not_layer(type)(num_classes)(x)
50+
x = hard_or.or_layer(type)(1800, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
51+
x = hard_not.not_layer(type)(1, dtype=jax.numpy.float16)(x)
5252
x = x.ravel()
5353
x = harden_layer.harden_layer(type)(x)
5454
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
@@ -128,13 +128,14 @@ def get_datasets():
128128
ds_builder.download_and_prepare()
129129
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
130130
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
131-
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
132-
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
131+
# XXXX
132+
train_ds["image"] = (jnp.float32(train_ds["image"]) / 255.0)
133+
test_ds["image"] = (jnp.float32(test_ds["image"]) / 255.0)
133134
# TODO: we don't need to do this even when we don't use the real encoder
134135
# Use grayscale information
135136
# Convert the floating point values in [0,1] to binary values in {0,1}
136-
train_ds["image"] = jnp.round(train_ds["image"])
137-
test_ds["image"] = jnp.round(test_ds["image"])
137+
#train_ds["image"] = jnp.round(train_ds["image"])
138+
#test_ds["image"] = jnp.round(test_ds["image"])
138139
return train_ds, test_ds
139140

140141

tests/test_real_encoder.py

Lines changed: 15 additions & 8 deletions
Large diffs are not rendered by default.

tests/test_symbolic_generation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import numpy
55

66
from neurallogic import (hard_and, hard_majority, hard_not, hard_or, hard_xor,
7-
harden, harden_layer, neural_logic_net,
7+
harden, harden_layer, neural_logic_net, real_encoder,
88
symbolic_generation)
99
from tests import utils
1010

1111

1212
def nln(type, x, width):
13-
# TODO: test real_encoder layer
13+
# Can't symbolically support this layer yet since the symbolic output is an unevaluated string that
14+
# lacks the correct tensor structure
15+
# x = real_encoder.real_encoder_layer(type)(2)(x)
16+
# x = x.ravel()
1417
x = hard_or.or_layer(type)(width)(x)
1518
x = hard_and.and_layer(type)(width)(x)
1619
x = hard_xor.xor_layer(type)(width)(x)

0 commit comments

Comments
 (0)