Skip to content

Commit 90dab25

Browse files
committed
remove jaxpr layer type
1 parent 167db49 commit 90dab25

File tree

4 files changed

+30
-50
lines changed

4 files changed

+30
-50
lines changed

neurallogic/hard_and.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Callable
77

88

9-
from neurallogic import neural_logic_net, sym_gen, symbolic_primitives
9+
from neurallogic import neural_logic_net, sym_gen
1010

1111

1212
def soft_and_include(w: float, x: float) -> float:
@@ -42,7 +42,6 @@ def hard_and_neuron(w, x):
4242
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
4343

4444

45-
4645
def initialize_near_to_zero():
4746
# TODO: investigate better initialization
4847
def init(key, shape, dtype):
@@ -95,44 +94,17 @@ def __call__(self, x):
9594
return hard_and_layer(weights, x)
9695

9796

98-
class JaxprAndLayer:
99-
def __init__(self, layer_size):
100-
self.layer_size = layer_size
101-
self.hard_and_layer = HardAndLayer(self.layer_size)
102-
103-
def __call__(self, x):
104-
jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, x)
105-
return sym_gen.eval_symbolic(jaxpr, x)
106-
107-
10897
class SymbolicAndLayer:
10998
def __init__(self, layer_size):
11099
self.layer_size = layer_size
111100
self.hard_and_layer = HardAndLayer(self.layer_size)
112101

113102
def __call__(self, x):
114-
actual_weights = self.hard_and_layer.get_variable("params", "weights")
115-
if isinstance(actual_weights, list) or (isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object):
116-
numeric_weights = symbolic_primitives.map_at_elements(actual_weights, lambda x: 0)
117-
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.float32)
118-
sym_gen.put_variable(self.hard_and_layer, "params", "weights", numeric_weights)
119-
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
120-
xn = symbolic_primitives.map_at_elements(x, lambda x: 0)
121-
xn = numpy.asarray(xn, dtype=numpy.float32)
122-
else:
123-
xn = x
124-
jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, xn)
125-
# Swap out the numeric consts (that represent the weights) for the symbolic weights
126-
jaxpr.consts = [actual_weights]
103+
jaxpr = sym_gen.make_symbolic_flax_jaxpr(self.hard_and_layer, x)
127104
return sym_gen.symbolic_expression(jaxpr, x)
128105

129106

130107
and_layer = neural_logic_net.select(
131-
lambda layer_size, weights_init=initialize_near_to_zero(),
132-
dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
133-
lambda layer_size, weights_init=initialize_near_to_zero(
134-
), dtype=jax.numpy.float32: HardAndLayer(layer_size),
135-
lambda layer_size, weights_init=initialize_near_to_zero(
136-
), dtype=jax.numpy.float32: JaxprAndLayer(layer_size),
137-
lambda layer_size, weights_init=initialize_near_to_zero(),
138-
dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
108+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
109+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
110+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))

neurallogic/neural_logic_net.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from enum import Enum
22
from flax import linen as nn
33

4-
NetType = Enum('NetType', ['Soft', 'Hard', 'Jaxpr', 'Symbolic'])
4+
NetType = Enum('NetType', ['Soft', 'Hard', 'Symbolic'])
55

6-
def select(soft, hard, jaxpr, symbolic):
6+
def select(soft, hard, symbolic):
77
def selector(type: NetType):
88
return {
99
NetType.Soft: soft,
1010
NetType.Hard: hard,
11-
NetType.Jaxpr: jaxpr,
1211
NetType.Symbolic: symbolic
1312
}[type]
1413
return selector
@@ -22,14 +21,9 @@ class HardNet(nn.Module):
2221
@nn.compact
2322
def __call__(self, x):
2423
return f(NetType.Hard, x)
25-
class JaxprNet(nn.Module):
26-
@nn.compact
27-
def __call__(self, x):
28-
return f(NetType.Jaxpr, x)
2924
class SymbolicNet(nn.Module):
3025
@nn.compact
3126
def __call__(self, x):
3227
return f(NetType.Symbolic, x)
33-
return SoftNet(), HardNet(), JaxprNet(), SymbolicNet()
28+
return SoftNet(), HardNet(), SymbolicNet()
3429

