Skip to content

Commit 8ca469c

Browse files
committed
fix #846
1 parent 80f7182 commit 8ca469c

File tree

3 files changed

+124
-97
lines changed

3 files changed

+124
-97
lines changed

crates/emmylua_code_analysis/src/compilation/test/generic_test.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,4 +631,33 @@ mod test {
631631
"#,
632632
));
633633
}
634+
635+
#[test]
636+
fn test_issue_846() {
637+
let mut ws = VirtualWorkspace::new();
638+
639+
ws.def(
640+
r#"
641+
---@alias Parameters<T extends function> T extends (fun(...: infer P): any) and P or never
642+
643+
---@param x number
644+
---@param y number
645+
---@return number
646+
function pow(x, y) end
647+
648+
---@generic F
649+
---@param f F
650+
---@return Parameters<F>
651+
function return_params(f) end
652+
"#,
653+
);
654+
assert!(ws.check_code_for(
655+
DiagnosticCode::ParamTypeMismatch,
656+
r#"
657+
result = return_params(pow)
658+
"#,
659+
));
660+
let result_ty = ws.expr_ty("result");
661+
assert_eq!(ws.humanize_type(result_ty), "(number,number)");
662+
}
634663
}

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ pub fn instantiate_func_generic(
7676
if let Some(type_list) = call_expr.get_call_generic_type_list() {
7777
apply_call_generic_type_list(db, file_id, &mut context, &type_list);
7878
} else {
79+
// 没有指定泛型, 从调用参数中推断
7980
infer_generic_types_from_call(
8081
db,
8182
&mut context,

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs

Lines changed: 94 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,6 @@ pub fn instantiate_generic(
283283
let mut new_params = Vec::new();
284284
for param in generic_params {
285285
let new_param = instantiate_type_generic(db, param, substitutor);
286-
// if let LuaType::Variadic(variadic) = &new_param {
287-
// match variadic.deref() {
288-
// VariadicType::Base(_) => {}
289-
// VariadicType::Multi(types) => {
290-
// for typ in types {
291-
// new_params.push(typ.clone());
292-
// }
293-
// continue;
294-
// }
295-
// }
296-
// }
297286
new_params.push(new_param);
298287
}
299288

@@ -332,29 +321,12 @@ fn instantiate_table_generic(
332321
}
333322

334323
fn instantiate_tpl_ref(_: &DbIndex, tpl: &GenericTpl, substitutor: &TypeSubstitutor) -> LuaType {
335-
// if tpl.is_variadic() {
336-
// if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) {
337-
// match generics.len() {
338-
// 1 => return generics[0].clone(),
339-
// _ => {
340-
// return LuaType::Variadic(VariadicType::Multi(generics.clone()).into());
341-
// // return LuaType::Tuple(
342-
// // LuaTupleType::new(generics.clone(), LuaTupleStatus::DocResolve).into(),
343-
// // );
344-
// }
345-
// }
346-
// } else {
347-
// return LuaType::Never;
348-
// }
349-
// }
350-
351324
if let Some(value) = substitutor.get(tpl.get_tpl_id()) {
352325
match value {
353326
SubstitutorValue::None => {}
354327
SubstitutorValue::Type(ty) => return ty.clone(),
355328
SubstitutorValue::MultiTypes(types) => {
356329
return LuaType::Variadic(VariadicType::Multi(types.clone()).into());
357-
// return types.first().unwrap_or(&LuaType::Unknown).clone();
358330
}
359331
SubstitutorValue::Params(params) => {
360332
return params
@@ -615,91 +587,116 @@ fn collect_infer_assignments(
615587
}
616588
}
617589
LuaType::DocFunction(pattern_func) => {
618-
if let LuaType::DocFunction(source_func) = source {
619-
// 匹配函数参数
620-
let pattern_params = pattern_func.get_params();
621-
let source_params = source_func.get_params();
622-
let has_variadic = pattern_params.last().is_some_and(|(name, ty)| {
623-
name == "..." || ty.as_ref().is_some_and(|ty| ty.is_variadic())
624-
});
625-
let normal_param_len = if has_variadic {
626-
pattern_params.len().saturating_sub(1)
627-
} else {
628-
pattern_params.len()
629-
};
590+
match source {
591+
LuaType::DocFunction(source_func) => {
592+
// 匹配函数参数
593+
let pattern_params = pattern_func.get_params();
594+
let source_params = source_func.get_params();
595+
let has_variadic = pattern_params.last().is_some_and(|(name, ty)| {
596+
name == "..." || ty.as_ref().is_some_and(|ty| ty.is_variadic())
597+
});
598+
let normal_param_len = if has_variadic {
599+
pattern_params.len().saturating_sub(1)
600+
} else {
601+
pattern_params.len()
602+
};
630603

631-
if !has_variadic && source_params.len() > normal_param_len {
632-
return false;
633-
}
604+
if !has_variadic && source_params.len() > normal_param_len {
605+
return false;
606+
}
634607

635-
for (i, (_, pattern_param)) in
636-
pattern_params.iter().take(normal_param_len).enumerate()
637-
{
638-
if let Some((_, source_param)) = source_params.get(i) {
639-
match (source_param, pattern_param) {
640-
(Some(source_ty), Some(pattern_ty)) => {
608+
for (i, (_, pattern_param)) in
609+
pattern_params.iter().take(normal_param_len).enumerate()
610+
{
611+
if let Some((_, source_param)) = source_params.get(i) {
612+
match (source_param, pattern_param) {
613+
(Some(source_ty), Some(pattern_ty)) => {
614+
if !collect_infer_assignments(
615+
db,
616+
source_ty,
617+
pattern_ty,
618+
assignments,
619+
) {
620+
return false;
621+
}
622+
}
623+
(Some(_), None) => continue,
624+
(None, Some(pattern_ty)) => {
625+
if contains_conditional_infer(pattern_ty) {
626+
return false;
627+
}
628+
}
629+
(None, None) => continue,
630+
}
631+
} else if let Some(pattern_ty) = pattern_param {
632+
if contains_conditional_infer(pattern_ty)
633+
|| !is_optional_param_type(db, pattern_ty)
634+
{
635+
return false;
636+
}
637+
}
638+
}
639+
640+
if has_variadic && let Some((_, variadic_param)) = pattern_params.last() {
641+
if let Some(pattern_ty) = variadic_param {
642+
if contains_conditional_infer(pattern_ty) {
643+
let rest = if normal_param_len < source_params.len() {
644+
&source_params[normal_param_len..]
645+
} else {
646+
&[]
647+
};
648+
let mut rest_types = Vec::with_capacity(rest.len());
649+
for (_, source_param) in rest {
650+
let Some(source_ty) = source_param.as_ref() else {
651+
return false;
652+
};
653+
rest_types.push(source_ty.clone());
654+
}
655+
656+
let tuple_ty = LuaType::Tuple(
657+
LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve)
658+
.into(),
659+
);
641660
if !collect_infer_assignments(
642661
db,
643-
source_ty,
662+
&tuple_ty,
644663
pattern_ty,
645664
assignments,
646665
) {
647666
return false;
648667
}
649668
}
650-
(Some(_), None) => continue,
651-
(None, Some(pattern_ty)) => {
652-
if contains_conditional_infer(pattern_ty) {
653-
return false;
654-
}
655-
}
656-
(None, None) => continue,
657-
}
658-
} else if let Some(pattern_ty) = pattern_param {
659-
if contains_conditional_infer(pattern_ty)
660-
|| !is_optional_param_type(db, pattern_ty)
661-
{
662-
return false;
663669
}
664670
}
665-
}
666-
667-
if has_variadic && let Some((_, variadic_param)) = pattern_params.last() {
668-
if let Some(pattern_ty) = variadic_param {
669-
if contains_conditional_infer(pattern_ty) {
670-
let rest = if normal_param_len < source_params.len() {
671-
&source_params[normal_param_len..]
672-
} else {
673-
&[]
674-
};
675-
let mut rest_types = Vec::with_capacity(rest.len());
676-
for (_, source_param) in rest {
677-
let Some(source_ty) = source_param.as_ref() else {
678-
return false;
679-
};
680-
rest_types.push(source_ty.clone());
681-
}
682671

683-
let tuple_ty = LuaType::Tuple(
684-
LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve).into(),
685-
);
686-
if !collect_infer_assignments(db, &tuple_ty, pattern_ty, assignments) {
687-
return false;
688-
}
689-
}
672+
// 匹配函数返回值
673+
let pattern_ret = pattern_func.get_ret();
674+
if contains_conditional_infer(pattern_ret) {
675+
// 如果返回值也包含 infer, 继续与来源返回值进行匹配
676+
collect_infer_assignments(
677+
db,
678+
source_func.get_ret(),
679+
pattern_ret,
680+
assignments,
681+
)
682+
} else {
683+
true
690684
}
691685
}
692-
693-
// 匹配函数返回值
694-
let pattern_ret = pattern_func.get_ret();
695-
if contains_conditional_infer(pattern_ret) {
696-
// 如果返回值也包含 infer, 继续与来源返回值进行匹配
697-
collect_infer_assignments(db, source_func.get_ret(), pattern_ret, assignments)
698-
} else {
699-
true
686+
LuaType::Signature(id) => {
687+
if let Some(signature) = db.get_signature_index().get(id) {
688+
let source_func = signature.to_doc_func_type();
689+
collect_infer_assignments(
690+
db,
691+
&LuaType::DocFunction(source_func),
692+
pattern,
693+
assignments,
694+
)
695+
} else {
696+
false
697+
}
700698
}
701-
} else {
702-
false
699+
_ => false,
703700
}
704701
}
705702
LuaType::Array(array) => {

0 commit comments

Comments
 (0)