@@ -25,20 +25,11 @@ def hard_xor_include(w, x):
25
25
def soft_xor_neuron (w , x ):
26
26
# Conditionally include input bits, according to weights
27
27
x = jax .vmap (soft_xor_include , 0 , 0 )(w , x )
28
- # Compute the most sensitive bit
29
- margins = jax .vmap (lambda x : jax .numpy .abs (0.5 - x ))(x )
30
- sensitive_bit_index = jax .numpy .argmin (margins )
31
- sensitive_bit = jax .numpy .take (x , sensitive_bit_index )
32
- # Compute the logical xor of the bits
33
- hard_x = jax .vmap (lambda x : jax .numpy .where (x > 0.5 , True , False ))(x )
34
- logical_xor = jax .lax .reduce (hard_x , False , jax .numpy .logical_xor , (0 ,))
35
- # Compute the representative bit
36
- hard_sensitive_bit = jax .numpy .where (sensitive_bit > 0.5 , True , False )
37
- representative_bit = jax .numpy .where (logical_xor == hard_sensitive_bit ,
38
- sensitive_bit ,
39
- 1.0 - sensitive_bit
40
- )
41
- return representative_bit
28
+
29
+ def xor (x , y ):
30
+ return jax .numpy .minimum (jax .numpy .maximum (x , y ), 1.0 - jax .numpy .minimum (x , y ))
31
+ x = jax .lax .reduce (x , jax .numpy .float16 (0.0 ), xor , (0 ,))
32
+ return x
42
33
43
34
44
35
def hard_xor_neuron (w , x ):
0 commit comments