From 0db596e6f5e785b51e6810baa035f9ef9f4094e6 Mon Sep 17 00:00:00 2001
From: pchintalapudi <34727397+pchintalapudi@users.noreply.github.com>
Date: Tue, 22 Feb 2022 22:12:06 -0500
Subject: [PATCH] Remove uses of PointerType::getElementType for opaque
 pointers (#44310)

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
---
 src/ccall.cpp                     |  8 ++++++--
 src/cgutils.cpp                   | 13 +++++++------
 src/codegen_shared.h              |  4 +---
 src/llvm-alloc-opt.cpp            |  6 ++----
 src/llvm-propagate-addrspaces.cpp | 14 +++++++-------
 src/llvm-remove-addrspaces.cpp    | 25 +++++++++++++++++--------
 src/llvmcalltest.cpp              |  2 +-
 7 files changed, 41 insertions(+), 31 deletions(-)

diff --git a/src/ccall.cpp b/src/ccall.cpp
index 5f260d9178ff..1902aed9f794 100644
--- a/src/ccall.cpp
+++ b/src/ccall.cpp
@@ -1879,6 +1879,8 @@ jl_cgval_t function_sig_t::emit_a_ccall(
     }
 
     Value *result = NULL;
+    //This is only needed if !retboxed && srt && !jlretboxed
+    Type *sretty = nullptr;
     // First, if the ABI requires us to provide the space for the return
     // argument, allocate the box and store that as the first argument type
     bool sretboxed = false;
@@ -1886,6 +1888,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
         assert(!retboxed && jl_is_datatype(rt) && "sret return type invalid");
         if (jl_is_pointerfree(rt)) {
             result = emit_static_alloca(ctx, lrt);
+            sretty = lrt;
             argvals[0] = ctx.builder.CreateBitCast(result, fargt_sig.at(0));
         }
         else {
@@ -1895,6 +1898,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
             assert(jl_datatype_size(rt) > 0 && "sret shouldn't be a singleton instance");
             result = emit_allocobj(ctx, jl_datatype_size(rt),
                                    literal_pointer_val(ctx, (jl_value_t*)rt));
+            sretty = ctx.types().T_jlvalue;
             sretboxed = true;
             gc_uses.push_back(result);
             argvals[0] = ctx.builder.CreateBitCast(emit_pointer_from_objref(ctx, result), fargt_sig.at(0));
@@ -1983,7 +1987,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
     if (cc != CallingConv::C)
         ((CallInst*)ret)->setCallingConv(cc);
     if (!sret)
-        result = ret;
+        result = ret; // no need to update sretty here because we know !sret
     if (0) { // Enable this to turn on SSPREQ (-fstack-protector) on the function containing this ccall
         ctx.f->addFnAttr(Attribute::StackProtectReq);
     }
@@ -2008,7 +2012,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
             // something alloca'd above is SSA
             if (static_rt)
                 return mark_julia_slot(result, rt, NULL, ctx.tbaa(), ctx.tbaa().tbaa_stack);
-            result = ctx.builder.CreateLoad(cast<PointerType>(result->getType())->getElementType(), result);
+            result = ctx.builder.CreateLoad(sretty, result);
         }
     }
     else {
diff --git a/src/cgutils.cpp b/src/cgutils.cpp
index e04abe8c06e0..ba692e41199f 100644
--- a/src/cgutils.cpp
+++ b/src/cgutils.cpp
@@ -24,7 +24,7 @@ static Value *decay_derived(jl_codectx_t &ctx, Value *V)
     if (cast<PointerType>(T)->getAddressSpace() == AddressSpace::Derived)
         return V;
     // Once llvm deletes pointer element types, we won't need it here any more either.
-    Type *NewT = PointerType::get(cast<PointerType>(T)->getElementType(), AddressSpace::Derived);
+    Type *NewT = PointerType::getWithSamePointeeType(cast<PointerType>(T), AddressSpace::Derived);
     return ctx.builder.CreateAddrSpaceCast(V, NewT);
 }
 
@@ -34,7 +34,7 @@ static Value *maybe_decay_tracked(jl_codectx_t &ctx, Value *V)
     Type *T = V->getType();
     if (cast<PointerType>(T)->getAddressSpace() != AddressSpace::Tracked)
         return V;
