Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from xtuner.v1.data_proto.messages import ChatMessages
from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate
from xtuner.v1.utils import get_logger
from xtuner.v1.utils import get_logger, trim_memory

from ..data_item import BaseMLLMDataItem, CacheItem
from ..utils import CachableTokenizeFunction, tokenizer_xxhash, with_proxy_attention_flops
Expand Down Expand Up @@ -118,7 +118,8 @@ def replace_image_token(


def load_image(image_path: str):
return Image.open(image_path).convert("RGB")
with Image.open(image_path) as img:
return img.convert("RGB")


def get_image_path(image_path: str, media_root: str):
Expand All @@ -144,6 +145,7 @@ def __init__(
data_name: str | None = None,
llm_pack_weight: float = 1.0,
visual_pack_weight: float = 0.0,
trim_memory_interval: int = 1,
):
self.max_length = max_length
self._tokenizer_hash = tokenizer_hash
Expand All @@ -158,6 +160,9 @@ def __init__(
self._video_wh_list: list[list] = []
self._video_extra_info_list: list[dict] = []

self._trim_memory_interval = trim_memory_interval
self._trim_memory_counter = 0

self._hash_str += f"llm_pack_weight:{llm_pack_weight}_visual_pack_weight:{visual_pack_weight}"
super().__init__(tokenizer, llm_pack_weight=llm_pack_weight, visual_pack_weight=visual_pack_weight)

Expand Down Expand Up @@ -213,16 +218,23 @@ def __call__(self, item: dict, media_root: str = "", **kwargs) -> T | CacheItem:
ret = self.calc_num_tokens_multi_modal_get_item(item)
else:
ret = self.multi_modal_get_item(item, media_root)
Comment thread
hhaAndroid marked this conversation as resolved.
if self._trim_memory_counter % self._trim_memory_interval == 0:
trim_memory()
self._trim_memory_counter += 1
elif len(self._video_path) > 0:
if self.state == "cache":
ret = self.calc_num_tokens_video_get_item(item)
else:
ret = self.video_get_item(item, media_root)
if self._trim_memory_counter % self._trim_memory_interval == 0:
trim_memory()
self._trim_memory_counter += 1
else:
if self.state == "cache":
ret = self.calc_num_tokens_pure_text_get_item(item)
else:
ret = self.pure_text_get_item(item)

return ret

def hash(self) -> str:
Expand Down Expand Up @@ -257,6 +269,7 @@ class BaseMLLMTokenizeFnConfig(BaseModel):
add_bos_token: bool = False # for mllm pretrain
llm_pack_weight: float = 1.0
visual_pack_weight: float = 0.0
trim_memory_interval: int = 1

def build(
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
Expand Down
5 changes: 5 additions & 0 deletions xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
hash: str | None = None,
add_eos_token: bool = True, # for mllm pretrain
add_bos_token: bool = False, # for mllm pretrain
trim_memory_interval: int = 1,
):
self.oss_loader = None
self.debug = debug
Expand Down Expand Up @@ -335,6 +336,7 @@ def __init__(
data_name=self.data_name,
llm_pack_weight=llm_pack_weight,
visual_pack_weight=visual_pack_weight,
trim_memory_interval=trim_memory_interval,
)

def _truncated_data_item(
Expand Down Expand Up @@ -904,6 +906,8 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig):
# it's helpful to add labels to the images and videos for better reference.
add_vision_id: bool = True

trim_memory_interval: int = 1

def build(
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
) -> Qwen3VLTokenizeFunction:
Expand Down Expand Up @@ -932,4 +936,5 @@ def build(
oss_time_log_thr=self.oss_time_log_thr,
add_eos_token=self.add_eos_token, # for mllm pretrain
add_bos_token=self.add_bos_token, # for mllm pretrain
trim_memory_interval=self.trim_memory_interval,
)
16 changes: 9 additions & 7 deletions xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@


def pil_loader(img_str):
buff = io.BytesIO(img_str)
img = Image.open(buff)
return img.convert("RGB")
# Ensure both the BytesIO buffer and PIL image file handle are closed promptly.
with io.BytesIO(img_str) as buff:
with Image.open(buff) as img:
return img.convert("RGB")


def extract_frame_number(filename):
Expand Down Expand Up @@ -109,12 +110,13 @@ def read_frames_folder(
start_time = time.time()
image_byte = client.get(image_list[frame_index])
oss_read_time += time.time() - start_time
frame = Image.open(io.BytesIO(image_byte))
frame_list.append(np.array(frame))
with io.BytesIO(image_byte) as buff:
with Image.open(buff) as frame:
frame_list.append(np.array(frame))
else:
fp = os.path.join(video_path, image_list[frame_index])
frame = Image.open(fp).convert("RGB")
frame_list.append(np.array(frame))
with Image.open(fp) as frame:
frame_list.append(np.array(frame.convert("RGB")))

frames = numpy_to_tensor(frame_list)
return frames, oss_read_time, len(frames), frames_indices, timestamps
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_hf_model_path,
is_local_rank0,
record_git_info,
trim_memory,
)
from .pad import pad_to_max_length, pad_to_multiple_of
from .profile import profile_time, profile_time_and_memory, timer, timer_logger
Expand Down Expand Up @@ -61,4 +62,5 @@
"ray_method",
"profile_time",
"clean_param_name",
"trim_memory",
]
23 changes: 23 additions & 0 deletions xtuner/v1/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,26 @@ def clean_param_name(name: str) -> str:
if "_orig_mod." in name:
name = name.replace("_orig_mod.", "")
return name


_TRIM_MEMORY_WARNED = False


def trim_memory() -> bool:
"""Try to return free heap pages to OS.

Best-effort only: on platforms without `malloc_trim` (or when unavailable),
this will fail. We log the failure once per process to avoid spamming.
"""
global _TRIM_MEMORY_WARNED
try:
import ctypes

libc = ctypes.CDLL("libc.so.6")
return libc.malloc_trim(0)
except Exception as e:
if not _TRIM_MEMORY_WARNED:
_logger = get_logger()
_logger.warning(f" >>>>>>>>> [trim_memory] Failed to trim memory: {e} <<<<<<<<")
_TRIM_MEMORY_WARNED = True
return False
Loading