-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
some ops are not yet supported for Dtensor
This seems to work, though there may be better solutions.
def copy_stochastic_(target: Tensor, source: Tensor):
"""
copies source into target using stochastic rounding
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
if isinstance(target, DTensor):
target_for_op = target.to_local()
else:
target_for_op = target
if isinstance(source, DTensor):
source_for_op = source.to_local()
else:
source_for_op = source
# create a random 16 bit integer
result = torch.randint_like(
source_for_op,
dtype=torch.int32,
low=0,
high=(1 << 16),
)
# add the random number to the lower 16 bit of the mantissa
result.add_(source_for_op.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target_for_op.copy_(result.view(dtype=torch.float32))
torch.distributed.breakpoint(0)
if isinstance(target, DTensor):
target_for_op = DTensor.from_local(target_for_op, device_mesh=target.device_mesh, placements=target.placements, shape=target.shape, stride=target.stride())
target.copy_(target_for_op)
# del target_for_op
# if isinstance(source, DTensor):
# del source_for_op
del result
```Metadata
Metadata
Assignees
Labels
No labels