Make skill file writes atomic
This commit is contained in:
parent
d63b363cde
commit
566aeaeefa
1 changed files with 41 additions and 8 deletions
|
|
@ -37,6 +37,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
@ -190,6 +191,38 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||||
|
"""
|
||||||
|
Atomically write text content to a file.
|
||||||
|
|
||||||
|
Uses a temporary file in the same directory and os.replace() to ensure
|
||||||
|
the target file is never left in a partially-written state if the process
|
||||||
|
crashes or is interrupted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Target file path
|
||||||
|
content: Content to write
|
||||||
|
encoding: Text encoding (default: utf-8)
|
||||||
|
"""
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fd, temp_path = tempfile.mkstemp(
|
||||||
|
dir=str(file_path.parent),
|
||||||
|
prefix=f".{file_path.name}.tmp.",
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, "w", encoding=encoding) as f:
|
||||||
|
f.write(content)
|
||||||
|
os.replace(temp_path, file_path)
|
||||||
|
except Exception:
|
||||||
|
# Clean up temp file on error
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Core actions
|
# Core actions
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -218,9 +251,9 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
||||||
skill_dir = _resolve_skill_dir(name, category)
|
skill_dir = _resolve_skill_dir(name, category)
|
||||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Write SKILL.md
|
# Write SKILL.md atomically
|
||||||
skill_md = skill_dir / "SKILL.md"
|
skill_md = skill_dir / "SKILL.md"
|
||||||
skill_md.write_text(content, encoding="utf-8")
|
_atomic_write_text(skill_md, content)
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(skill_dir)
|
scan_error = _security_scan_skill(skill_dir)
|
||||||
|
|
@ -256,13 +289,13 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||||
skill_md = existing["path"] / "SKILL.md"
|
skill_md = existing["path"] / "SKILL.md"
|
||||||
# Back up original content for rollback
|
# Back up original content for rollback
|
||||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||||
skill_md.write_text(content, encoding="utf-8")
|
_atomic_write_text(skill_md, content)
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(existing["path"])
|
scan_error = _security_scan_skill(existing["path"])
|
||||||
if scan_error:
|
if scan_error:
|
||||||
if original_content is not None:
|
if original_content is not None:
|
||||||
skill_md.write_text(original_content, encoding="utf-8")
|
_atomic_write_text(skill_md, original_content)
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -342,12 +375,12 @@ def _patch_skill(
|
||||||
}
|
}
|
||||||
|
|
||||||
original_content = content # for rollback
|
original_content = content # for rollback
|
||||||
target.write_text(new_content, encoding="utf-8")
|
_atomic_write_text(target, new_content)
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(skill_dir)
|
scan_error = _security_scan_skill(skill_dir)
|
||||||
if scan_error:
|
if scan_error:
|
||||||
target.write_text(original_content, encoding="utf-8")
|
_atomic_write_text(target, original_content)
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
||||||
replacements = count if replace_all else 1
|
replacements = count if replace_all else 1
|
||||||
|
|
@ -394,13 +427,13 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||||
target.parent.mkdir(parents=True, exist_ok=True)
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# Back up for rollback
|
# Back up for rollback
|
||||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||||
target.write_text(file_content, encoding="utf-8")
|
_atomic_write_text(target, file_content)
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(existing["path"])
|
scan_error = _security_scan_skill(existing["path"])
|
||||||
if scan_error:
|
if scan_error:
|
||||||
if original_content is not None:
|
if original_content is not None:
|
||||||
target.write_text(original_content, encoding="utf-8")
|
_atomic_write_text(target, original_content)
|
||||||
else:
|
else:
|
||||||
target.unlink(missing_ok=True)
|
target.unlink(missing_ok=True)
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue