Skip to content

Commit 049f973

Browse files
committed
better xor
1 parent 4440ac9 commit 049f973

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

neurallogic/hard_xor.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,11 @@ def hard_xor_include(w, x):
2525
def soft_xor_neuron(w, x):
2626
# Conditionally include input bits, according to weights
2727
x = jax.vmap(soft_xor_include, 0, 0)(w, x)
28-
# Compute the most sensitive bit
29-
margins = jax.vmap(lambda x: jax.numpy.abs(0.5 - x))(x)
30-
sensitive_bit_index = jax.numpy.argmin(margins)
31-
sensitive_bit = jax.numpy.take(x, sensitive_bit_index)
32-
# Compute the logical xor of the bits
33-
hard_x = jax.vmap(lambda x: jax.numpy.where(x > 0.5, True, False))(x)
34-
logical_xor = jax.lax.reduce(hard_x, False, jax.numpy.logical_xor, (0,))
35-
# Compute the representative bit
36-
hard_sensitive_bit = jax.numpy.where(sensitive_bit > 0.5, True, False)
37-
representative_bit = jax.numpy.where(logical_xor == hard_sensitive_bit,
38-
sensitive_bit,
39-
1.0 - sensitive_bit
40-
)
41-
return representative_bit
28+
29+
def xor(x, y):
30+
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
31+
x = jax.lax.reduce(x, jax.numpy.float16(0.0), xor, (0,))
32+
return x
4233

4334

4435
def hard_xor_neuron(w, x):

0 commit comments

Comments
 (0)