mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-22 02:50:03 +00:00
172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Minimal test for raw text training implementation.
|
|
Tests basic functionality without heavy dependencies.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
import importlib.util
|
|
|
|
|
|
# Mock the datasets module since it's not installed
|
|
class MockDataset:
|
|
def __init__(self, data_dict):
|
|
self.data = data_dict
|
|
self.column_names = list(data_dict.keys())
|
|
|
|
def __len__(self):
|
|
return len(next(iter(self.data.values())))
|
|
|
|
def __getitem__(self, idx):
|
|
if isinstance(idx, str):
|
|
# Allow accessing columns by name like dataset['text']
|
|
return self.data[idx]
|
|
elif isinstance(idx, int):
|
|
# Allow accessing individual rows by index
|
|
return {key: values[idx] for key, values in self.data.items()}
|
|
else:
|
|
raise TypeError(f"Invalid index type: {type(idx)}")
|
|
|
|
@classmethod
|
|
def from_dict(cls, data_dict):
|
|
return cls(data_dict)
|
|
|
|
|
|
# Mock datasets module
|
|
datasets_mock = type(sys)("datasets")
|
|
datasets_mock.Dataset = MockDataset
|
|
sys.modules["datasets"] = datasets_mock
|
|
|
|
# Import the raw_text module directly to avoid unsloth/__init__.py dependencies
|
|
current_dir = os.path.dirname(__file__)
|
|
raw_text_path = os.path.join(
|
|
os.path.dirname(current_dir), "unsloth", "dataprep", "raw_text.py"
|
|
)
|
|
|
|
spec = importlib.util.spec_from_file_location("raw_text", raw_text_path)
|
|
raw_text_module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(raw_text_module)
|
|
|
|
RawTextDataLoader = raw_text_module.RawTextDataLoader
|
|
TextPreprocessor = raw_text_module.TextPreprocessor
|
|
|
|
|
|
def test_raw_text_loader():
|
|
"""Test basic RawTextDataLoader functionality."""
|
|
|
|
# Mock tokenizer for testing
|
|
class MockTokenizer:
|
|
def __init__(self):
|
|
self.eos_token = "</s>"
|
|
self.eos_token_id = 2 # Mock EOS token ID
|
|
|
|
def __call__(self, text, return_tensors = None, add_special_tokens = False):
|
|
words = text.split()
|
|
token_ids = list(range(len(words)))
|
|
|
|
if return_tensors == "pt":
|
|
# Mock tensor-like object
|
|
class MockTensor:
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def tolist(self):
|
|
return self.data
|
|
|
|
return {"input_ids": [MockTensor(token_ids)]}
|
|
return {"input_ids": token_ids}
|
|
|
|
def decode(self, token_ids, skip_special_tokens = False):
|
|
return " ".join([f"word_{i}" for i in token_ids])
|
|
|
|
# Create test file
|
|
test_content = "This is a test file for raw text training. " * 10
|
|
with tempfile.NamedTemporaryFile(mode = "w", suffix = ".txt", delete = False) as f:
|
|
f.write(test_content)
|
|
test_file = f.name
|
|
|
|
try:
|
|
# Test loader
|
|
tokenizer = MockTokenizer()
|
|
loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2)
|
|
|
|
# Test loading with text output (legacy mode)
|
|
text_dataset = loader.load_from_file(test_file, return_tokenized = False)
|
|
assert len(text_dataset) > 0, "Should create at least one chunk"
|
|
assert "text" in text_dataset.column_names, "Dataset should have 'text' column"
|
|
|
|
# Test loading with tokenized output (new efficient mode)
|
|
tokenized_dataset = loader.load_from_file(test_file, return_tokenized = True)
|
|
assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk"
|
|
assert (
|
|
"input_ids" in tokenized_dataset.column_names
|
|
), "Dataset should have 'input_ids' column"
|
|
assert (
|
|
"attention_mask" in tokenized_dataset.column_names
|
|
), "Dataset should have 'attention_mask' column"
|
|
|
|
# Verify tokenized data structure
|
|
first_sample = tokenized_dataset[0]
|
|
assert isinstance(first_sample["input_ids"], list), "input_ids should be a list"
|
|
assert isinstance(
|
|
first_sample["attention_mask"], list
|
|
), "attention_mask should be a list"
|
|
assert len(first_sample["input_ids"]) == len(
|
|
first_sample["attention_mask"]
|
|
), "input_ids and attention_mask should have same length"
|
|
|
|
# Verify labels field exists (for causal LM training)
|
|
assert (
|
|
"labels" in tokenized_dataset.column_names
|
|
), "Dataset should have 'labels' column"
|
|
assert (
|
|
first_sample["labels"] == first_sample["input_ids"]
|
|
), "labels should match input_ids"
|
|
|
|
# Test constructor validation
|
|
try:
|
|
bad_loader = RawTextDataLoader(tokenizer, chunk_size = 0, stride = 2)
|
|
assert False, "Should raise ValueError for chunk_size=0"
|
|
except ValueError as e:
|
|
assert "chunk_size must be positive" in str(e)
|
|
|
|
try:
|
|
bad_loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 10)
|
|
assert False, "Should raise ValueError for stride >= chunk_size"
|
|
except ValueError as e:
|
|
assert "stride" in str(e) and "chunk_size" in str(e)
|
|
|
|
# Test preprocessor
|
|
preprocessor = TextPreprocessor()
|
|
clean_text = preprocessor.clean_text(" messy text \n\n\n ")
|
|
assert "messy text" in clean_text, "Should clean text properly"
|
|
|
|
# Test validation
|
|
stats = preprocessor.validate_dataset(text_dataset)
|
|
assert stats["total_samples"] > 0, "Should count samples"
|
|
assert "warnings" in stats, "Should include warnings"
|
|
|
|
print("✅ All tests passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Test failed: {e}")
|
|
return False
|
|
|
|
finally:
|
|
# Cleanup
|
|
os.unlink(test_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = test_raw_text_loader()
|
|
sys.exit(0 if success else 1)
|