mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-02 21:51:30 +00:00
78 lines
3.7 KiB
Python
78 lines
3.7 KiB
Python
import torch
|
|
from transformers import PreTrainedModel
|
|
import warnings
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
from peft.config import PeftConfig
|
|
|
|
from ktransformers.sft.peft_utils.lora_model import LoraModel
|
|
from ktransformers.sft.peft_utils.peft_model import PeftModel, PeftModelForCausalLM
|
|
|
|
def get_peft_model(
|
|
model: PreTrainedModel,
|
|
peft_config: PeftConfig,
|
|
adapter_name: str = "default",
|
|
mixed: bool = False,
|
|
autocast_adapter_dtype: bool = True,
|
|
revision: Optional[str] = None,
|
|
low_cpu_mem_usage: bool = False,
|
|
) -> PeftModel:
|
|
"""
|
|
Returns a Peft model object from a model and a config.
|
|
|
|
Args:
|
|
model ([`transformers.PreTrainedModel`]):
|
|
Model to be wrapped.
|
|
peft_config ([`PeftConfig`]):
|
|
Configuration object containing the parameters of the Peft model.
|
|
adapter_name (`str`, `optional`, defaults to `"default"`):
|
|
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
|
mixed (`bool`, `optional`, defaults to `False`):
|
|
Whether to allow mixing different (compatible) adapter types.
|
|
autocast_adapter_dtype (`bool`, *optional*):
|
|
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
|
|
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
|
|
select PEFT tuners.
|
|
revision (`str`, `optional`, defaults to `main`):
|
|
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
|
|
the base model
|
|
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
|
|
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
|
|
False if you intend on training the model, unless the adapter weights will be replaced by different weights
|
|
before training starts.
|
|
"""
|
|
new_name = model.__dict__.get("name_or_path", None)
|
|
peft_config.base_model_name_or_path = new_name
|
|
|
|
return PeftModelForCausalLM(
|
|
model,
|
|
peft_config,
|
|
adapter_name=adapter_name,
|
|
autocast_adapter_dtype=autocast_adapter_dtype,
|
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
)
|
|
|
|
def inject_adapter_in_model(
|
|
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
|
|
) -> torch.nn.Module:
|
|
r"""
|
|
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
|
|
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
|
|
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
|
|
|
|
Args:
|
|
peft_config (`PeftConfig`):
|
|
Configuration object containing the parameters of the Peft model.
|
|
model (`torch.nn.Module`):
|
|
The input model where the adapter will be injected.
|
|
adapter_name (`str`, `optional`, defaults to `"default"`):
|
|
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
|
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
|
|
Create empty adapter weights on meta device. Useful to speed up the loading process.
|
|
"""
|
|
# tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING["LORA"]
|
|
|
|
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
|
peft_model = LoraModel(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
|
|
return peft_model.model
|