@@ -15,11 +15,15 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[mypy.plugin.Fun
1515 return _attribute_instantiation_hook
1616 return None
1717
18- def get_attribute_hook (self , fullname : str
19- ) -> Optional [Callable [[mypy .plugin .AttributeContext ], mypy .types .Type ]]:
20- sym = self .lookup_fully_qualified (fullname )
21- if sym and sym .type and _is_attribute_marked_nullable (sym .type ):
22- return lambda ctx : mypy .types .UnionType ([ctx .default_attr_type , mypy .types .NoneType ()])
18+ def get_method_signature_hook (self , fullname : str
19+ ) -> Optional [Callable [[mypy .plugin .MethodSigContext ], mypy .types .CallableType ]]:
20+ class_name , method_name = fullname .rsplit ('.' , 1 )
21+ sym = self .lookup_fully_qualified (class_name )
22+ if sym and _is_attribute_type_node (sym .node ):
23+ if method_name == '__get__' :
24+ return _get_method_sig_hook
25+ elif method_name == '__set__' :
26+ return _set_method_sig_hook
2327 return None
2428
2529
@@ -48,6 +52,48 @@ def _get_bool_literal(n: mypy.nodes.Node) -> Optional[bool]:
4852 }.get (n .fullname or '' ) if isinstance (n , mypy .nodes .NameExpr ) else None
4953
5054
55+ def _make_optional (t : mypy .types .Type ) -> mypy .types .UnionType :
56+ return mypy .types .UnionType ([t , mypy .types .NoneType ()])
57+
58+
59+ def _unwrap_optional (t : mypy .types .Type ) -> mypy .types .Type :
60+ if not isinstance (t , mypy .types .UnionType ):
61+ return t
62+ t = mypy .types .UnionType ([item for item in t .items if not isinstance (item , mypy .types .NoneType )])
63+ if len (t .items ) == 0 :
64+ return mypy .types .NoneType ()
65+ elif len (t .items ) == 1 :
66+ return t .items [0 ]
67+ else :
68+ return t
69+
70+
71+ def _get_method_sig_hook (ctx : mypy .plugin .MethodSigContext ) -> mypy .types .CallableType :
72+ sig = ctx .default_signature
73+ if not _is_attribute_marked_nullable (ctx .type ):
74+ return sig
75+ try :
76+ (instance_type , owner_type ) = sig .arg_types
77+ except ValueError :
78+ return sig
79+ if not isinstance (instance_type , mypy .types .AnyType ): # instance attribute access
80+ return sig
81+ return sig .copy_modified (ret_type = _make_optional (sig .ret_type ))
82+
83+
84+ def _set_method_sig_hook (ctx : mypy .plugin .MethodSigContext ) -> mypy .types .CallableType :
85+ sig = ctx .default_signature
86+ if _is_attribute_marked_nullable (ctx .type ):
87+ return sig
88+ try :
89+ (instance_type , value_type ) = sig .arg_types
90+ except ValueError :
91+ return sig
92+ if not isinstance (instance_type , mypy .types .AnyType ): # instance attribute access
93+ return sig
94+ return sig .copy_modified (arg_types = [instance_type , _unwrap_optional (value_type )])
95+
96+
5197def _attribute_instantiation_hook (ctx : mypy .plugin .FunctionContext ) -> mypy .types .Type :
5298 """
5399 Handles attribute instantiation, e.g. MyAttribute(null=True)
0 commit comments