|
| 1 | +--- |
| 2 | +title: MMDDatasetEvaluator |
| 3 | +createTime: 2025/04/04 19:46 |
| 4 | +permalink: /en/api/operators/text_sft/eval/mmddatasetevaluator/ |
| 5 | +--- |
| 6 | + |
| 7 | +## 📘 Overview |
| 8 | + |
| 9 | +The `MMDDatasetEvaluator` is an operator that evaluates the distribution discrepancy between two datasets using the Maximum Mean Discrepancy (MMD) method. It embeds text into a high-dimensional space and computes the kernel-based distance to quantify the distribution shift between the evaluation dataset and a reference dataset. A smaller MMD score indicates that the two distributions are closer. |
| 10 | + |
| 11 | +## `__init__` |
| 12 | + |
| 13 | +```python |
| 14 | +def __init__( |
| 15 | + self, |
| 16 | + ref_frame: DataFlowStorage, |
| 17 | + *, |
| 18 | + ref_max_sample_num: int = 5000, |
| 19 | + ref_shuffle_seed: int = 42, |
| 20 | + ref_instruction_key: str = "input", |
| 21 | + ref_output_key: str = "output", |
| 22 | + kernel_type: Literal["RBF"] = "RBF", |
| 23 | + bias: bool = True, |
| 24 | + rbf_sigma: float = 1.0, |
| 25 | + embedding_type: Literal["vllm", "sentence_transformers"] = "sentence_transformers", |
| 26 | + embedding_model_name: str | None = None, |
| 27 | + st_device: str = "cuda", |
| 28 | + st_batch_size: int = 32, |
| 29 | + st_normalize_embeddings: bool = True, |
| 30 | + vllm_max_num_seqs: int = 128, |
| 31 | + vllm_gpu_memory_utilization: float = 0.9, |
| 32 | + vllm_tensor_parallel_size: int = 1, |
| 33 | + vllm_pipeline_parallel_size: int = 1, |
| 34 | + vllm_truncate_max_length: int = 40960, |
| 35 | + cache_type: Literal["redis", "none"] = "none", |
| 36 | + redis_url: str = "redis://127.0.0.1:6379", |
| 37 | + max_concurrent_requests: int = 50, |
| 38 | + redis_db: int = 0, |
| 39 | + cache_model_id: str | None = None, |
| 40 | +) |
| 41 | +``` |
| 42 | + |
| 43 | +| Parameter | Type | Default | Description | |
| 44 | +| :--- | :--- | :--- | :--- | |
| 45 | +| **ref_frame** | DataFlowStorage | Required | The reference dataset used as the distribution baseline. | |
| 46 | +| **ref_max_sample_num** | int | `5000` | Maximum number of samples to draw from the reference dataset. | |
| 47 | +| **ref_shuffle_seed** | int | `42` | Random seed for sampling the reference dataset. | |
| 48 | +| **ref_instruction_key** | str | `'input'` | Column name for the instruction field in the reference dataset. | |
| 49 | +| **ref_output_key** | str | `'output'` | Column name for the output field in the reference dataset. | |
| 50 | +| **kernel_type** | str | `'RBF'` | Kernel function type; currently only `'RBF'` is supported. | |
| 51 | +| **bias** | bool | `True` | Whether to use bias in the MMD computation. | |
| 52 | +| **rbf_sigma** | float | `1.0` | Bandwidth parameter for the RBF kernel. | |
| 53 | +| **embedding_type** | str | `'sentence_transformers'` | Embedding backend to use; either `'sentence_transformers'` or `'vllm'`. **Note:** when using `'vllm'`, you need to install `distflow[vllm]` first. | |
| 54 | +| **embedding_model_name** | str | Required | Name of the embedding model (required). | |
| 55 | +| **st_device** | str | `'cuda'` | Device for SentenceTransformers (e.g., `'cuda'`, `'cpu'`). | |
| 56 | +| **st_batch_size** | int | `32` | Batch size for SentenceTransformers inference. | |
| 57 | +| **st_normalize_embeddings** | bool | `True` | Whether to normalize embeddings from SentenceTransformers. | |
| 58 | +| **vllm_max_num_seqs** | int | `128` | Maximum number of sequences for vLLM. | |
| 59 | +| **vllm_gpu_memory_utilization** | float | `0.9` | GPU memory utilization ratio for vLLM. | |
| 60 | +| **vllm_tensor_parallel_size** | int | `1` | Tensor parallel size for vLLM. | |
| 61 | +| **vllm_pipeline_parallel_size** | int | `1` | Pipeline parallel size for vLLM. | |
| 62 | +| **vllm_truncate_max_length** | int | `40960` | Maximum truncation length for vLLM inputs. | |
| 63 | +| **cache_type** | str | `'none'` | Cache type for embeddings; either `'redis'` or `'none'`. | |
| 64 | +| **redis_url** | str | `'redis://127.0.0.1:6379'` | Redis connection URL when `cache_type='redis'`. | |
| 65 | +| **max_concurrent_requests** | int | `50` | Maximum concurrent requests to Redis. | |
| 66 | +| **redis_db** | int | `0` | Redis database index. | |
| 67 | +| **cache_model_id** | str | `None` | Model identifier used for the Redis cache key. | |
| 68 | + |
| 69 | +## `run` |
| 70 | + |
| 71 | +```python |
| 72 | +def run( |
| 73 | + self, |
| 74 | + storage: DataFlowStorage, |
| 75 | + input_instruction_key: str, |
| 76 | + input_output_key: str, |
| 77 | + max_sample_num: int | None = None, |
| 78 | + shuffle_seed: int | None = None, |
| 79 | +) -> tuple[float, dict[str, Any]] |
| 80 | +``` |
| 81 | + |
| 82 | +| Parameter | Type | Default | Description | |
| 83 | +| :--- | :--- | :--- | :--- | |
| 84 | +| **storage** | DataFlowStorage | Required | The DataFlowStorage instance containing the evaluation dataset. | |
| 85 | +| **input_instruction_key** | str | Required | Column name for the instruction field in the evaluation dataset. | |
| 86 | +| **input_output_key** | str | Required | Column name for the output field in the evaluation dataset. | |
| 87 | +| **max_sample_num** | int | `None` | Maximum samples from the evaluation dataset; falls back to `ref_max_sample_num` if not set. | |
| 88 | +| **shuffle_seed** | int | `None` | Random seed for sampling the evaluation dataset; falls back to `ref_shuffle_seed` if not set. | |
| 89 | + |
| 90 | +## 🧠 Example Usage |
| 91 | + |
| 92 | +```python |
| 93 | +from dataflow.operators.text_sft.eval import MMDDatasetEvaluator |
| 94 | +from dataflow.utils.storage import FileStorage |
| 95 | + |
| 96 | +# Prepare reference and evaluation storages |
| 97 | +ref_storage = FileStorage(first_entry_file_name="reference_data.jsonl") |
| 98 | +eval_storage = FileStorage(first_entry_file_name="eval_data.jsonl") |
| 99 | + |
| 100 | +# Initialize the evaluator |
| 101 | +evaluator = MMDDatasetEvaluator( |
| 102 | + ref_frame=ref_storage.step(), |
| 103 | + ref_instruction_key="instruction", |
| 104 | + ref_output_key="output", |
| 105 | + embedding_type="sentence_transformers", |
| 106 | + embedding_model_name="BAAI/bge-large-zh", |
| 107 | + st_device="cuda", |
| 108 | + st_batch_size=32, |
| 109 | +) |
| 110 | + |
| 111 | +# Run evaluation |
| 112 | +mmd_score, mmd_meta = evaluator.run( |
| 113 | + eval_storage.step(), |
| 114 | + input_instruction_key="instruction", |
| 115 | + input_output_key="output", |
| 116 | +) |
| 117 | +print(f"MMD Score: {mmd_score}, Meta: {mmd_meta}") |
| 118 | +``` |
| 119 | + |
| 120 | +#### 🧾 Default Output Format |
| 121 | + |
| 122 | +| Field | Type | Description | |
| 123 | +| :--- | :--- | :--- | |
| 124 | +| **MMDScore** | float | The computed MMD distance (smaller is closer). | |
| 125 | +| **MMDMeta** | dict | Metadata dictionary containing computation details. | |
| 126 | + |
| 127 | +**Example Output:** |
| 128 | +```json |
| 129 | +{ |
| 130 | + "MMDScore": 0.00342, |
| 131 | + "MMDMeta": { |
| 132 | + "num_src_samples": 5000, |
| 133 | + "num_tgt_samples": 5000 |
| 134 | + } |
| 135 | +} |
| 136 | +``` |
0 commit comments