Skip to content

Commit f247891

Browse files
authored
Merge pull request #54 from github/mnist-exp
Real encoder
2 parents c1774da + 27aff8c commit f247891

21 files changed

+1242
-708
lines changed

neurallogic/hard_and.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any
22

3-
import numpy
43
import jax
54
from flax import linen as nn
65
from typing import Callable
@@ -70,7 +69,7 @@ class SoftAndLayer(nn.Module):
7069
@nn.compact
7170
def __call__(self, x):
7271
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
73-
weights = self.param('weights', self.weights_init,
72+
weights = self.param('bit_weights', self.weights_init,
7473
weights_shape, self.dtype)
7574
x = jax.numpy.asarray(x, self.dtype)
7675
return soft_and_layer(weights, x)
@@ -90,7 +89,7 @@ class HardAndLayer(nn.Module):
9089
def __call__(self, x):
9190
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9291
weights = self.param(
93-
'weights', nn.initializers.constant(0.0), weights_shape)
92+
'bit_weights', nn.initializers.constant(0.0), weights_shape)
9493
return hard_and_layer(weights, x)
9594

9695

neurallogic/hard_not.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,11 @@ def hard_not(w: bool, x: bool) -> bool:
2828
hard_not_neuron = jax.vmap(hard_not, 0, 0)
2929

3030

31-
3231
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
3332

3433
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
3534

3635

37-
38-
3936
class SoftNotLayer(nn.Module):
4037
layer_size: int
4138
weights_init: Callable = nn.initializers.uniform(1.0)
@@ -44,8 +41,7 @@ class SoftNotLayer(nn.Module):
4441
@nn.compact
4542
def __call__(self, x):
4643
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
47-
weights = self.param('weights', self.weights_init,
48-
weights_shape, self.dtype)
44+
weights = self.param("bit_weights", self.weights_init, weights_shape, self.dtype)
4945
x = jax.numpy.asarray(x, self.dtype)
5046
return soft_not_layer(weights, x)
5147

@@ -56,8 +52,7 @@ class HardNotLayer(nn.Module):
5652
@nn.compact
5753
def __call__(self, x):
5854
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param(
60-
'weights', nn.initializers.constant(0.0), weights_shape)
55+
weights = self.param("bit_weights", nn.initializers.constant(0.0), weights_shape)
6156
return hard_not_layer(weights, x)
6257

6358

@@ -73,7 +68,12 @@ def __call__(self, x):
7368

7469
not_layer = neural_logic_net.select(
7570
lambda layer_size, weights_init=nn.initializers.uniform(
76-
1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
71+
1.0
72+
), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
73+
lambda layer_size, weights_init=nn.initializers.uniform(
74+
1.0
75+
), dtype=jax.numpy.float32: HardNotLayer(layer_size),
7776
lambda layer_size, weights_init=nn.initializers.uniform(
78-
1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
79-
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size))
77+
1.0
78+
), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
79+
)

neurallogic/hard_or.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class SoftOrLayer(nn.Module):
5656
@nn.compact
5757
def __call__(self, x):
5858
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param('weights', self.weights_init, weights_shape, self.dtype)
59+
weights = self.param('bit_weights', self.weights_init, weights_shape, self.dtype)
6060
x = jax.numpy.asarray(x, self.dtype)
6161
return soft_or_layer(weights, x)
6262

@@ -66,7 +66,7 @@ class HardOrLayer(nn.Module):
6666
@nn.compact
6767
def __call__(self, x):
6868
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
69-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
69+
weights = self.param('bit_weights', nn.initializers.constant(0.0), weights_shape)
7070
return hard_or_layer(weights, x)
7171

7272
class SymbolicOrLayer:

neurallogic/harden.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,74 @@
22
import jax
33
import numpy
44
from plum import dispatch
5-
from neurallogic import symbolic_primitives
5+
6+
from neurallogic import map_at_elements
67

78

89
def harden_float(x: float) -> bool:
910
return x > 0.5
1011

12+
1113
harden_array = jax.vmap(harden_float, 0, 0)
1214

15+
1316
@dispatch
1417
def harden(x: float):
18+
if numpy.isnan(x):
19+
return x
1520
return harden_float(x)
1621

22+
1723
@dispatch
1824
def harden(x: list):
19-
return symbolic_primitives.map_at_elements(x, harden_float)
25+
return map_at_elements.map_at_elements(x, harden_float)
26+
2027

2128
@dispatch
2229
def harden(x: numpy.ndarray):
2330
return harden_array(x)
2431

32+
2533
@dispatch
2634
def harden(x: jax.numpy.ndarray):
2735
return harden_array(x)
2836

37+
2938
@dispatch
3039
def harden(x: dict):
31-
return symbolic_primitives.map_at_elements(x, harden_float)
40+
# Only harden parameters that explicitly represent bits
41+
def conditional_harden(k, v):
42+
if k.startswith("bit_"):
43+
return map_at_elements.map_at_elements(v, harden)
44+
elif isinstance(v, dict) or isinstance(v, flax.core.FrozenDict) or isinstance(v, list):
45+
return harden(v)
46+
return v
47+
48+
return {k: conditional_harden(k, v) for k, v in x.items()}
49+
3250

3351
@dispatch
3452
def harden(x: flax.core.FrozenDict):
35-
return flax.core.FrozenDict(symbolic_primitives.map_at_elements(x.unfreeze(), harden_float))
53+
return harden(x.unfreeze())
54+
3655

3756
@dispatch
3857
def harden(*args):
3958
if len(args) == 1:
4059
return harden(args[0])
4160
return tuple([harden(arg) for arg in args])
4261

