Skip to content

Commit 948cca5

Browse files
committed
cleanup
1 parent ec8adc4 commit 948cca5

File tree

5 files changed

+171
-16
lines changed

5 files changed

+171
-16
lines changed

neurallogic/hard_not.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def hard_not(w: bool, x: bool) -> bool:
2828
hard_not_neuron = jax.vmap(hard_not, 0, 0)
2929

3030

31-
soft_not_layer = jax.vmap(soft_not_neuron, (0, 0), 0)
31+
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
3232

33-
hard_not_layer = jax.vmap(hard_not_neuron, (0, 0), 0)
33+
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
3434

3535

3636
class SoftNotLayer(nn.Module):

neurallogic/harden.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,58 @@ def harden_float(x: float) -> bool:
1111

1212
harden_array = jax.vmap(harden_float, 0, 0)
1313

14+
1415
@dispatch
1516
def harden(x: float):
1617
if numpy.isnan(x):
1718
return x
1819
return harden_float(x)
1920

21+
2022
@dispatch
2123
def harden(x: list):
2224
return symbolic_primitives.map_at_elements(x, harden_float)
2325

26+
2427
@dispatch
2528
def harden(x: numpy.ndarray):
2629
return harden_array(x)
2730

31+
2832
@dispatch
2933
def harden(x: jax.numpy.ndarray):
3034
return harden_array(x)
3135

36+
3237
@dispatch
3338
def harden(x: dict):
34-
return symbolic_primitives.map_at_elements(x, harden_float)
39+
return symbolic_primitives.map_at_elements(x, harden)
40+
3541

3642
@dispatch
3743
def harden(x: flax.core.FrozenDict):
38-
return flax.core.FrozenDict(symbolic_primitives.map_at_elements(x.unfreeze(), harden_float))
44+
return flax.core.FrozenDict(
45+
symbolic_primitives.map_at_elements(x.unfreeze(), harden)
46+
)
47+
3948

4049
@dispatch
4150
def harden(*args):
4251
if len(args) == 1:
4352
return harden(args[0])
4453
return tuple([harden(arg) for arg in args])
4554

55+
4656
@dispatch
4757
def map_keys_nested(f, d: dict) -> dict:
48-
return {f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items()}
58+
return {
59+
f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items()
60+
}
61+
4962

5063
def hard_weights(weights):
51-
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze())))
64+
return flax.core.FrozenDict(
65+
map_keys_nested(
66+
lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze())
67+
)
68+
)

neurallogic/real_encoder.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
from typing import Callable
2+
13
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
27

38

49
def soft_real_encoder(t: float, x: float) -> float:
@@ -29,3 +34,51 @@ def hard_real_encoder(t: float, x: float) -> bool:
2934
hard_real_encoder_layer = jax.vmap(hard_real_encoder_neuron, (0, 0), 0)
3035

3136

37+
class SoftRealEncoderLayer(nn.Module):
38+
bits_per_real: int
39+
thresholds_init: Callable = nn.initializers.uniform(1.0)
40+
dtype: jax.numpy.dtype = jax.numpy.float32
41+
42+
@nn.compact
43+
def __call__(self, x):
44+
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
45+
thresholds = self.param("thresholds", self.thresholds_init, thresholds_shape, self.dtype)
46+
x = jax.numpy.asarray(x, self.dtype)
47+
return soft_real_encoder_layer(thresholds, x)
48+
49+
50+
class HardRealEncoderLayer(nn.Module):
51+
bits_per_real: int
52+
53+
@nn.compact
54+
def __call__(self, x):
55+
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
56+
thresholds = self.param("thresholds", nn.initializers.constant(0.0), thresholds_shape)
57+
return hard_real_encoder_layer(thresholds, x)
58+
59+
60+
class SymbolicRealEncoderLayer:
61+
def __init__(self, bits_per_real):
62+
self.bits_per_real = bits_per_real
63+
self.hard_real_encoder_layer = HardRealEncoderLayer(self.bits_per_real)
64+
65+
def __call__(self, x):
66+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
67+
self.hard_real_encoder_layer, x
68+
)
69+
return symbolic_generation.symbolic_expression(jaxpr, x)
70+
71+
72+
real_encoder_layer = neural_logic_net.select(
73+
lambda bits_per_real, weights_init=nn.initializers.uniform(
74+
1.0
75+
), dtype=jax.numpy.float32: SoftRealEncoderLayer(
76+
bits_per_real, weights_init, dtype
77+
),
78+
lambda bits_per_real, weights_init=nn.initializers.uniform(
79+
1.0
80+
), dtype=jax.numpy.float32: HardRealEncoderLayer(bits_per_real),
81+
lambda bits_per_real, weights_init=nn.initializers.uniform(
82+
1.0
83+
), dtype=jax.numpy.float32: SymbolicRealEncoderLayer(bits_per_real),
84+
)

tests/test_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def get_config():
195195
# Always commit with num_epochs = 1 for short test time
196196
config.momentum = 0.9
197197
config.batch_size = 128
198-
config.num_epochs = 1000
198+
config.num_epochs = 2
199199
return config
200200

201201

tests/test_real_encoder.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
from typing import Callable
22
import numpy
33
import jax
4+
from jax import random
45

5-
from neurallogic import harden, real_encoder, symbolic_generation
6+
from neurallogic import harden, real_encoder, symbolic_generation, neural_logic_net
67

78

