Skip to content

Commit d15d7ff

Browse files
committed
add test
1 parent 853a291 commit d15d7ff

File tree

1 file changed

+6
-35
lines changed

1 file changed

+6
-35
lines changed

tests/test_real_encoder.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def test_net(type, x):
119119
soft, hard, symbolic = neural_logic_net.net(test_net)
120120
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
121121
hard_weights = harden.hard_weights(weights)
122-
print(f'weights: {weights}')
123-
print(f'hard_weights: {hard_weights}')
124122

125123
test_data = [
126124
[
@@ -133,63 +131,36 @@ def test_net(type, x):
133131
[
134132
[0.6, 0.0],
135133
[
136-
[
137-
0.9469013,
138-
0.320184,
139-
0.3194083,
140-
],
141-
[
142-
0.58414006,
143-
0.7815013,
144-
0.04193211,
145-
],
134+
[0.78442293, 0.7857669, 0.3154459],
135+
[0.0, 0.0, 0.0],
146136
],
147137
],
148138
[
149139
[0.1, 0.9],
150140
[
151-
[
152-
0.05309868,
153-
0.679816,
154-
0.6805917,
155-
],
156-
[
157-
0.41585994,
158-
0.2184987,
159-
0.9580679,
160-
],
141+
[0.5149515, 0.51797545, 0.05257431],
142+
[0.69679934, 0.629154, 0.84623945],
161143
],
162144
],
163145
[
164146
[0.4, 0.6],
165147
[
166-
[
167-
0.05309868,
168-
0.320184,
169-
0.6805917,
170-
],
171-
[
172-
0.58414006,
173-
0.2184987,
174-
0.04193211,
175-
],
148+
[0.6766343, 0.67865026, 0.21029726],
149+
[0.35924158, 0.34675142, 0.4445637],
176150
],
177151
],
178152
]
179153
for input, expected in test_data:
180154
# Check that the soft function performs as expected
181155
soft_output = soft.apply(weights, jax.numpy.array(input))
182156
soft_expected = jax.numpy.array(expected)
183-
print(f'soft_output: {soft_output}\nsoft_expected: {soft_expected}')
184157
assert jax.numpy.allclose(soft_output, soft_expected)
185158

186159
# Check that the hard function performs as expected
187160
hard_expected = harden.harden(jax.numpy.array(expected))
188161
hard_output = hard.apply(hard_weights, jax.numpy.array(input))
189-
print(f'hard_output: {hard_output}\nhard_expected: {hard_expected}')
190162
assert jax.numpy.allclose(hard_output, hard_expected)
191163

192164
# Check that the symbolic function performs as expected
193165
symbolic_output = symbolic.apply(hard_weights, jax.numpy.array(input))
194-
print(f'symbolic_output: {symbolic_output}\nhard_expected: {hard_expected}')
195166
assert numpy.allclose(symbolic_output, hard_expected)

0 commit comments

Comments
 (0)