62+
4363
@dispatch
4464
def map_keys_nested(f, d: dict) -> dict:
45-
return {f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items()}
65+
return {
66+
f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items()
67+
}
68+
4669

4770
def hard_weights(weights):
48-
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze())))
71+
return flax.core.FrozenDict(
72+
map_keys_nested(
73+
lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze())
74+
)
75+
)

neurallogic/map_at_elements.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import typing
2+
3+
import jax
4+
import numpy
5+
from plum import dispatch
6+
7+
8+
@dispatch
9+
def map_at_elements(x: str, func: typing.Callable):
10+
return func(x)
11+
12+
13+
@dispatch
14+
def map_at_elements(x: bool, func: typing.Callable):
15+
return func(x)
16+
17+
18+
@dispatch
19+
def map_at_elements(x: numpy.bool_, func: typing.Callable):
20+
return func(x)
21+
22+
23+
@dispatch
24+
def map_at_elements(x: float, func: typing.Callable):
25+
return func(x)
26+
27+
28+
@dispatch
29+
def map_at_elements(x: numpy.float32, func: typing.Callable):
30+
return func(x)
31+
32+
33+
@dispatch
34+
def map_at_elements(x: list, func: typing.Callable):
35+
return [map_at_elements(item, func) for item in x]
36+
37+
38+
@dispatch
39+
def map_at_elements(x: numpy.ndarray, func: typing.Callable):
40+
return numpy.array([map_at_elements(item, func) for item in x], dtype=object)
41+
42+
43+
@dispatch
44+
def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable):
45+
if x.ndim == 0:
46+
return func(x.item())
47+
return jax.numpy.array([map_at_elements(item, func) for item in x])
48+
49+
50+
@dispatch
51+
def map_at_elements(x: dict, func: typing.Callable):
52+
return {k: map_at_elements(v, func) for k, v in x.items()}
53+
54+
55+
@dispatch
56+
def map_at_elements(x: tuple, func: typing.Callable):
57+
return tuple(map_at_elements(list(x), func))
58+

neurallogic/real_encoder.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Callable
2+
3+
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
7+
8+
9+
def soft_real_encoder(t: float, x: float) -> float:
10+
eps = 0.0000001
11+
# x should be in [0, 1]
12+
t = jax.numpy.clip(t, 0.0, 1.0)
13+
return jax.numpy.where(
14+
jax.numpy.isclose(t, x),
15+
0.5,
16+
# t != x
17+
jax.numpy.where(
18+
x < t,
19+
(1.0 / (2.0 * t + eps)) * x,
20+
# x > t
21+
(1.0 / (2.0 * (1.0 - t) + eps)) * (x + 1.0 - 2.0 * t)
22+
)
23+
)
24+
25+
26+
def hard_real_encoder(t: float, x: float) -> bool:
27+
# t and x must be floats
28+
return jax.numpy.where(soft_real_encoder(t, x) > 0.5, True, False)
29+
30+
31+
soft_real_encoder_neuron = jax.vmap(soft_real_encoder, in_axes=(0, None))
32+
33+
hard_real_encoder_neuron = jax.vmap(hard_real_encoder, in_axes=(0, None))
34+
35+
soft_real_encoder_layer = jax.vmap(soft_real_encoder_neuron, (0, 0), 0)
36+
37+
hard_real_encoder_layer = jax.vmap(hard_real_encoder_neuron, (0, 0), 0)
38+
39+
40+
class SoftRealEncoderLayer(nn.Module):
41+
bits_per_real: int
42+
thresholds_init: Callable = nn.initializers.uniform(1.0)
43+
dtype: jax.numpy.dtype = jax.numpy.float32
44+
45+
@nn.compact
46+
def __call__(self, x):
47+
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
48+
thresholds = self.param(
49+
"thresholds", self.thresholds_init, thresholds_shape, self.dtype)
50+
x = jax.numpy.asarray(x, self.dtype)
51+
return soft_real_encoder_layer(thresholds, x)
52+
53+
54+
class HardRealEncoderLayer(nn.Module):
55+
bits_per_real: int
56+
57+
@nn.compact
58+
def __call__(self, x):
59+
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
60+
thresholds = self.param(
61+
"thresholds", nn.initializers.constant(0.0), thresholds_shape)
62+
return hard_real_encoder_layer(thresholds, x)
63+
64+
65+
class SymbolicRealEncoderLayer:
66+
def __init__(self, bits_per_real):
67+
self.bits_per_real = bits_per_real
68+
self.hard_real_encoder_layer = HardRealEncoderLayer(self.bits_per_real)
69+
70+
def __call__(self, x):
71+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
72+
self.hard_real_encoder_layer, x
73+
)
74+
return symbolic_generation.symbolic_expression(jaxpr, x)
75+
76+
77+
real_encoder_layer = neural_logic_net.select(
78+
lambda bits_per_real, weights_init=nn.initializers.uniform(
79+
1.0
80+
), dtype=jax.numpy.float32: SoftRealEncoderLayer(
81+
bits_per_real, weights_init, dtype
82+
),
83+
lambda bits_per_real, weights_init=nn.initializers.uniform(
84+
1.0
85+
), dtype=jax.numpy.float32: HardRealEncoderLayer(bits_per_real),
86+
lambda bits_per_real, weights_init=nn.initializers.uniform(
87+
1.0
88+
), dtype=jax.numpy.float32: SymbolicRealEncoderLayer(bits_per_real),
89+
)

0 commit comments

Comments
 (0)