Files
recscripts/esrgan.py
T
2026-04-05 18:07:21 +02:00

304 lines
10 KiB
Python

import argparse
import os
import queue
import subprocess
import sys
import tempfile
import threading
import time
import urllib.request
import zipfile
from pathlib import Path
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
console = Console(stderr=True)
MODELS = {
"x4": ("realesrgan-x4plus", 4),
"x4-anime": ("realesrgan-x4plus-anime", 4),
"x2": ("realesr-animevideov3-x2", 2),
"x3": ("realesr-animevideov3-x3", 3),
"video-x2": ("realesr-animevideov3-x2", 2),
"video-x3": ("realesr-animevideov3-x3", 3),
"video-x4": ("realesr-animevideov3-x4", 4),
}
_RELEASE = "realesrgan-ncnn-vulkan-20220424"
_DOWNLOAD_URLS = {
"win32": f"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/{_RELEASE}-windows.zip",
"linux": f"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/{_RELEASE}-ubuntu.zip",
"darwin": f"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/{_RELEASE}-macos.zip",
}
_EXE_NAME = "realesrgan-ncnn-vulkan.exe" if sys.platform == "win32" else "realesrgan-ncnn-vulkan"
def _download_ncnn(tools: Path):
platform = sys.platform
url = _DOWNLOAD_URLS.get(platform)
if not url:
console.print(f"[red]no download available for platform: {platform}[/red]")
sys.exit(1)
tools.mkdir(parents=True, exist_ok=True)
zip_path = tools / f"{_RELEASE}.zip"
console.print(f"[dim]downloading {url}[/dim]")
with Progress(
TextColumn("{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task("Downloading ", total=None)
def _reporthook(block: int, block_size: int, total: int):
if total > 0:
progress.update(task, total=total, completed=min(block * block_size, total))
urllib.request.urlretrieve(url, zip_path, reporthook=_reporthook)
console.print("[dim]extracting...[/dim]")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(tools)
zip_path.unlink()
if sys.platform != "win32":
exe = tools / _RELEASE / _EXE_NAME
exe.chmod(exe.stat().st_mode | 0o111)
def find_executable() -> Path:
tools = Path(__file__).parent / ".tools"
candidates = sorted(tools.glob(f"{_RELEASE}-*"))
if not candidates:
_download_ncnn(tools)
candidates = sorted(tools.glob(f"{_RELEASE}-*"))
exe = candidates[-1] / _EXE_NAME
if not exe.exists():
console.print(f"[red]executable not found: {exe}[/red]")
sys.exit(1)
return exe
def _parse_time(t: str) -> float:
parts = t.split(":")
return sum(float(p) * 60 ** i for i, p in enumerate(reversed(parts)))
def n_frames_to_seconds(ss: str | None, to: str | None, total: float) -> float:
start = _parse_time(ss) if ss else 0.0
end = _parse_time(to) if to else total
return max(0.0, end - start)
def probe_video(path: str) -> dict:
result = subprocess.run(
[
"ffprobe", "-v", "error", "-select_streams", "v:0",
"-show_entries", "stream=r_frame_rate,width,height:format=duration",
"-of", "default=noprint_wrappers=1",
path,
],
capture_output=True, text=True,
)
info: dict[str, str] = {}
for line in result.stdout.splitlines():
k, _, v = line.partition("=")
info[k.strip()] = v.strip()
num, _, den = info.get("r_frame_rate", "30/1").partition("/")
fps = float(num) / float(den or 1)
return {
"fps": fps,
"duration": float(info.get("duration", 0)),
"width": int(info.get("width", 0)),
"height": int(info.get("height", 0)),
}
def _pipe_reader(stream, q: queue.Queue):
for line in iter(stream.readline, b""):
q.put(line.decode(errors="replace").strip())
q.put(None)
def ffmpeg_with_progress(cmd: list[str], label: str, total: float | None) -> int:
t = total or 100.0
with Progress(
TextColumn("{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task(label, total=t)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
q: queue.Queue = queue.Queue()
threading.Thread(target=_pipe_reader, args=(proc.stdout, q), daemon=True).start()
threading.Thread(target=_pipe_reader, args=(proc.stderr, queue.Queue()), daemon=True).start()
try:
while True:
line = q.get()
if line is None:
break
key, _, value = line.partition("=")
if key == "out_time" and total:
try:
h, m, s = value.split(":")
elapsed_s = int(h) * 3600 + int(m) * 60 + float(s)
progress.update(task, completed=min(elapsed_s, t))
except ValueError:
pass
elif key == "progress" and value == "end":
progress.update(task, completed=t)
except KeyboardInterrupt:
proc.terminate()
proc.wait()
console.print("\n[yellow]cancelled[/yellow]")
sys.exit(130)
proc.wait()
return proc.returncode
def extract_frames(input_file: str, frames_dir: Path, ss: str | None, to: str | None, duration: float | None) -> int:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
if ss:
cmd += ["-ss", ss]
if to:
cmd += ["-to", to]
cmd += ["-i", input_file, str(frames_dir / "%08d.png")]
cmd += ["-progress", "pipe:1", "-stats_period", "0.5"]
rc = ffmpeg_with_progress(cmd, "Extracting ", duration)
if rc != 0:
console.print("[red]frame extraction failed[/red]")
sys.exit(rc)
return len(sorted(frames_dir.glob("*.png")))
def upscale_frames(exe: Path, frames_dir: Path, out_dir: Path, model: str, scale: int, gpu_id: int, n_frames: int):
cmd = [
str(exe),
"-i", str(frames_dir),
"-o", str(out_dir),
"-n", model,
"-s", str(scale),
"-g", str(gpu_id),
"-f", "png",
]
with Progress(
TextColumn("{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task("Upscaling ", total=n_frames)
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
stop = threading.Event()
def _watcher():
while not stop.is_set():
count = len(list(out_dir.glob("*.png")))
progress.update(task, completed=count)
time.sleep(0.5)
threading.Thread(target=_watcher, daemon=True).start()
try:
proc.wait()
except KeyboardInterrupt:
proc.terminate()
proc.wait()
stop.set()
console.print("\n[yellow]cancelled[/yellow]")
sys.exit(130)
stop.set()
if proc.returncode != 0:
console.print("[red]upscaling failed[/red]")
sys.exit(proc.returncode)
def assemble_video(frames_dir: Path, original: str, output: str, fps: float, clip_duration: float | None, crf: int, ss: str | None, to: str | None):
audio_input: list[str] = []
if ss:
audio_input += ["-ss", ss]
if to:
audio_input += ["-to", to]
audio_input += ["-i", original]
cmd = [
"ffmpeg", "-hide_banner", "-loglevel", "error", "-y",
"-framerate", str(fps),
"-i", str(frames_dir / "%08d.png"),
*audio_input,
"-map", "0:v", "-map", "1:a?",
"-c:v", "libx264", "-crf", str(crf), "-preset", "slow",
"-c:a", "copy",
"-progress", "pipe:1", "-stats_period", "0.5",
output,
]
rc = ffmpeg_with_progress(cmd, "Assembling ", clip_duration)
if rc != 0:
console.print("[red]assembly failed[/red]")
sys.exit(rc)
def main():
parser = argparse.ArgumentParser(description="Real-ESRGAN video upscaler (NCNN)")
parser.add_argument("input", help="Input video file")
parser.add_argument("output", help="Output video file")
parser.add_argument("-ss", help="Start time (e.g. 00:01:30 or 90)")
parser.add_argument("-to", help="End time (e.g. 00:02:00 or 120)")
parser.add_argument("--model", default="x4", choices=list(MODELS), help="Model (default: x4)")
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID, -1 for CPU (default: 0)")
parser.add_argument("--crf", type=int, default=18, help="Output CRF quality (default: 18)")
args = parser.parse_args()
exe = find_executable()
model_name, scale = MODELS[args.model]
info = probe_video(args.input)
console.print(
f"[dim]{Path(args.input).name}[/dim] "
f"[cyan]{info['width']}x{info['height']}[/cyan] "
f"[cyan]{info['fps']:.2f} fps[/cyan] "
f"[cyan]{info['duration']:.1f}s[/cyan] "
f"[dim]→ {model_name} ({scale}x)[/dim]"
)
with tempfile.TemporaryDirectory(prefix="esrgan_in_") as frames_tmp, \
tempfile.TemporaryDirectory(prefix="esrgan_out_") as out_tmp:
frames_dir = Path(frames_tmp)
out_dir = Path(out_tmp)
clip_duration = n_frames_to_seconds(args.ss, args.to, info["duration"])
n = extract_frames(args.input, frames_dir, args.ss, args.to, clip_duration)
console.print(f"[dim]{n} frames[/dim]")
upscale_frames(exe, frames_dir, out_dir, model_name, scale, args.gpu, n)
assemble_video(out_dir, args.input, args.output, info["fps"], clip_duration, args.crf, args.ss, args.to)
console.print(f"[green]done[/green] {args.output}")
if __name__ == "__main__":
main()