|
1 | 1 | from typing import Any
|
2 |
| -from functools import reduce |
3 |
| -from typing import (Callable, Mapping) |
4 | 2 |
|
5 | 3 | import numpy
|
6 | 4 | import jax
|
7 | 5 | from flax import linen as nn
|
| 6 | +from typing import Callable |
| 7 | + |
8 | 8 |
|
9 | 9 | from neurallogic import neural_logic_net, sym_gen, symbolic_primitives
|
10 | 10 |
|
@@ -105,48 +105,25 @@ def __call__(self, x):
|
105 | 105 | return sym_gen.eval_symbolic(jaxpr, x)
|
106 | 106 |
|
107 | 107 |
|
108 |
| -def my_scope_put_variable(self, col: str, name: str, value: Any): |
109 |
| - self._check_valid() |
110 |
| - self._validate_trace_level() |
111 |
| - variables = self._collection(col) |
112 |
| - |
113 |
| - def put(target, key, val): |
114 |
| - if (key in target and isinstance(target[key], dict) and |
115 |
| - isinstance(val, Mapping)): |
116 |
| - for k, v in val.items(): |
117 |
| - put(target[key], k, v) |
118 |
| - else: |
119 |
| - target[key] = val |
120 |
| - |
121 |
| - put(variables, name, value) |
122 |
| - |
123 |
| - |
124 |
| -def my_put_variable(self, col: str, name: str, value: Any): |
125 |
| - if self.scope is None: |
126 |
| - raise ValueError("Can't access variables on unbound modules") |
127 |
| - self.scope._variables = self.scope.variables().unfreeze() |
128 |
| - my_scope_put_variable(self.scope, col, name, value) |
129 |
| - |
130 |
| - |
131 | 108 | class SymbolicAndLayer:
|
132 | 109 | def __init__(self, layer_size):
|
133 | 110 | self.layer_size = layer_size
|
134 | 111 | self.hard_and_layer = HardAndLayer(self.layer_size)
|
135 | 112 |
|
136 | 113 | def __call__(self, x):
|
137 |
| - symbolic_weights = self.hard_and_layer.get_variable("params", "weights") |
138 |
| - if isinstance(symbolic_weights, list) or (isinstance(symbolic_weights, numpy.ndarray) and symbolic_weights.dtype == object): |
139 |
| - symbolic_weights_n = symbolic_primitives.map_at_elements(symbolic_weights, lambda x: 0) |
140 |
| - symbolic_weights_n = numpy.asarray(symbolic_weights_n, dtype=numpy.float32) |
141 |
| - my_put_variable(self.hard_and_layer, "params", "weights", symbolic_weights_n) |
| 114 | + actual_weights = self.hard_and_layer.get_variable("params", "weights") |
| 115 | + if isinstance(actual_weights, list) or (isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object): |
| 116 | + numeric_weights = symbolic_primitives.map_at_elements(actual_weights, lambda x: 0) |
| 117 | + numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.float32) |
| 118 | + sym_gen.put_variable(self.hard_and_layer, "params", "weights", numeric_weights) |
142 | 119 | if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
|
143 | 120 | xn = symbolic_primitives.map_at_elements(x, lambda x: 0)
|
144 | 121 | xn = numpy.asarray(xn, dtype=numpy.float32)
|
145 | 122 | else:
|
146 | 123 | xn = x
|
147 | 124 | jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, xn)
|
148 | 125 | # Swap out the numeric consts (that represent the weights) for the symbolic weights
|
149 |
| - jaxpr.consts = [symbolic_weights] |
| 126 | + jaxpr.consts = [actual_weights] |
150 | 127 | return sym_gen.symbolic_expression(jaxpr, x)
|
151 | 128 |
|
152 | 129 |
|
|
0 commit comments