Skip to content

Commit 2388a8f

Browse files
committed
remove prints and comments
1 parent a8e0cf6 commit 2388a8f

File tree

4 files changed

+3
-85
lines changed

4 files changed

+3
-85
lines changed

neurallogic/hard_and.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,9 @@ def __call__(self, x):
171171

172172

173173
def my_scope_put_variable(self, col: str, name: str, value: Any):
174-
"""Updates the value of the given variable if it is mutable, or an error otherwise.
175-
176-
Args:
177-
col: the collection of the variable.
178-
name: the name of the variable.
179-
value: the new value of the given variable.
180-
"""
181174
self._check_valid()
182175
self._validate_trace_level()
183-
#if not self.is_mutable_collection(col):
184-
# raise errors.ModifyScopeVariableError(col, name, self.path_text)
185-
#variables = self._mutable_collection(col)
186176
variables = self._collection(col)
187-
# Make sure reference sharing of child variable dictionaries isn't broken
188177

189178
def put(target, key, val):
190179
if (key in target and isinstance(target[key], dict) and
@@ -198,25 +187,10 @@ def put(target, key, val):
198187

199188

200189
def my_put_variable(self, col: str, name: str, value: Any):
201-
"""Sets the value of a Variable.
202-
203-
Args:
204-
col: the variable collection.
205-
name: the name of the variable.
206-
value: the new value of the variable.
207-
208-
Returns:
209-
210-
"""
211190
if self.scope is None:
212191
raise ValueError("Can't access variables on unbound modules")
213-
mutable_variables = self.scope.variables().unfreeze()
214-
self.scope._variables = mutable_variables
215-
#mutated_self = my_scope_put_variable(self.scope, col, name, value)
216-
#self.scope.put_variable(col, name, value)
192+
self.scope._variables = self.scope.variables().unfreeze()
217193
my_scope_put_variable(self.scope, col, name, value)
218-
#immutable_scope = self.scope.variables.freeze()
219-
#return immutable_scope
220194

221195

222196
class SymbolicAndLayer:
@@ -226,52 +200,18 @@ def __init__(self, layer_size):
226200

227201
def __call__(self, x):
228202
symbolic_weights = self.hard_and_layer.get_variable("params", "weights")
229-
print(f'symbolic_weights: {symbolic_weights} of type {type(symbolic_weights)}')
230-
if isinstance(symbolic_weights, list) or (isinstance(symbolic_weights, numpy.ndarray) and symbolic_weights.dtype == numpy.object):
203+
if isinstance(symbolic_weights, list) or (isinstance(symbolic_weights, numpy.ndarray) and symbolic_weights.dtype == object):
231204
symbolic_weights_n = symbolic_primitives.map_at_elements(symbolic_weights, lambda x: 0)
232205
symbolic_weights_n = numpy.asarray(symbolic_weights_n, dtype=numpy.float32)
233206
my_put_variable(self.hard_and_layer, "params", "weights", symbolic_weights_n)
234-
print(f'converted to symbolic_weights_n: {symbolic_weights_n} of type {type(symbolic_weights_n)}')
235-
236-
#print(
237-
# f'symbolic_weights: {symbolic_weights} of type {type(symbolic_weights)}')
238-
# Convert the symbolic inputs to numeric inputs so that we can generate a jaxpr
239-
#numeric_weights = sym_gen.make_numeric(symbolic_weights)
240-
#print(
241-
# f'numeric_weights: {numeric_weights} of type {type(numeric_weights)}')
242-
#numeric_input = numpy.array(
243-
# sym_gen.make_numeric(x), dtype=numpy.float32)
244-
#print(f'numeric_input: {numeric_input} of type {type(numeric_input)}')
245-
# Overwrite the supplied weights with the temporary numeric weights
246-
#my_put_variable(self.hard_and_layer, "params", "weights", symbolic_weights_n)
247-
# Generate the jaxpr for this layer
248-
#jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, numeric_input)
249-
print(f'x: {x} of type {type(x)}')
250-
if isinstance(x, numpy.ndarray):
251-
print(f'x is a numpy array with dtype = {x.dtype}')
252-
if isinstance(x, jax.numpy.ndarray):
253-
print(f'x is a jax.numpy array with dtype = {x.dtype}')
254-
#xn = sym_gen.make_numeric(x)
255-
#print(f'converted to xn: {xn} of type {type(xn)}')
256-
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == numpy.object):
257-
# x = numpy.zeros_like(list)
258-
#x = numpy.asarray(x, dtype=numpy.float32)
207+
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
259208
xn = symbolic_primitives.map_at_elements(x, lambda x: 0)
260209
xn = numpy.asarray(xn, dtype=numpy.float32)
261210
else:
262211
xn = x
263-
print(f'converted to xn: {xn} of type {type(xn)}')
264212
jaxpr = sym_gen.make_symbolic_jaxpr(self.hard_and_layer, xn)
265-
print(f'jaxpr consts: {jaxpr.consts} of type {type(jaxpr.consts)}')
266-
print(f'jaxpr: {jaxpr}')
267213
# Swap out the numeric consts (that represent the weights) for the symbolic weights
268214
jaxpr.consts = [symbolic_weights]
269-
#print(
270-
# f'symbolic jaxpr consts: {jaxpr.consts} of type {type(jaxpr.consts)}')
271-
#symbolic_input = sym_gen.make_symbolic(x)
272-
#print(
273-
# f'symbolic_input: {symbolic_input} of type {type(symbolic_input)}')
274-
print(f'x: {x} of type {type(x)}')
275215
return sym_gen.symbolic_expression(jaxpr, x)
276216

277217

neurallogic/harden.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def harden(x: dict):
7171

