diff --git a/hugginggpt/server/models_server.py b/hugginggpt/server/models_server.py index 2d7c2a38..7b4e44b2 100644 --- a/hugginggpt/server/models_server.py +++ b/hugginggpt/server/models_server.py @@ -29,7 +29,7 @@ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector from controlnet_aux.open_pose.body import Body from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large -from controlnet_aux.hed import Network +from controlnet_aux.hed import ControlNetHED_Apache2 from transformers import DPTForDepthEstimation, DPTFeatureExtractor import warnings import time @@ -279,8 +279,10 @@ def mlsd_control_network(): model.load_state_dict(torch.load(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"), strict=True) return MLSDdetector(model) - - hed_network = Network(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth") + model_path = f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/ControlNetHED.pth" + hed_network = ControlNetHED_Apache2() + hed_network.load_state_dict(torch.load(model_path, map_location="cpu")) + hed_network.float().eval() controlnet_sd_pipes = { "openpose-control": { @@ -632,4 +634,4 @@ def models(model_id): if not os.path.exists("public/videos"): os.makedirs("public/videos") - waitress.serve(app, host="0.0.0.0", port=port) \ No newline at end of file + waitress.serve(app, host="0.0.0.0", port=port)