diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ff8ce9c3d..54506f41d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -222,6 +222,7 @@ def __new__( quant_storage: torch.dtype = torch.uint8, module: Optional["Linear4bit"] = None, bnb_quantized: bool = False, + **kwargs, ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -680,6 +681,7 @@ def __new__( has_fp16_weights=False, CB: Optional[torch.Tensor] = None, SCB: Optional[torch.Tensor] = None, + **kwargs, ): if data is None: data = torch.empty(0)