4
4
5
5
import numpy
6
6
import jax
7
- from flax import errors
8
7
from flax import linen as nn
9
8
10
9
from neurallogic import neural_logic_net , sym_gen , symbolic_primitives
@@ -22,39 +21,10 @@ def soft_and_include(w: float, x: float) -> float:
22
21
return jax .numpy .maximum (x , 1.0 - w )
23
22
24
23
25
- # TODO: do we need to jit here? should apply jit at the highest level of the architecture
26
- # TODO: may need to jit in unit tests, however
27
- """
28
- @jax.jit
29
- def hard_and_include(w: bool, x: bool) -> bool:
30
- print(f"hard_and_include: w={w}, x={x}")
31
- # TODO: this works when the function is jitted, but not when it is not jitted
32
- return x | ~w
33
- #return x or not w
34
- """
35
-
36
24
37
25
def hard_and_include (w , x ):
38
26
return jax .numpy .logical_or (x , jax .numpy .logical_not (w ))
39
27
40
- # def hard_and_include(w, x):
41
- # return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
42
-
43
-
44
- """
45
- def symbolic_and_include(w, x):
46
- expression = f"({x} or not({w}))"
47
- # Check if w is of type bool
48
- if isinstance(w, bool) and isinstance(x, bool):
49
- # We know the value of w and x, so we can evaluate the expression
50
- return eval(expression)
51
- # We don't know the value of w or x, so we return the expression
52
- return expression
53
- """
54
-
55
- # def symbolic_and_include(w, x):
56
- # symbolic_f = sym_gen.make_symbolic(hard_and_include, w, x)
57
- # return sym_gen.eval_symbolic(symbolic_f, w, x)
58
28
59
29
60
30
def soft_and_neuron (w , x ):
@@ -67,49 +37,14 @@ def hard_and_neuron(w, x):
67
37
return jax .lax .reduce (x , True , jax .lax .bitwise_and , [0 ])
68
38
69
39
70
- """
71
- def hard_and_neuron(w, x):
72
- x = jax.vmap(hard_and_include, 0, 0)(w, x)
73
- return jax.lax.reduce(x, True, jax.numpy.logical_and, [0])
74
- """
75
-
76
- """
77
- def symbolic_and_neuron(w, x):
78
- # TODO: ensure that this implementation has the same generality over tensors as vmap
79
- if not isinstance(w, list):
80
- raise TypeError(f"Input {x} should be a list")
81
- if not isinstance(x, list):
82
- raise TypeError(f"Input {x} should be a list")
83
- y = [symbolic_and_include(wi, xi) for wi, xi in zip(w, x)]
84
- expression = "(" + str(reduce(lambda a, b: f"{a} and {b}", y)) + ")"
85
- if all(isinstance(yi, bool) for yi in y):
86
- # We know the value of all yis, so we can evaluate the expression
87
- return eval(expression)
88
- return expression
89
- """
90
-
91
40
soft_and_layer = jax .vmap (soft_and_neuron , (0 , None ), 0 )
92
41
93
42
hard_and_layer = jax .vmap (hard_and_neuron , (0 , None ), 0 )
94
43
95
- """
96
- def symbolic_and_layer(w, x):
97
- # TODO: ensure that this implementation has the same generality over tensors as vmap
98
- if not isinstance(w, list):
99
- raise TypeError(f"Input {x} should be a list")
100
- if not isinstance(x, list):
101
- raise TypeError(f"Input {x} should be a list")
102
- return [symbolic_and_neuron(wi, x) for wi in w]
103
- """
104
-
105
- # def symbolic_and_layer(w, x):
106
- # symbolic_hard_and_layer = sym_gen.make_symbolic(hard_and_layer)
107
- # return sym_gen.eval_symbolic(symbolic_hard_and_layer, w, x)
108
-
109
- # TODO: investigate better initialization
110
44
111
45
112
46
def initialize_near_to_zero ():
47
+ # TODO: investigate better initialization
113
48
def init (key , shape , dtype ):
114
49
dtype = jax .dtypes .canonicalize_dtype (dtype )
115
50
# Sample from standard normal distribution (zero mean, unit variance)
0 commit comments