Loading regstar_library/strategy.py 0 → 100644 +65 −0 Original line number Diff line number Diff line # Standard Imports from typing import Annotated # Thirty-party Imports import torch # Local Imports from graph.augmented_graph import Mask SavedStrategy = list[list[float]] StrategyTensor = Annotated[torch.Tensor, "(..., Augnodes, Augnodes)"] class Strategy(torch.nn.Module): mask: StrategyTensor def __init__( self, mask: Mask | None, rounding_threshold: float, probs: SavedStrategy | StrategyTensor | None = None, ) -> None: super().__init__() if mask is None: assert probs is not None probs_tensor: StrategyTensor = ( probs if torch.is_tensor(probs) else torch.tensor(probs) ) mask_tensor = probs_tensor > 0 else: mask_tensor = torch.tensor(mask) mask_tensor.requires_grad = False self.register_buffer(name="mask", tensor=mask_tensor) self.params = self._init_params(probs) self.threshold = rounding_threshold def __call__(self) -> StrategyTensor: return super().__call__() def _init_params( self, probs: SavedStrategy | StrategyTensor | None ) -> torch.nn.Parameter: if probs is None: probs_tensor = torch.rand(self.mask.shape) elif torch.is_tensor(probs): probs_tensor = probs.detach().clone() else: probs_tensor = torch.tensor(probs) probs_tensor.requires_grad = True return torch.nn.Parameter(torch.log(probs_tensor)) def forward(self) -> StrategyTensor: x = torch.exp(self.params) * self.mask strategy = torch.nn.functional.normalize(x, p=1.0, dim=-1) if self.training: return strategy strategy = torch.threshold(strategy, threshold=self.threshold, value=0.0) strategy = torch.nn.functional.normalize(strategy, p=1.0, dim=-1) return strategy Loading
regstar_library/strategy.py 0 → 100644 +65 −0 Original line number Diff line number Diff line # Standard Imports from typing import Annotated # Thirty-party Imports import torch # Local Imports from graph.augmented_graph import Mask SavedStrategy = list[list[float]] StrategyTensor = Annotated[torch.Tensor, "(..., Augnodes, Augnodes)"] class Strategy(torch.nn.Module): mask: StrategyTensor def __init__( self, mask: Mask | None, rounding_threshold: float, probs: SavedStrategy | StrategyTensor | None = None, ) -> None: super().__init__() if mask is None: assert probs is not None probs_tensor: StrategyTensor = ( probs if torch.is_tensor(probs) else torch.tensor(probs) ) mask_tensor = probs_tensor > 0 else: mask_tensor = torch.tensor(mask) mask_tensor.requires_grad = False self.register_buffer(name="mask", tensor=mask_tensor) self.params = self._init_params(probs) self.threshold = rounding_threshold def __call__(self) -> StrategyTensor: return super().__call__() def _init_params( self, probs: SavedStrategy | StrategyTensor | None ) -> torch.nn.Parameter: if probs is None: probs_tensor = torch.rand(self.mask.shape) elif torch.is_tensor(probs): probs_tensor = probs.detach().clone() else: probs_tensor = torch.tensor(probs) probs_tensor.requires_grad = True return torch.nn.Parameter(torch.log(probs_tensor)) def forward(self) -> StrategyTensor: x = torch.exp(self.params) * self.mask strategy = torch.nn.functional.normalize(x, p=1.0, dim=-1) if self.training: return strategy strategy = torch.threshold(strategy, threshold=self.threshold, value=0.0) strategy = torch.nn.functional.normalize(strategy, p=1.0, dim=-1) return strategy