diff --git a/compressai/models/google.py b/compressai/models/google.py index 7b65ef8e..712ce031 100644 --- a/compressai/models/google.py +++ b/compressai/models/google.py @@ -673,6 +673,7 @@ def decompress(self, strings, shape): y_hat = torch.zeros( (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, + dtype=z_hat.dtype, ) for i, y_string in enumerate(strings[0]):