-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add an fma intrinsic #8900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
abadams
wants to merge
14
commits into
main
Choose a base branch
from
abadams/strict_fma
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+611
−115
Open
Add an fma intrinsic #8900
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
cdb9e67
Add an fma intrinsic
abadams c622de9
Add fma to python bindings
abadams 9ea8c63
Don't even try for fma on arm 32
abadams 22f6765
Get fma working in C and GPU backends
abadams 2f8558b
move definition of has_builtin
abadams 5ae7b14
Comment fixes
abadams 55e8956
Skip fma test on two legacy platforms
abadams fcfa871
Fix double-rounding bug in double -> (b)float16 casts
abadams 9b23ae6
Share more code between coming from 64 and 32 bits
abadams cbc59b8
Merge remote-tracking branch 'origin/abadams/double_float16_conversio…
abadams 82f24c7
handle float16 fmas
abadams 649ac3e
wasm fix
abadams 7656742
Skip test on webgpu
abadams 6acb53c
Don't check for kandw
abadams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3306,28 +3306,52 @@ void CodeGen_LLVM::visit(const Call *op) { | |
| // Evaluate the args first outside the strict scope, as they may use | ||
| // non-strict operations. | ||
| std::vector<Expr> new_args(op->args.size()); | ||
| std::vector<std::string> to_pop; | ||
| for (size_t i = 0; i < op->args.size(); i++) { | ||
| const Expr &arg = op->args[i]; | ||
| if (arg.as<Variable>() || is_const(arg)) { | ||
| new_args[i] = arg; | ||
| } else { | ||
| std::string name = unique_name('t'); | ||
| sym_push(name, codegen(arg)); | ||
| to_pop.push_back(name); | ||
| new_args[i] = Variable::make(arg.type(), name); | ||
| } | ||
| } | ||
|
|
||
| Expr call = Call::make(op->type, op->name, new_args, op->call_type); | ||
| { | ||
| ScopedValue<bool> old_in_strict_float(in_strict_float, true); | ||
| value = codegen(unstrictify_float(call.as<Call>())); | ||
| if (op->is_intrinsic(Call::strict_fma)) { | ||
| if (op->type.is_float() && op->type.bits() <= 16 && | ||
| upgrade_type_for_arithmetic(op->type) != op->type) { | ||
| // For (b)float16 and below, doing the fma as a | ||
| // double-precision fma is exact and is what llvm does. A | ||
| // double has enough bits of precision such that the add in | ||
| // the fma has no rounding error in the cases where the fma | ||
| // is going to return a finite float16. We do this | ||
| // legalization manually so that we can use our custom | ||
| // vectorizable float16 casts instead of letting llvm call | ||
| // library functions. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate on going to f64 instead of f32 to calculate an f16 fma? |
||
| Type wide_t = Float(64, op->type.lanes()); | ||
| for (Expr &e : new_args) { | ||
| e = cast(wide_t, e); | ||
| } | ||
| Expr equiv = Call::make(wide_t, op->name, new_args, op->call_type); | ||
| equiv = cast(op->type, equiv); | ||
| value = codegen(equiv); | ||
| } else { | ||
| std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); | ||
| value = call_intrin(op->type, op->type.lanes(), name, new_args); | ||
| } | ||
| } else { | ||
| // Lower to something other than a call node | ||
| Expr call = Call::make(op->type, op->name, new_args, op->call_type); | ||
| value = codegen(unstrictify_float(call.as<Call>())); | ||
| } | ||
| } | ||
|
|
||
| for (size_t i = 0; i < op->args.size(); i++) { | ||
| const Expr &arg = op->args[i]; | ||
| if (!arg.as<Variable>() && !is_const(arg)) { | ||
| sym_pop(new_args[i].as<Variable>()->name); | ||
| } | ||
| for (const auto &s : to_pop) { | ||
| sym_pop(s); | ||
| } | ||
|
|
||
| } else if (is_float16_transcendental(op) && !supports_call_as_float16(op)) { | ||
|
|
@@ -4739,23 +4763,29 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, | |
| Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, | ||
| const string &name, vector<Value *> arg_values, | ||
| bool scalable_vector_result, bool is_reduction) { | ||
| auto fix_vector_lanes_of_type = [&](const llvm::Type *t) { | ||
| if (intrin_lanes == 1 || is_reduction) { | ||
| return t->getScalarType(); | ||
| } else { | ||
| if (scalable_vector_result && effective_vscale != 0) { | ||
| return get_vector_type(result_type->getScalarType(), | ||
| intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); | ||
| } else { | ||
| return get_vector_type(result_type->getScalarType(), | ||
| intrin_lanes, VectorTypeConstraint::Fixed); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| llvm::Function *fn = module->getFunction(name); | ||
| if (!fn) { | ||
| vector<llvm::Type *> arg_types(arg_values.size()); | ||
| for (size_t i = 0; i < arg_values.size(); i++) { | ||
| arg_types[i] = arg_values[i]->getType(); | ||
| llvm::Type *t = arg_values[i]->getType(); | ||
| arg_types[i] = fix_vector_lanes_of_type(t); | ||
| } | ||
|
|
||
| llvm::Type *intrinsic_result_type = result_type->getScalarType(); | ||
| if (intrin_lanes > 1 && !is_reduction) { | ||
| if (scalable_vector_result && effective_vscale != 0) { | ||
| intrinsic_result_type = get_vector_type(result_type->getScalarType(), | ||
| intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); | ||
| } else { | ||
| intrinsic_result_type = get_vector_type(result_type->getScalarType(), | ||
| intrin_lanes, VectorTypeConstraint::Fixed); | ||
| } | ||
| } | ||
| llvm::Type *intrinsic_result_type = fix_vector_lanes_of_type(result_type); | ||
| FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); | ||
| fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); | ||
| fn->setCallingConv(CallingConv::C); | ||
|
|
@@ -4790,7 +4820,7 @@ Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes | |
| if (arg_i_lanes >= arg_lanes) { | ||
| // Horizontally reducing intrinsics may have | ||
| // arguments that have more lanes than the | ||
| // result. Assume that the horizontally reduce | ||
| // result. Assume that they horizontally reduce | ||
| // neighboring elements... | ||
| int reduce = arg_i_lanes / arg_lanes; | ||
| args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious: what's the point of casting? It looks like this would make it accept
long double, but actually not respect the required precision (which is hard on SSE fp either way).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was for float16 support. It's not quite right doing it in a wider type though - the rounding on the wider fma might result in a tie when casting back to the narrow type, and that tie may break in a different direction than directly rounding the fma result to the narrow type. Not sure how to handle this. A static assert that T is a double or a float? What should the C backend do if you use a float16 fma call?