import numpy as np
import os
import torch
import functools
import torch
import subprocess
from torch import Tensor
import logging
# LUTNet
import itertools
use_cuda = torch.cuda.is_available()
torch_cuda = torch.cuda if use_cuda else torch
device = torch.device("cuda:0" if use_cuda else "cpu")
logger = logging.getLogger(__name__)
[docs]
def is_tensor(x):
return torch.is_tensor(x)
[docs]
def to_numpy(x):
if use_cuda:
x = x.cpu()
return x.detach().numpy()
[docs]
def to_numpy_if_tensor(x):
if is_tensor(x):
return to_numpy(x)
return x
[docs]
def to_tensor(x):
return torch.from_numpy(x).to(device)
[docs]
def to_tensor_if_numpy(x):
if isinstance(x, np.ndarray):
return to_tensor(x)
return x
[docs]
def copy_weights(src_weight: Tensor, tgt_weight: Tensor):
with torch.no_grad():
tgt_weight.copy_(src_weight)
[docs]
def get_checkpoint_file(checkpoint_dir):
for file in os.listdir(checkpoint_dir):
if file.endswith(".ckpt"):
return file
[docs]
def execute_cli(cmd, log_output: bool = True, log_file=None, cwd="."):
if log_output:
logger.debug("{} (cwd = {})".format(subprocess.list2cmdline(cmd), cwd))
with subprocess.Popen(
cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, cwd=cwd
) as result:
if log_file:
f = open(log_file, "w")
if result.stdout or result.stderr:
logger.info("")
if result.stdout:
for line in result.stdout:
if log_file:
f.write(line)
line = line.rstrip("\n")
# logger.trace(line)
if result.stderr:
for line in result.stderr:
if log_file:
f.write(line)
line = line.rstrip("\n")
# logger.trace(line)
if log_file:
f.close()
else:
result = subprocess.run(cmd, stdout=subprocess.DEVNULL, cwd=cwd)
return result.returncode
[docs]
def get_factors(n):
factors = np.sort(
list(
set(
functools.reduce(
list.__add__,
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
)
)
)
)
factors = [int(x) for x in factors]
return factors
# ---------------------------------------------
# LUTNet helpers
# ---------------------------------------------
[docs]
def generate_truth_table(k: int, tables_count: int, device: None) -> torch.Tensor:
"""This function generate truth tables with size of k * (2**k) * tables_count
Args:
k (int): truth table power
tables_count (int): number of truth table repetition
device (str): target device of the result
Returns:
torch.Tensor: 2d torch tensor with k*tables_count rows and (2**k) columns
"""
table = torch.from_numpy(np.array(list(itertools.product([-1, 1], repeat=k)))).T
return torch.vstack([table] * tables_count).to(device)
[docs]
def init_LinearLUT_weight(
levels,
k,
original_pruning_mask,
original_weight,
in_features,
out_features,
new_module,
):
# Initialize the weight based on the trained binaried network
# weight shape of the lagrange trainer [tables_count, self.kk]
input_mask = new_module.input_mask.reshape(
-1, k * in_features
) # (out_feature, k * in_feature)
expanded_original_weight = original_weight[
np.arange(out_features)[:, np.newaxis], input_mask
].reshape(-1, k, 1)
index_weight, reconnected_weight = (
expanded_original_weight[:, 0, :],
expanded_original_weight[:, 1:, :],
) # [input_feature * output_feature, 1]
# Establish pruning mask
expanded_pruning_masks = original_pruning_mask[
np.arange(out_features)[:, np.newaxis], input_mask
].reshape(
-1, k, 1
) # (out_feature * in_feature, k, 1)
pruned_connection = expanded_pruning_masks[:, 0, :]
d = generate_truth_table(k=k, tables_count=1, device=None)
initialized_weight = index_weight * d[0, :]
for extra_input_index in range(1, k):
pruned_extra_input = ~(
expanded_pruning_masks[:, extra_input_index, :].squeeze().bool()
)
initialized_weight[pruned_extra_input, :] = (
initialized_weight[pruned_extra_input, :]
+ (reconnected_weight * d[extra_input_index, :]).squeeze()[
pruned_extra_input, :
]
)
initialized_weight = torch.cat([initialized_weight] * levels, dim=0)
pruned_connection = torch.cat([pruned_connection] * levels, dim=0)
return initialized_weight, pruned_connection
[docs]
def init_Conv2dLUT_weight(
levels,
k,
original_pruning_mask,
original_weight,
out_channels,
in_channels,
kernel_size,
new_module,
):
# Initialize the weight based on the trained binaried network
# weight shape of the lagrange trainer [tables_count, self.kk]
input_mask = new_module.input_mask.reshape(
-1,
in_channels * kernel_size[0] * kernel_size[1] * k,
3,
) # [oc, k * kh * kw * ic ,3[ic,kh,kw]]
expanded_original_weight = original_weight[
np.arange(out_channels)[:, np.newaxis],
input_mask[:, :, 0],
input_mask[:, :, 1],
input_mask[:, :, 2],
].reshape(
-1, k, 1
) # [oc * ic * kw * kh , k, 1]
index_weight, reconnected_weight = (
expanded_original_weight[:, 0, :],
expanded_original_weight[:, 1:, :],
)
# Establish pruning mask
expanded_pruning_masks = original_pruning_mask[
np.arange(out_channels)[:, np.newaxis],
input_mask[:, :, 0],
input_mask[:, :, 1],
input_mask[:, :, 2],
].reshape(
-1, k, 1
) # (out_feature * in_feature, k, 1)
pruned_connection = expanded_pruning_masks[
:, 0, :
] # [input_feature * output_feature, 1]
d = generate_truth_table(k=k, tables_count=1, device=None)
initialized_weight = index_weight * d[0, :]
for extra_input_index in range(1, k):
pruned_extra_input = ~(
expanded_pruning_masks[:, extra_input_index, :].squeeze().bool()
)
initialized_weight[pruned_extra_input, :] = (
initialized_weight[pruned_extra_input, :]
+ (reconnected_weight * d[extra_input_index, :]).squeeze()[
pruned_extra_input, :
]
)
initialized_weight = torch.cat([initialized_weight] * levels, dim=0)
pruned_connection = torch.cat([pruned_connection] * levels, dim=0)
return initialized_weight, pruned_connection
[docs]
def nested_dict_replacer(compound_dict, fn):
def _finditem(obj):
for k, v in obj.items():
if isinstance(v, dict):
_finditem(v) # added return statement
else:
obj[k] = fn(v)
_finditem(compound_dict)
return compound_dict
[docs]
def parse_accelerator(accelerator: str):
if accelerator == "auto":
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
elif accelerator in ("gpu", torch.device("cuda:0")):
device = torch.device("cuda:0")
elif accelerator in ("cpu", torch.device("cpu")):
device = torch.device("cpu")
else:
raise RuntimeError(f"Unsupported accelerator {accelerator}")
return device
[docs]
def set_excepthook():
import sys, pdb, traceback
def excepthook(exc_type, exc_value, exc_traceback):
traceback.print_exception(exc_type, exc_value, exc_traceback)
print("\nEntering debugger...")
pdb.post_mortem(exc_traceback)
sys.excepthook = excepthook
[docs]
def deepsetattr(obj, attr, value):
"""Recurses through an attribute chain to set the ultimate value."""
attrs = attr.split(".")
if len(attrs) > 1:
deepsetattr(getattr(obj, attrs[0]), ".".join(attrs[1:]), value)
else:
setattr(obj, attr, value)
[docs]
def deepgetattr(obj, attr, default=None):
"""Recurses through an attribute chain to get the ultimate value."""
try:
return functools.reduce(getattr, attr.split("."), obj)
except AttributeError:
return default