Skip to content

Commit a786e05

Browse files
committed
cleanup
1 parent b382177 commit a786e05

File tree

3 files changed

+8
-85
lines changed

3 files changed

+8
-85
lines changed

neurallogic/hard_and.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy
66
import jax
7-
from flax import errors
87
from flax import linen as nn
98

109
from neurallogic import neural_logic_net, sym_gen, symbolic_primitives
@@ -22,39 +21,10 @@ def soft_and_include(w: float, x: float) -> float:
2221
return jax.numpy.maximum(x, 1.0 - w)
2322

2423

25-
# TODO: do we need to jit here? should apply jit at the highest level of the architecture
26-
# TODO: may need to jit in unit tests, however
27-
"""
28-
@jax.jit
29-
def hard_and_include(w: bool, x: bool) -> bool:
30-
print(f"hard_and_include: w={w}, x={x}")
31-
# TODO: this works when the function is jitted, but not when it is not jitted
32-
return x | ~w
33-
#return x or not w
34-
"""
35-
3624

3725
def hard_and_include(w, x):
3826
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
3927

40-
# def hard_and_include(w, x):
41-
# return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
42-
43-
44-
"""
45-
def symbolic_and_include(w, x):
46-
expression = f"({x} or not({w}))"
47-
# Check if w is of type bool
48-
if isinstance(w, bool) and isinstance(x, bool):
49-
# We know the value of w and x, so we can evaluate the expression
50-
return eval(expression)
51-
# We don't know the value of w or x, so we return the expression
52-
return expression
53-
"""
54-
55-
# def symbolic_and_include(w, x):
56-
# symbolic_f = sym_gen.make_symbolic(hard_and_include, w, x)
57-
# return sym_gen.eval_symbolic(symbolic_f, w, x)
5828

5929

6030
def soft_and_neuron(w, x):
@@ -67,49 +37,14 @@ def hard_and_neuron(w, x):
6737
return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0])
6838

6939

70-
"""
71-
def hard_and_neuron(w, x):
72-
x = jax.vmap(hard_and_include, 0, 0)(w, x)
73-
return jax.lax.reduce(x, True, jax.numpy.logical_and, [0])
74-
"""
75-
76-
"""
77-
def symbolic_and_neuron(w, x):
78-
# TODO: ensure that this implementation has the same generality over tensors as vmap
79-
if not isinstance(w, list):
80-
raise TypeError(f"Input {x} should be a list")
81-
if not isinstance(x, list):
82-
raise TypeError(f"Input {x} should be a list")
83-
y = [symbolic_and_include(wi, xi) for wi, xi in zip(w, x)]
84-
expression = "(" + str(reduce(lambda a, b: f"{a} and {b}", y)) + ")"
85-
if all(isinstance(yi, bool) for yi in y):
86-
# We know the value of all yis, so we can evaluate the expression
87-
return eval(expression)
88-
return expression
89-
"""
90-
9140
soft_and_layer = jax.vmap(soft_and_neuron, (0, None), 0)
9241

9342
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
9443

95-
"""
96-
def symbolic_and_layer(w, x):
97-
# TODO: ensure that this implementation has the same generality over tensors as vmap
98-
if not isinstance(w, list):
99-
raise TypeError(f"Input {x} should be a list")
100-
if not isinstance(x, list):
101-
raise TypeError(f"Input {x} should be a list")
102-
return [symbolic_and_neuron(wi, x) for wi in w]
103-
"""
104-
105-
# def symbolic_and_layer(w, x):
106-
# symbolic_hard_and_layer = sym_gen.make_symbolic(hard_and_layer)
107-
# return sym_gen.eval_symbolic(symbolic_hard_and_layer, w, x)
108-
109-
# TODO: investigate better initialization
11044

11145

11246
def initialize_near_to_zero():
47+
# TODO: investigate better initialization
11348
def init(key, shape, dtype):
11449
dtype = jax.dtypes.canonicalize_dtype(dtype)
11550
# Sample from standard normal distribution (zero mean, unit variance)

neurallogic/sym_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from jax import core
55
from jax._src.util import safe_map
66
import flax
7-
from neurallogic import symbolic_primitives, harden
7+
from neurallogic import symbolic_primitives
88
from plum import dispatch
99
import typing
1010

tests/test_hard_and.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from neurallogic import hard_and, harden, neural_logic_net, primitives, sym_gen, symbolic_primitives
1010

1111

12-
def check_consistency(soft: typing.Callable, hard: typing.Callable, symbolic: typing.Callable, expected, *args):
12+
def check_consistency(soft: typing.Callable, hard: typing.Callable, expected, *args):
1313
# Check that the soft function performs as expected
1414
assert numpy.allclose(soft(*args), expected)
1515

@@ -18,18 +18,11 @@ def check_consistency(soft: typing.Callable, hard: typing.Callable, symbolic: t
1818
hard_expected = harden.harden(expected)
1919
assert numpy.allclose(hard(*hard_args), hard_expected)
2020

21-
# Check that the symbolic function performs as expected
22-
symbolic_f = sym_gen.make_symbolic_jaxpr(symbolic, *hard_args)
21+
# Check that the jaxpr expression performs as expected
22+
symbolic_f = sym_gen.make_symbolic_jaxpr(hard, *hard_args)
2323
assert numpy.allclose(sym_gen.eval_symbolic(
2424
symbolic_f, *hard_args), hard_expected)
2525

26-
# Check that the symbolic function, when evaluted with symbolic inputs, performs as expected
27-
symbolic_input = sym_gen.make_symbolic(*hard_args)
28-
symbolic_expression = sym_gen.symbolic_expression(
29-
symbolic_f, *symbolic_input)
30-
symbolic_output = sym_gen.eval_symbolic_expression(symbolic_expression)
31-
assert numpy.allclose(symbolic_output, hard_expected)
32-
3326

3427
def test_include():
3528
test_data = [
@@ -44,7 +37,7 @@ def test_include():
4437
]
4538
for input, expected in test_data:
4639
check_consistency(hard_and.soft_and_include, hard_and.hard_and_include,
47-
hard_and.hard_and_include, expected, input[0], input[1])
40+
expected, input[0], input[1])
4841

4942

5043
def test_neuron():
@@ -63,10 +56,7 @@ def soft(weights, input):
6356
def hard(weights, input):
6457
return hard_and.hard_and_neuron(weights, input)
6558

66-
def symbolic(weights, input):
67-
return hard(weights, input)
68-
69-
check_consistency(soft, hard, symbolic, expected,
59+
check_consistency(soft, hard, expected,
7060
jax.numpy.array(weights), jax.numpy.array(input))
7161

7262

@@ -90,10 +80,8 @@ def soft(weights, input):
9080
def hard(weights, input):
9181
return hard_and.hard_and_layer(weights, input)
9282

93-
def symbolic(weights, input):
94-
return hard(weights, input)
9583

96-
check_consistency(soft, hard, symbolic, jax.numpy.array(expected),
84+
check_consistency(soft, hard, jax.numpy.array(expected),
9785
jax.numpy.array(weights), jax.numpy.array(input))
9886

9987

0 commit comments

Comments
 (0)