agatha.ml.gpt2_finetune.gpt2_finetune module

agatha.ml.gpt2_finetune.gpt2_finetune.abstract_record_to_string(abstract)
Return type

str

agatha.ml.gpt2_finetune.gpt2_finetune.collate_token_batch(tokens, include_labels=True, device=None)
Return type

Dict[str, Any]

agatha.ml.gpt2_finetune.gpt2_finetune.weighted_index_sample(weights, omit_small_terms=False)

Performs weighted sample of weights. Returns index. :type weights: FloatTensor :param weights: len <vocab_size> :type omit_small_terms: bool :param omit_small_terms: If words have a weighted probability less than

1/len(weights) they will not be considered.

Return type

List[int]

Returns

weighted sample index for each input in batch