kvcache-ai-ktransformers/KT-SFT/ktransformers/sft/peft_utils/mapping.py
2025-11-04 04:24:30 +00:00

78 lines
3.7 KiB
Python

import torch
from transformers import PreTrainedModel
import warnings
from typing import TYPE_CHECKING, Any, Optional
from peft.config import PeftConfig
from ktransformers.sft.peft_utils.lora_model import LoraModel
from ktransformers.sft.peft_utils.peft_model import PeftModel, PeftModelForCausalLM
def get_peft_model(
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
autocast_adapter_dtype: bool = True,
revision: Optional[str] = None,
low_cpu_mem_usage: bool = False,
) -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]):
Model to be wrapped.
peft_config ([`PeftConfig`]):
Configuration object containing the parameters of the Peft model.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
revision (`str`, `optional`, defaults to `main`):
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
False if you intend on training the model, unless the adapter weights will be replaced by different weights
before training starts.
"""
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name
return PeftModelForCausalLM(
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def inject_adapter_in_model(
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
) -> torch.nn.Module:
r"""
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
Args:
peft_config (`PeftConfig`):
Configuration object containing the parameters of the Peft model.
model (`torch.nn.Module`):
The input model where the adapter will be injected.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
"""
# tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING["LORA"]
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
peft_model = LoraModel(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
return peft_model.model