Skip to content

Commit 16b19c9

Browse files
committed
train while avoiding nan
1 parent d15d7ff commit 16b19c9

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

neurallogic/real_encoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8-
98
def soft_real_encoder(t: float, x: float) -> float:
9+
eps = 0.0000001
1010
# x should be in [0, 1]
1111
t = jax.numpy.clip(t, 0.0, 1.0)
1212
return jax.numpy.where(
13-
x == t,
13+
jax.numpy.isclose(t, x),
1414
0.5,
15+
# t != x
1516
jax.numpy.where(
1617
x < t,
17-
(1.0 / (2.0 * t)) * x,
18-
(1.0 / (2.0 * (1.0 - t))) * (x + 1.0 - 2.0 * t),
19-
),
18+
(1.0 / (2.0 * t + eps)) * x,
19+
# x > t
20+
(1.0 / (2.0 * (1.0 - t) + eps)) * (x + 1.0 - 2.0 * t)
21+
)
2022
)
2123

2224

neurallogic/symbolic_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def symbolic_bind(prim, *args, **params):
2323
"le": symbolic_primitives.symbolic_le,
2424
"lt": symbolic_primitives.symbolic_lt,
2525
"gt": symbolic_primitives.symbolic_gt,
26+
"abs": symbolic_primitives.symbolic_abs,
2627
"add": symbolic_primitives.symbolic_add,
2728
"sub": symbolic_primitives.symbolic_sub,
2829
"mul": symbolic_primitives.symbolic_mul,

neurallogic/symbolic_primitives.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ def symbolic_gt(*args, **kwargs):
234234
return binary_infix_operator(">", *args, **kwargs)
235235

236236

237+
def symbolic_abs(*args, **kwargs):
238+
if all_concrete_values([*args]):
239+
return lax_reference.abs(*args, **kwargs)
240+
else:
241+
return unary_operator("np.absolute", *args, **kwargs)
242+
243+
237244
def symbolic_add(*args, **kwargs):
238245
if all_concrete_values([*args]):
239246
return lax_reference.add(*args, **kwargs)

tests/test_real_encoder.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from typing import Callable
2-
import numpy
2+
33
import jax
4+
from jax.config import config
5+
import numpy
6+
import optax
7+
from flax.training import train_state
48
from jax import random
59

6-
from neurallogic import harden, real_encoder, symbolic_generation, neural_logic_net
10+
from neurallogic import harden, neural_logic_net, real_encoder, symbolic_generation
11+
12+
13+
config.update("jax_debug_nans", True)
714

815

916
def check_consistency(soft: Callable, hard: Callable, expected, *args):
@@ -123,10 +130,7 @@ def test_net(type, x):
123130
test_data = [
124131
[
125132
[1.0, 0.8],
126-
[
127-
[1.0, 1.0, 1.0],
128-
[0.47898874, 0.4623352, 0.6924789]
129-
],
133+
[[1.0, 1.0, 1.0], [0.47898874, 0.4623352, 0.6924789]],
130134
],
131135
[
132136
[0.6, 0.0],
@@ -145,7 +149,7 @@ def test_net(type, x):
145149
[
146150
[0.4, 0.6],
147151
[
148-
[0.6766343, 0.67865026, 0.21029726],
152+
[0.6766343, 0.67865026, 0.21029726],
149153
[0.35924158, 0.34675142, 0.4445637],
150154
],
151155
],
@@ -164,3 +168,51 @@ def test_net(type, x):
164168
# Check that the symbolic function performs as expected
165169
symbolic_output = symbolic.apply(hard_weights, jax.numpy.array(input))
166170
assert numpy.allclose(symbolic_output, hard_expected)
171+
172+
173+
def test_train_real_encoder():
174+
def test_net(type, x):
175+
return real_encoder.real_encoder_layer(type)(3)(x)
176+
177+
soft, hard, symbolic = neural_logic_net.net(test_net)
178+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
179+
180+
x = [
181+
[0.8, 0.9],
182+
[0.85, 0.1],
183+
[0.2, 0.8],
184+
[0.3, 0.7],
185+
]
186+
y = [
187+
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
188+
[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
189+
[[1.0, 1.0, 0.0], [1.0, 1.0, 1.0]],
190+
[[1.0, 1.0, 0.0], [1.0, 0.0, 1.0]],
191+
]
192+
input = jax.numpy.array(x)
193+
output = jax.numpy.array(y)
194+
195+
# Train the real_encoder layer
196+
tx = optax.sgd(0.1)
197+
state = train_state.TrainState.create(
198+
apply_fn=jax.vmap(soft.apply, in_axes=(None, 0)), params=weights, tx=tx
199+
)
200+
grad_fn = jax.jit(
201+
jax.value_and_grad(
202+
lambda params, x, y: jax.numpy.mean((state.apply_fn(params, x) - y) ** 2)
203+
)
204+
)
205+
for epoch in range(1, 100):
206+
loss, grads = grad_fn(state.params, input, output)
207+
state = state.apply_gradients(grads=grads)
208+
209+
# Test that the real_encoder layer (both soft and hard variants) correctly predicts y
210+
weights = state.params
211+
hard_weights = harden.hard_weights(weights)
212+
213+
for input, expected in zip(x, y):
214+
hard_expected = harden.harden(jax.numpy.array(expected))
215+
hard_result = hard.apply(hard_weights, jax.numpy.array(input))
216+
assert jax.numpy.allclose(hard_result, hard_expected)
217+
symbolic_output = symbolic.apply(hard_weights, jax.numpy.array(input))
218+
assert jax.numpy.array_equal(symbolic_output, hard_expected)

0 commit comments

Comments
 (0)