Files
deamonkai fc94008530 initial
2026-01-23 12:11:21 -06:00

238 lines
8.4 KiB
Python

"""CSR creation helpers using the OpenSSL CLI.
Provides utilities to create a CSR from a private key file or PEM content,
supporting Subject fields and SubjectAltName entries.
"""
from typing import List, Optional, Dict
import shutil
import subprocess
import tempfile
import os
class OpenSSLNotFound(Exception):
pass
def _openssl_bin() -> str:
path = shutil.which("openssl")
if not path:
raise OpenSSLNotFound("`openssl` not found in PATH")
return path
def _build_subject_string(subject: Dict[str, str]) -> str:
# Accepts keys like C, ST, L, O, OU, CN
parts = []
for k in ["C", "ST", "L", "O", "OU", "CN"]:
v = subject.get(k)
if v:
parts.append(f"/{k}={v}")
return "".join(parts)
def normalize_and_validate_san(s: str) -> str:
"""Normalize and validate a SAN entry.
Accepts 'dns.example.com', 'IP', or prefixed 'DNS:...' / 'IP:...'.
Returns 'DNS:...' or 'IP:...' normalized string. Raises ValueError on invalid input.
"""
import ipaddress
s = s.strip()
if not s:
raise ValueError("empty SAN")
if s.upper().startswith("DNS:") or s.upper().startswith("IP:"):
prefix, val = s.split(":", 1)
val = val.strip()
else:
val = s
prefix = None
# Try IP first
try:
ipaddress.ip_address(val)
return f"IP:{val}"
except Exception:
pass
# Validate DNS name roughly (allow wildcard prefixes like '*.example.com')
import re
# simplified hostname regex (labels separated by dots)
HOST_RE = re.compile(r"^(?:[A-Za-z0-9](?:[A-Za-z0-9-]{0,61}[A-Za-z0-9])?)(?:\.(?:[A-Za-z0-9](?:[A-Za-z0-9-]{0,61}[A-Za-z0-9])?))*$")
if val.startswith("*."):
rest = val[2:]
if HOST_RE.match(rest):
return f"DNS:{val}"
if HOST_RE.match(val):
return f"DNS:{val}"
raise ValueError(f"Invalid SAN value: {s}")
def create_csr_from_key(key_path: str, subject: Optional[Dict[str, str]] = None, sans: Optional[List[str]] = None, passphrase: Optional[str] = None) -> str:
"""Create a CSR using OpenSSL and return the CSR PEM text.
- `key_path` must point to a private key file (PEM); if it is encrypted,
supply `passphrase` which will be passed to OpenSSL via `-passin`.
- `subject` may be provided either as a dict with fields like C, ST, L, O, OU, CN
or as a subject string starting with a leading slash, e.g. '/C=US/ST=CA/CN=example.com'.
- `sans` is a list of DNS names/IPs for SubjectAltName.
"""
openssl = _openssl_bin()
if isinstance(subject, dict):
subj = _build_subject_string(subject)
else:
subj = subject
# If sans look like ['DNS:...','IP:...'] normalize to plain names for config
norm_sans = []
if sans:
for s in sans:
# Normalize and validate; allow ValueError to bubble to caller
ns = normalize_and_validate_san(s)
norm_sans.append(ns)
with tempfile.TemporaryDirectory() as td:
conf_path = os.path.join(td, "csr.conf")
csr_path = os.path.join(td, "req.csr")
# Build minimal OpenSSL config with SANs if provided
conf_lines = ["[ req ]", "distinguished_name = req_distinguished_name", "prompt = no"]
if norm_sans:
conf_lines.append("req_extensions = v3_req")
conf_lines.append("")
conf_lines.append("[ req_distinguished_name ]")
# No need to fill DN here when using -subj, but keep the section present
conf_lines.append("")
if norm_sans:
conf_lines.append("[ v3_req ]")
conf_lines.append(f"subjectAltName = {', '.join(norm_sans)}")
with open(conf_path, "w", encoding="utf-8") as f:
f.write("\n".join(conf_lines))
cmd = [openssl, "req", "-new", "-key", key_path, "-out", csr_path, "-config", conf_path]
if subj:
cmd += ["-subj", subj]
if passphrase:
cmd += ["-passin", f"pass:{passphrase}"]
subprocess.run(cmd, check=True, capture_output=True)
with open(csr_path, "r", encoding="utf-8") as f:
return f.read()
def csr_has_san(csr_pem: str, san: str) -> bool:
"""Return True if the CSR PEM contains the provided SAN entry.
This uses `openssl req -in - -noout -text` to parse the CSR contents.
"""
openssl = _openssl_bin()
p = subprocess.run([openssl, "req", "-in", "/dev/stdin", "-noout", "-text"], input=csr_pem.encode("utf-8"), check=True, capture_output=True)
out = p.stdout.decode("utf-8")
# Accept a few representations: 'DNS:...' or 'IP:...' -> just look for the name/value
if san.startswith("DNS:"):
return san[4:] in out
if san.startswith("IP:"):
return san[3:] in out
return san in out
def extract_subject_and_sans_from_cert(cert_pem: str) -> (Dict[str, str], List[str]):
"""Extract subject fields and SANs from a certificate PEM using openssl.
Returns (subject_dict, san_list)
"""
openssl = _openssl_bin()
# Get subject in RFC2253 style for easier parsing
p = subprocess.run([openssl, "x509", "-in", "/dev/stdin", "-noout", "-subject", "-nameopt", "RFC2253"], input=cert_pem.encode("utf-8"), check=True, capture_output=True)
subj_out = p.stdout.decode("utf-8").strip()
# subj_out is like: "subject=CN=example.com,O=Example,C=US"
subj = subj_out.split("=", 1)[1] if "=" in subj_out else ""
subject_parts = {}
for part in subj.split(","):
if "=" in part:
k, v = part.split("=", 1)
subject_parts[k.strip()] = v.strip()
# Get SANs from the text representation
p2 = subprocess.run([openssl, "x509", "-in", "/dev/stdin", "-noout", "-text"], input=cert_pem.encode("utf-8"), check=True, capture_output=True)
txt = p2.stdout.decode("utf-8")
sans = []
import re
for m in re.finditer(r"DNS:([^,\s]+)", txt):
sans.append(f"DNS:{m.group(1)}")
# OpenSSL sometimes prints IPs as 'IP:' or 'IP Address: '
for m in re.finditer(r"IP:?\s*Address:?\s*([^,\s]+)", txt, flags=re.IGNORECASE):
sans.append(f"IP:{m.group(1)}")
for m in re.finditer(r"IP:([^,\s]+)", txt):
sans.append(f"IP:{m.group(1)}")
return subject_parts, sans
def create_csr_from_cert(cert_path: str, key_path: str, passphrase: Optional[str] = None) -> str:
"""Create a CSR based on an existing certificate's subject and SANs, signing with `key_path`."""
with open(cert_path, "r", encoding="utf-8") as f:
cert_pem = f.read()
subject, sans = extract_subject_and_sans_from_cert(cert_pem)
# convert subject dict into subject string
subj_str = _build_subject_string(subject) if subject else None
return create_csr_from_key(key_path, subj_str, sans=sans, passphrase=passphrase)
def prompt_for_subject_and_sans() -> (Dict[str, str], List[str]):
"""Interactively prompt the user for subject fields and SANs.
Returns a (subject_dict, san_list) tuple.
"""
fields = ["C", "ST", "L", "O", "OU", "CN"]
subject: Dict[str, str] = {}
from .term import print_info
print_info("Enter subject fields (press Enter to skip a field)")
for f in fields:
try:
val = input(f"{f}: ").strip()
except EOFError:
val = ""
if val:
subject[f] = val
print_info("Enter SANs (one per line). Examples: 'www.example.com' or '10.0.0.1'. Leave blank to finish.")
sans: List[str] = []
while True:
try:
s = input("SAN: ").strip()
except EOFError:
break
if not s:
break
try:
norm = normalize_and_validate_san(s)
except ValueError as e:
from .term import print_error
print_error(f"Invalid SAN: {e}. Try again.")
continue
# If this is a wildcard SAN, ask for explicit confirmation
if norm.startswith("DNS:*."):
ans = input(
f"Wildcard SAN detected: {norm}. Wildcard SANs (e.g., *.example.com) broaden certificate scope and can be risky. Type 'yes' to include it, or anything else to exclude: "
).strip().lower()
if ans != "yes":
from .term import print_info
print_info(f"Excluded {norm}")
continue
else:
from .term import print_info
print_info(f"Included {norm}")
# store the normalized SAN
sans.append(norm)
return subject, sans