Skip to content

Commit 4c4fbf2

Browse files
committed
pull out a util
1 parent 8ae5251 commit 4c4fbf2

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

tests/test_hard_and.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,9 @@
77
import typing
88

99
from neurallogic import hard_and, harden, neural_logic_net, symbolic_generation
10+
from tests import utils
1011

1112

12-
def check_consistency(soft: typing.Callable, hard: typing.Callable, expected, *args):
13-
# Check that the soft function performs as expected
14-
assert numpy.allclose(soft(*args), expected)
15-
16-
# Check that the hard function performs as expected
17-
hard_args = harden.harden(*args)
18-
hard_expected = harden.harden(expected)
19-
assert numpy.allclose(hard(*hard_args), hard_expected)
20-
21-
# Check that the jaxpr performs as expected
22-
symbolic_f = symbolic_generation.make_symbolic_jaxpr(hard, *hard_args)
23-
assert numpy.allclose(symbolic_generation.eval_symbolic(
24-
symbolic_f, *hard_args), hard_expected)
25-
2613

2714
def test_include():
2815
test_data = [
@@ -36,7 +23,7 @@ def test_include():
3623
[[-0.1, 1.0], 1.0]
3724
]
3825
for input, expected in test_data:
39-
check_consistency(hard_and.soft_and_include, hard_and.hard_and_include,
26+
utils.check_consistency(hard_and.soft_and_include, hard_and.hard_and_include,
4027
expected, input[0], input[1])
4128

4229

@@ -56,7 +43,7 @@ def soft(weights, input):
5643
def hard(weights, input):
5744
return hard_and.hard_and_neuron(weights, input)
5845

59-
check_consistency(soft, hard, expected,
46+
utils.check_consistency(soft, hard, expected,
6047
jax.numpy.array(weights), jax.numpy.array(input))
6148

6249

@@ -81,7 +68,7 @@ def hard(weights, input):
8168
return hard_and.hard_and_layer(weights, input)
8269

8370

84-
check_consistency(soft, hard, jax.numpy.array(expected),
71+
utils.check_consistency(soft, hard, jax.numpy.array(expected),
8572
jax.numpy.array(weights), jax.numpy.array(input))
8673

8774

tests/test_hard_not.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def test_activation():
1818
[[-0.1, 0.0], 1.0],
1919
[[-0.1, 1.0], 0.0]
2020
]
21+
for input, expected in test_data:
22+
check_consistency(hard_and.soft_and_include, hard_and.hard_and_include,
23+
expected, input[0], input[1])
24+
2125
for input, expected in test_data:
2226
assert hard_not.soft_not(*input) == expected
2327
assert hard_not.hard_not(*harden.harden(input)

tests/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Callable
2+
import numpy
3+
from neurallogic import harden, symbolic_generation
4+
5+
def check_consistency(soft: Callable, hard: Callable, expected, *args):
6+
# Check that the soft function performs as expected
7+
assert numpy.allclose(soft(*args), expected)
8+
9+
# Check that the hard function performs as expected
10+
hard_args = harden.harden(*args)
11+
hard_expected = harden.harden(expected)
12+
assert numpy.allclose(hard(*hard_args), hard_expected)
13+
14+
# Check that the jaxpr performs as expected
15+
symbolic_f = symbolic_generation.make_symbolic_jaxpr(hard, *hard_args)
16+
assert numpy.allclose(symbolic_generation.eval_symbolic(
17+
symbolic_f, *hard_args), hard_expected)

0 commit comments

Comments
 (0)