7272
@dispatch
7373
def harden(*args):
74-
print(f'harden: {args} of type {type(args)} with length {len(args)}')
7574
if len(args) == 1:
7675
return harden(args[0])
7776
return tuple([harden(arg) for arg in args])
@@ -83,8 +82,6 @@ def map_keys_nested(f, d: dict) -> dict:
8382

8483
def hard_weights(weights):
8584
unfrozen_weights = weights.unfreeze()
86-
print(f'Unfrozen weights: {unfrozen_weights} of type {type(unfrozen_weights)}')
8785
hard_weights = harden(unfrozen_weights)
88-
print(f'Hard weights: {hard_weights} of type {type(hard_weights)}')
8986
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Hard"), hard_weights))
9087

neurallogic/sym_gen.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ def symbolic_write(var, val):
8383
safe_map(write, jaxpr.invars, args)
8484
safe_map(write, jaxpr.constvars, consts)
8585
safe_map(symbolic_write, jaxpr.invars, args)
86-
print(f'jaxpr.constvars: {jaxpr.constvars} of type {type(jaxpr.constvars)}')
87-
print(f'consts: {consts} of type {type(consts)}')
8886
safe_map(symbolic_write, jaxpr.constvars, consts)
8987

9088
def eval_jaxpr_impl(jaxpr):
@@ -209,7 +207,6 @@ def make_symbolic(x: flax.core.FrozenDict):
209207
return flax.core.FrozenDict(make_symbolic(x))
210208
@dispatch
211209
def make_numeric(x: flax.core.FrozenDict):
212-
#x = convert_jax_to_numpy_arrays(x.unfreeze())
213210
x = x.unfreeze()
214211
return flax.core.FrozenDict(make_numeric(x))
215212

@@ -227,12 +224,6 @@ def eval_symbolic(symbolic_function, *args):
227224

228225
def symbolic_expression(jaxpr, *args):
229226
if hasattr(jaxpr, 'literals'):
230-
#symbolic_jaxpr_literals = safe_map(
231-
# lambda x: numpy.array(x, dtype=object), jaxpr.literals)
232-
#symbolic_jaxpr_literals = make_symbolic(
233-
# symbolic_jaxpr_literals)
234-
#sym_expr = eval_jaxpr(True, jaxpr.jaxpr,
235-
# symbolic_jaxpr_literals, *args)
236227
sym_expr = eval_jaxpr(True, jaxpr.jaxpr,
237228
jaxpr.literals, *args)
238229
else:

tests/test_hard_and.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def check_consistency(soft: typing.Callable, hard: typing.Callable, symbolic: t
1515

1616
# Check that the hard function performs as expected
1717
hard_args = harden.harden(*args)
18-
print(f'hard_args={hard_args} of type {type(hard_args)}')
1918
hard_expected = harden.harden(expected)
2019
assert numpy.allclose(hard(*hard_args), hard_expected)
2120

@@ -106,7 +105,6 @@ def test_net(type, x):
106105

107106
soft, hard, jaxpr, symbolic = neural_logic_net.net(test_net)
108107
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
109-
print(f'Weights: {weights} of type {type(weights)}')
110108
hard_weights = harden.hard_weights(weights)
111109

112110
test_data = [
@@ -144,7 +142,6 @@ def test_net(type, x):
144142

145143
# Check that the symbolic function performs as expected
146144
symbolic_output = symbolic.apply(hard_weights, hard_input)
147-
print(f'Symbolic output: {symbolic_output} of type {type(symbolic_output)}')
148145
assert numpy.allclose(symbolic_output, hard_expected)
149146

150147

@@ -209,7 +206,6 @@ def test_net(type, x):
209206
# Compute hard result
210207
hard_weights = harden.hard_weights(soft_weights)
211208
hard_input = harden.harden(soft_input)
212-
print(f'hard_input: {hard_input} of type {type(hard_input)}')
213209
hard_result = hard.apply(hard_weights, numpy.array(hard_input))
214210
# Check that the hard result is the same as the soft result
215211
assert numpy.array_equal(harden.harden(soft_result), hard_result)
@@ -223,7 +219,6 @@ def test_net(type, x):
223219
symbolic_input = ['x1', 'x2']
224220
symbolic_output = symbolic.apply(hard_weights, symbolic_input)
225221
# Check the form of the symbolic expression
226-
#print(f'Symbolic output: {symbolic_output} of type {type(symbolic_output)}')
227222
assert numpy.array_equal(symbolic_output, ['True and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or False) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True)',
228223
'True and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or False) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True)',
229224
'True and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or False) != 0.0 or True) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True)',
@@ -233,7 +228,6 @@ def test_net(type, x):
233228
symbolic_weights = sym_gen.make_symbolic(hard_weights)
234229
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
235230
# Check the form of the symbolic expression
236-
#print(f'Symbolic output: {symbolic_output} of type {type(symbolic_output)}')
237231
assert numpy.array_equal(symbolic_output, ['True and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(True != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0))',
238232
'True and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(True != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0))',
239233
'True and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(True != 0.0)) != 0.0 or not(False != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0))',
@@ -242,11 +236,7 @@ def test_net(type, x):
242236
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
243237
symbolic_input = ['True', 'False']
244238
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
245-
#print(f'symbolic_output = {symbolic_output} of type {type(symbolic_output)}')
246239
symbolic_output = sym_gen.eval_symbolic_expression(symbolic_output)
247-
#print(f'symbolic_output = {symbolic_output} of type {type(symbolic_output)}')
248240
# Check that the symbolic result is the same as the hard result
249-
print(f'symbolic_output = {symbolic_output} of type {type(symbolic_output)}')
250-
print(f'hard_result = {hard_result} of type {type(hard_result)}')
251241
assert numpy.array_equal(symbolic_output, hard_result)
252242

0 commit comments

Comments
 (0)