Skip to content

add constant lowering to dsp #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: latestMain
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# Explicit files to ignore (only matches one).
#==============================================================================#
# Various tag programs
tags
/tags
/TAGS
/GPATH
Expand Down
29 changes: 20 additions & 9 deletions mlir/examples/dsp/SimpleBlocks/include/toy/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ExprAST {
enum ExprASTKind {
Expr_VarDecl,
Expr_Return,
Expr_Num,
Expr_Int,
Expr_Double,
Expr_Literal,
Expr_Var,
Expr_BinOp,
Expand All @@ -61,18 +62,28 @@ class ExprAST {
/// A block-list of expressions.
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;

/// Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double val;
class IntExprAST : public ExprAST {
int val;

public:
NumberExprAST(Location loc, double val)
: ExprAST(Expr_Num, std::move(loc)), val(val) {}
IntExprAST(Location loc, int64_t val)
: ExprAST(Expr_Int, std::move(loc)), val(val) {}

int64_t getInt() { return val; }

double getValue() { return val; }
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Int; }
};

/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
class DoubleExprAST : public ExprAST {
double val;

public:
DoubleExprAST(Location loc, double val)
: ExprAST(Expr_Double, std::move(loc)), val(val) {}

double getDouble() { return val; }

static bool classof(const ExprAST *c) { return c->getKind() == Expr_Double; }
};

/// Expression class for a literal value.
Expand Down
49 changes: 33 additions & 16 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ enum Token : int {

// primary
tok_identifier = -5,
tok_number = -6,

tok_int = -6,
tok_double = -7,
};

/// The Lexer is an abstract base class providing all the facilities that the
Expand Down Expand Up @@ -83,10 +85,14 @@ class Lexer {
return identifierStr;
}

/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
return numVal;
int64_t getIntValue() {
assert(curTok == tok_int);
return numInt;
}

double getDoubleValue() {
assert(curTok == tok_double);
return numDouble;
}

/// Return the location for the beginning of the current token.
Expand Down Expand Up @@ -148,16 +154,27 @@ class Lexer {
return tok_identifier;
}

// Number: [0-9.]+
if (isdigit(lastChar) || lastChar == '.') {
std::string numStr;
do {
numStr += lastChar;
lastChar = Token(getNextChar());
} while (isdigit(lastChar) || lastChar == '.');

numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
if(isdigit(lastChar)) {
std::string numStr;
bool isDouble = false;

do {
if(lastChar == '.') isDouble = true;

numStr += lastChar;
lastChar = Token(getNextChar());
} while(isdigit(lastChar) || lastChar == '.');

if(isDouble) {
numDouble = strtod(numStr.c_str(), nullptr);
return tok_double;
}
else {
char ** p_end;
numInt = strtol(numStr.c_str(), p_end, 10);
return tok_int;
}
}

if (lastChar == '#') {
Expand Down Expand Up @@ -189,8 +206,8 @@ class Lexer {
/// If the current Token is an identifier, this string contains the value.
std::string identifierStr;

/// If the current Token is a number, this contains the value.
double numVal = 0;
int64_t numInt = 0;
double numDouble = 0;

/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
Expand Down
39 changes: 36 additions & 3 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,42 @@ def ConstantOp : Dsp_Op<"constant", [Pure]> {

// Build a constant with a given constant floating-point value.
OpBuilder<(ins "double":$value)>,
];

// Build a constant with a given constant floating-point value.
// OpBuilder<(ins "int":$value)>
// Indicate that additional verification for this operation is necessary.
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// IntegerConstantOp
//===----------------------------------------------------------------------===//

def IntegerConstantOp : Dsp_Op<"integer_constant", [Pure]> {
let summary = "integer constant";
let description = [{
Integer Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:

```mlir
%0 = dsp.integer_constant dense<[[1, 2, 30], [4, 5, 6]]>
: tensor<2x3xi64>
```
}];

// expect an integer constant tensor value of type I64
let arguments = (ins I64ElementsAttr:$value);

let results = (outs I64Tensor);

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "DenseIntElementsAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>,

// Build a constant with a given constant int64 value.
OpBuilder<(ins "int64_t":$value)>,
];

// Indicate that additional verification for this operation is necessary.
Expand Down Expand Up @@ -299,7 +332,7 @@ def PrintOp : Dsp_Op<"print"> {

// The print operation takes an input tensor to print.
// We also allow a F64MemRef to enable interop during partial lowering.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
let arguments = (ins AnyTypeOf<[F64, I64, F64Tensor, F64MemRef, I64Tensor, I64MemRef]>:$input);

let assemblyFormat = "$input attr-dict `:` type($input)";
}
Expand Down
64 changes: 47 additions & 17 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,21 @@ class Parser {
return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
}

/// Parse a literal number.
/// numberexpr ::= number
std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
return std::move(result);

std::unique_ptr<ExprAST> parseIntExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<IntExprAST>(std::move(loc), lexer.getIntValue());
lexer.consume(tok_int);
return std::move(result);
}

std::unique_ptr<ExprAST> parseDoubleExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<DoubleExprAST>(std::move(loc), lexer.getDoubleValue());
lexer.consume(tok_double);
return std::move(result);
}

/// Parse a literal array expression.
Expand All @@ -103,9 +110,17 @@ class Parser {
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
values.push_back(parseNumberExpr());
// test for int and double
//if (lexer.getCurToken() != tok_number)
//return parseError<ExprAST>("<num> or [", "in literal expression");
//values.push_back(parseNumberExpr());

if(lexer.getCurToken() != tok_int && lexer.getCurToken() != tok_double) {
return parseError<ExprAST>("<num> or [", "in literal expression");
}

if(lexer.getCurToken() == tok_int) values.push_back(parseIntExpr());
else if(lexer.getCurToken() == tok_double) values.push_back(parseDoubleExpr());
}

// End of this list on ']'
Expand Down Expand Up @@ -150,6 +165,7 @@ class Parser {
"inside literal expression");
}
}

return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
std::move(dims));
}
Expand Down Expand Up @@ -224,8 +240,13 @@ class Parser {
return nullptr;
case tok_identifier:
return parseIdentifierExpr();
case tok_number:
return parseNumberExpr();
/* test for int and double */
//case tok_number:
//return parseNumberExpr();
case tok_int:
return parseIntExpr();
case tok_double:
return parseDoubleExpr();
case '(':
return parseParenExpr();
case '[':
Expand Down Expand Up @@ -295,11 +316,20 @@ class Parser {

auto type = std::make_unique<VarType>();

while (lexer.getCurToken() == tok_number) {
type->shape.push_back(lexer.getValue());
lexer.getNextToken();
if (lexer.getCurToken() == ',')
// test for int and double
//while (lexer.getCurToken() == tok_number) {
//type->shape.push_back(lexer.getValue());
//lexer.getNextToken();
//if (lexer.getCurToken() == ',')
//lexer.getNextToken();
//}

while(lexer.getCurToken() == tok_int || lexer.getCurToken() == tok_double) {
if(lexer.getCurToken() == tok_int) type->shape.push_back(lexer.getIntValue());
else if(lexer.getCurToken() == tok_double) type->shape.push_back(lexer.getDoubleValue());
lexer.getNextToken();

if(lexer.getCurToken() == ',') lexer.getNextToken();
}

if (lexer.getCurToken() != '>')
Expand Down
52 changes: 52 additions & 0 deletions mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,58 @@ mlir::LogicalResult ConstantOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// Integer ConstantOp
//===----------------------------------------------------------------------===//

void IntegerConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
int64_t value) {
auto dataType = RankedTensorType::get({}, builder.getI64Type());
auto dataAttribute = mlir::DenseIntElementsAttr::get(dataType, value);
IntegerConstantOp::build(builder, state, dataType, dataAttribute);
}

mlir::ParseResult IntegerConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseIntElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();

result.addTypes(value.getType());
return success();
}

void IntegerConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << getValue();
}

mlir::LogicalResult IntegerConstantOp::verify() {
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();

auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}

// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand Down
Loading