diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9c0a647cb..6358badec 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -435,7 +435,7 @@ class Linear4bit(nn.Linear): import torch.nn as nn import bitsandbytes as bnb - from bnb.nn import Linear4bit + from bitsandbytes.nn import Linear4bit fp16_model = nn.Sequential( nn.Linear(64, 64), @@ -949,7 +949,7 @@ class Linear8bitLt(nn.Linear): import torch.nn as nn import bitsandbytes as bnb - from bnb.nn import Linear8bitLt + from bitsandbytes.nn import Linear8bitLt fp16_model = nn.Sequential( nn.Linear(64, 64),