Skip to content

Commit 7a1910b

Browse files
committed
reactivate mnist test
1 parent 9a241fa commit 7a1910b

File tree

2 files changed

+161
-143
lines changed

2 files changed

+161
-143
lines changed

neurallogic/hard_dropout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SoftHardDropout(nn.Module):
3131
deterministic: Optional[bool] = None
3232
rng_collection: str = "dropout"
3333
dropout_value: float = 0.0
34+
dtype: jax.numpy.dtype = jax.numpy.float32
3435

3536
@nn.compact
3637
def __call__(self, inputs, deterministic: Optional[bool] = None):
@@ -64,7 +65,9 @@ def __call__(self, inputs, deterministic: Optional[bool] = None):
6465
broadcast_shape[dim] = 1
6566
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
6667
mask = jax.numpy.broadcast_to(mask, inputs.shape)
67-
masked_values = jax.numpy.full_like(inputs, self.dropout_value, dtype=float)
68+
masked_values = jax.numpy.full_like(
69+
inputs, self.dropout_value, dtype=self.dtype
70+
)
6871
return lax.select(mask, inputs, masked_values)
6972

7073

0 commit comments

Comments
 (0)