8325520: Vector loads and stores with indices and masks incorrectly compiled

Backport-of: 0c934ff4e2
This commit is contained in:
Martin Doerr
2024-06-24 09:36:54 +00:00
committed by Vitaly Provodin
parent 825acb0dc4
commit e9ca2a372a
5 changed files with 1465 additions and 2 deletions

View File

@@ -2758,7 +2758,10 @@ Node* StoreNode::Identity(PhaseGVN* phase) {
val->in(MemNode::Address)->eqv_uncast(adr) &&
val->in(MemNode::Memory )->eqv_uncast(mem) &&
val->as_Load()->store_Opcode() == Opcode()) {
result = mem;
// Ensure vector type is the same
if (!is_StoreVector() || as_StoreVector()->vect_type() == mem->as_LoadVector()->vect_type()) {
result = mem;
}
}
// Two stores in a row of the same value?
@@ -2767,7 +2770,24 @@ Node* StoreNode::Identity(PhaseGVN* phase) {
mem->in(MemNode::Address)->eqv_uncast(adr) &&
mem->in(MemNode::ValueIn)->eqv_uncast(val) &&
mem->Opcode() == Opcode()) {
result = mem;
if (!is_StoreVector()) {
result = mem;
} else {
const StoreVectorNode* store_vector = as_StoreVector();
const StoreVectorNode* mem_vector = mem->as_StoreVector();
const Node* store_indices = store_vector->indices();
const Node* mem_indices = mem_vector->indices();
const Node* store_mask = store_vector->mask();
const Node* mem_mask = mem_vector->mask();
// Ensure types, indices, and masks match
if (store_vector->vect_type() == mem_vector->vect_type() &&
((store_indices == nullptr) == (mem_indices == nullptr) &&
(store_indices == nullptr || store_indices->eqv_uncast(mem_indices))) &&
((store_mask == nullptr) == (mem_mask == nullptr) &&
(store_mask == nullptr || store_mask->eqv_uncast(mem_mask)))) {
result = mem;
}
}
}
// Store of zero anywhere into a freshly-allocated object?

View File

@@ -172,8 +172,10 @@ class LoadVectorNode;
class LoadVectorMaskedNode;
class StoreVectorMaskedNode;
class LoadVectorGatherNode;
class LoadVectorGatherMaskedNode;
class StoreVectorNode;
class StoreVectorScatterNode;
class StoreVectorScatterMaskedNode;
class VectorMaskCmpNode;
class VectorUnboxNode;
class VectorSet;
@@ -968,8 +970,12 @@ public:
DEFINE_CLASS_QUERY(CompressM)
DEFINE_CLASS_QUERY(LoadVector)
DEFINE_CLASS_QUERY(LoadVectorGather)
DEFINE_CLASS_QUERY(LoadVectorMasked)
DEFINE_CLASS_QUERY(LoadVectorGatherMasked)
DEFINE_CLASS_QUERY(StoreVector)
DEFINE_CLASS_QUERY(StoreVectorScatter)
DEFINE_CLASS_QUERY(StoreVectorMasked)
DEFINE_CLASS_QUERY(StoreVectorScatterMasked)
DEFINE_CLASS_QUERY(ShiftV)
DEFINE_CLASS_QUERY(Unlock)

View File

