-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_network.py
More file actions
69 lines (66 loc) · 2.81 KB
/
build_network.py
File metadata and controls
69 lines (66 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from wcode.net.CNN.VNet.VNet import VNet
from wcode.net.CNN.UNet.UNet import UNet
from wcode.net.CNN.ResUNet.ResUNet import ResUNet
from wcode.net.CNN.DFUNet.DFUNet import DFUNet
from wcode.net.CNN.STUNet.STUNet import build_STUNet
from wcode.net.Vision_Transformer.SAM.build_sam import sam_model_registry
from wcode.inferring.utils.load_pretrain_weight import load_pretrained_weights
def build_network(network_settings: dict):
if network_settings["label"].lower() in ["vnet", "unet", "resunet", "dfunet"]:
if network_settings["label"].lower() == "vnet":
network_class = VNet
elif network_settings["label"].lower() == "unet":
network_class = UNet
elif network_settings["label"].lower() == "resunet":
network_class = ResUNet
elif network_settings["label"].lower() == "dfunet":
network_class = DFUNet
else:
raise Exception("Unsupported network class.")
if (
network_settings.__contains__("weight_path")
and network_settings["weight_path"] is not None
):
net = network_class(network_settings)
print("Loading weight from:", network_settings["weight_path"])
load_pretrained_weights(
net,
network_settings["weight_path"],
verbose=True,
)
return net
else:
return network_class(network_settings)
elif network_settings["label"].lower() == "stunet":
if (
network_settings.__contains__("weight_path")
and network_settings["weight_path"] is not None
):
print("Loading weight from:", network_settings["weight_path"])
return build_STUNet(
network_settings["in_channels"],
network_settings["out_channels"],
network_settings["pool_kernel_size"],
network_settings["deep_supervision"],
network_settings["model_registry"],
network_settings["weight_path"],
)
else:
return build_STUNet(
network_settings["in_channels"],
network_settings["out_channels"],
network_settings["pool_kernel_size"],
network_settings["deep_supervision"],
network_settings["model_registry"],
)
elif network_settings["label"].lower() == "sam":
if (
network_settings.__contains__("weight_path")
and network_settings["weight_path"] is not None
):
print("Loading weight from:", network_settings["weight_path"])
return sam_model_registry["model_registry"](
checkpoint=network_settings["weight_path"]
)
else:
return sam_model_registry["model_registry"]()