Skip to content

Instantly share code, notes, and snippets.

@pypt
Created September 9, 2015 22:10
Show Gist options
  • Select an option

  • Save pypt/94d747fe5180851196eb to your computer and use it in GitHub Desktop.

Select an option

Save pypt/94d747fe5180851196eb to your computer and use it in GitHub Desktop.
PyYAML: raise exception on duplicate keys on the same document hierarchy level
import yaml
from yaml.constructor import ConstructorError
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
def no_duplicates_constructor(loader, node, deep=False):
"""Check for duplicate keys."""
mapping = {}
for key_node, value_node in node.value:
key = loader.construct_object(key_node, deep=deep)
value = loader.construct_object(value_node, deep=deep)
if key in mapping:
raise ConstructorError("while constructing a mapping", node.start_mark,
"found duplicate key (%s)" % key, key_node.start_mark)
mapping[key] = value
return loader.construct_mapping(node, deep)
yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, no_duplicates_constructor)
# Works fine (no duplicate keys)
yaml_data = yaml.load('''
---
foo: bar
baz: qux
'''
)
# Works fine (no duplicate keys on the same level)
yaml_data = yaml.load('''
---
foo:
bar: baz
baz: qux
bar:
bar: baz
baz: qux
'''
)
# Raises exception (has duplicate keys)
yaml_data = yaml.load('''
---
foo: bar
foo: qux
'''
)
@RHammond2
Copy link

based on pbsds,we can deal merge keys like this

class UniqueKeyLoader(yaml.SafeLoader):
    def construct_mapping(self, node, deep=False):
        mapping = set()
        for key_node, value_node in node.value:
            if ':merge' in key_node.tag:
                continue 
            key = self.construct_object(key_node, deep=deep)
            if key in mapping:
                raise ValueError(f"Duplicate {key!r} key found in YAML.")
            mapping.add(key)
        return super().construct_mapping(node, deep)

# other code

yaml_dic=yaml.load(yaml_file,Loader=UniqueKeyLoader)

This whole thread plus some other pieces I picked up on Stack Overflow (https://stackoverflow.com/questions/13319067/parsing-yaml-return-with-line-number) have been amazing! I've ended up creating something for another project (removed most of the project's specifics) that adds in line numbers to yell at users constructively with info on where their input files are breaking. Hopefully someone else can pick this up and add more.

from copy import deepcopy
from pathlib import Path

import yaml
from yaml.composer import Composer
from yaml.nodes import ScalarNode
from yaml.resolver import BaseResolver

class DuplicateKeyError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)


class Loader(yaml.SafeLoader):
    def __init__(self, stream):
        self._root = get_path(Path(stream.name).parent)

        super().__init__(stream)

    def include(self, node):
        filename = find_file(node.value, self._root)

        with Path.open(filename) as f:
            return yaml.load(f, Loader=self.__class__)
    
    def compose_node(self, parent, index):
        line = self.line
        node = Composer.compose_node(self, parent, index)
        node.__line__ = line + 1
        return node

    def construct_mapping(self, node, deep=False):
        numbered_node = deepcopy(node)
        numbered_nodes = []
        for key_node, _ in numbered_node.value:
            shadow_key_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value='__line__' + key_node.value)
            shadow_value_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value=key_node.__line__)
            numbered_nodes.append((shadow_key_node, shadow_value_node))

        numbered_node.value += numbered_nodes
        return self.check_duplicate_keys(numbered_node, node, deep)

    def check_duplicate_keys(self, numbered_node, node, deep=False):
        mapping = set()
        for key_node, _ in numbered_node.value:
            if ":merge" in key_node.tag:
                continue
            key = self.construct_object(key_node, deep=deep)
            if key in mapping:
                raise DuplicateKeyError(f"Duplicate '{key}' key found at line {key_node.__line__}.")
            mapping.add(key)

        return super().construct_mapping(node, deep)

def load_yaml(filename, loader=Loader) -> dict:
    if isinstance(filename, dict):
        return filename
    with Path.open(filename) as fid:
        try:
            return yaml.load(fid, loader)
        except DuplicateKeyError as e:
            raise ValueError(f"Duplicate key found in {filename}.") from e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment