Skip to content

Latest commit

 

History

History
executable file
·
40 lines (32 loc) · 1.53 KB

README.md

File metadata and controls

executable file
·
40 lines (32 loc) · 1.53 KB

Towards Faithful Dialogs via Focus Learning 论文代码正在整理中,先贴出FCE 的核心计算代码片段

论文核心代码:

    class CosineSimilarity(torch.nn.Module):
        def forward(self, tensor_1, tensor_2):
            normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True)
            normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True)
            return (normalized_tensor_1 * normalized_tensor_2).sum(dim=-1)
    cal_sim = CosineSimilarity()
    knowledge_emb = model.get_input_embeddings()(
        knowledges
    )
    sim_dist = -cal_sim(knowledge_emb, labels_emb)
    sim_score = -torch.log(sim_dist + 1 + self.config.get("fce_lamda", 0.01))+ 1
    
    weighted_lm_logits = torch.mul(sim_score.unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits)
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    fce_loss = loss_fct(weighted_lm_logits.view(-1, weighted_lm_logits.size(-1)),
                                   torch.where(labels == self.tokenizer.pad_token_id, -100, labels).view(-1))