@@ -156,6 +156,7 @@ def update_layer(
156156 arrow_config : ArrowConfig = None ,
157157 qalora_group_size : int = 32 ,
158158 inference_mode : bool = False ,
159+ lora_ga_config = None ,
159160 ** kwargs ,
160161 ):
161162 # collect the kwargs
@@ -226,7 +227,7 @@ def update_layer(
226227 self .orthogonal_init (adapter_name )
227228 elif init_lora_weights == "lora_ga" :
228229 with gather_params_ctx (self .get_base_layer ().weight ):
229- self .lora_ga_init (adapter_name )
230+ self .lora_ga_init (adapter_name , lora_ga_config )
230231 elif init_lora_weights :
231232 self .reset_lora_parameters (adapter_name , init_lora_weights )
232233 # call this before init of the lora variants
@@ -465,7 +466,7 @@ def orthogonal_init(self, adapter_name):
465466 self .lora_A [adapter_name ].weight = nn .Parameter (lora_A .contiguous ().to (dtype ))
466467 self .lora_B [adapter_name ].weight = nn .Parameter (lora_B .contiguous ().to (dtype ))
467468
468- def lora_ga_init (self , adapter_name ):
469+ def lora_ga_init (self , adapter_name , lora_ga_config ):
469470 """
470471 Initialize LoRA weights using gradient approximation.
471472
@@ -489,13 +490,11 @@ def lora_ga_init(self, adapter_name):
489490
490491 grad = base_layer ._peft_loraga_grad
491492
492- # Check for lora_ga_config attached by preprocess_loraga
493- if not hasattr ( base_layer , '_peft_lora_ga_config' ) :
493+ # Check for lora_ga_config
494+ if lora_ga_config is None :
494495 # Fall back to gaussian initialization
495496 self .reset_lora_parameters (adapter_name , init_lora_weights = True )
496497 return
497-
498- lora_ga_config = base_layer ._peft_lora_ga_config
499498 direction = lora_ga_config .direction
500499 scale = lora_ga_config .scale
501500 stable_gamma = lora_ga_config .stable_gamma
@@ -577,7 +576,6 @@ def lora_ga_init(self, adapter_name):
577576
578577 # Remove redundant fields
579578 del base_layer ._peft_loraga_grad
580- del base_layer ._peft_lora_ga_config
581579
582580 def _cache_store (self , key : str , value : Any ) -> None :
583581 self ._caches [key ] = value
@@ -730,6 +728,7 @@ def __init__(
730728 use_alora : bool = False ,
731729 arrow_config : ArrowConfig = None ,
732730 lora_bias : bool = False ,
731+ lora_ga_config = None ,
733732 ** kwargs ,
734733 ) -> None :
735734 super ().__init__ ()
@@ -748,6 +747,7 @@ def __init__(
748747 use_alora = use_alora ,
749748 lora_bias = lora_bias ,
750749 arrow_config = arrow_config ,
750+ lora_ga_config = lora_ga_config ,
751751 )
752752 self .is_target_conv_1d_layer = is_target_conv_1d_layer
753753
@@ -956,6 +956,7 @@ def __init__(
956956 use_dora : bool = False ,
957957 arrow_config : ArrowConfig = None ,
958958 lora_bias : bool = False ,
959+ lora_ga_config = None ,
959960 ** kwargs ,
960961 ) -> None :
961962 if lora_bias :
@@ -977,6 +978,7 @@ def __init__(
977978 use_dora = use_dora ,
978979 lora_bias = lora_bias ,
979980 arrow_config = arrow_config ,
981+ lora_ga_config = lora_ga_config ,
980982 )
981983
982984 def resolve_lora_variant (self , * , use_dora : bool , ** kwargs ) -> Optional [LoraVariant ]:
@@ -1266,6 +1268,7 @@ def __init__(
12661268 use_dora : bool = False ,
12671269 arrow_config : ArrowConfig = None ,
12681270 lora_bias : bool = False ,
1271+ lora_ga_config = None ,
12691272 ** kwargs ,
12701273 ) -> None :
12711274 super ().__init__ ()
@@ -1295,6 +1298,7 @@ def __init__(
12951298 use_dora = use_dora ,
12961299 lora_bias = lora_bias ,
12971300 arrow_config = arrow_config ,
1301+ lora_ga_config = lora_ga_config ,
12981302 )
12991303
13001304 def update_layer (
@@ -1638,6 +1642,7 @@ def __init__(
16381642 init_lora_weights : Union [bool , str ] = True ,
16391643 use_rslora : bool = False ,
16401644 use_dora : bool = False ,
1645+ lora_ga_config = None ,
16411646 ** kwargs ,
16421647 ) -> None :
16431648 # TODO work with separate weights
@@ -1666,13 +1671,14 @@ def __init__(
16661671 init_lora_weights = init_lora_weights ,
16671672 use_rslora = use_rslora ,
16681673 use_dora = use_dora ,
1674+ lora_ga_config = lora_ga_config ,
16691675 ** kwargs ,
16701676 )
16711677 else :
16721678 raise ValueError (f"out_proj must be an instance of nn.Linear for { self .__class__ .__name__ } ." )
16731679
16741680 self ._active_adapter = adapter_name
1675- self .update_layer (adapter_name , r , lora_alpha , lora_dropout , init_lora_weights , use_rslora )
1681+ self .update_layer (adapter_name , r , lora_alpha , lora_dropout , init_lora_weights , use_rslora , lora_ga_config = lora_ga_config )
16761682
16771683 @property
16781684 def embed_dim (self ) -> int :
@@ -2035,6 +2041,7 @@ def __init__(
20352041 use_rslora : bool = False ,
20362042 use_dora : bool = False ,
20372043 lora_bias : bool = False ,
2044+ lora_ga_config = None ,
20382045 ** kwargs ,
20392046 ) -> None :
20402047 super ().__init__ ()
@@ -2075,6 +2082,7 @@ def __init__(
20752082 use_rslora = use_rslora ,
20762083 use_dora = use_dora ,
20772084 lora_bias = lora_bias ,
2085+ lora_ga_config = lora_ga_config ,
20782086 )
20792087
20802088 def update_layer (
@@ -2153,7 +2161,7 @@ def update_layer(
21532161 self .orthogonal_init (adapter_name )
21542162 elif init_lora_weights == "lora_ga" :
21552163 with gather_params_ctx (self .get_base_layer ().weight ):
2156- self .lora_ga_init (adapter_name )
2164+ self .lora_ga_init (adapter_name , lora_ga_config )
21572165 elif init_lora_weights :
21582166 self .reset_lora_parameters (adapter_name , init_lora_weights )
21592167 # call this before init of the lora variants
0 commit comments