-
Notifications
You must be signed in to change notification settings - Fork 24
Description
The t sample is biased when t==r, because it's the maximum of two independent logit-normal/uniform samples. The problem is in "sample_tr" in meanflow.py
def sample_tr(self, bz):
t = self.noise_distribution()(bz)
r = self.noise_distribution()(bz)
t, r = jnp.maximum(t, r), jnp.minimum(t, r)
data_size = int(bz * self.data_proportion)
zero_mask = jnp.arange(bz) < data_size
zero_mask = zero_mask.reshape(bz, 1, 1, 1)
r = jnp.where(zero_mask, t, r)
return t, rThe code should return unbiased t sample before the maximum was taken for the case when you want t==r. Otherwsie for 75% of "standard" flow matching you are using biased t samples which skews the distribution to the right.
If the same code was used in the paper to obtain the results, perhaps that's why mu=-0.4 works the best, because it unskews the distribution just enough - this is just my guess.
Here is visualization of the buggy "t" distribution (for case where t should be equal to r) the code produces (red) vs. expected (blue). LHS shows plot with mean=0.0 and RHS is mean=-0.4 "paper's best" which centers the "buggy" distribution:
