mirror of
				https://gitlab.com/ytdl-org/youtube-dl.git
				synced 2025-11-04 03:07:07 -05:00 
			
		
		
		
	[utils] Make JSON file writes atomic (Fixes #3549)
This commit is contained in:
		@@ -24,6 +24,7 @@ import socket
 | 
			
		||||
import struct
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
import tempfile
 | 
			
		||||
import traceback
 | 
			
		||||
import xml.etree.ElementTree
 | 
			
		||||
import zlib
 | 
			
		||||
@@ -228,18 +229,36 @@ else:
 | 
			
		||||
        assert type(s) == type(u'')
 | 
			
		||||
        print(s)
 | 
			
		||||
 | 
			
		||||
# In Python 2.x, json.dump expects a bytestream.
 | 
			
		||||
# In Python 3.x, it writes to a character stream
 | 
			
		||||
if sys.version_info < (3,0):
 | 
			
		||||
    def write_json_file(obj, fn):
 | 
			
		||||
        with open(fn, 'wb') as f:
 | 
			
		||||
            json.dump(obj, f)
 | 
			
		||||
else:
 | 
			
		||||
    def write_json_file(obj, fn):
 | 
			
		||||
        with open(fn, 'w', encoding='utf-8') as f:
 | 
			
		||||
            json.dump(obj, f)
 | 
			
		||||
 | 
			
		||||
if sys.version_info >= (2,7):
 | 
			
		||||
def write_json_file(obj, fn):
 | 
			
		||||
    """ Encode obj as JSON and write it to fn, atomically """
 | 
			
		||||
 | 
			
		||||
    # In Python 2.x, json.dump expects a bytestream.
 | 
			
		||||
    # In Python 3.x, it writes to a character stream
 | 
			
		||||
    if sys.version_info < (3, 0):
 | 
			
		||||
        mode = 'wb'
 | 
			
		||||
        encoding = None
 | 
			
		||||
    else:
 | 
			
		||||
        mode = 'w'
 | 
			
		||||
        encoding = 'utf-8'
 | 
			
		||||
    tf = tempfile.NamedTemporaryFile(
 | 
			
		||||
        suffix='.tmp', prefix=os.path.basename(fn) + '.',
 | 
			
		||||
        dir=os.path.dirname(fn),
 | 
			
		||||
        delete=False)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        with tf:
 | 
			
		||||
            json.dump(obj, tf)
 | 
			
		||||
        os.rename(tf.name, fn)
 | 
			
		||||
    except:
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(tf.name)
 | 
			
		||||
        except OSError:
 | 
			
		||||
            pass
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if sys.version_info >= (2, 7):
 | 
			
		||||
    def find_xpath_attr(node, xpath, key, val):
 | 
			
		||||
        """ Find the xpath xpath[@key=val] """
 | 
			
		||||
        assert re.match(r'^[a-zA-Z-]+$', key)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user