mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
Add implementation to cli
This commit is contained in:
parent
face46d188
commit
d75fbb5d0a
2 changed files with 73 additions and 0 deletions
|
|
@ -42,6 +42,7 @@ def run(args):
|
|||
from transformers import TrainingArguments
|
||||
from unsloth import is_bfloat16_supported
|
||||
import logging
|
||||
from unsloth import RawTextDataLoader
|
||||
|
||||
logging.getLogger("hf-to-gguf").setLevel(logging.WARNING)
|
||||
|
||||
|
|
@ -98,6 +99,21 @@ def run(args):
|
|||
texts.append(text)
|
||||
return {"text": texts}
|
||||
|
||||
def load_dataset_smart(args):
|
||||
if args.raw_text_file:
|
||||
# Use raw text loader
|
||||
loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)
|
||||
dataset = loader.load_from_file(args.raw_text_file)
|
||||
elif args.dataset.endswith(('.txt', '.md', '.json', '.jsonl')):
|
||||
# Auto-detect local raw text files
|
||||
loader = RawTextDataLoader(tokenizer)
|
||||
dataset = loader.load_from_file(args.dataset)
|
||||
else:
|
||||
# Existing HuggingFace dataset logic
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
return dataset
|
||||
|
||||
use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False"))
|
||||
if use_modelscope:
|
||||
from modelscope import MsDataset
|
||||
|
|
@ -389,5 +405,37 @@ if __name__ == "__main__":
|
|||
"--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw_text_file",
|
||||
type=str,
|
||||
help="Path to raw text file for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of text chunks for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Overlap between chunks"
|
||||
)
|
||||
|
||||
TRAINING_MODES = {
|
||||
'instruction': 'Standard instruction-following',
|
||||
'causal': 'Causal language modeling (raw text)',
|
||||
'completion': 'Text completion tasks'
|
||||
}
|
||||
|
||||
parser.add_argument(
|
||||
"--training_mode",
|
||||
type=str,
|
||||
default="instruction",
|
||||
choices=list(TRAINING_MODES.keys()),
|
||||
help="Training mode for the model"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run(args)
|
||||
|
|
|
|||
|
|
@ -20,12 +20,28 @@ from typing import List, Dict, Any, Union, Optional
|
|||
from datasets import Dataset
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = [
|
||||
"RawTextDataLoader",
|
||||
"TextPreprocessor",
|
||||
]
|
||||
|
||||
SUPPORTED_FORMATS = {
|
||||
'.txt': 'plain_text',
|
||||
'.md': 'markdown',
|
||||
'.json': 'json_lines',
|
||||
'.jsonl': 'json_lines',
|
||||
'.csv': 'csv_text_column'
|
||||
}
|
||||
|
||||
class RawTextDataLoader:
|
||||
def __init__(self, tokenizer, chunk_size=2048, stride=512):
|
||||
self.tokenizer = tokenizer
|
||||
self.chunk_size = chunk_size
|
||||
self.stride = stride
|
||||
|
||||
def detect_format(self, file_path):
|
||||
"""Auto-detect file format and parse accordingly"""
|
||||
|
||||
def load_from_file(self, file_path):
|
||||
"""Load raw text and convert to dataset"""
|
||||
|
||||
|
|
@ -64,6 +80,15 @@ class TextPreprocessor:
|
|||
|
||||
def add_structure_tokens(self, text):
|
||||
"""Add special tokens for structure (chapters, sections)"""
|
||||
|
||||
def validate_dataset(self, dataset):
|
||||
"""
|
||||
Check for:
|
||||
- Minimum/maximum sequence lengths
|
||||
- Character encoding issues
|
||||
- Repeated content
|
||||
- Empty chunks
|
||||
"""
|
||||
|
||||
def validate_dataset(self, dataset):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue