Unverified Commit ee450c1e authored by Dzmitry Malyshau's avatar Dzmitry Malyshau Committed by GitHub
Browse files

Fix float-bool casts in MSL, SPV, and HLSL backends (#1459)

parent 3a2f7e61
......@@ -1786,20 +1786,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "{}", op_str)?;
self.write_expr(module, expr, func_ctx)?;
}
Expression::As { expr, kind, .. } => {
Expression::As {
expr,
kind,
convert,
} => {
let inner = func_ctx.info[expr].ty.inner_with(&module.types);
match *inner {
TypeInner::Vector { size, width, .. } => {
write!(
self.out,
"{}{}",
kind.to_hlsl_str(width)?,
back::vector_size_str(size),
)?;
}
TypeInner::Scalar { width, .. } => {
write!(self.out, "{}", kind.to_hlsl_str(width)?)?
}
let (size_str, src_width) = match *inner {
TypeInner::Vector { size, width, .. } => (back::vector_size_str(size), width),
TypeInner::Scalar { width, .. } => ("", width),
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::as {:?}",
......@@ -1807,7 +1802,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
)));
}
};
write!(self.out, "(")?;
let kind_str = kind.to_hlsl_str(convert.unwrap_or(src_width))?;
write!(self.out, "{}{}(", kind_str, size_str,)?;
self.write_expr(module, expr, func_ctx)?;
write!(self.out, ")")?;
}
......
......@@ -165,7 +165,13 @@ impl<'a> Display for TypeContext<'a> {
} else if self.access.contains(crate::StorageAccess::LOAD) {
"read"
} else {
unreachable!("module is not valid")
log::warn!(
"Storage access for {:?} (name '{}'): {:?}",
self.handle,
ty.name.as_deref().unwrap_or_default(),
self.access
);
unreachable!("module is not valid");
};
("texture", "", format.into(), access)
}
......@@ -1223,13 +1229,15 @@ impl<W: Write> Writer<W> {
convert,
} => {
let scalar = scalar_kind_string(kind);
let width = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { width, .. }
| crate::TypeInner::Vector { width, .. } => width,
let (src_kind, src_width) = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector { kind, width, .. } => (kind, width),
_ => return Err(Error::Validation),
};
let is_bool_cast =
kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
let op = match convert {
Some(w) if w == width => "static_cast",
Some(w) if w == src_width || is_bool_cast => "static_cast",
Some(8) if kind == crate::ScalarKind::Float => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}
......
......@@ -232,14 +232,12 @@ impl<'w> BlockContext<'w> {
crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
self.temp_list.clear();
self.temp_list.resize(size as usize, value_id);
let components = [value_id; 4];
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
&self.temp_list,
&components[..size as usize],
));
id
}
......@@ -726,25 +724,26 @@ impl<'w> BlockContext<'w> {
use crate::ScalarKind as Sk;
let expr_id = self.cached[expr];
let (src_kind, src_width) =
let (src_kind, src_size, src_width) =
match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector {
kind,
width,
size: _,
} => (kind, width),
crate::TypeInner::Matrix { width, .. } => (crate::ScalarKind::Float, width),
crate::TypeInner::Scalar { kind, width } => (kind, None, width),
crate::TypeInner::Vector { kind, width, size } => (kind, Some(size), width),
ref other => {
log::error!("As source {:?}", other);
return Err(Error::Validation("Unexpected Expression::As source"));
}
};
let id = self.gen_id();
enum Cast {
Unary(spirv::Op),
Binary(spirv::Op, Word),
Ternary(spirv::Op, Word, Word),
}
let instruction = match (src_kind, kind, convert) {
(_, Sk::Bool, Some(_)) if src_kind != Sk::Bool => {
let cast = match (src_kind, kind, convert) {
(_, _, None) | (Sk::Bool, Sk::Bool, Some(_)) => Cast::Unary(spirv::Op::Bitcast),
// casting to a bool - generate `OpXxxNotEqual`
(_, Sk::Bool, Some(_)) => {
let (op, value) = match src_kind {
Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
......@@ -753,34 +752,102 @@ impl<'w> BlockContext<'w> {
}
Sk::Bool => unreachable!(),
};
let zero_id = self.writer.get_constant_scalar(value, 4);
let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
let zero_id = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind: src_kind,
width: src_width,
pointer_class: None,
}));
let components = [zero_scalar_id; 4];
let zero_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
zero_id,
&components[..size as usize],
));
zero_id
}
None => zero_scalar_id,
};
Instruction::binary(op, result_type_id, id, expr_id, zero_id)
Cast::Binary(op, zero_id)
}
_ => {
let op = match (src_kind, kind, convert) {
(_, _, None) => spirv::Op::Bitcast,
(Sk::Float, Sk::Uint, Some(_)) => spirv::Op::ConvertFToU,
(Sk::Float, Sk::Sint, Some(_)) => spirv::Op::ConvertFToS,
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
spirv::Op::FConvert
}
(Sk::Sint, Sk::Float, Some(_)) => spirv::Op::ConvertSToF,
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
spirv::Op::SConvert
}
(Sk::Uint, Sk::Float, Some(_)) => spirv::Op::ConvertUToF,
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
spirv::Op::UConvert
// casting from a bool - generate `OpSelect`
(Sk::Bool, _, Some(dst_width)) => {
let (val0, val1) = match kind {
Sk::Sint => (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)),
Sk::Uint => (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)),
Sk::Float => (
crate::ScalarValue::Float(0.0),
crate::ScalarValue::Float(1.0),
),
Sk::Bool => unreachable!(),
};
let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
let (accept_id, reject_id) = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind,
width: dst_width,
pointer_class: None,
}));
let components0 = [scalar0_id; 4];
let components1 = [scalar1_id; 4];
let vec0_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec0_id,
&components0[..size as usize],
));
let vec1_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec1_id,
&components1[..size as usize],
));
(vec1_id, vec0_id)
}
// We assume it's either an identity cast, or int-uint.
_ => spirv::Op::Bitcast,
None => (scalar1_id, scalar0_id),
};
Instruction::unary(op, result_type_id, id, expr_id)
Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
}
(Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
(Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::FConvert)
}
(Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::SConvert)
}
(Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::UConvert)
}
// We assume it's either an identity cast, or int-uint.
_ => Cast::Unary(spirv::Op::Bitcast),
};
let id = self.gen_id();
let instruction = match cast {
Cast::Unary(op) => Instruction::unary(op, result_type_id, id, expr_id),
Cast::Binary(op, operand) => {
Instruction::binary(op, result_type_id, id, expr_id, operand)
}
Cast::Ternary(op, op1, op2) => {
Instruction::ternary(op, result_type_id, id, expr_id, op1, op2)
}
};
block.body.push(instruction);
id
}
......
......@@ -33,6 +33,11 @@ fn unary() -> i32 {
if (!true) { return a; } else { return ~a; };
}
fn bool_cast(x: vec3<f32>) -> vec3<f32> {
let y = vec3<bool>(x);
return vec3<f32>(y);
}
struct Foo {
a: vec4<f32>;
b: i32;
......@@ -57,6 +62,7 @@ fn main() {
let a = builtins();
let b = splat();
let c = unary();
let d = constructors();
let d = bool_cast(v_f32_one.xyz);
let e = constructors();
modulo();
}
......@@ -37,6 +37,11 @@ int unary() {
}
}
vec3 bool_cast(vec3 x) {
bvec3 y = bvec3(x);
return vec3(y);
}
float constructors() {
Foo foo;
foo = Foo(vec4(1.0), 1);
......@@ -55,7 +60,8 @@ void main() {
vec4 _e4 = builtins();
vec4 _e5 = splat();
int _e6 = unary();
float _e7 = constructors();
vec3 _e8 = bool_cast(vec4(1.0, 1.0, 1.0, 1.0).xyz);
float _e9 = constructors();
modulo();
return;
}
......
......@@ -37,6 +37,12 @@ int unary()
}
}
float3 bool_cast(float3 x)
{
bool3 y = bool3(x);
return float3(y);
}
Foo ConstructFoo(float4 arg0, int arg1) {
Foo ret;
ret.a = arg0;
......@@ -67,7 +73,8 @@ void main()
const float4 _e4 = builtins();
const float4 _e5 = splat();
const int _e6 = unary();
const float _e7 = constructors();
const float3 _e8 = bool_cast(float4(1.0, 1.0, 1.0, 1.0).xyz);
const float _e9 = constructors();
modulo();
return;
}
......@@ -40,6 +40,13 @@ int unary(
}
}
metal::float3 bool_cast(
metal::float3 x
) {
metal::bool3 y = static_cast<metal::bool3>(x);
return static_cast<metal::float3>(y);
}
float constructors(
) {
Foo foo;
......@@ -61,7 +68,8 @@ kernel void main1(
metal::float4 _e4 = builtins();
metal::float4 _e5 = splat();
int _e6 = unary();
float _e7 = constructors();
metal::float3 _e8 = bool_cast(v_f32_one.xyz);
float _e9 = constructors();
modulo();
return;
}
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 115
; Bound: 128
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %108 "main"
OpExecutionMode %108 LocalSize 1 1 1
OpMemberDecorate %22 0 Offset 0
OpMemberDecorate %22 1 Offset 16
OpEntryPoint GLCompute %119 "main"
OpExecutionMode %119 LocalSize 1 1 1
OpMemberDecorate %23 0 Offset 0
OpMemberDecorate %23 1 Offset 16
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 1.0
......@@ -29,119 +29,135 @@ OpMemberDecorate %22 1 Offset 16
%19 = OpTypeVector %4 4
%20 = OpTypeVector %8 4
%21 = OpTypeVector %10 4
%22 = OpTypeStruct %19 %8
%23 = OpConstantComposite %19 %3 %3 %3 %3
%24 = OpConstantComposite %19 %5 %5 %5 %5
%25 = OpConstantComposite %19 %6 %6 %6 %6
%26 = OpConstantComposite %20 %7 %7 %7 %7
%29 = OpTypeFunction %19
%55 = OpTypeVector %4 2
%71 = OpTypeFunction %8
%78 = OpConstantNull %8
%80 = OpTypePointer Function %22
%83 = OpTypeFunction %4
%87 = OpTypePointer Function %19
%88 = OpTypePointer Function %4
%90 = OpTypeInt 32 0
%89 = OpConstant %90 0
%95 = OpTypeFunction %2
%99 = OpTypeVector %8 3
%103 = OpTypeVector %4 3
%28 = OpFunction %19 None %29
%27 = OpLabel
OpBranch %30
%30 = OpLabel
%31 = OpSelect %8 %9 %7 %11
%33 = OpCompositeConstruct %21 %9 %9 %9 %9
%32 = OpSelect %19 %33 %23 %24
%34 = OpCompositeConstruct %21 %12 %12 %12 %12
%35 = OpSelect %19 %34 %24 %23
%36 = OpExtInst %19 %1 FMix %24 %23 %25
%38 = OpCompositeConstruct %19 %13 %13 %13 %13
%37 = OpExtInst %19 %1 FMix %24 %23 %38
%39 = OpCompositeExtract %8 %26 0
%40 = OpBitcast %4 %39
%41 = OpBitcast %19 %26
%42 = OpConvertFToS %20 %24
%43 = OpCompositeConstruct %20 %31 %31 %31 %31
%44 = OpIAdd %20 %43 %42
%45 = OpConvertSToF %19 %44
%46 = OpFAdd %19 %45 %32
%47 = OpFAdd %19 %46 %36
%22 = OpTypeVector %4 3
%23 = OpTypeStruct %19 %8
%24 = OpConstantComposite %19 %3 %3 %3 %3
%25 = OpConstantComposite %19 %5 %5 %5 %5
%26 = OpConstantComposite %19 %6 %6 %6 %6
%27 = OpConstantComposite %20 %7 %7 %7 %7
%30 = OpTypeFunction %19
%56 = OpTypeVector %4 2
%72 = OpTypeFunction %8
%79 = OpConstantNull %8
%83 = OpTypeFunction %22 %22
%85 = OpTypeVector %10 3
%92 = OpTypePointer Function %23
%95 = OpTypeFunction %4
%99 = OpTypePointer Function %19
%100 = OpTypePointer Function %4
%102 = OpTypeInt 32 0
%101 = OpConstant %102 0
%107 = OpTypeFunction %2
%111 = OpTypeVector %8 3
%29 = OpFunction %19 None %30
%28 = OpLabel
OpBranch %31
%31 = OpLabel
%32 = OpSelect %8 %9 %7 %11
%34 = OpCompositeConstruct %21 %9 %9 %9 %9
%33 = OpSelect %19 %34 %24 %25
%35 = OpCompositeConstruct %21 %12 %12 %12 %12
%36 = OpSelect %19 %35 %25 %24
%37 = OpExtInst %19 %1 FMix %25 %24 %26
%39 = OpCompositeConstruct %19 %13 %13 %13 %13
%38 = OpExtInst %19 %1 FMix %25 %24 %39
%40 = OpCompositeExtract %8 %27 0
%41 = OpBitcast %4 %40
%42 = OpBitcast %19 %27
%43 = OpConvertFToS %20 %25
%44 = OpCompositeConstruct %20 %32 %32 %32 %32
%45 = OpIAdd %20 %44 %43
%46 = OpConvertSToF %19 %45
%47 = OpFAdd %19 %46 %33
%48 = OpFAdd %19 %47 %37
%49 = OpCompositeConstruct %19 %40 %40 %40 %40
%50 = OpFAdd %19 %48 %49
%51 = OpFAdd %19 %50 %41
OpReturnValue %51
%49 = OpFAdd %19 %48 %38
%50 = OpCompositeConstruct %19 %41 %41 %41 %41
%51 = OpFAdd %19 %49 %50
%52 = OpFAdd %19 %51 %42
OpReturnValue %52
OpFunctionEnd
%53 = OpFunction %19 None %29
%52 = OpLabel
OpBranch %54
%54 = OpLabel
%56 = OpCompositeConstruct %55 %14 %14
%57 = OpCompositeConstruct %55 %3 %3
%58 = OpFAdd %55 %57 %56
%59 = OpCompositeConstruct %55 %15 %15
%60 = OpFSub %55 %58 %59
%61 = OpCompositeConstruct %55 %16 %16
%62 = OpFDiv %55 %60 %61
%63 = OpCompositeConstruct %20 %17 %17 %17 %17
%64 = OpCompositeConstruct %20 %18 %18 %18 %18
%65 = OpSMod %20 %63 %64
%66 = OpVectorShuffle %19 %62 %62 0 1 0 1
%67 = OpConvertSToF %19 %65
%68 = OpFAdd %19 %66 %67
OpReturnValue %68
%54 = OpFunction %19 None %30
%53 = OpLabel
OpBranch %55
%55 = OpLabel
%57 = OpCompositeConstruct %56 %14 %14
%58 = OpCompositeConstruct %56 %3 %3
%59 = OpFAdd %56 %58 %57
%60 = OpCompositeConstruct %56 %15 %15
%61 = OpFSub %56 %59 %60
%62 = OpCompositeConstruct %56 %16 %16
%63 = OpFDiv %56 %61 %62
%64 = OpCompositeConstruct %20 %17 %17 %17 %17
%65 = OpCompositeConstruct %20 %18 %18 %18 %18
%66 = OpSMod %20 %64 %65
%67 = OpVectorShuffle %19 %63 %63 0 1 0 1
%68 = OpConvertSToF %19 %66
%69 = OpFAdd %19 %67 %68
OpReturnValue %69
OpFunctionEnd
%70 = OpFunction %8 None %71
%69 = OpLabel
OpBranch %72
%72 = OpLabel
%73 = OpLogicalNot %10 %9
OpSelectionMerge %74 None
OpBranchConditional %73 %75 %76
%75 = OpLabel
OpReturnValue %7
%71 = OpFunction %8 None %72
%70 = OpLabel
OpBranch %73
%73 = OpLabel
%74 = OpLogicalNot %10 %9
OpSelectionMerge %75 None
OpBranchConditional %74 %76 %77
%76 = OpLabel
%77 = OpNot %8 %7
OpReturnValue %77
%74 = OpLabel
OpReturnValue %7
%77 = OpLabel
%78 = OpNot %8 %7
OpReturnValue %78
%75 = OpLabel
OpReturnValue %79
OpFunctionEnd
%82 = OpFunction %4 None %83
%81 = OpLabel
%79 = OpVariable %80 Function
%82 = OpFunction %22 None %83
%81 = OpFunctionParameter %22
%80 = OpLabel
OpBranch %84
%84 = OpLabel
%85 = OpCompositeConstruct %19 %3 %3 %3 %3
%86 = OpCompositeConstruct %22 %85 %7
OpStore %79 %86
%91 = OpAccessChain %88 %79 %89 %89
%92 = OpLoad %4 %91
OpReturnValue %92
%86 = OpCompositeConstruct %22 %5 %5 %5
%87 = OpFUnordNotEqual %85 %81 %86
%88 = OpCompositeConstruct %22 %5 %5 %5
%89 = OpCompositeConstruct %22 %3 %3 %3
%90 = OpSelect %22 %87 %89 %88
OpReturnValue %90
OpFunctionEnd
%94 = OpFunction %2 None %95
%94 = OpFunction %4 None %95
%93 = OpLabel
%91 = OpVariable %92 Function
OpBranch %96
%96 = OpLabel
%97 = OpSMod %8 %7 %7
%98 = OpFMod %4 %3 %3
%100 = OpCompositeConstruct %99 %7 %7 %7
%101 = OpCompositeConstruct %99 %7 %7 %7
%102 = OpSMod %99 %100 %101
%104 = OpCompositeConstruct %103 %3 %3 %3
%105 = OpCompositeConstruct %103 %3 %3 %3
%106 = OpFMod %103 %104 %105
%97 = OpCompositeConstruct %19 %3 %3 %3 %3
%98 = OpCompositeConstruct %23 %97 %7
OpStore %91 %98
%103 = OpAccessChain %100 %91 %101 %101
%104 = OpLoad %4 %103
OpReturnValue %104
OpFunctionEnd
%106 = OpFunction %2 None %107
%105 = OpLabel
OpBranch %108
%108 = OpLabel
%109 = OpSMod %8 %7 %7
%110 = OpFMod %4 %3 %3
%112 = OpCompositeConstruct %111 %7 %7 %7
%113 = OpCompositeConstruct %111 %7 %7 %7
%114 = OpSMod %111 %112 %113
%115 = OpCompositeConstruct %22 %3 %3 %3
%116 = OpCompositeConstruct %22 %3 %3 %3
%117 = OpFMod %22 %115 %116
OpReturn
OpFunctionEnd
%108 = OpFunction %2 None %95
%107 = OpLabel
OpBranch %109
%109 = OpLabel
%110 = OpFunctionCall %19 %28
%111 = OpFunctionCall %19 %53
%112 = OpFunctionCall %8 %70
%113 = OpFunctionCall %4 %82
%114 = OpFunctionCall %2 %94
%119 = OpFunction %2 None %107
%118 = OpLabel
OpBranch %120
%120 = OpLabel
%121 = OpFunctionCall %19 %29
%122 = OpFunctionCall %19 %54
%123 = OpFunctionCall %8 %71
%124 = OpVectorShuffle %22 %24 %24 0 1 2
%125 = OpFunctionCall %22 %82 %124
%126 = OpFunctionCall %4 %94
%127 = OpFunctionCall %2 %106
OpReturn
OpFunctionEnd
\ No newline at end of file
......@@ -33,6 +33,11 @@ fn unary() -> i32 {
}
}
fn bool_cast(x: vec3<f32>) -> vec3<f32> {
let y: vec3<bool> = vec3<bool>(x);
return vec3<f32>(y);
}
fn constructors() -> f32 {
var foo: Foo;
......@@ -53,7 +58,8 @@ fn main() {
let e4: vec4<f32> = builtins();