Skip to content

Commit 0156a34

Browse files
Pass lora_ga_config as parameter instead of attaching to modules
1 parent 747befa commit 0156a34

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def update_layer(
156156
arrow_config: ArrowConfig = None,
157157
qalora_group_size: int = 32,
158158
inference_mode: bool = False,
159+
lora_ga_config=None,
159160
**kwargs,
160161
):
161162
# collect the kwargs
@@ -226,7 +227,7 @@ def update_layer(
226227
self.orthogonal_init(adapter_name)
227228
elif init_lora_weights == "lora_ga":
228229
with gather_params_ctx(self.get_base_layer().weight):
229-
self.lora_ga_init(adapter_name)
230+
self.lora_ga_init(adapter_name, lora_ga_config)
230231
elif init_lora_weights:
231232
self.reset_lora_parameters(adapter_name, init_lora_weights)
232233
# call this before init of the lora variants
@@ -465,7 +466,7 @@ def orthogonal_init(self, adapter_name):
465466
self.lora_A[adapter_name].weight = nn.Parameter(lora_A.contiguous().to(dtype))
466467
self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype))
467468

468-
def lora_ga_init(self, adapter_name):
469+
def lora_ga_init(self, adapter_name, lora_ga_config):
469470
"""
470471
Initialize LoRA weights using gradient approximation.
471472
@@ -489,13 +490,11 @@ def lora_ga_init(self, adapter_name):
489490

490491
grad = base_layer._peft_loraga_grad
491492

492-
# Check for lora_ga_config attached by preprocess_loraga
493-
if not hasattr(base_layer, '_peft_lora_ga_config'):
493+
# Check for lora_ga_config
494+
if lora_ga_config is None:
494495
# Fall back to gaussian initialization
495496
self.reset_lora_parameters(adapter_name, init_lora_weights=True)
496497
return
497-
498-
lora_ga_config = base_layer._peft_lora_ga_config
499498
direction = lora_ga_config.direction
500499
scale = lora_ga_config.scale
501500
stable_gamma = lora_ga_config.stable_gamma
@@ -577,7 +576,6 @@ def lora_ga_init(self, adapter_name):
577576

578577
# Remove redundant fields
579578
del base_layer._peft_loraga_grad
580-
del base_layer._peft_lora_ga_config
581579

582580
def _cache_store(self, key: str, value: Any) -> None:
583581
self._caches[key] = value
@@ -730,6 +728,7 @@ def __init__(
730728
use_alora: bool = False,
731729
arrow_config: ArrowConfig = None,
732730
lora_bias: bool = False,
731+
lora_ga_config=None,
733732
**kwargs,
734733
) -> None:
735734
super().__init__()
@@ -748,6 +747,7 @@ def __init__(
748747
use_alora=use_alora,
749748
lora_bias=lora_bias,
750749
arrow_config=arrow_config,
750+
lora_ga_config=lora_ga_config,
751751
)
752752
self.is_target_conv_1d_layer = is_target_conv_1d_layer
753753

@@ -956,6 +956,7 @@ def __init__(
956956
use_dora: bool = False,
957957
arrow_config: ArrowConfig = None,
958958
lora_bias: bool = False,
959+
lora_ga_config=None,
959960
**kwargs,
960961
) -> None:
961962
if lora_bias:
@@ -977,6 +978,7 @@ def __init__(
977978
use_dora=use_dora,
978979
lora_bias=lora_bias,
979980
arrow_config=arrow_config,
981+
lora_ga_config=lora_ga_config,
980982
)
981983

982984
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
@@ -1266,6 +1268,7 @@ def __init__(
12661268
use_dora: bool = False,
12671269
arrow_config: ArrowConfig = None,
12681270
lora_bias: bool = False,
1271+
lora_ga_config=None,
12691272
**kwargs,
12701273
) -> None:
12711274
super().__init__()
@@ -1295,6 +1298,7 @@ def __init__(
12951298
use_dora=use_dora,
12961299
lora_bias=lora_bias,
12971300
arrow_config=arrow_config,
1301+
lora_ga_config=lora_ga_config,
12981302
)
12991303

13001304
def update_layer(
@@ -1638,6 +1642,7 @@ def __init__(
16381642
init_lora_weights: Union[bool, str] = True,
16391643
use_rslora: bool = False,
16401644
use_dora: bool = False,
1645+
lora_ga_config=None,
16411646
**kwargs,
16421647
) -> None:
16431648
# TODO work with separate weights
@@ -1666,13 +1671,14 @@ def __init__(
16661671
init_lora_weights=init_lora_weights,
16671672
use_rslora=use_rslora,
16681673
use_dora=use_dora,
1674+
lora_ga_config=lora_ga_config,
16691675
**kwargs,
16701676
)
16711677
else:
16721678
raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.")
16731679

16741680
self._active_adapter = adapter_name
1675-
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
1681+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, lora_ga_config=lora_ga_config)
16761682

16771683
@property
16781684
def embed_dim(self) -> int:
@@ -2035,6 +2041,7 @@ def __init__(
20352041
use_rslora: bool = False,
20362042
use_dora: bool = False,
20372043
lora_bias: bool = False,
2044+
lora_ga_config=None,
20382045
**kwargs,
20392046
) -> None:
20402047
super().__init__()
@@ -2075,6 +2082,7 @@ def __init__(
20752082
use_rslora=use_rslora,
20762083
use_dora=use_dora,
20772084
lora_bias=lora_bias,
2085+
lora_ga_config=lora_ga_config,
20782086
)
20792087

20802088
def update_layer(
@@ -2153,7 +2161,7 @@ def update_layer(
21532161
self.orthogonal_init(adapter_name)
21542162
elif init_lora_weights == "lora_ga":
21552163
with gather_params_ctx(self.get_base_layer().weight):
2156-
self.lora_ga_init(adapter_name)
2164+
self.lora_ga_init(adapter_name, lora_ga_config)
21572165
elif init_lora_weights:
21582166
self.reset_lora_parameters(adapter_name, init_lora_weights)
21592167
# call this before init of the lora variants

src/peft/tuners/lora/loraga.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ def preprocess_loraga(
9696
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
9797
torch.save(cache, cache_file)
9898

99-
# Attach lora_ga_config to each target module for layer initialization
100-
for name, module in target_modules(model, lora_config):
101-
module._peft_lora_ga_config = lora_config.lora_ga_config
102-
10399

104100
def estimate_gradients(
105101
model: nn.Module,

src/peft/tuners/lora/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _create_and_replace(
215215
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
216216
"lora_bias": lora_config.lora_bias,
217217
"arrow_config": lora_config.arrow_config,
218+
"lora_ga_config": lora_config.lora_ga_config,
218219
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
219220
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
220221
"parameter_name": parameter_name,
@@ -250,6 +251,7 @@ def _create_and_replace(
250251
use_dora=lora_config.use_dora,
251252
lora_bias=lora_config.lora_bias,
252253
arrow_config=lora_config.arrow_config,
254+
lora_ga_config=lora_config.lora_ga_config,
253255
inference_mode=lora_config.inference_mode,
254256
)
255257
else:

0 commit comments

Comments
 (0)