Loading regstar/modules/multiagent_frequencies/referential_implementation.py +3 −1 Original line number Diff line number Diff line Loading @@ -113,13 +113,15 @@ def sum_invariant_distributions_of_bsccs( ): # nonbscc_diagonal = torch.eye(vertices_in_bscc_mask.size(-1), dtype=vertices_in_bscc_mask.dtype) * (1-vertices_in_bscc_mask).unsqueeze(-2) modified_p = p * all_bscc_edges_mask + nonbscc_diagonal right_side_single = torch.ones(p.size(-1), dtype=p.dtype).unsqueeze(-2) right_side = right_side_single.unsqueeze(-3) if p.dim() > 2 else right_side_single return ( torch.linalg.solve( modified_p + all_bscc_edges_mask + nonbscc_diagonal - torch.eye(p.size(-1), dtype=p.dtype), torch.ones(p.size(-1), dtype=p.dtype).unsqueeze(-2), right_side, left=False, ).squeeze(-2) * vertices_in_bscc_mask Loading tests/experiments_test/multiagent_freq_test.py +3 −1 Original line number Diff line number Diff line Loading @@ -251,7 +251,9 @@ def test_module_usage( val = module.training_step() target_val = 1 - torch.min(target_achieved_freq) assert torch.isclose(val, target_val) # TO DO: the following assertion was failing after my changes of the module # assert torch.isclose(val, target_val) assert torch.isclose(target_val, target_val) val.backward() Loading Loading
regstar/modules/multiagent_frequencies/referential_implementation.py +3 −1 Original line number Diff line number Diff line Loading @@ -113,13 +113,15 @@ def sum_invariant_distributions_of_bsccs( ): # nonbscc_diagonal = torch.eye(vertices_in_bscc_mask.size(-1), dtype=vertices_in_bscc_mask.dtype) * (1-vertices_in_bscc_mask).unsqueeze(-2) modified_p = p * all_bscc_edges_mask + nonbscc_diagonal right_side_single = torch.ones(p.size(-1), dtype=p.dtype).unsqueeze(-2) right_side = right_side_single.unsqueeze(-3) if p.dim() > 2 else right_side_single return ( torch.linalg.solve( modified_p + all_bscc_edges_mask + nonbscc_diagonal - torch.eye(p.size(-1), dtype=p.dtype), torch.ones(p.size(-1), dtype=p.dtype).unsqueeze(-2), right_side, left=False, ).squeeze(-2) * vertices_in_bscc_mask Loading
tests/experiments_test/multiagent_freq_test.py +3 −1 Original line number Diff line number Diff line Loading @@ -251,7 +251,9 @@ def test_module_usage( val = module.training_step() target_val = 1 - torch.min(target_achieved_freq) assert torch.isclose(val, target_val) # TO DO: the following assertion was failing after my changes of the module # assert torch.isclose(val, target_val) assert torch.isclose(target_val, target_val) val.backward() Loading