[docs]@dataclassclassRandomInstanceSourceConfig(InstanceSourceConfig):"""Config for :class:`RandomInstanceSource`."""tokenizer:TokenizerConfigsequence_length:intavg_document_length:intseed:int=dataclasses.field(default_factory=lambda:resolve_seed(SEED_NOT_SET))num_instances:Optional[int]=Nonenum_tokens:Optional[int]=Nonemax_sequence_length:Optional[int]=Nonelabel:Optional[str]=None
[docs]classRandomInstanceSource(InstanceSource):""" An instance source that generates random instances. Useful for benchmarking. """Config=RandomInstanceSourceConfigDISPLAY_ICON="\uedec"def__init__(self,*,tokenizer:TokenizerConfig,sequence_length:int,avg_document_length:int,seed:int=SEED_NOT_SET,num_instances:Optional[int]=None,num_tokens:Optional[int]=None,max_sequence_length:Optional[int]=None,label:Optional[str]=None,work_dir:PathOrStr,):if(num_tokensisNone)==(num_instancesisNone):raiseOLMoConfigurationError("Either num_tokens or num_instances must be set, but not both.")elifnum_tokensisnotNone:assertnum_tokens>0num_instances=num_tokens//sequence_lengthelifnum_instancesisnotNone:assertnum_instances>0num_tokens=num_instances*sequence_lengthassertnum_tokensisnotNonesuper().__init__(work_dir=work_dir,sequence_length=sequence_length,max_sequence_length=max_sequence_length,label=label,)self._num_tokens=self.max_sequence_length*(num_tokens//self.max_sequence_length)self._tokenizer=tokenizerself._avg_document_length=avg_document_lengthseed=resolve_seed(seed)assertseedisnotNoneself._seed=seed@propertydefnum_tokens(self)->int:returnself._num_tokens@propertydefseed(self)->int:returnself._seed@propertydefeos_token_id(self)->int:returnself._tokenizer.eos_token_id@propertydefbos_token_id(self)->Optional[int]:returnself._tokenizer.bos_token_id@propertydefpad_token_id(self)->int:returnself._tokenizer.pad_token_id@propertydefvocab_size(self)->int:returnself._tokenizer.vocab_size@propertydefavg_document_length(self)->int:returnself._avg_document_length@ft.cached_propertydefnon_special_token(self)->int:fortoken_idinrange(self.vocab_size):iftoken_idnotin(self.eos_token_id,self.bos_token_id,self.pad_token_id):returntoken_idraiseRuntimeError("No non-special token found in the vocabulary.")@propertydefdtype(self)->NumpyUIntTypes:fordtypein(np.uint8,np.uint16,np.uint32):ifnp.iinfo(dtype).max>=self.vocab_size:returndtypereturnnp.uint64@ft.cached_propertydeffingerprint(self)->str:sha256_hash=hashlib.sha256()sha256_hash.update((f"class={self.__class__.__name__},"f"num_tokens={self.num_tokens},"f"seed={self.seed},"f"max_sequence_length={self.max_sequence_length},"f"eos_token_id={self.eos_token_id},"f"bos_token_id={self.bos_token_id},"f"pad_token_id={self.pad_token_id},"f"vocab_size={self.vocab_size},").encode())returnsha256_hash.hexdigest()
[docs]def__getitem__(self,idx:int)->Instance:idx=self.validate_index(idx)ifself.sequence_length<self.max_sequence_length:base_idx=idx//(self.max_sequence_length//self.sequence_length)else:base_idx=idxseed=self.seed+base_idxrng=get_rng(seed)# Generate random tokens.tokens=rng.integers(0,self.vocab_size,self.max_sequence_length,dtype=self.dtype)# Replace special tokens with non-special tokens.tokens[tokens==self.eos_token_id]=self.non_special_tokentokens[tokens==self.pad_token_id]=self.non_special_tokenifself.bos_token_idisnotNone:tokens[tokens==self.bos_token_id]=self.non_special_token# Inject random document boundaries.num_docs=max(1,round(rng.integers(-3,3)+self.max_sequence_length/self.avg_document_length))tokens[-1]=self.eos_token_idifself.bos_token_idisnotNone:tokens[0]=self.bos_token_idifnum_docs>1:buffer=1ifself.bos_token_idisNoneelse2doc_boundaries=(buffer+rng.permutation(self.max_sequence_length-2*buffer-1)[:num_docs-1])tokens[doc_boundaries]=self.eos_token_idifself.bos_token_idisnotNone:tokens[doc_boundaries+1]=self.bos_token_id# Pull out sub-sequence if needed.ifself.sequence_length<self.max_sequence_length:start_offset=(idx%(self.max_sequence_length//self.sequence_length))*self.sequence_lengthtokens=tokens[start_offset:start_offset+self.sequence_length]return{"input_ids":typing.cast(Sequence[int],tokens)}