@@ -877,6 +877,10 @@ class LoadVectorGatherNode : public LoadVectorNode {
virtual int Opcode() const;
virtual uint match_edge(uint idx) const { return idx == MemNode::Address || idx == MemNode::ValueIn; }
virtual int store_Opcode() const {
// Ensure it is different from any store opcode to avoid folding when indices are used
return -1;
}
};
//------------------------------StoreVectorNode--------------------------------
@@ -908,6 +912,8 @@ class StoreVectorNode : public StoreNode {
// Needed for proper cloning.
virtual uint size_of() const { return sizeof(*this); }
virtual Node* mask() const { return nullptr; }
virtual Node* indices() const { return nullptr; }
};
//------------------------------StoreVectorScatterNode------------------------------
@@ -915,6 +921,7 @@ class StoreVectorNode : public StoreNode {
class StoreVectorScatterNode : public StoreVectorNode {
public:
enum { Indices = 4 };
StoreVectorScatterNode(Node* c, Node* mem, Node* adr, const TypePtr* at, Node* val, Node* indices)
: StoreVectorNode(c, mem, adr, at, val) {
init_class_id(Class_StoreVectorScatter);
@@ -926,12 +933,14 @@ class StoreVectorNode : public StoreNode {
virtual uint match_edge(uint idx) const { return idx == MemNode::Address ||
idx == MemNode::ValueIn ||
idx == MemNode::ValueIn + 1; }
virtual Node* indices() const { return in(Indices); }
};
//------------------------------StoreVectorMaskedNode--------------------------------
// Store Vector to memory under the influence of a predicate register(mask).
class StoreVectorMaskedNode : public StoreVectorNode {
public:
enum { Mask = 4 };
StoreVectorMaskedNode(Node* c, Node* mem, Node* dst, Node* src, const TypePtr* at, Node* mask)
: StoreVectorNode(c, mem, dst, at, src) {
init_class_id(Class_StoreVectorMasked);
@@ -945,6 +954,7 @@ class StoreVectorMaskedNode : public StoreVectorNode {
return idx > 1;
}
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
virtual Node* mask() const { return in(Mask); }
};
//------------------------------LoadVectorMaskedNode--------------------------------
@@ -965,6 +975,10 @@ class LoadVectorMaskedNode : public LoadVectorNode {
return idx > 1;
}
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
virtual int store_Opcode() const {
// Ensure it is different from any store opcode to avoid folding when a mask is used
return -1;
}
};
//-------------------------------LoadVectorGatherMaskedNode---------------------------------
@@ -985,12 +999,19 @@ class LoadVectorGatherMaskedNode : public LoadVectorNode {
virtual uint match_edge(uint idx) const { return idx == MemNode::Address ||
idx == MemNode::ValueIn ||
idx == MemNode::ValueIn + 1; }
virtual int store_Opcode() const {
// Ensure it is different from any store opcode to avoid folding when indices and mask are used
return -1;
}
};
//------------------------------StoreVectorScatterMaskedNode--------------------------------
// Store Vector into memory via index map under the influence of a predicate register(mask).
class StoreVectorScatterMaskedNode : public StoreVectorNode {
public:
enum { Indices = 4,
Mask
};
StoreVectorScatterMaskedNode(Node* c, Node* mem, Node* adr, const TypePtr* at, Node* val, Node* indices, Node* mask)
: StoreVectorNode(c, mem, adr, at, val) {
init_class_id(Class_StoreVectorScatterMasked);
@@ -1005,6 +1026,8 @@ class StoreVectorScatterMaskedNode : public StoreVectorNode {
idx == MemNode::ValueIn ||
idx == MemNode::ValueIn + 1 ||
idx == MemNode::ValueIn + 2; }
virtual Node* mask() const { return in(Mask); }
virtual Node* indices() const { return in(Indices); }
};
//------------------------------VectorCmpMaskedNode--------------------------------

View File

@@ -730,6 +730,11 @@ public class IRNode {
beforeMatchingNameRegex(LOAD_VECTOR_GATHER, "LoadVectorGather");
}
public static final String LOAD_VECTOR_MASKED = PREFIX + "LOAD_VECTOR_MASKED" + POSTFIX;
static {
beforeMatchingNameRegex(LOAD_VECTOR_MASKED, "LoadVectorMasked");
}
public static final String LOAD_VECTOR_GATHER_MASKED = PREFIX + "LOAD_VECTOR_GATHER_MASKED" + POSTFIX;
static {
beforeMatchingNameRegex(LOAD_VECTOR_GATHER_MASKED, "LoadVectorGatherMasked");
@@ -1391,6 +1396,11 @@ public class IRNode {
beforeMatchingNameRegex(STORE_VECTOR_SCATTER, "StoreVectorScatter");
}
public static final String STORE_VECTOR_MASKED = PREFIX + "STORE_VECTOR_MASKED" + POSTFIX;
static {
beforeMatchingNameRegex(STORE_VECTOR_MASKED, "StoreVectorMasked");
}
public static final String STORE_VECTOR_SCATTER_MASKED = PREFIX + "STORE_VECTOR_SCATTER_MASKED" + POSTFIX;
static {
beforeMatchingNameRegex(STORE_VECTOR_SCATTER_MASKED, "StoreVectorScatterMasked");

File diff suppressed because it is too large Load Diff