35-
# TODO: support init of all three net types at once

neurallogic/sym_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@ def put_variable(self, col: str, name: str, value: Any):
5050
self.scope._variables = self.scope.variables().unfreeze()
5151
scope_put_variable(self.scope, col, name, value)
5252

53+
def make_symbolic_flax_jaxpr(flax_layer, x):
54+
actual_weights = flax_layer.get_variable("params", "weights")
55+
# Convert actual weights to dummy numeric weights (if needed)
56+
if isinstance(actual_weights, list) or (isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object):
57+
numeric_weights = symbolic_primitives.map_at_elements(actual_weights, lambda x: 0)
58+
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.float32)
59+
put_variable(flax_layer, "params", "weights", numeric_weights)
60+
# Convert input to dummy numeric input (if needed)
61+
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
62+
x = symbolic_primitives.map_at_elements(x, lambda x: 0)
63+
x = numpy.asarray(x, dtype=numpy.float32)
64+
# Make the jaxpr that corresponds to the flax layer
65+
jaxpr = make_symbolic_jaxpr(flax_layer, x)
66+
# Replace the dummy numeric weights with the actual weights in the jaxpr
67+
jaxpr.consts = [actual_weights]
68+
return jaxpr
69+
70+
5371
def eval_jaxpr(symbolic, jaxpr, consts, *args):
5472
"""Evaluates a jaxpr by interpreting it as Python code.
5573

tests/test_hard_and.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def check_consistency(soft: typing.Callable, hard: typing.Callable, expected, *a
1818
hard_expected = harden.harden(expected)
1919
assert numpy.allclose(hard(*hard_args), hard_expected)
2020

21-
# Check that the jaxpr expression performs as expected
21+
# Check that the jaxpr performs as expected
2222
symbolic_f = sym_gen.make_symbolic_jaxpr(hard, *hard_args)
2323
assert numpy.allclose(sym_gen.eval_symbolic(
2424
symbolic_f, *hard_args), hard_expected)
@@ -91,7 +91,7 @@ def test_net(type, x):
9191
x = x.ravel()
9292
return x
9393

94-
soft, hard, jaxpr, symbolic = neural_logic_net.net(test_net)
94+
soft, hard, symbolic = neural_logic_net.net(test_net)
9595
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
9696
hard_weights = harden.hard_weights(weights)
9797

@@ -124,10 +124,6 @@ def test_net(type, x):
124124
hard_output = hard.apply(hard_weights, hard_input)
125125
assert jax.numpy.allclose(hard_output, hard_expected)
126126

127-
# Check that the jaxpr expression performs as expected
128-
jaxpr_output = jaxpr.apply(hard_weights, hard_input)
129-
assert numpy.allclose(jaxpr_output, hard_expected)
130-
131127
# Check that the symbolic function performs as expected
132128
symbolic_output = symbolic.apply(hard_weights, hard_input)
133129
assert numpy.allclose(symbolic_output, hard_expected)
@@ -137,7 +133,7 @@ def test_train_and():
137133
def test_net(type, x):
138134
return hard_and.and_layer(type)(4, nn.initializers.uniform(1.0))(x)
139135

140-
soft, hard, jaxpr, symbolic = neural_logic_net.net(test_net)
136+
soft, hard, symbolic = neural_logic_net.net(test_net)
141137
soft_weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
142138

143139
x = [
@@ -184,7 +180,7 @@ def test_net(type, x):
184180
x = hard_and.and_layer(type)(4, nn.initializers.uniform(1.0))(x)
185181
return x
186182

187-
soft, hard, jaxpr, symbolic = neural_logic_net.net(test_net)
183+
soft, hard, symbolic = neural_logic_net.net(test_net)
188184

189185
# Compute soft result
190186
soft_input = jax.numpy.array([0.6, 0.45])

0 commit comments

Comments
 (0)