Skip to content

Commit 167db49

Browse files
committed
move surgery code to sym_gen
1 parent 16d6dfa commit 167db49

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

neurallogic/hard_and.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Any
2-
from functools import reduce
3-
from typing import (Callable, Mapping)
42

53
import numpy
64
import jax
75
from flax import linen as nn
6+
from typing import Callable
7+
88

99
from neurallogic import neural_logic_net, sym_gen, symbolic_primitives
1010

@@ -105,48 +105,25 @@ def __call__(self, x):
105105
return sym_gen.eval_symbolic(jaxpr, x)
106106

107107

108-
def my_scope_put_variable(self, col: str, name: str, value: Any):
109-
self._check_valid()
110-
self._validate_trace_level()
111-
variables = self._collection(col)
112-
113-
def put(target, key, val):
114-
if (key in target and isinstance(target[key], dict) and
115-
isinstance(val, Mapping)):
116-
for k, v in val.items():
117-
put(target[key], k, v)
118-
else:
119-
target[key] = val
120-
121-
put(variables, name, value)
122-
123-
124-
def my_put_variable(self, col: str, name: str, value: Any):
125-
if self.scope is None:
126-
raise ValueError("Can't access variables on unbound modules")
127-
self.scope._variables = self.scope.variables().unfreeze()
128-
my_scope_put_variable(self.scope, col, name, value)
129-
130-
131108
class SymbolicAndLayer:
132109
def __init__(self, layer_size):
133110
self.layer_size = layer_size
134111
self.hard_and_layer = HardAndLayer(self.layer_size)
135112

136113
def __call__(self, x):
137-
symbolic_weights = self.hard_and_layer.get_variable("params", "weights")
138-
if isinstance(symbolic_weights, list) or (isinstance(symbolic_weights, numpy.ndarray) and symbolic_weights.dtype == object):
139-
symbolic_weights_n = symbolic_primitives.map_at_elements(symbolic_weights, lambda x: 0)
140-
symbolic_weights_n = numpy.asarray(symbolic_weights_n, dtype=numpy.float32)
141-
my_put_variable(self.hard_and_layer, "params", "weights", symbolic_weights_n)
114+
actual_weights = self.hard_and_layer.get_variable("params", "weights")
115+
if isinstance(actual_weights, list) or (isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object):
116+
numeric_weights = symbolic_primitives.map_at_elements(actual_weights, lambda x: 0)
117+
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.float32)
118+
sym_gen.put_variable(self.hard_and_layer, "params", "weights", numeric_weights)
142119
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
143120
xn = symbolic_primitives.map_at_elements(x, lambda x: 0)
144121
xn = numpy.asarray(xn, dtype=numpy.float32)
145122
else:
146123
xn = x
147124
jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, xn)
148125
# Swap out the numeric consts (that represent the weights) for the symbolic weights
149-
jaxpr.consts = [symbolic_weights]
126+
jaxpr.consts = [actual_weights]
150127
return sym_gen.symbolic_expression(jaxpr, x)
151128

152129

neurallogic/sym_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from neurallogic import symbolic_primitives
88
from plum import dispatch
99
import typing
10+
from typing import (Any, Callable)
1011

1112
# TODO: rename this file to symbolic.py
1213

@@ -31,6 +32,23 @@ def symbolic_bind(prim, *args, **params):
3132
}[prim.name](*args, **params)
3233
return symbolic_outvals
3334

35+
def scope_put_variable(self, col: str, name: str, value: Any):
36+
variables = self._collection(col)
37+
38+
def put(target, key, val):
39+
if (key in target and isinstance(target[key], dict) and
40+
isinstance(val, Mapping)):
41+
for k, v in val.items():
42+
put(target[key], k, v)
43+
else:
44+
target[key] = val
45+
46+
put(variables, name, value)
47+
48+
49+
def put_variable(self, col: str, name: str, value: Any):
50+
self.scope._variables = self.scope.variables().unfreeze()
51+
scope_put_variable(self.scope, col, name, value)
3452

3553
def eval_jaxpr(symbolic, jaxpr, consts, *args):
3654
"""Evaluates a jaxpr by interpreting it as Python code.

0 commit comments

Comments
 (0)