-    Type *NewT = PointerType::get(cast<PointerType>(T)->getElementType(), AddressSpace::Derived);
+    Type *NewT = PointerType::getWithSamePointeeType(cast<PointerType>(T), AddressSpace::Derived);
     return ctx.builder.CreateAddrSpaceCast(V, NewT);
 }
 
@@ -418,9 +418,7 @@ static Value *emit_bitcast(jl_codectx_t &ctx, Value *v, Type *jl_value)
     if (isa<PointerType>(jl_value) &&
         v->getType()->getPointerAddressSpace() != jl_value->getPointerAddressSpace()) {
         // Cast to the proper address space
-        Type *jl_value_addr =
-                PointerType::get(cast<PointerType>(jl_value)->getElementType(),
-                                 v->getType()->getPointerAddressSpace());
+        Type *jl_value_addr = PointerType::getWithSamePointeeType(cast<PointerType>(jl_value), v->getType()->getPointerAddressSpace());
         return ctx.builder.CreateBitCast(v, jl_value_addr);
     }
     else {
@@ -1943,8 +1941,10 @@ static void emit_memcpy_llvm(jl_codectx_t &ctx, Value *dst, MDNode *tbaa_dst, Va
         // machine size vectors this should be enough.
         const DataLayout &DL = jl_Module->getDataLayout();
         auto srcty = cast<PointerType>(src->getType());
+        //TODO unsafe nonopaque pointer
         auto srcel = srcty->getElementType();
         auto dstty = cast<PointerType>(dst->getType());
+        //TODO unsafe nonopaque pointer
         auto dstel = dstty->getElementType();
         if (srcel->isArrayTy() && srcel->getArrayNumElements() == 1) {
             src = ctx.builder.CreateConstInBoundsGEP2_32(srcel, src, 0, 0);
@@ -2806,7 +2806,7 @@ static Value *load_i8box(jl_codectx_t &ctx, Value *v, jl_datatype_t *ty)
     auto jvar = ty == jl_int8_type ? jlboxed_int8_cache : jlboxed_uint8_cache;
     GlobalVariable *gv = prepare_global_in(jl_Module, jvar);
     Value *idx[] = {ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0), ctx.builder.CreateZExt(v, getInt32Ty(ctx.builder.getContext()))};
-    auto slot = ctx.builder.CreateInBoundsGEP(gv->getType()->getElementType(), gv, idx);
+    auto slot = ctx.builder.CreateInBoundsGEP(gv->getValueType(), gv, idx);
     return tbaa_decorate(ctx.tbaa().tbaa_const, maybe_mark_load_dereferenceable(
             ctx.builder.CreateAlignedLoad(ctx.types().T_pjlvalue, slot, Align(sizeof(void*))), false,
             (jl_value_t*)ty));
@@ -3197,6 +3197,7 @@ static void emit_cpointercheck(jl_codectx_t &ctx, const jl_cgval_t &x, const std
 }
 
 // allocation for known size object
+// returns a prjlvalue
 static Value *emit_allocobj(jl_codectx_t &ctx, size_t static_size, Value *jt)
 {
     Value *current_task = get_current_task(ctx);
diff --git a/src/codegen_shared.h b/src/codegen_shared.h
index 181cf51cffc0..f32d81dadbba 100644
--- a/src/codegen_shared.h
+++ b/src/codegen_shared.h
@@ -147,9 +147,7 @@ static inline llvm::Value *emit_bitcast_with_builder(llvm::IRBuilder<> &builder,
     if (isa<PointerType>(jl_value) &&
         v->getType()->getPointerAddressSpace() != jl_value->getPointerAddressSpace()) {
         // Cast to the proper address space
-        Type *jl_value_addr =
-                PointerType::get(cast<PointerType>(jl_value)->getElementType(),
-                                 v->getType()->getPointerAddressSpace());
+        Type *jl_value_addr = PointerType::getWithSamePointeeType(cast<PointerType>(jl_value), v->getType()->getPointerAddressSpace());
         return builder.CreateBitCast(v, jl_value_addr);
     }
     else {
diff --git a/src/llvm-alloc-opt.cpp b/src/llvm-alloc-opt.cpp
index 4397992d79f4..fa02ecc8a56b 100644
--- a/src/llvm-alloc-opt.cpp
+++ b/src/llvm-alloc-opt.cpp
@@ -658,8 +658,7 @@ void Optimizer::moveToStack(CallInst *orig_inst, size_t sz, bool has_ref)
             user->replaceUsesOfWith(orig_i, replace);
         }
         else if (isa<AddrSpaceCastInst>(user) || isa<BitCastInst>(user)) {
-            auto cast_t = PointerType::get(cast<PointerType>(user->getType())->getElementType(),
-                                           0);
+            auto cast_t = PointerType::getWithSamePointeeType(cast<PointerType>(user->getType()), AddressSpace::Generic);
             auto replace_i = new_i;
             Type *new_t = new_i->getType();
             if (cast_t != new_t) {
@@ -953,8 +952,7 @@ void Optimizer::splitOnStack(CallInst *orig_inst)
                     store_ty = pass.T_pjlvalue;
                 }
                 else {
-                    store_ty = cast<PointerType>(pass.T_pjlvalue)->getElementType()
-                        ->getPointerTo(cast<PointerType>(store_ty)->getAddressSpace());
+                    store_ty = PointerType::getWithSamePointeeType(pass.T_pjlvalue, cast<PointerType>(store_ty)->getAddressSpace());
                     store_val = builder.CreateBitCast(store_val, store_ty);
                 }
                 if (cast<PointerType>(store_ty)->getAddressSpace() != AddressSpace::Tracked)
diff --git a/src/llvm-propagate-addrspaces.cpp b/src/llvm-propagate-addrspaces.cpp
index d28cd09f8176..8da0e108c94d 100644
--- a/src/llvm-propagate-addrspaces.cpp
+++ b/src/llvm-propagate-addrspaces.cpp
@@ -48,7 +48,7 @@ struct PropagateJuliaAddrspacesVisitor : public InstVisitor<PropagateJuliaAddrsp
 
 public:
     bool runOnFunction(Function &F) override;
-    Value *LiftPointer(Value *V, Type *LocTy = nullptr, Instruction *InsertPt=nullptr);
+    Value *LiftPointer(Value *V, Instruction *InsertPt=nullptr);
     void visitMemop(Instruction &I, Type *T, unsigned OpIndex);
     void visitLoadInst(LoadInst &LI);
     void visitStoreInst(StoreInst &SI);
@@ -82,7 +82,7 @@ void PropagateJuliaAddrspacesVisitor::PoisonValues(std::vector<Value *> &Worklis
     }
 }
 
-Value *PropagateJuliaAddrspaces::LiftPointer(Value *V, Type *LocTy, Instruction *InsertPt) {
+Value *PropagateJuliaAddrspaces::LiftPointer(Value *V, Instruction *InsertPt) {
     SmallVector<Value *, 4> Stack;
     std::vector<Value *> Worklist;
     std::set<Value *> LocalVisited;
@@ -165,7 +165,7 @@ Value *PropagateJuliaAddrspacesVisitor::LiftPointer(Value *V, Type *LocTy, Instr
             Instruction *InstV = cast<Instruction>(V);
             Instruction *NewV = InstV->clone();
             ToInsert.push_back(std::make_pair(NewV, InstV));
-            Type *NewRetTy = cast<PointerType>(InstV->getType())->getElementType()->getPointerTo(0);
+            Type *NewRetTy = PointerType::getWithSamePointeeType(cast<PointerType>(InstV->getType()), AddressSpace::Generic);
             NewV->mutateType(NewRetTy);
             LiftingMap[InstV] = NewV;
             ToRevisit.push_back(NewV);
@@ -173,7 +173,7 @@ Value *PropagateJuliaAddrspacesVisitor::LiftPointer(Value *V, Type *LocTy, Instr
     }
 
     auto CollapseCastsAndLift = [&](Value *CurrentV, Instruction *InsertPt) -> Value * {
-        PointerType *TargetType = cast<PointerType>(CurrentV->getType())->getElementType()->getPointerTo(0);
+        PointerType *TargetType = PointerType::getWithSamePointeeType(cast<PointerType>(CurrentV->getType()), AddressSpace::Generic);
         while (!LiftingMap.count(CurrentV)) {
             if (isa<BitCastInst>(CurrentV))
                 CurrentV = cast<BitCastInst>(CurrentV)->getOperand(0);
@@ -222,7 +222,7 @@ void PropagateJuliaAddrspacesVisitor::visitMemop(Instruction &I, Type *T, unsign
     unsigned AS = Original->getType()->getPointerAddressSpace();
     if (!isSpecialAS(AS))
         return;
-    Value *Replacement = LiftPointer(Original, T, &I);
+    Value *Replacement = LiftPointer(Original, &I);
     if (!Replacement)
         return;
     I.setOperand(OpIndex, Replacement);
@@ -264,13 +264,13 @@ void PropagateJuliaAddrspacesVisitor::visitMemTransferInst(MemTransferInst &MTI)
         return;
     Value *Dest = MTI.getRawDest();
     if (isSpecialAS(DestAS)) {
-        Value *Replacement = LiftPointer(Dest, cast<PointerType>(Dest->getType())->getElementType(), &MTI);
+        Value *Replacement = LiftPointer(Dest, &MTI);
         if (Replacement)
             Dest = Replacement;
     }
     Value *Src = MTI.getRawSource();
     if (isSpecialAS(SrcAS)) {
-        Value *Replacement = LiftPointer(Src, cast<PointerType>(Src->getType())->getElementType(), &MTI);
+        Value *Replacement = LiftPointer(Src, &MTI);
         if (Replacement)
             Src = Replacement;
     }
diff --git a/src/llvm-remove-addrspaces.cpp b/src/llvm-remove-addrspaces.cpp
index ba45ae190e03..610268cfa983 100644
--- a/src/llvm-remove-addrspaces.cpp
+++ b/src/llvm-remove-addrspaces.cpp
@@ -44,10 +44,17 @@ class AddrspaceRemoveTypeRemapper : public ValueMapTypeRemapper {
             return DstTy;
 
         DstTy = SrcTy;
-        if (auto Ty = dyn_cast<PointerType>(SrcTy))
-            DstTy = PointerType::get(
-                    remapType(Ty->getElementType()),
-                    ASRemapper(Ty->getAddressSpace()));
+        if (auto Ty = dyn_cast<PointerType>(SrcTy)) {
+            if (Ty->isOpaque()) {
+                DstTy = PointerType::get(Ty->getContext(), ASRemapper(Ty->getAddressSpace()));
+            }
+            else {
+                //Remove once opaque pointer transition is complete
+                DstTy = PointerType::get(
+                        remapType(Ty->getElementType()),
+                        ASRemapper(Ty->getAddressSpace()));
+            }
+        }
         else if (auto Ty = dyn_cast<FunctionType>(SrcTy)) {
             SmallVector<Type *, 4> Params;
             for (unsigned Index = 0; Index < Ty->getNumParams(); ++Index)
@@ -152,10 +159,12 @@ class AddrspaceRemoveValueMaterializer : public ValueMaterializer {
                     // GEP const exprs need to know the type of the source.
                     // asserts remapType(typeof arg0) == typeof mapValue(arg0).
                     Constant *Src = CE->getOperand(0);
-                    Type *SrcTy = remapType(
-                            cast<PointerType>(Src->getType()->getScalarType())
-                                    ->getElementType());
-                    DstV = CE->getWithOperands(Ops, Ty, false, SrcTy);
+                    auto ptrty = cast<PointerType>(Src->getType()->getScalarType());
+                    //Remove once opaque pointer transition is complete
+                    if (!ptrty->isOpaque()) {
+                        Type *SrcTy = remapType(ptrty->getElementType());
+                        DstV = CE->getWithOperands(Ops, Ty, false, SrcTy);
+                    }
                 }
                 else
                     DstV = CE->getWithOperands(Ops, Ty);
diff --git a/src/llvmcalltest.cpp b/src/llvmcalltest.cpp
index 1ce8e9fe55be..f3bd22732fd6 100644
--- a/src/llvmcalltest.cpp
+++ b/src/llvmcalltest.cpp
@@ -31,7 +31,7 @@ DLLEXPORT const char *MakeIdentityFunction(jl_value_t* jl_AnyTy) {
     PointerType *AnyTy = PointerType::get(StructType::get(Ctx), 0);
     // FIXME: get AnyTy via jl_type_to_llvm(Ctx, jl_AnyTy)
 
-    Type *TrackedTy = PointerType::get(AnyTy->getElementType(), AddressSpace::Tracked);
+    Type *TrackedTy = PointerType::get(StructType::get(Ctx), AddressSpace::Tracked);
     Module *M = new llvm::Module("shadow", Ctx);
     Function *F = Function::Create(
         FunctionType::get(