89
def check_consistency(soft: Callable, hard: Callable, expected, *args):
9-
print(f'\nchecking consistency for {soft.__name__}')
10+
# print(f'\nchecking consistency for {soft.__name__}')
1011
# Check that the soft function performs as expected
1112
soft_output = soft(*args)
12-
print(f'Expected: {expected}, Actual soft_output: {soft_output}')
13+
# print(f'Expected: {expected}, Actual soft_output: {soft_output}')
1314
assert numpy.allclose(soft_output, expected, equal_nan=True)
1415

1516
# Check that the hard function performs as expected
1617
# N.B. We don't harden the inputs because the hard_bit expects real-valued inputs
1718
hard_expected = harden.harden(expected)
18-
hard_output = hard(*args)
19-
print(f'Expected: {hard_expected}, Actual hard_output: {hard_output}')
19+
hard_output = hard(*args)
20+
# print(f'Expected: {hard_expected}, Actual hard_output: {hard_output}')
2021
assert numpy.allclose(hard_output, hard_expected, equal_nan=True)
2122

2223
# Check that the jaxpr performs as expected
2324
symbolic_f = symbolic_generation.make_symbolic_jaxpr(hard, *args)
2425
symbolic_output = symbolic_generation.eval_symbolic(symbolic_f, *args)
25-
print(f'Expected: {hard_expected}, Actual symbolic_output: {symbolic_output}')
26+
# print(f'Expected: {hard_expected}, Actual symbolic_output: {symbolic_output}')
2627
assert numpy.allclose(symbolic_output, hard_expected, equal_nan=True)
2728

2829

29-
3030
def test_activation():
3131
test_data = [
3232
[[1.0, 1.0], 0.5],
@@ -40,7 +40,11 @@ def test_activation():
4040
]
4141
for input, expected in test_data:
4242
check_consistency(
43-
real_encoder.soft_real_encoder, real_encoder.hard_real_encoder, expected, input[0], input[1]
43+
real_encoder.soft_real_encoder,
44+
real_encoder.hard_real_encoder,
45+
expected,
46+
input[0],
47+
input[1],
4448
)
4549

4650

@@ -50,7 +54,6 @@ def test_neuron():
5054
[0.0, [0.0, 0.0, 0.9], [0.5, 0.5, 0.0]],
5155
[1.0, [0.0, 1.0, 0.1], [1.0, 0.5, 1.0]],
5256
[0.0, [1.0, 0.0, 0.3], [0.0, 0.5, 0.0]],
53-
5457
[0.3, [0.2, 0.8, 0.3], [0.5625, 0.1875, 0.5]],
5558
[0.1, [0.9, 0.42, 0.5], [0.05555556, 0.11904762, 0.1]],
5659
[0.4, [0.2, 0.8, 0.7], [0.625, 0.25, 0.2857143]],
@@ -108,3 +111,85 @@ def hard(thresholds, input):
108111
jax.numpy.array(input),
109112
)
110113

114+
115+
def test_real_encoder():
116+
def test_net(type, x):
117+
return real_encoder.real_encoder_layer(type)(3)(x)
118+
119+
soft, hard, symbolic = neural_logic_net.net(test_net)
120+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
121+
hard_weights = harden.hard_weights(weights)
122+
print(f'weights: {weights}')
123+
print(f'hard_weights: {hard_weights}')
124+
125+
test_data = [
126+
[
127+
[1.0, 0.8],
128+
[
129+
[1.0, 1.0, 1.0],
130+
[0.47898874, 0.4623352, 0.6924789]
131+
],
132+
],
133+
[
134+
[0.6, 0.0],
135+
[
136+
[
137+
0.9469013,
138+
0.320184,
139+
0.3194083,
140+
],
141+
[
142+
0.58414006,
143+
0.7815013,
144+
0.04193211,
145+
],
146+
],
147+
],
148+
[
149+
[0.1, 0.9],
150+
[
151+
[
152+
0.05309868,
153+
0.679816,
154+
0.6805917,
155+
],
156+
[
157+
0.41585994,
158+
0.2184987,
159+
0.9580679,
160+
],
161+
],
162+
],
163+
[
164+
[0.4, 0.6],
165+
[
166+
[
167+
0.05309868,
168+
0.320184,
169+
0.6805917,
170+
],
171+
[
172+
0.58414006,
173+
0.2184987,
174+
0.04193211,
175+
],
176+
],
177+
],
178+
]
179+
for input, expected in test_data:
180+
# Check that the soft function performs as expected
181+
soft_output = soft.apply(weights, jax.numpy.array(input))
182+
soft_expected = jax.numpy.array(expected)
183+
print(f'soft_output: {soft_output}\nsoft_expected: {soft_expected}')
184+
assert jax.numpy.allclose(soft_output, soft_expected)
185+
186+
# Check that the hard function performs as expected
187+
hard_expected = harden.harden(jax.numpy.array(expected))
188+
hard_output = hard.apply(hard_weights, jax.numpy.array(input))
189+
print(f'hard_output: {hard_output}\nhard_expected: {hard_expected}')
190+
assert jax.numpy.allclose(hard_output, hard_expected)
191+
192+
# Check that the symbolic function performs as expected
193+
symbolic_output = symbolic.apply(hard_weights, jax.numpy.array(input))
194+
print(f'symbolic_output: {symbolic_output}\nhard_expected: {hard_expected}')
195+
assert numpy.allclose(symbolic_output, hard_expected)

0 commit comments

Comments
 (0)