1
1
from typing import Callable
2
- import numpy
2
+
3
3
import jax
4
+ from jax .config import config
5
+ import numpy
6
+ import optax
7
+ from flax .training import train_state
4
8
from jax import random
5
9
6
- from neurallogic import harden , real_encoder , symbolic_generation , neural_logic_net
10
+ from neurallogic import harden , neural_logic_net , real_encoder , symbolic_generation
11
+
12
+
13
+ config .update ("jax_debug_nans" , True )
7
14
8
15
9
16
def check_consistency (soft : Callable , hard : Callable , expected , * args ):
@@ -123,10 +130,7 @@ def test_net(type, x):
123
130
test_data = [
124
131
[
125
132
[1.0 , 0.8 ],
126
- [
127
- [1.0 , 1.0 , 1.0 ],
128
- [0.47898874 , 0.4623352 , 0.6924789 ]
129
- ],
133
+ [[1.0 , 1.0 , 1.0 ], [0.47898874 , 0.4623352 , 0.6924789 ]],
130
134
],
131
135
[
132
136
[0.6 , 0.0 ],
@@ -145,7 +149,7 @@ def test_net(type, x):
145
149
[
146
150
[0.4 , 0.6 ],
147
151
[
148
- [0.6766343 , 0.67865026 , 0.21029726 ],
152
+ [0.6766343 , 0.67865026 , 0.21029726 ],
149
153
[0.35924158 , 0.34675142 , 0.4445637 ],
150
154
],
151
155
],
@@ -164,3 +168,51 @@ def test_net(type, x):
164
168
# Check that the symbolic function performs as expected
165
169
symbolic_output = symbolic .apply (hard_weights , jax .numpy .array (input ))
166
170
assert numpy .allclose (symbolic_output , hard_expected )
171
+
172
+
173
+ def test_train_real_encoder ():
174
+ def test_net (type , x ):
175
+ return real_encoder .real_encoder_layer (type )(3 )(x )
176
+
177
+ soft , hard , symbolic = neural_logic_net .net (test_net )
178
+ weights = soft .init (random .PRNGKey (0 ), [0.0 , 0.0 ])
179
+
180
+ x = [
181
+ [0.8 , 0.9 ],
182
+ [0.85 , 0.1 ],
183
+ [0.2 , 0.8 ],
184
+ [0.3 , 0.7 ],
185
+ ]
186
+ y = [
187
+ [[1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 ]],
188
+ [[1.0 , 1.0 , 1.0 ], [0.0 , 0.0 , 0.0 ]],
189
+ [[1.0 , 1.0 , 0.0 ], [1.0 , 1.0 , 1.0 ]],
190
+ [[1.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 1.0 ]],
191
+ ]
192
+ input = jax .numpy .array (x )
193
+ output = jax .numpy .array (y )
194
+
195
+ # Train the real_encoder layer
196
+ tx = optax .sgd (0.1 )
197
+ state = train_state .TrainState .create (
198
+ apply_fn = jax .vmap (soft .apply , in_axes = (None , 0 )), params = weights , tx = tx
199
+ )
200
+ grad_fn = jax .jit (
201
+ jax .value_and_grad (
202
+ lambda params , x , y : jax .numpy .mean ((state .apply_fn (params , x ) - y ) ** 2 )
203
+ )
204
+ )
205
+ for epoch in range (1 , 100 ):
206
+ loss , grads = grad_fn (state .params , input , output )
207
+ state = state .apply_gradients (grads = grads )
208
+
209
+ # Test that the real_encoder layer (both soft and hard variants) correctly predicts y
210
+ weights = state .params
211
+ hard_weights = harden .hard_weights (weights )
212
+
213
+ for input , expected in zip (x , y ):
214
+ hard_expected = harden .harden (jax .numpy .array (expected ))
215
+ hard_result = hard .apply (hard_weights , jax .numpy .array (input ))
216
+ assert jax .numpy .allclose (hard_result , hard_expected )
217
+ symbolic_output = symbolic .apply (hard_weights , jax .numpy .array (input ))
218
+ assert jax .numpy .array_equal (symbolic_output , hard_expected )
0 commit comments