Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions spider/py/core/t4_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
INIT_TIMEOUT = 15 # 初始化超时(秒)
IDLE_EXPIRE = 30 * 60 # 实例空闲过期(秒)
CLEAN_INTERVAL = 5 * 60 # 清理间隔(秒)
REQUEST_TIMEOUT = 120 # 单次请求socket超时(秒)
REQUEST_TIMEOUT = 120 # 单次请求 socket 超时(秒)

LOG_LEVEL = os.environ.get("T4_LOG_LEVEL", "INFO").upper()
LOG_FILE = os.environ.get("T4_LOG_FILE") # 若未设置则只打到控制台
PID_FILE = os.environ.get("T4_PID_FILE") # 若设置则会写入PID
PID_FILE = os.environ.get("T4_PID_FILE") # 若设置则会写入 PID

# =========================
# 日志
Expand Down Expand Up @@ -71,7 +71,6 @@
'action': 'action',
}


# =========================
# 工具:长度前缀协议
# =========================
Expand All @@ -87,7 +86,6 @@ def recv_exact(rfile, n: int) -> bytes:
remaining -= len(chunk)
return b"".join(chunks)


def send_packet(wfile, obj: dict):
payload = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if len(payload) > MAX_MSG_SIZE:
Expand All @@ -96,7 +94,6 @@ def send_packet(wfile, obj: dict):
wfile.write(payload)
wfile.flush()


def recv_packet(rfile) -> dict:
header = recv_exact(rfile, 4)
(length,) = struct.unpack(">I", header)
Expand All @@ -105,7 +102,6 @@ def recv_packet(rfile) -> dict:
payload = recv_exact(rfile, length)
return pickle.loads(payload)


# =========================
# Spider 管理
# =========================
Expand All @@ -117,7 +113,6 @@ def __init__(self, spider):
self.init_event = threading.Event()
self.last_used = time.time()


class SpiderManager:
def __init__(self, logger):
self.logger = logger
Expand Down Expand Up @@ -154,7 +149,6 @@ def _parse_env(env_str: str):
proxy_url = data.get("proxyUrl", "") or ""
ext = data.get("ext", "") or ""
except Exception:
# 非JSON字符串时,保持兼容:当作 ext 传
ext = env_str
return proxy_url, ext

Expand All @@ -163,12 +157,10 @@ def _instance_key(self, script_path: str, env_str: str) -> str:
key_data = f"{Path(script_path).resolve()}|{proxy_url}|{ext}"
return hashlib.sha256(key_data.encode("utf-8")).hexdigest()

# ---------- 动态导入 ----------
def _load_module_from_file(self, file_path: Path):
name = file_path.stem
logger.info("_load_module_from_file %s", name)
# 加入项目根目录到 sys.path,保证 base.* 可以被导入
project_root = file_path.parent # 假设 py 是根目录
project_root = file_path.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
logger.info("Added %s to sys.path", project_root)
Expand All @@ -184,7 +176,6 @@ def _import_spider_module(self, script_path: str):
p = Path(script_path)
if p.exists() and p.is_file() and p.suffix == ".py":
return self._load_module_from_file(p)
# 作为模块路径导入(已在 sys.path 中)
return importlib.import_module(script_path)

def _create_spider(self, script_path: str, env_str: str):
Expand Down Expand Up @@ -237,11 +228,9 @@ def _ensure_instance(self, script_path: str, env_str: str) -> SpiderInstance:
return inst

def call(self, script_path: str, method_name: str, env_str: str, args_list):
# 解析 env 中 ext
_, ext = self._parse_env(env_str)
inst = self._ensure_instance(script_path, env_str)

# init 分支:同步初始化
if method_name == "init":
with threading.Lock():
if inst.initializing:
Expand All @@ -258,7 +247,6 @@ def call(self, script_path: str, method_name: str, env_str: str, args_list):
inst.initializing = False
return {"status": "already initialized"}

# 其他方法:若未初始化,则异步触发 + 等待
if not inst.initialized:
if not inst.initializing:
def _bg():
Expand All @@ -267,16 +255,12 @@ def _bg():
inst.initialized = True
inst.init_event.set()
except Exception:
# 失败也置事件,避免永等
inst.init_event.set()

inst.initializing = True
threading.Thread(target=_bg, daemon=True).start()

if not inst.init_event.wait(INIT_TIMEOUT) or not inst.initialized:
return {"success": False, "error": "init timeout or failed"}

# 解析 args
parsed_args = []
for a in (args_list or []):
if isinstance(a, (dict, list, int, float, bool, type(None))):
Expand All @@ -289,14 +273,12 @@ def _bg():
else:
parsed_args.append(a)

# 方法映射
invoke = METHOD_MAP.get(method_name, method_name)
if not hasattr(inst.spider, invoke):
return {"success": False, "error": f"Spider missing method '{invoke}'"}

try:
result = getattr(inst.spider, invoke)(*parsed_args)
# 若 Spider 提供 json2str 则尝试序列化
if result is not None and hasattr(inst.spider, "json2str"):
try:
return inst.spider.json2str(result)
Expand All @@ -311,13 +293,11 @@ def _bg():
"traceback": traceback.format_exc(),
}


# =========================
# Server
# =========================
_manager = SpiderManager(logger)


class T4Handler(StreamRequestHandler):
def handle(self):
self.request.settimeout(REQUEST_TIMEOUT)
Expand All @@ -337,41 +317,50 @@ def handle(self):
resp["error"] = result.get("error")
if result.get("traceback"):
resp["traceback"] = result["traceback"]

send_packet(self.wfile, resp)
except Exception as e:
logger.error("Handle error: %s", e)
try:
send_packet(self.wfile, {"success": False, "error": str(e)})
except Exception:
pass # 对端已断开

pass

class ThreadedTCPServer(ThreadingMixIn, TCPServer):
daemon_threads = True
allow_reuse_address = True


def run():
def _stop(*_):
logger.info("Stopping server ...")
# ✅ 让当前进程成为进程组组长,方便整组 kill
if os.name == "posix":
os.setpgrp()

def _stop(signum, frame):
"""
收到 SIGINT/SIGTERM 时的回调:
1. 停止 SpiderManager
2. 关闭 TCP Server
3. 强杀整个进程组,确保进程立即退出
"""
logger.info("Received %s, shutting down...", signum)
_manager.stop()
# 让 serve_forever() 退出
srv.shutdown()
# 立即强制退出,防止线程阻塞
os._exit(0)

# 注册信号
if os.name == "posix":
signal.signal(signal.SIGTERM, _stop)
signal.signal(signal.SIGINT, _stop)
signal.signal(signal.SIGTERM, _stop)

global srv
srv = ThreadedTCPServer((HOST, PORT), T4Handler)
logger.info("T4 daemon listening on %s:%d", HOST, PORT)

try:
srv.serve_forever(poll_interval=0.5)
finally:
srv.server_close()
logger.info("Server closed.")


if __name__ == "__main__":
run()